diff --git a/inc/common/opskernel/ge_task_info.h b/inc/common/opskernel/ge_task_info.h index 360f8a5d..8a55b7de 100644 --- a/inc/common/opskernel/ge_task_info.h +++ b/inc/common/opskernel/ge_task_info.h @@ -52,5 +52,16 @@ struct GETaskInfo { std::vector kernelHcclInfo; }; + +struct HcomOpertion { + std::string hcclType; + void *inputPtr; + void *outputPtr; + uint64_t count; + int32_t dataType; + int32_t opType; + int32_t root; +}; + } // namespace ge #endif // INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ diff --git a/inc/common/util/compress/compress.h b/inc/common/util/compress/compress.h index 6908fb75..e350f9e5 100644 --- a/inc/common/util/compress/compress.h +++ b/inc/common/util/compress/compress.h @@ -28,6 +28,7 @@ struct CompressConfig { size_t channel; // channels of L2 or DDR. For load balance size_t fractalSize; // size of compressing block bool isTight; // whether compose compressed data tightly + size_t init_offset; }; CmpStatus CompressWeights(char* input, const CompressConfig& compressConfig, char* indexs, char* output, diff --git a/inc/common/util/compress/compress_weight.h b/inc/common/util/compress/compress_weight.h new file mode 100644 index 00000000..34ea47d1 --- /dev/null +++ b/inc/common/util/compress/compress_weight.h @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPRESS_WEIGHT_H +#define COMPRESS_WEIGHT_H + +#include "compress.h" + +const int SHAPE_SIZE_WEIGHT = 4; + +struct CompressOpConfig { + int64_t wShape[SHAPE_SIZE_WEIGHT]; + size_t compressTilingK; + size_t compressTilingN; + struct CompressConfig compressConfig; +}; + +extern "C" CmpStatus CompressWeightsConv2D(const char *const input, char *const zipBuffer, char *const infoBuffer, + CompressOpConfig *const param); +#endif // COMPRESS_WEIGHT_H diff --git a/inc/common/util/platform_info.h b/inc/common/util/platform_info.h index cd143fcc..2a145d68 100644 --- a/inc/common/util/platform_info.h +++ b/inc/common/util/platform_info.h @@ -27,7 +27,6 @@ using std::string; using std::vector; namespace fe { - class PlatformInfoManager { public: PlatformInfoManager(const PlatformInfoManager &) = delete; @@ -39,6 +38,8 @@ class PlatformInfoManager { uint32_t GetPlatformInfo(const string SoCVersion, PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo); + uint32_t GetPlatformInfoWithOutSocVersion(PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo); + void SetOptionalCompilationInfo(OptionalInfo &optiCompilationInfo); private: @@ -94,6 +95,5 @@ class PlatformInfoManager { map platformInfoMap_; OptionalInfo optiCompilationInfo_; }; - } // namespace fe #endif diff --git a/inc/external/ge/ge_api_types.h b/inc/external/ge/ge_api_types.h index 1632f11c..cffb28bd 100644 --- a/inc/external/ge/ge_api_types.h +++ b/inc/external/ge/ge_api_types.h @@ -44,8 +44,12 @@ const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; const char *const OPTION_EXEC_DUMP_MODE = "ge.exec.dumpMode"; +const char *const OPTION_EXEC_ENABLE_DUMP_DEBUG = "ge.exec.enableDumpDebug"; +const char *const OPTION_EXEC_DUMP_DEBUG_MODE = "ge.exec.dumpDebugMode"; +const char *const OPTION_EXEC_OP_DEBUG_LEVEL = "ge.exec.opDebugLevel"; const char *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild"; const char *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath"; +const char *const OPTION_EXEC_ENABLE_SCOPE_FUSION_PASSES = "ge.exec.enableScopeFusionPasses"; // profiling flag const char *const OPTION_EXEC_PROFILING_MODE = "ge.exec.profilingMode"; const char *const OPTION_EXEC_PROFILING_OPTIONS = "ge.exec.profilingOptions"; @@ -219,6 +223,10 @@ const char *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream"; // Configure input fp16 nodes const std::string INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; +// Configure debug level, its value should be 0(default), 1 or 2. +// 0: close debug; 1: open TBE compiler; 2: open ccec compiler +const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel"; + // Graph run mode enum GraphRunMode { PREDICTION = 0, TRAIN }; diff --git a/inc/external/graph/types.h b/inc/external/graph/types.h index 4cd9ba91..a1245c9d 100644 --- a/inc/external/graph/types.h +++ b/inc/external/graph/types.h @@ -145,7 +145,8 @@ enum Format { FORMAT_FRACTAL_ZN_LSTM, FORMAT_FRACTAL_Z_G, FORMAT_RESERVED, - FORMAT_ALL + FORMAT_ALL, + FORMAT_NULL }; // for unknown shape op type diff --git a/inc/external/register/register.h b/inc/external/register/register.h index a8421511..9834d8a8 100644 --- a/inc/external/register/register.h +++ b/inc/external/register/register.h @@ -98,6 +98,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type); + OpRegistrationData &InputReorderVector(const vector &input_order); + domi::ImplyType GetImplyType() const; std::string GetOmOptype() const; std::set GetOriginOpTypeSet() const; diff --git a/inc/framework/common/debug/ge_log.h b/inc/framework/common/debug/ge_log.h index e2023cb8..6ac00037 100644 --- a/inc/framework/common/debug/ge_log.h +++ b/inc/framework/common/debug/ge_log.h @@ -51,30 +51,6 @@ inline pid_t GetTid() { return tid; } -#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() - -#define GE_TIMESTAMP_END(stage, stage_name) \ - do { \ - uint64_t endUsec_##stage = ge::GetCurrentTimestap(); \ - GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ - (endUsec_##stage - startUsec_##stage)); \ - } while (0); - -#define GE_TIMESTAMP_CALLNUM_START(stage) \ - uint64_t startUsec_##stage = ge::GetCurrentTimestap(); \ - uint64_t call_num_of##stage = 0; \ - uint64_t time_of##stage = 0 - -#define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = ge::GetCurrentTimestap()) - -#define GE_TIMESTAMP_ADD(stage) \ - time_of##stage += ge::GetCurrentTimestap() - startUsec_##stage; \ - call_num_of##stage++ - -#define GE_TIMESTAMP_CALLNUM_END(stage, stage_name) \ - GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \ - call_num_of##stage) - #define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \ dlog_error(MOD_NAME, "%lu %s: ErrorNo: %d(%s) " fmt, GetTid(), __FUNCTION__, ERROR_CODE, \ ((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), ##__VA_ARGS__) diff --git a/inc/framework/common/debug/log.h b/inc/framework/common/debug/log.h index 28c6585e..f07a8fa0 100644 --- a/inc/framework/common/debug/log.h +++ b/inc/framework/common/debug/log.h @@ -19,15 +19,12 @@ #include -#include "cce/cce_def.hpp" +#include "runtime/rt.h" #include "common/string_util.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" #include "ge/ge_api_error_codes.h" -using cce::CC_STATUS_SUCCESS; -using cce::ccStatus_t; - #if !defined(__ANDROID__) && !defined(ANDROID) #define DOMI_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) #else @@ -102,17 +99,13 @@ using cce::ccStatus_t; } while (0); // If expr is not true, print the log and return the specified status -#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ - do { \ - bool b = (expr); \ - if (!b) { \ - std::string msg; \ - (void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \ - (void)msg.append( \ - ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ - DOMI_LOGE("%s", msg.c_str()); \ - return _status; \ - } \ +#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ + do { \ + bool b = (expr); \ + if (!b) { \ + GELOGE(_status, __VA_ARGS__); \ + return _status; \ + } \ } while (0); // If expr is not true, print the log and return the specified status @@ -132,7 +125,7 @@ using cce::ccStatus_t; DOMI_LOGE(__VA_ARGS__); \ exec_expr; \ } \ - }; + } // If expr is not true, print the log and execute a custom statement #define GE_CHK_BOOL_EXEC_WARN(expr, exec_expr, ...) \ @@ -142,7 +135,7 @@ using cce::ccStatus_t; GELOGW(__VA_ARGS__); \ exec_expr; \ } \ - }; + } // If expr is not true, print the log and execute a custom statement #define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \ { \ @@ -151,7 +144,7 @@ using cce::ccStatus_t; GELOGI(__VA_ARGS__); \ exec_expr; \ } \ - }; + } // If expr is not true, print the log and execute a custom statement #define GE_CHK_BOOL_TRUE_EXEC_INFO(expr, exec_expr, ...) \ @@ -161,7 +154,7 @@ using cce::ccStatus_t; GELOGI(__VA_ARGS__); \ exec_expr; \ } \ - }; + } // If expr is true, print logs and execute custom statements #define GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(expr, exec_expr, ...) \ @@ -171,7 +164,7 @@ using cce::ccStatus_t; DOMI_LOGE(__VA_ARGS__); \ exec_expr; \ } \ - }; + } // If expr is true, print the Information log and execute a custom statement #define GE_CHK_TRUE_EXEC_INFO(expr, exec_expr, ...) \ { \ @@ -180,7 +173,7 @@ using cce::ccStatus_t; GELOGI(__VA_ARGS__); \ exec_expr; \ } \ - }; + } // If expr is not SUCCESS, print the log and execute the expression + return #define GE_CHK_BOOL_TRUE_RET_VOID(expr, exec_expr, ...) \ @@ -191,7 +184,7 @@ using cce::ccStatus_t; exec_expr; \ return; \ } \ - }; + } // If expr is not SUCCESS, print the log and execute the expression + return _status #define GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(expr, _status, exec_expr, ...) \ @@ -202,7 +195,7 @@ using cce::ccStatus_t; exec_expr; \ return _status; \ } \ - }; + } // If expr is not true, execute a custom statement #define GE_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \ @@ -211,7 +204,7 @@ using cce::ccStatus_t; if (!b) { \ exec_expr; \ } \ - }; + } // -----------------runtime related macro definitions------------------------------- // If expr is not RT_ERROR_NONE, print the log @@ -231,7 +224,7 @@ using cce::ccStatus_t; DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ exec_expr; \ } \ - }; + } // If expr is not RT_ERROR_NONE, print the log and return #define GE_CHK_RT_RET(expr) \ @@ -243,23 +236,13 @@ using cce::ccStatus_t; } \ } while (0); -// ------------------------cce related macro definitions---------------------------- -// If expr is not CC_STATUS_SUCCESS, print the log -#define GE_CHK_CCE(expr) \ - do { \ - ccStatus_t _cc_ret = (expr); \ - if (_cc_ret != CC_STATUS_SUCCESS) { \ - DOMI_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ - } \ - } while (0); - // If expr is true, execute exec_expr without printing logs #define GE_IF_BOOL_EXEC(expr, exec_expr) \ { \ if (expr) { \ exec_expr; \ } \ - }; + } // If make_shared is abnormal, print the log and execute the statement #define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ diff --git a/inc/framework/common/ge_types.h b/inc/framework/common/ge_types.h index 27ae28ee..00bfa301 100644 --- a/inc/framework/common/ge_types.h +++ b/inc/framework/common/ge_types.h @@ -54,9 +54,9 @@ const char *const GE_ENGINE_ATTR_MEM_TYPE_HBM = "HBM"; struct DataBuffer { public: void *data; // Data address - uint32_t length; // Data length + uint64_t length; // Data length bool isDataSupportMemShare = false; - DataBuffer(void *dataIn, uint32_t len, bool isSupportMemShare) + DataBuffer(void *dataIn, uint64_t len, bool isSupportMemShare) : data(dataIn), length(len), isDataSupportMemShare(isSupportMemShare) {} DataBuffer() : data(nullptr), length(0), isDataSupportMemShare(false) {} @@ -106,7 +106,7 @@ struct ShapeDescription { // Definition of input and output description information struct InputOutputDescInfo { std::string name; - uint32_t size; + uint64_t size; uint32_t data_type; ShapeDescription shape_info; }; @@ -231,6 +231,7 @@ struct Options { // Profiling info of task struct TaskDescInfo { + std::string model_name; std::string op_name; uint32_t block_dim; uint32_t task_id; @@ -239,6 +240,7 @@ struct TaskDescInfo { // Profiling info of graph struct ComputeGraphDescInfo { + std::string model_name; std::string op_name; std::string op_type; std::vector input_format; diff --git a/inc/framework/common/helper/model_helper.h b/inc/framework/common/helper/model_helper.h index 3c9de891..3671f970 100644 --- a/inc/framework/common/helper/model_helper.h +++ b/inc/framework/common/helper/model_helper.h @@ -44,8 +44,6 @@ class ModelHelper { void SetSaveMode(bool val) { is_offline_ = val; } bool GetSaveMode(void) const { return is_offline_; } - static Status TransModelToGeModel(const ModelPtr& model, GeModelPtr& ge_model); - static Status TransGeModelToModel(const GeModelPtr& geModelPtr, ModelPtr& modelPtr); Status GetBaseNameFromFileName(const std::string& file_name, std::string& base_name); Status GetModelNameFromMergedGraphName(const std::string& graph_name, std::string& model_name); diff --git a/inc/framework/common/types.h b/inc/framework/common/types.h index e3844a61..50e41755 100644 --- a/inc/framework/common/types.h +++ b/inc/framework/common/types.h @@ -48,6 +48,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_S FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_FILE_PATH; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_MODE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_AICORE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_ATOMIC; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_ALL; // Supported public properties name FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_START_TIME; // Start time @@ -335,6 +338,7 @@ REGISTER_OPTYPE_DECLARE(BASICLSTMCELL, "BasicLSTMCell"); REGISTER_OPTYPE_DECLARE(GETNEXT, "GetNext"); REGISTER_OPTYPE_DECLARE(INITDATA, "InitData"); REGISTER_OPTYPE_DECLARE(TRANSSHAPE, "TransShape") +REGISTER_OPTYPE_DECLARE(REFIDENTITY, "RefIdentity"); // ANN dedicated operator REGISTER_OPTYPE_DECLARE(ANN_MEAN, "AnnMean"); @@ -631,6 +635,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_N FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_END_GRAPH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_OP_DEBUG; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_TYPE_OP_DEBUG; + // convolution node type FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_TYPE_CONVOLUTION; // adds a convolutional node name for the hard AIPP diff --git a/inc/framework/executor/ge_executor.h b/inc/framework/executor/ge_executor.h index 87e30805..2b7335ef 100644 --- a/inc/framework/executor/ge_executor.h +++ b/inc/framework/executor/ge_executor.h @@ -21,12 +21,12 @@ #include #include +#include "common/dynamic_aipp.h" #include "common/ge_inner_error_codes.h" #include "common/ge_types.h" #include "common/types.h" #include "graph/tensor.h" #include "runtime/base.h" -#include "common/dynamic_aipp.h" namespace ge { class ModelListenerAdapter; @@ -62,7 +62,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { // Get input and output descriptor ge::Status GetModelDescInfo(uint32_t model_id, std::vector &input_desc, - std::vector &output_desc); + std::vector &output_desc, bool new_model_desc = false); /// /// @ingroup ge diff --git a/inc/framework/ge_runtime/model_runner.h b/inc/framework/ge_runtime/model_runner.h index 6e7abcb9..8e312b09 100644 --- a/inc/framework/ge_runtime/model_runner.h +++ b/inc/framework/ge_runtime/model_runner.h @@ -28,16 +28,21 @@ namespace ge { namespace model_runner { class RuntimeModel; - +using RuntimeInfo = std::tuple; class ModelRunner { public: static ModelRunner &Instance(); bool LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, std::shared_ptr davinci_model, std::shared_ptr listener); + bool LoadModelComplete(uint32_t model_id); const std::vector &GetTaskIdList(uint32_t model_id) const; + const std::vector &GetStreamIdList(uint32_t model_id) const; + + const std::map> &GetRuntimeInfoMap(uint32_t model_id) const; + bool UnloadModel(uint32_t model_id); bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); diff --git a/inc/framework/ge_runtime/task_info.h b/inc/framework/ge_runtime/task_info.h index a48ed68b..68d71870 100644 --- a/inc/framework/ge_runtime/task_info.h +++ b/inc/framework/ge_runtime/task_info.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include "cce/taskdown_api.h" @@ -52,21 +53,27 @@ class TaskInfo { virtual ~TaskInfo() {} uint32_t stream_id() const { return stream_id_; } TaskInfoType type() const { return type_; } + std::string op_name() const { return op_name_; } + bool dump_flag() const { return dump_flag_; } protected: - TaskInfo(uint32_t stream_id, TaskInfoType type) : stream_id_(stream_id), type_(type) {} + TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag) + : op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {} private: + std::string op_name_; uint32_t stream_id_; TaskInfoType type_; + bool dump_flag_; }; class CceTaskInfo : public TaskInfo { public: - CceTaskInfo(uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, uint32_t block_dim, - const std::vector &args, uint32_t args_size, const std::vector &sm_desc, - const std::vector &flow_table, const std::vector &args_offset, bool is_flowtable) - : TaskInfo(stream_id, TaskInfoType::CCE), + CceTaskInfo(const std::string &op_name, uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, + uint32_t block_dim, const std::vector &args, uint32_t args_size, + const std::vector &sm_desc, const std::vector &flow_table, + const std::vector &args_offset, bool is_flowtable) + : TaskInfo(op_name, stream_id, TaskInfoType::CCE, false), ctx_(ctx), stub_func_(stub_func), block_dim_(block_dim), @@ -102,11 +109,11 @@ class CceTaskInfo : public TaskInfo { class TbeTaskInfo : public TaskInfo { public: - TbeTaskInfo(uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, const std::vector &args, - uint32_t args_size, const std::vector &sm_desc, void *binary, uint32_t binary_size, - const std::vector &meta_data, const std::vector &input_data_addrs, - const std::vector &output_data_addrs, const std::vector &workspace_addrs) - : TaskInfo(stream_id, TaskInfoType::TBE), + TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, + const std::vector &args, uint32_t args_size, const std::vector &sm_desc, void *binary, + uint32_t binary_size, const std::vector &meta_data, const std::vector &input_data_addrs, + const std::vector &output_data_addrs, const std::vector &workspace_addrs, bool dump_flag) + : TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag), stub_func_(stub_func), block_dim_(block_dim), args_(args), @@ -153,9 +160,10 @@ class TbeTaskInfo : public TaskInfo { class AicpuTaskInfo : public TaskInfo { public: - AicpuTaskInfo(uint32_t stream_id, const string &so_name, const std::string &kernel_name, const std::string &node_def, - const std::vector &input_data_addrs, const std::vector &output_data_addrs) - : TaskInfo(stream_id, TaskInfoType::AICPU), + AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name, + const std::string &node_def, const std::vector &input_data_addrs, + const std::vector &output_data_addrs, bool dump_flag) + : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), so_name_(so_name), kernel_name_(kernel_name), node_def_(node_def), @@ -177,37 +185,45 @@ class AicpuTaskInfo : public TaskInfo { std::vector output_data_addrs_; }; -class LabelTaskInfo : public TaskInfo { +class LabelSetTaskInfo : public TaskInfo { public: + LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) + : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {} + ~LabelSetTaskInfo() override {} uint32_t label_id() const { return label_id_; } - protected: - LabelTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t label_id) - : TaskInfo(stream_id, type), label_id_(label_id) {} - virtual ~LabelTaskInfo() override {} - + private: uint32_t label_id_; }; -class LabelSetTaskInfo : public LabelTaskInfo { +class LabelGotoTaskInfo : public TaskInfo { public: - LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id) - : LabelTaskInfo(stream_id, TaskInfoType::LABEL_SET, label_id) {} - ~LabelSetTaskInfo() override {} + LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) + : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {} + ~LabelGotoTaskInfo() override {} + uint32_t label_id() const { return label_id_; } + + private: + uint32_t label_id_; }; -class LabelSwitchTaskInfo : public LabelTaskInfo { +class LabelSwitchTaskInfo : public TaskInfo { public: - LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_id) - : LabelTaskInfo(stream_id, TaskInfoType::LABEL_SWITCH, label_id) {} + LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size, + const std::vector &label_list, void *cond) + : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false), + label_size_(label_size), + label_list_(label_list), + cond_(cond) {} ~LabelSwitchTaskInfo() override {} -}; + uint32_t label_size() { return label_size_; }; + const std::vector &label_list() { return label_list_; }; + void *cond() { return cond_; }; -class LabelGotoTaskInfo : public LabelTaskInfo { - public: - LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id) - : LabelTaskInfo(stream_id, TaskInfoType::LABEL_GOTO, label_id) {} - ~LabelGotoTaskInfo() override {} + private: + uint32_t label_size_; + std::vector label_list_; + void *cond_; }; class EventTaskInfo : public TaskInfo { @@ -215,8 +231,8 @@ class EventTaskInfo : public TaskInfo { uint32_t event_id() const { return event_id_; } protected: - EventTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t event_id) - : TaskInfo(stream_id, type), event_id_(event_id) {} + EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id) + : TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {} virtual ~EventTaskInfo() override {} uint32_t event_id_; @@ -224,39 +240,41 @@ class EventTaskInfo : public TaskInfo { class EventRecordTaskInfo : public EventTaskInfo { public: - EventRecordTaskInfo(uint32_t stream_id, uint32_t event_id) - : EventTaskInfo(stream_id, TaskInfoType::EVENT_RECORD, event_id) {} + EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) + : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {} ~EventRecordTaskInfo() override {} }; class EventWaitTaskInfo : public EventTaskInfo { public: - EventWaitTaskInfo(uint32_t stream_id, uint32_t event_id) - : EventTaskInfo(stream_id, TaskInfoType::EVENT_WAIT, event_id) {} + EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) + : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {} ~EventWaitTaskInfo() override {} }; class FusionStartTaskInfo : public TaskInfo { public: - explicit FusionStartTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_START) {} + explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id) + : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {} ~FusionStartTaskInfo() override {} }; class FusionEndTaskInfo : public TaskInfo { public: - explicit FusionEndTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_END) {} + explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id) + : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {} ~FusionEndTaskInfo() override {} }; class HcclTaskInfo : public TaskInfo { public: - HcclTaskInfo(uint32_t stream_id, const std::string hccl_type, void *input_data_addr, void *output_data_addr, - void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, + HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr, + void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, const std::vector &private_def, void *ops_kernel_store, int32_t count, int64_t root_id, - int64_t op_type, int64_t data_type, std::function hcom_bind_model, - std::function hcom_unbind_model, - std::function, void *)> hcom_distribute_task) - : TaskInfo(stream_id, TaskInfoType::HCCL), + int64_t op_type, int64_t data_type, const std::string &group, + std::function hcom_bind_model, std::function hcom_unbind_model, + std::function, void *)> hcom_distribute_task, bool dump_flag) + : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag), hccl_type_(hccl_type), input_data_addr_(input_data_addr), output_data_addr_(output_data_addr), @@ -269,6 +287,7 @@ class HcclTaskInfo : public TaskInfo { root_id_(root_id), op_type_(op_type), data_type_(data_type), + group_(group), hcom_bind_model_(hcom_bind_model), hcom_unbind_model_(hcom_unbind_model), hcom_distribute_task_(hcom_distribute_task) {} @@ -286,6 +305,7 @@ class HcclTaskInfo : public TaskInfo { int64_t root_id() const { return root_id_; } int64_t op_type() const { return op_type_; } int64_t data_type() const { return data_type_; } + const std::string &group() const { return group_; } std::function hcom_bind_model() const { return hcom_bind_model_; } std::function hcom_unbind_model() const { return hcom_unbind_model_; } std::function, void *)> hcom_distribute_task() const { @@ -305,6 +325,7 @@ class HcclTaskInfo : public TaskInfo { int64_t root_id_; int64_t op_type_; int64_t data_type_; + std::string group_; std::function hcom_bind_model_; std::function hcom_unbind_model_; std::function, void *)> hcom_distribute_task_; @@ -312,8 +333,11 @@ class HcclTaskInfo : public TaskInfo { class ProfilerTraceTaskInfo : public TaskInfo { public: - ProfilerTraceTaskInfo(uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) - : TaskInfo(stream_id, TaskInfoType::PROFILER_TRACE), log_id_(log_id), notify_(notify), flat_(flat) {} + ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) + : TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false), + log_id_(log_id), + notify_(notify), + flat_(flat) {} ~ProfilerTraceTaskInfo() override {} uint64_t log_id() const { return log_id_; } @@ -328,8 +352,9 @@ class ProfilerTraceTaskInfo : public TaskInfo { class MemcpyAsyncTaskInfo : public TaskInfo { public: - MemcpyAsyncTaskInfo(uint32_t stream_id, void *dst, uint64_t dst_max, void *src, uint64_t count, uint32_t kind) - : TaskInfo(stream_id, TaskInfoType::MEMCPY_ASYNC), + MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src, + uint64_t count, uint32_t kind, bool dump_flag) + : TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag), dst_(dst), dst_max_(dst_max), src_(src), @@ -353,9 +378,9 @@ class MemcpyAsyncTaskInfo : public TaskInfo { class StreamSwitchTaskInfo : public TaskInfo { public: - StreamSwitchTaskInfo(uint32_t stream_id, int64_t true_stream_id, void *input_addr, void *value_addr, int64_t cond, - int64_t data_type) - : TaskInfo(stream_id, TaskInfoType::STREAM_SWITCH), + StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr, + void *value_addr, int64_t cond, int64_t data_type) + : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false), true_stream_id_(true_stream_id), input_addr_(input_addr), value_addr_(value_addr), @@ -379,8 +404,8 @@ class StreamSwitchTaskInfo : public TaskInfo { class StreamActiveTaskInfo : public TaskInfo { public: - StreamActiveTaskInfo(uint32_t stream_id, uint32_t active_stream_id) - : TaskInfo(stream_id, TaskInfoType::STREAM_ACTIVE), active_stream_id_(active_stream_id) {} + StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id) + : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {} ~StreamActiveTaskInfo() override {} uint32_t active_stream_id() const { return active_stream_id_; } diff --git a/inc/framework/generator/ge_generator.h b/inc/framework/generator/ge_generator.h index f0707c67..d3f472e9 100644 --- a/inc/framework/generator/ge_generator.h +++ b/inc/framework/generator/ge_generator.h @@ -27,6 +27,7 @@ #include "graph/ge_tensor.h" #include "graph/graph.h" #include "graph/op_desc.h" +#include "graph/detail/attributes_holder.h" namespace ge { class GeGenerator { diff --git a/inc/framework/omg/omg.h b/inc/framework/omg/omg.h index 07d78490..45a8896d 100644 --- a/inc/framework/omg/omg.h +++ b/inc/framework/omg/omg.h @@ -98,13 +98,14 @@ Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file); Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format); -Status GetOutputLeaf(ge::NodePtr node, std::vector> &output_nodes_info, - std::vector &output_nodes_name); +Status GetOutputLeaf(ge::NodePtr node, std::vector> &output_nodes_info); + +void GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, + std::vector &output_nodes_name); void UpdateOmgCtxWithParserCtx(); void UpdateParserCtxWithOmgCtx(); - } // namespace ge namespace domi { diff --git a/inc/framework/omg/omg_inner_types.h b/inc/framework/omg/omg_inner_types.h index 8e5bc484..70d59c2f 100644 --- a/inc/framework/omg/omg_inner_types.h +++ b/inc/framework/omg/omg_inner_types.h @@ -94,6 +94,8 @@ struct OmgContext { std::vector> user_out_nodes; // net out nodes (where user_out_nodes or leaf nodes) std::vector net_out_nodes; + // net out nodes top names(only caffe has top) + std::vector out_top_names; // path for the aicpu custom operator so_file std::vector aicpu_op_run_paths; // ddk version diff --git a/inc/graph/compute_graph.h b/inc/graph/compute_graph.h index 4f865f12..1cb65a6c 100644 --- a/inc/graph/compute_graph.h +++ b/inc/graph/compute_graph.h @@ -74,6 +74,9 @@ class ComputeGraph : public std::enable_shared_from_this, public A size_t GetAllNodesSize() const; Vistor GetAllNodes() const; + // is_unknown_shape: false, same with GetAllNodes func + // is_unknown_shape: true, same with GetDirectNodes func + Vistor GetNodes(bool is_unknown_shape) const; size_t GetDirectNodesSize() const; Vistor GetDirectNode() const; Vistor GetInputNodes() const; @@ -174,6 +177,10 @@ class ComputeGraph : public std::enable_shared_from_this, public A void SetInputSize(uint32_t size) { input_size_ = size; } uint32_t GetInputSize() const { return input_size_; } + // false: known shape true: unknow shape + bool GetGraphUnknownFlag() const { return is_unknown_shape_graph_; } + void SetGraphUnknownFlag(bool flag) { is_unknown_shape_graph_ = flag; } + /// /// Set is need train iteration. /// If set true, it means this graph need to be run iteration some @@ -282,7 +289,8 @@ class ComputeGraph : public std::enable_shared_from_this, public A std::map op_name_map_; uint64_t session_id_ = 0; ge::Format data_format_ = ge::FORMAT_ND; + // unknown graph indicator, default is false, mean known shape + bool is_unknown_shape_graph_ = false; }; } // namespace ge - #endif // INC_GRAPH_COMPUTE_GRAPH_H_ diff --git a/inc/graph/debug/ge_attr_define.h b/inc/graph/debug/ge_attr_define.h index ea5544d1..ff015be1 100644 --- a/inc/graph/debug/ge_attr_define.h +++ b/inc/graph/debug/ge_attr_define.h @@ -139,6 +139,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP_INPUTS; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP_OUTPUTS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DIMS; + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_GRAPH_NAME; @@ -776,6 +778,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ATC_VERSION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OPP_VERSION; + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; @@ -994,7 +1000,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE; -// used for l1 fusion and other fusion in future +// used for lX fusion GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_KEY; @@ -1008,9 +1014,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_DUMP_REF; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L2_FUSION_GROUP_ID; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_ADDR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE; + +// op overflow dump +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_MODE; // functional ops attr GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_THEN_BRANCH; @@ -1056,6 +1070,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_HOR // for gradient group GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_GROUP; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_FLAG; + +// dynamic shape attrs +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX; + +// for fusion op plugin +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE; } // namespace ge #endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ diff --git a/inc/graph/detail/attributes_holder.h b/inc/graph/detail/attributes_holder.h index bb26dec5..a82ecca8 100644 --- a/inc/graph/detail/attributes_holder.h +++ b/inc/graph/detail/attributes_holder.h @@ -149,5 +149,4 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder { AnyMap extAttrs_; }; } // namespace ge - #endif // INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ diff --git a/inc/graph/ge_context.h b/inc/graph/ge_context.h index b1ccd5b9..af6b35bc 100644 --- a/inc/graph/ge_context.h +++ b/inc/graph/ge_context.h @@ -28,6 +28,7 @@ class GEContext { uint32_t DeviceId(); uint64_t TraceId(); void Init(); + void SetSessionId(uint64_t session_id); void SetCtxDeviceId(uint32_t device_id); private: diff --git a/inc/graph/ge_tensor.h b/inc/graph/ge_tensor.h index 29a315d6..834dca0b 100644 --- a/inc/graph/ge_tensor.h +++ b/inc/graph/ge_tensor.h @@ -25,6 +25,7 @@ #include "graph/buffer.h" #include "graph/ge_error_codes.h" #include "graph/types.h" + namespace ge { class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { public: @@ -108,8 +109,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrH DataType GetDataType() const; void SetDataType(DataType dt); - void SetOriginDataType(DataType originDataType); DataType GetOriginDataType() const; + void SetOriginDataType(DataType originDataType); + + std::vector GetRefPortIndex() const; + void SetRefPortByIndex(const std::vector &index); GeTensorDesc Clone() const; GeTensorDesc &operator=(const GeTensorDesc &desc); @@ -186,5 +190,4 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor { GeTensorDesc &DescReference() const; }; } // namespace ge - #endif // INC_GRAPH_GE_TENSOR_H_ diff --git a/inc/graph/model_serialize.h b/inc/graph/model_serialize.h index 3f7d65a9..16529512 100644 --- a/inc/graph/model_serialize.h +++ b/inc/graph/model_serialize.h @@ -49,5 +49,4 @@ class ModelSerialize { friend class GraphDebugImp; }; } // namespace ge - #endif // INC_GRAPH_MODEL_SERIALIZE_H_ diff --git a/inc/graph/op_desc.h b/inc/graph/op_desc.h index faca2d99..1bba7340 100644 --- a/inc/graph/op_desc.h +++ b/inc/graph/op_desc.h @@ -105,6 +105,8 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { GeTensorDescPtr MutableInputDesc(uint32_t index) const; + GeTensorDescPtr MutableInputDesc(const string &name) const; + Vistor GetAllInputsDesc() const; Vistor GetAllInputsDescPtr() const; @@ -127,6 +129,8 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { GeTensorDescPtr MutableOutputDesc(uint32_t index) const; + GeTensorDescPtr MutableOutputDesc(const string &name) const; + uint32_t GetAllOutputsDescSize() const; Vistor GetAllOutputsDesc() const; diff --git a/inc/graph/utils/graph_utils.h b/inc/graph/utils/graph_utils.h index 6c344435..61c713c1 100644 --- a/inc/graph/utils/graph_utils.h +++ b/inc/graph/utils/graph_utils.h @@ -130,7 +130,7 @@ struct NodeIndexIO { IOType io_type_ = kOut; std::string value_; - std::string ToString() const { return value_; } + const std::string &ToString() const { return value_; } }; class GraphUtils { @@ -188,8 +188,8 @@ class GraphUtils { /// @param [in] output_index /// @return graphStatus /// - static graphStatus InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector &dsts, - const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0); + static graphStatus InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector &dsts, + const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0); static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node); @@ -303,6 +303,14 @@ class GraphUtils { /// static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); + /// + /// Copy all in-data edges from `src_node` to `dst_node` + /// @param src_node + /// @param dst_node + /// @return + /// + static graphStatus CopyInDataEdges(const NodePtr &src_node, NodePtr &dst_node); + static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector &node_vec); @@ -728,5 +736,4 @@ class PartialGraphBuilder : public ComputeGraphBuilder { std::vector exist_nodes_; }; } // namespace ge - #endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_ diff --git a/inc/graph/utils/node_utils.h b/inc/graph/utils/node_utils.h index 6e0e655d..4307e008 100644 --- a/inc/graph/utils/node_utils.h +++ b/inc/graph/utils/node_utils.h @@ -100,6 +100,13 @@ class NodeUtils { static NodePtr GetParentInput(const NodePtr &node); /// + /// @brief Check is varying_input for while node + /// @param [in] node: Data node for subgraph + /// @return bool + /// + static bool IsWhileVaryingInput(const ge::NodePtr &node); + + /// /// @brief Get subgraph input is constant. /// @param [in] node /// @param [out] string @@ -114,6 +121,24 @@ class NodeUtils { /// static graphStatus RemoveSubgraphsOnNode(const NodePtr &node); + /// + /// @brief Get subgraph input data node by index. + /// @param [in] node + /// @return Node + /// + static vector GetSubgraphDataNodesByIndex(const Node &node, int index); + + /// + /// @brief Get subgraph input data node by index. + /// @param [in] node + /// @return Node + /// + static vector GetSubgraphOutputNodes(const Node &node); + + static NodePtr GetInDataNodeByIndex(const Node &node, int index); + + static vector GetOutDataNodesByIndex(const Node &node, int index); + private: static std::map> map_send_info_; static std::map> map_recv_info_; diff --git a/inc/graph/utils/tensor_adapter.h b/inc/graph/utils/tensor_adapter.h index f9993606..a7355553 100644 --- a/inc/graph/utils/tensor_adapter.h +++ b/inc/graph/utils/tensor_adapter.h @@ -20,6 +20,7 @@ #include #include "graph/ge_tensor.h" #include "graph/tensor.h" + namespace ge { using GeTensorPtr = std::shared_ptr; using ConstGeTensorPtr = std::shared_ptr; diff --git a/inc/graph/utils/tensor_utils.h b/inc/graph/utils/tensor_utils.h index 2fa398db..caa80dcf 100644 --- a/inc/graph/utils/tensor_utils.h +++ b/inc/graph/utils/tensor_utils.h @@ -21,6 +21,7 @@ #include "graph/def_types.h" #include "graph/ge_error_codes.h" #include "graph/ge_tensor.h" + namespace ge { class TensorUtils { public: diff --git a/src/common/graph/CMakeLists.txt b/src/common/graph/CMakeLists.txt index 43f5b597..f041e4b6 100755 --- a/src/common/graph/CMakeLists.txt +++ b/src/common/graph/CMakeLists.txt @@ -71,5 +71,6 @@ target_link_libraries(graph PRIVATE ${PROTOBUF_LIBRARY} ${c_sec} ${slog} + ${error_manager} rt dl) diff --git a/src/common/graph/compute_graph.cc b/src/common/graph/compute_graph.cc index b73cf939..8a0c9f06 100644 --- a/src/common/graph/compute_graph.cc +++ b/src/common/graph/compute_graph.cc @@ -106,6 +106,15 @@ ComputeGraph::Vistor ComputeGraph::AllGraphNodes(std::vector(shared_from_this(), all_nodes); } +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetNodes( + bool is_unknown_shape) const { + if (is_unknown_shape) { + return GetDirectNode(); + } else { + return GetAllNodes(); + } +} + size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetDirectNode() const { @@ -497,6 +506,10 @@ ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptrGetName()) { GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str()); } + if (names_to_subgraph_.find(name) != names_to_subgraph_.end()) { + GE_LOGE("The subgraph %s existed", name.c_str()); + return GRAPH_PARAM_INVALID; + } sub_graph_.push_back(subgraph); names_to_subgraph_[name] = subgraph; return GRAPH_SUCCESS; diff --git a/src/common/graph/debug/ge_op_types.h b/src/common/graph/debug/ge_op_types.h index da36f72c..f11ef31e 100644 --- a/src/common/graph/debug/ge_op_types.h +++ b/src/common/graph/debug/ge_op_types.h @@ -34,12 +34,16 @@ GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims"); GE_REGISTER_OPTYPE(SWITCH, "Switch"); GE_REGISTER_OPTYPE(MERGE, "Merge"); GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge"); +GE_REGISTER_OPTYPE(ENTER, "Enter"); +GE_REGISTER_OPTYPE(REFENTER, "RefEnter"); GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration"); GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); GE_REGISTER_OPTYPE(CONSTANT, "Const"); +GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder"); GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp"); GE_REGISTER_OPTYPE(GETNEXT, "GetNext"); GE_REGISTER_OPTYPE(INITDATA, "InitData"); +GE_REGISTER_OPTYPE(REFIDENTITY, "RefIdentity"); GE_REGISTER_OPTYPE(ANN_DATA, "AnnData"); GE_REGISTER_OPTYPE(CONSTANTOP, "Constant"); diff --git a/src/common/graph/format_refiner.cc b/src/common/graph/format_refiner.cc index 11a610ce..9cb76539 100644 --- a/src/common/graph/format_refiner.cc +++ b/src/common/graph/format_refiner.cc @@ -41,11 +41,9 @@ using namespace ge; using namespace std; namespace ge { namespace { -static const std::unordered_set kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; -static bool net_format_is_nd = true; -static Format g_user_set_format = FORMAT_ND; -static bool is_first_infer = true; -static RefRelations reflection_builder; +const std::unordered_set kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; +const string kIsGraphInferred = "_is_graph_inferred"; +RefRelations reflection_builder; } // namespace graphStatus ReflectionProcess(const std::unordered_set &reflection, @@ -72,9 +70,49 @@ graphStatus ReflectionProcess(const std::unordered_set &re return GRAPH_SUCCESS; } -graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) { +graphStatus BiasAddFormatFixProcess(ge::NodePtr &node_ptr) { + // 5 meas dim num + if (node_ptr->GetType() != "BiasAdd") { + return GRAPH_SUCCESS; + } + std::unordered_map kTfFormatFix = {{"NHWC", FORMAT_NDHWC}, {"NCHW", FORMAT_NCDHW}}; + for (size_t i = 0; i < node_ptr->GetOpDesc()->GetInputsSize(); i++) { + auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(i); + GE_CHECK_NOTNULL(in_desc); + if (in_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num + continue; + } + auto format = in_desc->GetOriginFormat(); + auto key = TypeUtils::FormatToSerialString(format); + auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; + in_desc->SetOriginFormat(fixed_format); + in_desc->SetFormat(fixed_format); + GELOGD("fix the %zu'th input of node[%s]. Origin format is %s , after fixed it is %s", i, + node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), + TypeUtils::FormatToSerialString(fixed_format).c_str()); + } + for (size_t i = 0; i < node_ptr->GetOpDesc()->GetOutputsSize(); i++) { + auto out_desc = node_ptr->GetOpDesc()->MutableOutputDesc(i); + GE_CHECK_NOTNULL(out_desc); + if (out_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num + continue; + } + auto format = out_desc->GetOriginFormat(); + auto key = TypeUtils::FormatToSerialString(format); + auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; + out_desc->SetOriginFormat(fixed_format); + out_desc->SetFormat(fixed_format); + GELOGD("fix the %zu'th output of node[%s]. Origin format is %s , after fixed it is %s", i, + node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), + TypeUtils::FormatToSerialString(fixed_format).c_str()); + } + return GRAPH_SUCCESS; +} + +graphStatus FormatRefiner::RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { + GE_CHECK_NOTNULL(graph); GE_CHECK_NOTNULL(op_desc); - if (op_desc->GetType() == CONSTANTOP && is_first_infer == true) { + if (op_desc->GetType() == CONSTANTOP && !IsGraphInferred(graph)) { ConstGeTensorPtr tensor_value; if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) { GELOGE(GRAPH_FAILED, "Get value failed, node name:%s.", op_desc->GetName().c_str()); @@ -95,7 +133,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std } anchor_points.clear(); // Get all anchor point nodes and switch nodes - for (const auto &node_ptr : graph->GetAllNodes()) { + for (auto &node_ptr : graph->GetAllNodes()) { if (node_ptr == nullptr) { return GRAPH_FAILED; } @@ -103,7 +141,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std if (op_desc == nullptr) { return GRAPH_FAILED; } - graphStatus status = RefreshConstantOutProcess(op_desc); + graphStatus status = RefreshConstantOutProcess(graph, op_desc); if (status != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "refresh constant out process failed!"); return GRAPH_FAILED; @@ -135,6 +173,16 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std if (!node_is_all_nd) { continue; } + // special process for biasAdd op + // In tensorflow, biasAdd's format is alwayse NHWC even though set the arg + // "data_format" to NDHWC or NCDHW.It will destroy our format-infer mechanism + // so here do special process + status = BiasAddFormatFixProcess(node_ptr); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "fix biasAdd process failed!"); + return GRAPH_FAILED; + } + GELOGD("Node[%s] is anchor point!", node_ptr->GetName().c_str()); anchor_points.push_back(node_ptr); } @@ -344,14 +392,11 @@ void FormatRefiner::RefreshOriginFormatOfAnchor(std::vector &anchor } } -void FormatRefiner::SetInferOrigineFormatFlag(bool is_first) { is_first_infer = is_first; } - -graphStatus FormatRefiner::DataNodeFormatProcess(std::vector &data_nodes, ge::Format data_format, +graphStatus FormatRefiner::DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector &data_nodes, + ge::Format data_format, std::unordered_map &node_status) { - bool is_internal_format = TypeUtils::IsInternalFormat(data_format); - bool need_process = (!is_first_infer) && (!is_internal_format) && (data_format != FORMAT_ND); - if (!need_process) { - GELOGI("no necessary to do DataNodeFormatProcess.is_first_infer:%d, data_format:%s", is_first_infer, + if (!(IsGraphInferred(graph) && (!TypeUtils::IsInternalFormat(data_format)) && (data_format != FORMAT_ND))) { + GELOGI("no necessary to do DataNodeFormatProcess. is_graph_inferred:%d, data_format:%s", IsGraphInferred(graph), TypeUtils::FormatToSerialString(data_format).c_str()); return GRAPH_SUCCESS; } @@ -410,8 +455,6 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) std::vector anchor_points; std::vector data_nodes; // global net format - net_format_is_nd = true; - g_user_set_format = FORMAT_ND; if (graph == nullptr) { GELOGE(GRAPH_FAILED, "input graph is null"); @@ -448,10 +491,15 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) /// format for these data nodes. /// Notice: ignore 5D formats auto data_format = graph->GetDataFormat(); - status = DataNodeFormatProcess(data_nodes, data_format, node_status); - // Set infer flag to false - SetInferOrigineFormatFlag(false); + status = DataNodeFormatProcess(graph, data_nodes, data_format, node_status); + + (void)AttrUtils::SetBool(graph, kIsGraphInferred, true); return status; } + +bool FormatRefiner::IsGraphInferred(const ComputeGraphPtr &graph) { + bool is_graph_inferred = false; + return (AttrUtils::GetBool(graph, kIsGraphInferred, is_graph_inferred) && is_graph_inferred); +} } // namespace ge diff --git a/src/common/graph/format_refiner.h b/src/common/graph/format_refiner.h index fa40a034..eca93bae 100644 --- a/src/common/graph/format_refiner.h +++ b/src/common/graph/format_refiner.h @@ -30,10 +30,9 @@ namespace ge { class FormatRefiner { public: static graphStatus InferOrigineFormat(const ge::ComputeGraphPtr &graph); - static void SetInferOrigineFormatFlag(bool is_first = true); private: - static graphStatus RefreshConstantOutProcess(const OpDescPtr &op_desc); + static graphStatus RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); static graphStatus GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector &anchor_points, std::vector &data_nodes, std::unordered_map &node_status); @@ -43,8 +42,9 @@ class FormatRefiner { std::unordered_map &node_status); static graphStatus ForwardInferProcess(std::deque &nodes, ge::NodePtr &node, std::unordered_map &node_status); - static graphStatus DataNodeFormatProcess(std::vector &data_nodes, ge::Format data_format, - std::unordered_map &node_status); + static graphStatus DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector &data_nodes, + ge::Format data_format, std::unordered_map &node_status); + static bool IsGraphInferred(const ComputeGraphPtr &graph); }; } // namespace ge #endif // COMMON_GRAPH_FORMAT_REFINER_H_ diff --git a/src/common/graph/ge_attr_define.cc b/src/common/graph/ge_attr_define.cc index f780d525..90f1bc6a 100644 --- a/src/common/graph/ge_attr_define.cc +++ b/src/common/graph/ge_attr_define.cc @@ -121,6 +121,8 @@ const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; const std::string ATTR_NAME_AIPP_INPUTS = "_aipp_inputs"; const std::string ATTR_NAME_AIPP_OUTPUTS = "_aipp_outputs"; +const std::string ATTR_NAME_INPUT_DIMS = "input_dims"; + const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; const std::string ATTR_NAME_PARENT_GRAPH_NAME = "_parent_graph_name"; @@ -723,6 +725,10 @@ const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name"; const std::string ATTR_MODEL_CORE_TYPE = "core_type"; +const std::string ATTR_MODEL_ATC_VERSION = "atc_version"; + +const std::string ATTR_MODEL_OPP_VERSION = "opp_version"; + // Public attribute const std::string ATTR_NAME_IMPLY_TYPE = "imply_type"; @@ -932,7 +938,7 @@ const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace"; const std::string MODEL_ATTR_SESSION_ID = "session_id"; -// l1 fusion and other fusion in future +// lx fusion const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key"; const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; @@ -946,9 +952,17 @@ const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1 const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; +const std::string ATTR_DATA_DUMP_REF = "_datadump_ref"; const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION = "_output_offset_for_buffer_fusion"; const std::string ATTR_NAME_L2_FUSION_GROUP_ID = "_l2_fusion_group_id"; const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION = "_enable_l2_fusion"; +const std::string ATTR_NAME_OP_INPUT_L1_FLAG = "_op_input_l1_flag"; +const std::string ATTR_NAME_OP_INPUT_L1_ADDR = "_op_input_l1_addr"; +const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE = "_op_input_l1_valid_size"; + +// Op debug attrs +const std::string ATTR_OP_DEBUG_FLAG = "_op_debug_flag"; +const std::string ATTR_OP_DEBUG_MODE = "_op_debug_mode"; // Atomic addr clean attrs const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; @@ -1013,4 +1027,11 @@ const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE = "reduce_op"; // used for allreduce tailing optimization const std::string ATTR_NAME_HCCL_FUSED_GROUP = "_hccl_fused_group"; const std::string ATTR_NAME_HCCL_FUSED_FLAG = "_hccl_fused_node"; + +// dynamic shape attr +const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR = "_alloc_fixed_addr"; +const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX = "_alloc_fixed_addr_index"; + +// for fusion op plugin +const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type"; } // namespace ge diff --git a/src/common/graph/ge_tensor.cc b/src/common/graph/ge_tensor.cc index 8ffbba91..196b8569 100644 --- a/src/common/graph/ge_tensor.cc +++ b/src/common/graph/ge_tensor.cc @@ -220,6 +220,7 @@ const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape"; const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format"; const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type"; const string TENSOR_UTILS_SHAPE_RANGE = "shape_range"; +const string TENSOR_UTILS_REF_PORT_INDEX = "ref_port_index"; GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {} @@ -567,6 +568,16 @@ DataType GeTensorDesc::GetOriginDataType() const { return TypeUtils::SerialStringToDataType(origin_data_type_str); } +std::vector GeTensorDesc::GetRefPortIndex() const { + vector ref_port_index; + (void)AttrUtils::GetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, ref_port_index); + return ref_port_index; +} + +void GeTensorDesc::SetRefPortByIndex(const std::vector &index) { + (void)AttrUtils::SetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, index); +} + graphStatus GeTensorDesc::IsValid() const { auto dtype = this->GetDataType(); auto format = this->GetFormat(); diff --git a/src/common/graph/graph.cc b/src/common/graph/graph.cc index 09d4fd56..fc30e9d6 100644 --- a/src/common/graph/graph.cc +++ b/src/common/graph/graph.cc @@ -210,7 +210,7 @@ class GraphImpl { graphStatus FindOpByName(const string &name, ge::Operator &op) const { auto it = op_list_.find(name); - GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "Error: there is no op: %s.", name.c_str()); + GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "there is no op: %s.", name.c_str()); op = it->second; return GRAPH_SUCCESS; } diff --git a/src/common/graph/graph.mk b/src/common/graph/graph.mk index 744d1725..14e8b4b1 100644 --- a/src/common/graph/graph.mk +++ b/src/common/graph/graph.mk @@ -1,5 +1,5 @@ LOCAL_PATH := $(call my-dir) - +include $(LOCAL_PATH)/stub/Makefile COMMON_LOCAL_SRC_FILES := \ ./proto/om.proto \ ./proto/ge_ir.proto \ @@ -77,6 +77,7 @@ LOCAL_SHARED_LIBRARIES := \ libc_sec \ libprotobuf \ libslog \ + liberror_manager \ LOCAL_LDFLAGS := -lrt -ldl @@ -85,6 +86,54 @@ LOCAL_PROPRIETARY_MODULE := true include $(BUILD_HOST_SHARED_LIBRARY) +#compiler for host +include $(CLEAR_VARS) +LOCAL_MODULE := stub/libgraph + +LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 +LOCAL_CPPFLAGS += -fexceptions + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := \ + ../../out/graph/lib64/stub/graph.cc \ + ../../out/graph/lib64/stub/operator.cc \ + ../../out/graph/lib64/stub/tensor.cc \ + ../../out/graph/lib64/stub/operator_factory.cc \ + + +LOCAL_SHARED_LIBRARIES := + +LOCAL_LDFLAGS := -lrt -ldl + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_HOST_SHARED_LIBRARY) + +#compiler for host +include $(CLEAR_VARS) +LOCAL_MODULE := fwk_stub/libgraph + +LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 +LOCAL_CPPFLAGS += -fexceptions + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := \ + ../../out/graph/lib64/stub/attr_value.cc \ + ../../out/graph/lib64/stub/graph.cc \ + ../../out/graph/lib64/stub/operator.cc \ + ../../out/graph/lib64/stub/operator_factory.cc \ + ../../out/graph/lib64/stub/tensor.cc \ + + +LOCAL_SHARED_LIBRARIES := + +LOCAL_LDFLAGS := -lrt -ldl + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_HOST_SHARED_LIBRARY) #compiler for device include $(CLEAR_VARS) @@ -99,6 +148,7 @@ LOCAL_SHARED_LIBRARIES := \ libc_sec \ libprotobuf \ libslog \ + liberror_manager \ LOCAL_LDFLAGS := -lrt -ldl @@ -111,6 +161,60 @@ LOCAL_PROPRIETARY_MODULE := true include $(BUILD_SHARED_LIBRARY) +#compiler for device +include $(CLEAR_VARS) +LOCAL_MODULE := stub/libgraph + +LOCAL_CFLAGS += -O2 + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := \ + ../../out/graph/lib64/stub/graph.cc \ + ../../out/graph/lib64/stub/operator.cc \ + ../../out/graph/lib64/stub/tensor.cc \ + ../../out/graph/lib64/stub/operator_factory.cc \ + + +LOCAL_SHARED_LIBRARIES := + +LOCAL_LDFLAGS := -lrt -ldl + +ifeq ($(device_os),android) +LOCAL_LDFLAGS := -ldl +endif + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_SHARED_LIBRARY) + +#compiler for device +include $(CLEAR_VARS) +LOCAL_MODULE := fwk_stub/libgraph + +LOCAL_CFLAGS += -O2 + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := \ + ../../out/graph/lib64/stub/attr_value.cc \ + ../../out/graph/lib64/stub/graph.cc \ + ../../out/graph/lib64/stub/operator.cc \ + ../../out/graph/lib64/stub/operator_factory.cc \ + ../../out/graph/lib64/stub/tensor.cc \ + + +LOCAL_SHARED_LIBRARIES := + +LOCAL_LDFLAGS := -lrt -ldl + +ifeq ($(device_os),android) +LOCAL_LDFLAGS := -ldl +endif + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_SHARED_LIBRARY) # compile for ut/st include $(CLEAR_VARS) @@ -125,6 +229,7 @@ LOCAL_SHARED_LIBRARIES := \ libc_sec \ libprotobuf \ libslog \ + liberror_manager \ LOCAL_LDFLAGS := -lrt -ldl @@ -150,6 +255,7 @@ LOCAL_STATIC_LIBRARIES := \ LOCAL_SHARED_LIBRARIES := \ libc_sec \ libslog \ + liberror_manager \ LOCAL_LDFLAGS := -lrt -ldl @@ -173,6 +279,7 @@ LOCAL_STATIC_LIBRARIES := \ LOCAL_SHARED_LIBRARIES := \ libc_sec \ libslog \ + liberror_manager \ LOCAL_LDFLAGS := -lrt -ldl diff --git a/src/common/graph/model_serialize.cc b/src/common/graph/model_serialize.cc index 19cb4538..4bd5769f 100644 --- a/src/common/graph/model_serialize.cc +++ b/src/common/graph/model_serialize.cc @@ -88,10 +88,8 @@ bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_ } bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) { - if (op_desc == nullptr || op_def_proto == nullptr) { - GELOGE(GRAPH_FAILED, "Input Para Invalid"); - return false; - } + GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is null."); + GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null."); if (op_desc->op_def_.GetProtoMsg() != nullptr) { *op_def_proto = *op_desc->op_def_.GetProtoMsg(); // Delete unnecessary attr @@ -130,16 +128,17 @@ bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::Op for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { op_def_proto->add_subgraph_name(name); } - - proto::AttrDef key; - proto::AttrDef value; - for (auto &item : op_desc->output_name_idx_) { - key.mutable_list()->add_s(item.first); - value.mutable_list()->add_i(item.second); + if (!op_desc->output_name_idx_.empty()) { + proto::AttrDef key; + proto::AttrDef value; + for (auto &item : op_desc->output_name_idx_) { + key.mutable_list()->add_s(item.first); + value.mutable_list()->add_i(item.second); + } + auto op_desc_attr = op_def_proto->mutable_attr(); + op_desc_attr->insert({"_output_name_key", key}); + op_desc_attr->insert({"_output_name_value", value}); } - auto op_desc_attr = op_def_proto->mutable_attr(); - op_desc_attr->insert({"_output_name_key", key}); - op_desc_attr->insert({"_output_name_value", value}); } return true; } diff --git a/src/common/graph/node.cc b/src/common/graph/node.cc index e0939e7e..df8efd91 100644 --- a/src/common/graph/node.cc +++ b/src/common/graph/node.cc @@ -26,6 +26,7 @@ #include "utils/ge_ir_utils.h" #include "utils/node_utils.h" #include "utils/op_desc_utils.h" +#include "common/util/error_manager/error_manager.h" using std::string; using std::vector; @@ -154,7 +155,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAnchorIsEqual(cons const auto &peer_node = left_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); const auto &r_peer_node = right_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); if (peer_node == nullptr || r_peer_node == nullptr) { - GELOGE(GRAPH_FAILED, "Error: anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ", + GELOGE(GRAPH_FAILED, "anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ", this->GetName().c_str(), i, j); return false; } @@ -434,8 +435,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::Get GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAnchor(int idx) const { if (idx < 0 || idx >= static_cast(in_data_anchors_.size())) { - GELOGE(GRAPH_FAILED, "the node doesn't have %d th in_data_anchor, node %s:%s", idx, GetType().c_str(), - GetName().c_str()); + ErrorManager::GetInstance().ATCReportErrMessage( + "E19019", {"opname", "index", "anchorname", "optype"}, + {GetName().c_str(), std::to_string(idx), "in_data_anchor", GetType().c_str()}); + GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s in_data_anchor which optype is %s.", GetName().c_str(), idx, + GetType().c_str()); return nullptr; } else { return in_data_anchors_[idx]; @@ -445,7 +449,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAn GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int idx) const { // Idx can't be less than -1 or >= in_data_anchors_.size(), -1 means index of control anchor_ if (idx < -1 || idx >= static_cast(in_data_anchors_.size())) { - GELOGW("the node doesn't have %d th in_anchor, node %s:%s", idx, GetType().c_str(), GetName().c_str()); + ErrorManager::GetInstance().ATCReportErrMessage( + "E19019", {"opname", "index", "anchorname", "optype"}, + {GetName().c_str(), std::to_string(idx), "in_anchor", GetType().c_str()}); + GELOGW("Op[%s] doesn't have index[%d]'s in_anchor which optype is %s.", GetName().c_str(), idx, GetType().c_str()); return nullptr; } else { // Return control anchor @@ -461,8 +468,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int i GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int idx) const { // Idx can't be less than -1 or >= out_data_anchors_.size(), -1 means index of control anchor_ if (idx < -1 || idx >= static_cast(out_data_anchors_.size())) { - GELOGE(GRAPH_FAILED, "the node doesn't have %d th out_anchor, node %s:%s", idx, GetType().c_str(), - GetName().c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E19019", {"opname", "index", "anchorname", "optype"}, + { + GetName().c_str(), + std::to_string(idx), + "out_anchor", + GetType().c_str(), + }); + GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_anchor which optype is %s.", GetName().c_str(), idx, + GetType().c_str()); return nullptr; } else { // Return control anchor @@ -477,8 +491,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchorPtr Node::GetOutDataAnchor(int idx) const { if (idx < 0 || idx >= static_cast(out_data_anchors_.size())) { - GELOGE(GRAPH_FAILED, "the node doesn't have %d th out_data_anchor, node %s:%s", idx, GetType().c_str(), - GetName().c_str()); + ErrorManager::GetInstance().ATCReportErrMessage( + "E19019", {"opname", "index", "anchorname", "optype"}, + {GetName().c_str(), std::to_string(idx), "out_data_anchor", GetType().c_str()}); + GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_data_anchor which optype is %s.", GetName().c_str(), idx, + GetType().c_str()); return nullptr; } else { return out_data_anchors_[idx]; @@ -733,11 +750,15 @@ graphStatus Node::Verify() const { GELOGW("in anchor ptr is null"); continue; } - GE_CHK_BOOL_RET_STATUS( - op_->GetType() == data_type || op_->GetType() == aipp_data_type || op_->GetType() == const_type || - op_->GetType() == variable_type || op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || - in_anchor_ptr->GetPeerAnchors().size() > 0, - GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx()); + bool valid_anchor = op_->GetType() == data_type || op_->GetType() == aipp_data_type || + op_->GetType() == const_type || op_->GetType() == variable_type || + op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || in_anchor_ptr->GetPeerAnchors().size() > 0; + if (!valid_anchor) { + ErrorManager::GetInstance().ATCReportErrMessage("E11019", {"name", "index"}, + {GetName(), std::to_string(in_anchor_ptr->GetIdx())}); + GELOGE(GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx()); + return GRAPH_FAILED; + } } string frameworkop_type = "FrameworkOp"; diff --git a/src/common/graph/op_desc.cc b/src/common/graph/op_desc.cc index adb52162..e9436a32 100644 --- a/src/common/graph/op_desc.cc +++ b/src/common/graph/op_desc.cc @@ -19,6 +19,7 @@ #include "debug/ge_util.h" #include "external/graph/operator.h" #include "framework/common/debug/ge_log.h" +#include "common/util/error_manager/error_manager.h" #include "graph/ge_attr_value.h" #include "graph/ge_tensor.h" #include "graph/operator_factory_impl.h" @@ -470,6 +471,25 @@ GeTensorDesc OpDesc::GetInputDesc(const string &name) const { return *(inputs_desc_[it->second].get()); } +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const { + GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index); + if (inputs_desc_[index] == nullptr) { + return nullptr; + } + GE_CHK_BOOL_RET_STATUS(inputs_desc_[index]->IsValid() == GRAPH_SUCCESS, nullptr, "input desc is invalid"); + return inputs_desc_[index]; +} + +GeTensorDescPtr OpDesc::MutableInputDesc(const string &name) const { + auto input_name_idx = GetAllInputName(); + auto it = input_name_idx.find(name); + if (it == input_name_idx.end()) { + GELOGW("Failed to get [%s] input desc", name.c_str()); + return nullptr; + } + return MutableInputDesc(it->second); +} + GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputNames() const { auto input_name_idx = GetAllInputName(); vector names; @@ -496,15 +516,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpEngineName(cons GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpEngineName() const { return engine_name_; } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const { - GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index); - if (inputs_desc_[index] == nullptr) { - return nullptr; - } - GE_CHK_BOOL_RET_STATUS(inputs_desc_[index]->IsValid() == GRAPH_SUCCESS, nullptr, "input desc is invalid"); - return inputs_desc_[index]; -} - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputsDesc() const { vector temp{}; for (const auto &it : inputs_desc_) { @@ -609,6 +620,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOu return outputs_desc_[index]; } +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(const string &name) const { + auto it = output_name_idx_.find(name); + if (it == output_name_idx_.end()) { + GELOGW("Failed to get [%s] output desc", name.c_str()); + return nullptr; + } + return MutableOutputDesc(it->second); +} + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const { return static_cast(outputs_desc_.size()); } @@ -882,15 +902,22 @@ graphStatus OpDesc::CommonVerify() const { // Checking shape of all inputs vector ishape = GetInputDescPtr(iname)->GetShape().GetDims(); for (int64_t dim : ishape) { - GE_CHK_BOOL_RET_STATUS(dim >= -2, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", - iname.c_str()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + dim < -2, ErrorManager::GetInstance().ATCReportErrMessage( + "E19014", {"opname", "value", "reason"}, + {GetName(), "input " + iname + " shape", "contains negative or zero dimension"}); + return GRAPH_FAILED, "Op[%s]'s input %s shape contains negative or zero dimension.", GetName().c_str(), + iname.c_str()); } } // Check all attributes defined const auto &all_attributes = GetAllAttrs(); for (const auto &name : GetAllAttrNames()) { - GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED, - "operator attribute %s is empty.", name.c_str()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + all_attributes.find(name) == all_attributes.end(), + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {GetName(), "attribute " + name, "is empty"}); + return GRAPH_FAILED, "operator attribute %s is empty.", name.c_str()); } return GRAPH_SUCCESS; diff --git a/src/common/graph/operator.cc b/src/common/graph/operator.cc index 1ac8d41d..3a9fd698 100644 --- a/src/common/graph/operator.cc +++ b/src/common/graph/operator.cc @@ -36,6 +36,8 @@ #include "graph/op_desc.h" #include "graph/runtime_inference_context.h" #include "graph/usr_types.h" +#include "graph/utils/node_utils.h" +#include "graph/debug/ge_attr_define.h" #include "utils/graph_utils.h" #include "utils/op_desc_utils.h" #include "utils/tensor_adapter.h" @@ -57,8 +59,7 @@ using std::vector; namespace ge { class OpIO { public: - explicit OpIO(const string &name, int index, const OperatorImplPtr &owner) - : name_(name), index_(index), owner_(owner) {} + OpIO(const string &name, int index, const OperatorImplPtr &owner) : name_(name), index_(index), owner_(owner) {} ~OpIO() = default; @@ -546,56 +547,46 @@ Operator &Operator::AddControlInput(const Operator &src_oprt) { } graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) const { - if (operator_impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr."); - return GRAPH_FAILED; - } - ge::ConstNodePtr node_ptr = operator_impl_->GetNode(); - if (node_ptr) { + GE_CHECK_NOTNULL(operator_impl_); + auto node_ptr = operator_impl_->GetNode(); + if (node_ptr != nullptr) { // For inner compute graph auto op_desc = node_ptr->GetOpDesc(); - if (op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "op_desc is nullptr."); - return GRAPH_FAILED; - } + GE_CHECK_NOTNULL(op_desc); auto index = op_desc->GetInputIndexByName(dst_name); auto in_data_anchor = node_ptr->GetInDataAnchor(index); - if (in_data_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "in_data_anchor is nullptr."); - return GRAPH_FAILED; - } + GE_CHECK_NOTNULL(in_data_anchor); auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); - if (out_data_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "out_data_anchor is nullptr."); - return GRAPH_FAILED; - } - std::shared_ptr peer_node_ptr = out_data_anchor->GetOwnerNode(); - if (peer_node_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "peer_node_ptr is nullptr."); - return GRAPH_FAILED; - } - ge::OperatorImplPtr operator_impl_ptr = nullptr; - operator_impl_ptr = ComGraphMakeShared(peer_node_ptr); - if (operator_impl_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); - return GRAPH_FAILED; - } - Operator const_op(std::move(operator_impl_ptr)); - if (peer_node_ptr->GetOpDesc() != nullptr) { - const auto &op_descType = peer_node_ptr->GetOpDesc()->GetType(); - if (op_descType == CONSTANTOP) { - return const_op.GetAttr(op::Constant::name_attr_value(), data); - } else if (op_descType == CONSTANT) { - return const_op.GetAttr(op::Const::name_attr_value(), data); + GE_CHECK_NOTNULL(out_data_anchor); + auto peer_node = out_data_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(peer_node); + auto peer_op_desc = peer_node->GetOpDesc(); + GE_CHECK_NOTNULL(peer_op_desc); + auto peer_op_type = peer_op_desc->GetType(); + if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) { + auto const_op_impl = ComGraphMakeShared(peer_node); + GE_CHECK_NOTNULL(const_op_impl); + Operator const_op(std::move(const_op_impl)); + return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); + } else if (peer_op_type == DATA) { + auto parent_node = NodeUtils::GetParentInput(peer_node); + while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) { + parent_node = NodeUtils::GetParentInput(parent_node); + } + if ((parent_node != nullptr) && + ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) { + auto const_op_impl = ComGraphMakeShared(parent_node); + GE_CHECK_NOTNULL(const_op_impl); + Operator const_op(std::move(const_op_impl)); + return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); } } - // Try get from runtime inference context auto session_id = std::to_string(GetContext().SessionId()); RuntimeInferenceContext *runtime_infer_ctx = nullptr; if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) { GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str()); - auto ret = runtime_infer_ctx->GetTensor(peer_node_ptr->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data); + auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data); if (ret == GRAPH_SUCCESS) { return GRAPH_SUCCESS; } @@ -604,6 +595,8 @@ graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) co // For outer graph return GetInputConstDataOut(dst_name, data); } + auto op_name = operator_impl_->GetName(); + GELOGW("node[%s]'s input[%s]'s peer node is not const", op_name.c_str(), dst_name.c_str()); return GRAPH_FAILED; } graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) const { diff --git a/src/common/graph/option/ge_context.cc b/src/common/graph/option/ge_context.cc index f5ebdeee..f5f5e4c9 100644 --- a/src/common/graph/option/ge_context.cc +++ b/src/common/graph/option/ge_context.cc @@ -85,6 +85,8 @@ uint32_t GEContext::DeviceId() { return device_id_; } uint64_t GEContext::TraceId() { return trace_id_; } +void GEContext::SetSessionId(uint64_t session_id) { session_id_ = session_id; } + void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; } } // namespace ge diff --git a/src/common/graph/ref_relation.cc b/src/common/graph/ref_relation.cc index b3cf37af..906cb5f9 100644 --- a/src/common/graph/ref_relation.cc +++ b/src/common/graph/ref_relation.cc @@ -242,6 +242,10 @@ void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &r int sub_graph_idx = 0; for (const auto &name : sub_graph_names) { auto sub_graph = root_graph.GetSubgraph(name); + if (sub_graph == nullptr) { + GELOGW("Can not find the sub graph %s for root graph %s.", name.c_str(), root_graph.GetName().c_str()); + continue; + } for (const auto &sub_graph_node : sub_graph->GetDirectNode()) { auto sub_graph_node_type = sub_graph_node->GetType(); diff --git a/src/common/graph/shape_refiner.cc b/src/common/graph/shape_refiner.cc index edf426a5..dc1bc541 100644 --- a/src/common/graph/shape_refiner.cc +++ b/src/common/graph/shape_refiner.cc @@ -37,6 +37,115 @@ namespace ge { namespace { +const uint32_t kWhileBodySubGraphIdx = 1; + +graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) { + GELOGD("Enter reverse brush while body subgraph process!"); + + auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx); + if (sub_graph_body == nullptr) { + GELOGE(GRAPH_FAILED, "Get while body graph failed!"); + return GRAPH_FAILED; + } + + for (const auto &node_sub : sub_graph_body->GetAllNodes()) { + if (node_sub->GetInDataNodes().size() == 0) { + continue; + } + + for (size_t i = 0; i < node_sub->GetAllInDataAnchorsSize(); i++) { + auto input_desc = node_sub->GetOpDesc()->MutableInputDesc(i); + (void)input_desc->SetUnknownDimNumShape(); + } + for (size_t i = 0; i < node_sub->GetAllOutDataAnchorsSize(); i++) { + auto output_desc = node_sub->GetOpDesc()->MutableOutputDesc(i); + (void)output_desc->SetUnknownDimNumShape(); + } + } + + return GRAPH_SUCCESS; +} + +graphStatus UpdateParentNodeForBranch(const ConstNodePtr &node, + std::vector> &ref_out_tensors) { + GELOGD("Enter update parent node shape for class branch op process"); + // check sub_graph shape.If not same ,do unknown shape process + for (size_t i = 0; i < ref_out_tensors.size(); i++) { + if (ref_out_tensors[i].empty()) { + continue; + } + auto ref_out_tensor = ref_out_tensors[i].at(0); + ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape(); + for (auto &tensor : ref_out_tensors[i]) { + if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { + GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str()); + return GRAPH_FAILED; + } + auto shape = tensor.MutableShape(); + if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { + GELOGD("node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, + shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); + ref_out_tensor_shape = GeShape(UNKNOWN_RANK); + break; + } + for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { + if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { + continue; + } + GELOGD("node is %s, i : %d, j: %d ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, + j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); + (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); + } + } + (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); + } + return GRAPH_SUCCESS; +} + +graphStatus UpdateParentNodeForWhile(const ConstNodePtr &node, std::vector> &ref_data_tensors, + std::vector> &ref_out_tensors) { + GELOGD("Enter update parent node shape for class while op process"); + if (ref_data_tensors.size() != ref_out_tensors.size()) { + GELOGE(GRAPH_FAILED, "while op [%s] input number[%zu] and output number[%zu] is not same!", node->GetName().c_str(), + ref_data_tensors.size(), ref_out_tensors.size()); + return GRAPH_FAILED; + } + for (size_t i = 0; i < ref_data_tensors.size(); i++) { + if (ref_out_tensors[i].size() != 1) { + GELOGE(GRAPH_FAILED, "while op, every output should only find one output tensor in all graph!"); + return GRAPH_FAILED; + } + } + bool is_need_reverse_brush = false; + // check input and output + for (size_t i = 0; i < ref_out_tensors.size(); i++) { + if (ref_out_tensors[i].empty()) { + continue; + } + auto ref_out_tensor = ref_out_tensors[i].at(0); + auto tmp_shape = ref_out_tensor.MutableShape(); + // ref_i's data and output tensor shape should be same + for (auto &tensor : ref_data_tensors[i]) { + if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { + GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype or format output.", node->GetName().c_str()); + return GRAPH_FAILED; + } + auto shape = tensor.MutableShape(); + if (shape.GetDims() != tmp_shape.GetDims()) { + ref_out_tensor.SetUnknownDimNumShape(); + is_need_reverse_brush = true; + break; + } + } + (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); + } + // reverse refresh while body shape + if (is_need_reverse_brush) { + return ReverseBrushWhileBodySubGraph(node); + } + return GRAPH_SUCCESS; +} + graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { auto op_desc = node->GetOpDesc(); auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); @@ -98,6 +207,37 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { } return GRAPH_SUCCESS; } + +graphStatus FindSubgraphDataAndNetoutput(std::shared_ptr &sub_graph, NodePtr &netoutput, + const ConstNodePtr &node, + std::vector> &ref_data_tensors) { + auto sub_nodes = sub_graph->GetDirectNode(); + for (size_t i = sub_nodes.size(); i > 0; --i) { + auto sub_node = sub_nodes.at(i - 1); + if (sub_node->GetType() == NETOUTPUT) { + netoutput = sub_node; + } + if (sub_node->GetType() == DATA) { + if (sub_node->GetOpDesc() == nullptr) { + return GRAPH_FAILED; + } + + int ref_i; + if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { + GELOGE(GRAPH_FAILED, "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str()); + return GRAPH_FAILED; + } + if (ref_i < 0 || static_cast(ref_i) >= node->GetAllInDataAnchorsSize()) { + GELOGE(GRAPH_FAILED, "data node[%s]'s ref index[%d] is not in range [0, %zu)!", sub_node->GetName().c_str(), + ref_i, node->GetAllInDataAnchorsSize()); + return GRAPH_FAILED; + } + ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0)); + } + } + return GRAPH_SUCCESS; +} + graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { auto op_desc = node->GetOpDesc(); auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); @@ -105,7 +245,10 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { return GRAPH_SUCCESS; } + std::vector> ref_data_tensors(node->GetAllInDataAnchorsSize()); + std::vector> ref_out_tensors(node->GetAllOutDataAnchorsSize()); auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); + for (const auto &name : sub_graph_names) { if (name.empty()) { GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); @@ -117,13 +260,9 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { return GRAPH_FAILED; } NodePtr netoutput = nullptr; - auto sub_nodes = sub_graph->GetDirectNode(); - for (size_t i = sub_nodes.size(); i > 0; --i) { - auto sub_node = sub_nodes.at(i - 1); - if (sub_node->GetType() == NETOUTPUT) { - netoutput = sub_node; - break; - } + auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors); + if (ret != GRAPH_SUCCESS) { + return ret; } if (netoutput == nullptr) { GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str()); @@ -150,19 +289,17 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { continue; } GELOGI("Parent node index of edge desc is %d", ref_i); - auto output_desc = op_desc->MutableOutputDesc(static_cast(ref_i)); - if (output_desc == nullptr) { - GE_LOGE( - "The ref index(%d) on the input %d of netoutput %s on the sub graph %s " - "parent node %s are incompatible, outputs num %u", - ref_i, edge_anchor->GetIdx(), netoutput->GetName().c_str(), name.c_str(), node->GetName().c_str(), - node->GetAllOutDataAnchorsSize()); + if (ref_i < 0 || static_cast(ref_i) >= node->GetAllOutDataAnchorsSize()) { return GRAPH_FAILED; } - op_desc->UpdateOutputDesc(edge_anchor->GetIdx(), *edge_desc); + ref_out_tensors[ref_i].emplace_back(*edge_desc); } } - return GRAPH_SUCCESS; + + if (node->GetType() == WHILE) { + return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors); + } + return UpdateParentNodeForBranch(node, ref_out_tensors); } } // namespace void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) { @@ -170,6 +307,9 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str GELOGE(GRAPH_FAILED, "node is null"); return; } + if (!IsLogEnable(GE, DLOG_DEBUG)) { + return; + } ge::OpDescPtr op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return ); std::string str; diff --git a/src/common/graph/utils/ge_ir_utils.h b/src/common/graph/utils/ge_ir_utils.h index 9b16be18..b572ab38 100644 --- a/src/common/graph/utils/ge_ir_utils.h +++ b/src/common/graph/utils/ge_ir_utils.h @@ -1,18 +1,18 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * + * Copyright 2020 Huawei Technologies Co., Ltd + * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * http://www.apache.org/licenses/LICENSE-2.0 - * + * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - */ +*/ #ifndef COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ #define COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ diff --git a/src/common/graph/utils/graph_utils.cc b/src/common/graph/utils/graph_utils.cc index ca2ebcdc..a6980358 100644 --- a/src/common/graph/utils/graph_utils.cc +++ b/src/common/graph/utils/graph_utils.cc @@ -38,6 +38,7 @@ #include "utils/ge_ir_utils.h" #include "utils/node_utils.h" #include "debug/ge_op_types.h" +#include "external/ge/ge_api_types.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" @@ -410,8 +411,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertTra /// @return graphStatus /// GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector &dsts, - const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { +GraphUtils::InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector &dsts, + const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { GE_CHECK_NOTNULL(src); GE_CHECK_NOTNULL(insert_node); @@ -570,7 +571,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(cons static int max_dumpfile_num = 0; if (max_dumpfile_num == 0) { string opt = "0"; - (void)GetContext().GetOption("ge.maxDumpFileNum", opt); + (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); } if (max_dumpfile_num != 0 && file_idx > max_dumpfile_num) { @@ -670,7 +671,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToText if (maxDumpFileSize == 0) { string opt = "0"; // Can not check return value - (void)GetContext().GetOption("ge.maxDumpFileSize", opt); + (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_SIZE, opt); maxDumpFileSize = atol(opt.c_str()); } if (maxDumpFileSize != 0 && fileSize != -1 && fileSize > maxDumpFileSize) { @@ -740,7 +741,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnn static int max_dumpfile_num = 0; if (max_dumpfile_num == 0) { string opt = "0"; - (void)GetContext().GetOption("ge.maxDumpFileNum", opt); + (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); } if (max_dumpfile_num != 0 && file_index > max_dumpfile_num) { @@ -920,7 +921,7 @@ graphStatus RelinkDataIO(const NodePtr &node, const std::vector &io_map, In InNodesToOut GetFullConnectIONodes(const NodePtr &node) { InNodesToOut in_nodes_to_out; if (node == nullptr) { - GELOGE(GRAPH_FAILED, "Node is nullptr,node is %s", node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "Node is nullptr"); return in_nodes_to_out; } auto in_nodes_list = node->GetInNodes(); @@ -1308,6 +1309,36 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveOutCt return GRAPH_SUCCESS; } +/// +/// Copy all in-data edges from `src_node` to `dst_node`. +/// @param src_node +/// @param dst_node +/// @return +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyInDataEdges(const NodePtr &src_node, + NodePtr &dst_node) { + if ((src_node == nullptr) || (dst_node == nullptr)) { + GELOGE(GRAPH_FAILED, "Parameter is nullptr"); + return GRAPH_PARAM_INVALID; + } + auto src_data_in_nodes = src_node->GetInDataNodes(); + if (src_data_in_nodes.empty()) { + return GRAPH_SUCCESS; + } + for (const auto &in_data_anchor : src_node->GetAllInDataAnchors()) { + auto input_desc = src_node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()); + auto ret = + GraphUtils::AddEdge(in_data_anchor->GetPeerOutAnchor(), dst_node->GetInDataAnchor(in_data_anchor->GetIdx())); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to add data edge from %s to %s when copy in data edge from %s to %s", + in_data_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName().c_str(), dst_node->GetName().c_str(), + src_node->GetName().c_str(), dst_node->GetName().c_str()); + return ret; + } + } + return GRAPH_SUCCESS; +} + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AppendInputNode(const ComputeGraphPtr &graph, const NodePtr &node) { if (graph->AddInputNode(node) == nullptr) { @@ -1339,7 +1370,7 @@ graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph, std::map> &symbol_to_anchors, std::map &anchor_to_symbol) { GE_CHECK_NOTNULL(graph); - for (auto &node : graph->GetAllNodes()) { + for (const auto &node : graph->GetAllNodes()) { // in_data_anchor if (HandleInAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { GE_LOGE("Find ref_mapping for in_data_anchors of node %s failed.", node->GetName().c_str()); @@ -1396,16 +1427,16 @@ graphStatus GraphUtils::HandleInAnchorMapping(const NodePtr &node, return HandleSubgraphInput(node, symbol_to_anchors, anchor_to_symbol); } - std::string type = node->GetType(); + const std::string &type = node->GetType(); if ((type == MERGE) || (type == STREAMMERGE)) { return HandleMergeInput(node, symbol_to_anchors, anchor_to_symbol); } - for (auto &in_data_anchor : node->GetAllInDataAnchors()) { + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { NodeIndexIO cur_node_info(node, in_data_anchor->GetIdx(), kIn); OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); if (peer_out_anchor == nullptr) { - std::string symbol = cur_node_info.ToString(); + const std::string &symbol = cur_node_info.ToString(); GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); symbol_to_anchors[symbol] = {cur_node_info}; anchor_to_symbol[symbol] = symbol; @@ -1432,7 +1463,7 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, std::map> &symbol_to_anchors, std::map &anchor_to_symbol) { GE_CHECK_NOTNULL(node); - for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { NodeIndexIO cur_node_info(node, out_data_anchor->GetIdx(), kOut); if (anchor_to_symbol.find(cur_node_info.ToString()) != anchor_to_symbol.end()) { continue; @@ -1446,7 +1477,7 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, return GRAPH_FAILED; } } else { - std::string symbol = cur_node_info.ToString(); + const std::string &symbol = cur_node_info.ToString(); GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); symbol_to_anchors.emplace(std::make_pair(symbol, std::list{cur_node_info})); anchor_to_symbol.emplace(std::make_pair(symbol, symbol)); @@ -1506,7 +1537,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, GE_CHECK_NOTNULL(node); std::vector exist_node_infos; std::vector cur_node_infos; - for (auto &in_data_anchor : node->GetAllInDataAnchors()) { + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); if (peer_out_anchor == nullptr) { std::string next_name; @@ -1529,10 +1560,10 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, size_t anchor_nums = 0; NodeIndexIO max_node_index_io(nullptr, 0, kOut); - for (auto &temp_node_info : exist_node_infos) { + for (const auto &temp_node_info : exist_node_infos) { auto iter1 = anchor_to_symbol.find(temp_node_info.ToString()); if (iter1 != anchor_to_symbol.end()) { - std::string temp_symbol = iter1->second; + const std::string &temp_symbol = iter1->second; auto iter2 = symbol_to_anchors.find(temp_symbol); if (iter2 != symbol_to_anchors.end()) { if (iter2->second.size() > anchor_nums) { @@ -1544,7 +1575,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, } std::string symbol; - for (auto &temp_node_info : exist_node_infos) { + for (const auto &temp_node_info : exist_node_infos) { if ((UnionSymbolMapping(max_node_index_io, temp_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != GRAPH_SUCCESS) || symbol.empty()) { @@ -1556,7 +1587,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, auto iter = symbol_to_anchors.find(symbol); if (iter != symbol_to_anchors.end()) { - for (auto &temp_node_info : cur_node_infos) { + for (const auto &temp_node_info : cur_node_infos) { GELOGD("Add anchor %s, symbol %s.", temp_node_info.ToString().c_str(), symbol.c_str()); iter->second.emplace_back(temp_node_info); anchor_to_symbol.emplace(std::make_pair(temp_node_info.ToString(), symbol)); @@ -1584,7 +1615,7 @@ graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - for (auto &in_data_anchor : node->GetAllInDataAnchors()) { + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); GE_CHECK_NOTNULL(peer_out_anchor); @@ -1627,8 +1658,8 @@ graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, std::map> &symbol_to_anchors, std::map &anchor_to_symbol, std::string &symbol) { - std::string symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; - std::string symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; + const std::string &symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; + const std::string &symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; if (symbol1 == symbol2) { symbol = symbol1; GELOGI("no need to union."); @@ -1684,7 +1715,7 @@ graphStatus GraphUtils::UpdateRefMapping(const NodeIndexIO &cur_node_info, const return GRAPH_FAILED; } - std::string symbol = iter1->second; + const std::string &symbol = iter1->second; auto iter2 = symbol_to_anchors.find(symbol); if (iter2 == symbol_to_anchors.end()) { GE_LOGE("symbol %s not found.", symbol.c_str()); @@ -1712,7 +1743,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t // pass-through op NodePtr node = out_data_anchor->GetOwnerNode(); - std::string type = node->GetType(); + const std::string &type = node->GetType(); const std::set pass_through_set = {NETOUTPUT, WHILE, _WHILE, STATELESSWHILE}; if ((pass_through_set.count(type) > 0) || (NodeUtils::IsSubgraphInput(node))) { reuse_in_index = output_index; @@ -1755,7 +1786,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t uint32_t reuse_input_index = 0; if (TensorUtils::GetReuseInputIndex(*output_op_desc, reuse_input_index) == GRAPH_SUCCESS) { reuse_in_index = static_cast(reuse_input_index); - GELOGI("ReuseInput name[%s] output[%u] reuse input[%d].", op_desc->GetName().c_str(), output_index, + GELOGI("ReuseInput name[%s] output[%d] reuse input[%d].", op_desc->GetName().c_str(), output_index, reuse_in_index); return true; } @@ -2297,7 +2328,7 @@ void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string & return; } - std::string name = node->GetName() + "_RetVal"; + std::string name = node->GetName() + "_RetVal_" + std::to_string(index); OpDescPtr ret_val_desc = shared_ptr(new (std::nothrow) OpDesc(name, FRAMEWORKOP)); if (ret_val_desc == nullptr) { error_code = GRAPH_FAILED; diff --git a/src/common/graph/utils/node_utils.cc b/src/common/graph/utils/node_utils.cc index e4fb8b82..20bcacfb 100644 --- a/src/common/graph/utils/node_utils.cc +++ b/src/common/graph/utils/node_utils.cc @@ -296,15 +296,18 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer return GRAPH_FAILED; } for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) { - GeTensorDesc output_tensor = op_desc->GetOutputDesc(out_anchor->GetIdx()); - ge::TensorUtils::SetRealDimCnt(output_tensor, static_cast(output_tensor.GetShape().GetDims().size())); - output_tensor.SetOriginShape(output_tensor.GetShape()); - output_tensor.SetOriginDataType(output_tensor.GetDataType()); + auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); + ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetShape().GetDims().size())); + bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag(); + if (!is_unknown_graph) { + output_tensor->SetOriginShape(output_tensor->GetShape()); + output_tensor->SetOriginDataType(output_tensor->GetDataType()); + } GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", - node_ptr->GetName().c_str(), output_tensor.GetOriginShape().GetShapeSize(), - TypeUtils::FormatToSerialString(output_tensor.GetOriginFormat()).c_str(), - TypeUtils::DataTypeToSerialString(output_tensor.GetOriginDataType()).c_str()); - (void)op_desc->UpdateOutputDesc(out_anchor->GetIdx(), output_tensor); + node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), + TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), + TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); + for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null"); @@ -316,17 +319,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer continue; } GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", - peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor.GetShape().GetDimNum(), - output_tensor.GetDataType(), output_tensor.GetOriginDataType()); - peer_input_desc->SetShape(output_tensor.GetShape()); - peer_input_desc->SetOriginShape(output_tensor.GetOriginShape()); - peer_input_desc->SetDataType(output_tensor.GetDataType()); - peer_input_desc->SetOriginDataType(output_tensor.GetOriginDataType()); + peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(), + output_tensor->GetDataType(), output_tensor->GetOriginDataType()); + peer_input_desc->SetShape(output_tensor->GetShape()); + peer_input_desc->SetOriginShape(output_tensor->GetOriginShape()); + peer_input_desc->SetDataType(output_tensor->GetDataType()); + peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType()); std::vector> shape_range; - (void)output_tensor.GetShapeRange(shape_range); + (void)output_tensor->GetShapeRange(shape_range); peer_input_desc->SetShapeRange(shape_range); ge::TensorUtils::SetRealDimCnt(*peer_input_desc, - static_cast(output_tensor.GetShape().GetDims().size())); + static_cast(output_tensor->GetShape().GetDims().size())); GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d", peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(), peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType()); @@ -401,10 +404,13 @@ graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) { auto desc = node.GetOpDesc(); GE_CHECK_NOTNULL(desc); - + // check self + is_unknow = OpShapeIsUnknown(desc); + if (is_unknow) { + return GRAPH_SUCCESS; + } auto sub_graph_names = desc->GetSubgraphInstanceNames(); if (sub_graph_names.empty()) { - is_unknow = OpShapeIsUnknown(desc); return GRAPH_SUCCESS; } else { auto owner_graph = node.GetOwnerComputeGraph(); @@ -556,6 +562,53 @@ NodePtr NodeUtils::GetParentInput(const NodePtr &node) { } /// +/// @brief Check is varying_input for while node +/// @param [in] node: Data node for subgraph +/// @return bool +/// +bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) { + if (node == nullptr) { + return false; + } + if (node->GetType() != DATA) { + return false; // not input_node for subgraph + } + + const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode(); + if (parent_node == nullptr) { + return false; // root graph + } + + if (kWhileOpTypes.count(parent_node->GetType()) == 0) { + return false; // not input_node for while subgraph + } + + uint32_t index_i = 0; + if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) { + GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str()); + return false; + } + bool varying_flag = true; + for (const auto &item : node->GetOutDataNodesAndAnchors()) { + if (item.first->GetType() != NETOUTPUT) { + continue; + } + OpDescPtr op_desc = item.first->GetOpDesc(); + uint32_t index_o = 0; + if ((op_desc == nullptr) || + !AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) { + continue; // input for while-cond subgraph + } + if (index_i != index_o) { + continue; // varying input for while-body subgraph + } + varying_flag = false; + break; + } + return varying_flag; +} + +/// /// @brief Get subgraph input is constant. /// @param [in] node /// @param [out] string @@ -637,4 +690,86 @@ Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) { return GRAPH_SUCCESS; } +/// +/// @brief Get subgraph input data node by index. +/// @param [in] node +/// @return Node +/// +vector NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) { + vector in_data_node_vec; + auto op_desc = node.GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec); + auto subgraph_names = op_desc->GetSubgraphInstanceNames(); + if (subgraph_names.empty()) { + GELOGW("Node %s is single node without sub graph.", node.GetName().c_str()); + return in_data_node_vec; + } + auto compute_graph = node.GetOwnerComputeGraph(); + for (const std::string &instance_name : subgraph_names) { + auto subgraph = compute_graph->GetSubgraph(instance_name); + for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { + int parent_index = 0; + if (NodeUtils::IsSubgraphInput(node_in_subgraph)) { + (void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index); + if (parent_index == index) { + in_data_node_vec.emplace_back(node_in_subgraph); + } + } + } + } + return in_data_node_vec; +} +/// +/// @brief Get subgraph input data node by index. +/// @param [in] node +/// @return Node +/// +vector NodeUtils::GetSubgraphOutputNodes(const Node &node) { + vector out_data_node_vec; + auto op_desc = node.GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec); + auto subgraph_names = op_desc->GetSubgraphInstanceNames(); + if (subgraph_names.empty()) { + GELOGI("Node %s is single node without sub graph.", node.GetName().c_str()); + return out_data_node_vec; + } + auto compute_graph = node.GetOwnerComputeGraph(); + for (const std::string &instance_name : subgraph_names) { + auto subgraph = compute_graph->GetSubgraph(instance_name); + for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { + if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) { + out_data_node_vec.emplace_back(node_in_subgraph); + } + } + } + return out_data_node_vec; +} + +NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, int index) { + if (node.GetInDataAnchor(index) == nullptr) { + return nullptr; + } + if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) { + return nullptr; + } + return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode(); +} + +vector NodeUtils::GetOutDataNodesByIndex(const Node &node, int index) { + vector out_data_nodes; + auto out_data_anchor = node.GetOutDataAnchor(index); + if (out_data_anchor == nullptr) { + return out_data_nodes; + } + for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + if (peer_in_anchor == nullptr) { + continue; + } + if (peer_in_anchor->GetOwnerNode() == nullptr) { + continue; + } + out_data_nodes.emplace_back(peer_in_anchor->GetOwnerNode()); + } + return out_data_nodes; +} } // namespace ge diff --git a/src/common/graph/utils/op_desc_utils.cc b/src/common/graph/utils/op_desc_utils.cc index 6264ddb9..c5de264f 100644 --- a/src/common/graph/utils/op_desc_utils.cc +++ b/src/common/graph/utils/op_desc_utils.cc @@ -197,24 +197,33 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils:: continue; } auto in_node = out_anchor->GetOwnerNode(); - if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { - ret.push_back(in_node); - } else if (in_node->GetType() == DATA) { - const ComputeGraphPtr &graph = node.GetOwnerComputeGraph(); - GE_CHK_BOOL_EXEC(graph != nullptr, continue, "Owner graph is null"); - - const NodePtr &parent_node = graph->GetParentNode(); - if (parent_node == nullptr) { - continue; // Root graph. - } - - if (kWhileOpTypes.count(parent_node->GetType()) > 0) { - continue; // Subgraph of While cond or body. + while (true) { + if (in_node == nullptr) { + break; } - - NodePtr input_node = NodeUtils::GetParentInput(in_node); - if ((input_node != nullptr) && ((input_node->GetType() == CONSTANT) || (input_node->GetType() == CONSTANTOP))) { - ret.push_back(input_node); + if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { + ret.push_back(in_node); + break; + } else if (in_node->GetType() == DATA) { + if (NodeUtils::IsWhileVaryingInput(in_node)) { + break; + } + in_node = NodeUtils::GetParentInput(in_node); + } else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) { + bool is_constant = false; + (void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant); + if (!is_constant) { + break; + } + // Enter node has and only has one input + if (in_node->GetInDataNodes().size() != 1) { + GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(), + in_node->GetInDataNodes().size()); + break; + } + in_node = in_node->GetInDataNodes().at(0); + } else { + break; } } } @@ -435,10 +444,27 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils:: GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::MutableWeights(const ge::Node &node) { vector ret; - GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return ret, "node.GetOpDesc is nullptr!"); + auto op_desc = node.GetOpDesc(); + GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!"); + // Place holder operator, try to get the weight from parent node + // when parent node is const operator + if (node.GetType() == PLACEHOLDER) { + std::string parent_op; + (void)AttrUtils::GetStr(op_desc, "parentOpType", parent_op); + // This if judgment is necessary because the current subgraph optimization is multithreaded + // and the parent node of the PLD operation should be a stable type, such as const + if (parent_op == CONSTANT || parent_op == CONSTANTOP) { + NodePtr parent_node = nullptr; + parent_node = op_desc->TryGetExtAttr("parentNode", parent_node); + if (parent_node != nullptr) { + op_desc = parent_node->GetOpDesc(); + GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str()); + } + } + } // Const operator, take the weight directly - if (node.GetOpDesc()->GetType() == CONSTANT || (node.GetOpDesc()->GetType() == CONSTANTOP)) { - auto weight = MutableWeights(node.GetOpDesc()); + if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) { + auto weight = MutableWeights(op_desc); if (weight == nullptr) { GELOGI("const op has no weight, op name:%s", node.GetName().c_str()); return ret; diff --git a/src/common/graph/utils/tensor_utils.cc b/src/common/graph/utils/tensor_utils.cc index 674cab55..26ac8cc8 100644 --- a/src/common/graph/utils/tensor_utils.cc +++ b/src/common/graph/utils/tensor_utils.cc @@ -19,6 +19,7 @@ #include "debug/ge_log.h" #include "framework/common/debug/ge_log.h" +#include "common/util/error_manager/error_manager.h" #include "graph/ge_tensor.h" #include "graph/types.h" #include "graph/utils/type_utils.h" @@ -105,7 +106,10 @@ static graphStatus CalcElementCntByDims(const std::vector &dims, int64_ element_cnt = 1; for (int64_t dim : dims) { if (CheckMultiplyOverflowInt64(element_cnt, dim)) { - GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, as when multiplying %ld and %ld.", element_cnt, dim); + ErrorManager::GetInstance().ATCReportErrMessage( + "E19013", {"function", "var1", "var2"}, + {"CheckMultiplyOverflowInt64", std::to_string(element_cnt), std::to_string(dim)}); + GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, when multiplying %ld and %ld.", element_cnt, dim); return GRAPH_FAILED; } element_cnt *= dim; @@ -273,7 +277,6 @@ static graphStatus CalcTensorElementCnt(const std::vector &dims, Format case FORMAT_FRACTAL_Z: graph_status = CalcElementCntOfFractalZ(dims, data_type, element_cnt); break; - case FORMAT_NC1HWC0_C04: case FORMAT_FRACTAL_NZ: case FORMAT_FRACTAL_ZZ: case FORMAT_NDHWC: @@ -285,6 +288,7 @@ static graphStatus CalcTensorElementCnt(const std::vector &dims, Format case FORMAT_NDC1HWC0: case FORMAT_FRACTAL_Z_C04: case FORMAT_FRACTAL_ZN_LSTM: + case FORMAT_NC1HWC0_C04: graph_status = CalcElementCntByDims(dims, element_cnt); break; default: diff --git a/src/common/graph/utils/type_utils.cc b/src/common/graph/utils/type_utils.cc index e4986931..5215b141 100644 --- a/src/common/graph/utils/type_utils.cc +++ b/src/common/graph/utils/type_utils.cc @@ -147,7 +147,8 @@ static const std::map kStringToFormatMap = { {"FRACTAL_ZN_LSTM", FORMAT_FRACTAL_ZN_LSTM}, {"FRACTAL_Z_G", FORMAT_FRACTAL_Z_G}, {"FORMAT_RESERVED", FORMAT_RESERVED}, - {"ALL", FORMAT_ALL}}; + {"ALL", FORMAT_ALL}, + {"NULL", FORMAT_NULL}}; static const std::map kDataTypeToStringMap = { {DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. diff --git a/src/ge/CMakeLists.txt b/src/ge/CMakeLists.txt index 894eaf1e..8d20caf2 100755 --- a/src/ge/CMakeLists.txt +++ b/src/ge/CMakeLists.txt @@ -60,6 +60,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "common/formats/formats.cc" "common/formats/utils/formats_trans_utils.cc" "common/fp16_t.cc" + "common/ge/op_tiling_manager.cc" "common/ge/plugin_manager.cc" "common/helper/model_cache_helper.cc" "common/profiling/profiling_manager.cc" @@ -94,7 +95,6 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" "graph/load/new_model_manager/task_info/task_info.cc" - "graph/load/output/output.cc" "graph/manager/*.cc" "graph/manager/model_manager/event_manager.cc" "graph/manager/util/debug.cc" @@ -159,8 +159,11 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "hybrid/node_executor/aicpu/aicpu_ext_info.cc" "hybrid/node_executor/aicpu/aicpu_node_executor.cc" "hybrid/node_executor/compiledsubgraph/known_node_executor.cc" + "hybrid/node_executor/controlop/control_op_executor.cc" + "hybrid/node_executor/hccl/hccl_node_executor.cc" "hybrid/node_executor/hostcpu/ge_local_node_executor.cc" "hybrid/node_executor/node_executor.cc" + "hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" "hybrid/node_executor/task_context.cc" "init/gelib.cc" "model/ge_model.cc" @@ -204,6 +207,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "common/formats/formats.cc" "common/formats/utils/formats_trans_utils.cc" "common/fp16_t.cc" + "common/ge/op_tiling_manager.cc" "common/ge/plugin_manager.cc" "common/helper/model_cache_helper.cc" "common/profiling/profiling_manager.cc" @@ -236,7 +240,6 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" "graph/load/new_model_manager/task_info/task_info.cc" - "graph/load/output/output.cc" "graph/manager/*.cc" "graph/manager/model_manager/event_manager.cc" "graph/manager/util/debug.cc" diff --git a/src/ge/client/ge_api.cc b/src/ge/client/ge_api.cc index ae6a9892..120c144a 100644 --- a/src/ge/client/ge_api.cc +++ b/src/ge/client/ge_api.cc @@ -28,6 +28,7 @@ #include "graph/opsproto_manager.h" #include "graph/utils/type_utils.h" #include "graph/manager/util/rt_context_util.h" +#include "graph/common/ge_call_wrapper.h" #include "register/op_registry.h" #include "common/ge/tbe_plugin_manager.h" @@ -41,8 +42,8 @@ namespace { const int32_t kMaxStrLen = 128; } -static bool kGeInitialized = false; -static std::mutex kGeReleaseMutex; // GEFinalize and ~Session use +static bool g_ge_initialized = false; +static std::mutex g_ge_release_mutex; // GEFinalize and ~Session use namespace ge { void GetOpsProtoPath(std::string &opsproto_path) { @@ -61,31 +62,6 @@ void GetOpsProtoPath(std::string &opsproto_path) { opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); } -Status CheckDumpAndReuseMemory(const std::map &options) { - const int kDecimal = 10; - auto dump_op_env = std::getenv("DUMP_OP"); - int dump_op_flag = (dump_op_env != nullptr) ? std::strtol(dump_op_env, nullptr, kDecimal) : 0; - auto disableReuseMemoryIter = options.find("ge.exec.disableReuseMemory"); - if (disableReuseMemoryIter != options.end()) { - if (disableReuseMemoryIter->second == "0") { - GELOGD("ge.exec.disableReuseMemory=0, reuse memory is open"); - if (dump_op_flag) { - GELOGW("Will dump incorrect op data with GE Option ge.exec.disableReuseMemory=0"); - } - } else if (disableReuseMemoryIter->second == "1") { - GELOGD("ge.exec.disableReuseMemory=1, reuse memory is close"); - } else { - GELOGE(PARAM_INVALID, "CheckDumpAndReuseMemory ge.exec.disableReuseMemory is valid"); - return FAILED; - } - } else { - if (dump_op_flag) { - GELOGW("Will dump incorrect op data with default reuse memory"); - } - } - return SUCCESS; -} - Status CheckOptionsValid(const std::map &options) { // check job_id is valid auto job_id_iter = options.find(OPTION_EXEC_JOB_ID); @@ -96,11 +72,6 @@ Status CheckOptionsValid(const std::map &options) { } } - // Check ge.exec.disableReuseMemory and env DUMP_OP - if (CheckDumpAndReuseMemory(options) != SUCCESS) { - return FAILED; - } - return SUCCESS; } @@ -108,7 +79,7 @@ Status CheckOptionsValid(const std::map &options) { Status GEInitialize(const std::map &options) { GELOGT(TRACE_INIT, "GEInitialize start"); // 0.check init status - if (kGeInitialized) { + if (g_ge_initialized) { GELOGW("GEInitialize is called more than once"); return SUCCESS; } @@ -147,9 +118,9 @@ Status GEInitialize(const std::map &options) { } // 7.check return status, return - if (!kGeInitialized) { + if (!g_ge_initialized) { // Initialize success, first time calling initialize - kGeInitialized = true; + g_ge_initialized = true; } GELOGT(TRACE_STOP, "GEInitialize finished"); @@ -160,12 +131,12 @@ Status GEInitialize(const std::map &options) { Status GEFinalize() { GELOGT(TRACE_INIT, "GEFinalize start"); // check init status - if (!kGeInitialized) { + if (!g_ge_initialized) { GELOGW("GEFinalize is called before GEInitialize"); return SUCCESS; } - std::lock_guard lock(kGeReleaseMutex); + std::lock_guard lock(g_ge_release_mutex); // call Finalize Status ret = SUCCESS; Status middle_ret; @@ -187,10 +158,10 @@ Status GEFinalize() { ret = middle_ret; } - if (kGeInitialized && ret == SUCCESS) { + if (g_ge_initialized && ret == SUCCESS) { // Unified destruct rt_context - RtContextUtil::GetInstance().DestroyrtContexts(); - kGeInitialized = false; + RtContextUtil::GetInstance().DestroyAllRtContexts(); + g_ge_initialized = false; } GELOGT(TRACE_STOP, "GEFinalize finished"); @@ -202,7 +173,7 @@ Session::Session(const std::map &options) { GELOGT(TRACE_INIT, "Session Constructor start"); // check init status sessionId_ = 0; - if (!kGeInitialized) { + if (!g_ge_initialized) { GELOGE(GE_CLI_GE_NOT_INITIALIZED); return; } @@ -232,13 +203,13 @@ Session::Session(const std::map &options) { Session::~Session() { GELOGT(TRACE_INIT, "Session Destructor start"); // 0.check init status - if (!kGeInitialized) { + if (!g_ge_initialized) { GELOGW("GE is not yet initialized or is finalized."); return; } Status ret = FAILED; - std::lock_guard lock(kGeReleaseMutex); + std::lock_guard lock(g_ge_release_mutex); try { uint64_t session_id = sessionId_; // call DestroySession diff --git a/src/ge/common/convert/pb2json.cc b/src/ge/common/convert/pb2json.cc index 832a8278..0a5d24ee 100644 --- a/src/ge/common/convert/pb2json.cc +++ b/src/ge/common/convert/pb2json.cc @@ -72,9 +72,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(cons void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, const ProtobufReflection *reflection, const set &black_fields, Json &json, bool enum2str) { - if (field == nullptr || reflection == nullptr) { - return; - } switch (field->type()) { case ProtobufFieldDescriptor::TYPE_MESSAGE: { const ProtobufMsg &tmp_message = reflection->GetMessage(message, field); @@ -118,8 +115,12 @@ void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescr case ProtobufFieldDescriptor::TYPE_FLOAT: char str[kSignificantDigits]; - sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)); - json[field->name()] = str; + if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1) { + json[field->name()] = str; + } else { + json[field->name()] = reflection->GetFloat(message, field); + } + break; case ProtobufFieldDescriptor::TYPE_STRING: diff --git a/src/ge/common/formats/format_transfers/datatype_transfer.cc b/src/ge/common/formats/format_transfers/datatype_transfer.cc index 0bd4b8e5..08c6889f 100644 --- a/src/ge/common/formats/format_transfers/datatype_transfer.cc +++ b/src/ge/common/formats/format_transfers/datatype_transfer.cc @@ -29,7 +29,6 @@ namespace ge { namespace formats { - namespace { enum DataTypeTransMode { kTransferWithDatatypeFloatToFloat16, diff --git a/src/ge/common/formats/format_transfers/datatype_transfer.h b/src/ge/common/formats/format_transfers/datatype_transfer.h index 0702592f..4d93fd6c 100644 --- a/src/ge/common/formats/format_transfers/datatype_transfer.h +++ b/src/ge/common/formats/format_transfers/datatype_transfer.h @@ -27,7 +27,6 @@ namespace ge { namespace formats { - struct CastArgs { const uint8_t *data; size_t src_data_size; diff --git a/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc b/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc index dc8e1033..76d8696a 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc @@ -179,6 +179,5 @@ Status FormatTransferDhwcnFractalZ3D::TransShape(Format src_format, const std::v } REGISTER_FORMAT_TRANSFER(FormatTransferDhwcnFractalZ3D, FORMAT_DHWCN, FORMAT_FRACTAL_Z_3D) - } // namespace formats } // namespace ge diff --git a/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc b/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc index 11e3d270..9de2e3a0 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc @@ -180,6 +180,5 @@ Status FormatTransferDhwncFractalZ3DTranspose::TransShape(Format src_format, con } REGISTER_FORMAT_TRANSFER(FormatTransferDhwncFractalZ3DTranspose, FORMAT_DHWNC, FORMAT_FRACTAL_Z_3D_TRANSPOSE) - } // namespace formats } // namespace ge diff --git a/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc b/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc index ff7b84a4..65798f29 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc @@ -56,7 +56,7 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap dst_shape.clear(); hw_shape.clear(); auto w0 = GetCubeSizeByDataType(data_type); - auto h0 = GetCubeSizeByDataType(data_type); + int64_t h0 = kCubeSize; switch (src_shape.size()) { case 1: dst_shape.push_back(Ceil(src_shape[0], w0)); diff --git a/src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc b/src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc index f3d06496..f2ec29da 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc @@ -19,6 +19,7 @@ #include #include +#include "common/debug/log.h" #include "common/formats/utils/formats_definitions.h" #include "common/formats/utils/formats_trans_utils.h" #include "framework/common/debug/ge_log.h" @@ -107,8 +108,8 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { int64_t hw = h * w; int64_t chw = c * hw; - int64_t hwc0 = hw * c0; int64_t nchw = n * chw; + int64_t hwc0 = hw * c0; // horizontal fractal matrix count (N) int64_t hf_cnt = Ceil(n, static_cast(kNiSize)); @@ -119,18 +120,15 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { int64_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; int size = GetSizeByDataType(args.src_data_type); int64_t dst_size = total_ele_cnt * size; - if (dst_size == 0) { - result.length = static_cast(dst_size); - return SUCCESS; - } + GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast(dst_size); return SUCCESS;); std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); - if (dst == nullptr) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + dst == nullptr, GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", TypeUtils::FormatToSerialString(args.src_format).c_str(), TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); - return OUT_OF_MEMORY; - } + return OUT_OF_MEMORY;); for (int64_t vfi = 0; vfi < vf_cnt; vfi++) { // vertical fractal matrix base index @@ -156,12 +154,20 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { auto protected_size = dst_size - offset < static_cast(SECUREC_MEM_MAX_LEN) ? dst_size - offset : static_cast(SECUREC_MEM_MAX_LEN); - errno_t ret; + errno_t ret = EOK; if (need_pad_zero) { ret = memset_s(dst.get() + offset, static_cast(protected_size), 0, static_cast(size)); } else { - ret = memcpy_s(dst.get() + offset, static_cast(protected_size), args.data + src_offset * size, - static_cast(size)); + if (protected_size < size) { + GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld", + protected_size, size); + return INTERNAL_ERROR; + } + char *dst_data = reinterpret_cast(dst.get() + offset); + const char *src_data = reinterpret_cast(args.data + src_offset * size); + for (int64_t index = 0; index < size; index++) { + *dst_data++ = *src_data++; + } } if (ret != EOK) { GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d pad mode %d", offset, @@ -199,18 +205,15 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { dst_size *= dim; } dst_size *= data_size; - if (dst_size == 0) { - result.length = static_cast(dst_size); - return SUCCESS; - } + GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast(dst_size); return SUCCESS;); std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); - if (dst == nullptr) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + dst == nullptr, GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", TypeUtils::FormatToSerialString(args.src_format).c_str(), TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); - return OUT_OF_MEMORY; - } + return OUT_OF_MEMORY;); for (int64_t c1i = 0; c1i < c1; c1i++) { for (int64_t hi = 0; hi < h; hi++) { @@ -223,14 +226,22 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { ? dst_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); - errno_t ret; + errno_t ret = EOK; if (pad_zero) { ret = memset_s(dst.get() + dst_offset, static_cast(protected_size), 0, static_cast(data_size)); } else { + if (protected_size < data_size) { + GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld", + protected_size, data_size); + return INTERNAL_ERROR; + } int64_t src_idx = hi * wcn + wi * cn + (c1i * c0 + c0i) * n + n1n0i; - ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), - args.data + src_idx * data_size, static_cast(data_size)); + char *dst_data = reinterpret_cast(dst.get() + dst_offset); + const char *src_data = reinterpret_cast(args.data + src_idx * data_size); + for (int64_t index = 0; index < data_size; index++) { + *dst_data++ = *src_data++; + } } if (ret != EOK) { GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", @@ -269,18 +280,15 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { dst_size *= dim; } dst_size *= data_size; - if (dst_size == 0) { - result.length = static_cast(dst_size); - return SUCCESS; - } + GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast(dst_size); return SUCCESS;); std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); - if (dst == nullptr) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + dst == nullptr, GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", TypeUtils::FormatToSerialString(args.src_format).c_str(), TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); - return OUT_OF_MEMORY; - } + return OUT_OF_MEMORY;); for (int64_t c1i = 0; c1i < c1; c1i++) { for (int64_t hi = 0; hi < h; hi++) { @@ -293,14 +301,22 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { ? dst_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); - errno_t ret; + errno_t ret = EOK; if (pad_zero) { ret = memset_s(dst.get() + dst_offset, static_cast(protected_size), 0, static_cast(data_size)); } else { + if (protected_size < data_size) { + GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld", + protected_size, data_size); + return INTERNAL_ERROR; + } int64_t src_idx = n1n0i * hwc + hi * wc + wi * c + (c1i * c0 + c0i); - ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), - args.data + src_idx * data_size, static_cast(data_size)); + char *dst_data = reinterpret_cast(dst.get() + dst_offset); + const char *src_data = reinterpret_cast(args.data + src_idx * data_size); + for (int64_t index = 0; index < data_size; index++) { + *dst_data++ = *src_data++; + } } if (ret != EOK) { GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", @@ -337,16 +353,16 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r return PARAM_INVALID; } - if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { - return TransFormatFromNchwToFz(args, result); + if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) { + return TransFormatNhwcToFz(args, result); } if (args.src_format == FORMAT_HWCN && args.dst_format == FORMAT_FRACTAL_Z) { return TransFormatHwcnToFz(args, result); } - if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) { - return TransFormatNhwcToFz(args, result); + if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { + return TransFormatFromNchwToFz(args, result); } return UNSUPPORTED; @@ -358,14 +374,14 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector 0 ? SUCCESS : UNSUPPORTED; } @@ -109,7 +108,7 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { return NOT_CHANGED; } - /* prepare for padding in chw*/ + // prepare for padding in chw int64_t tmp = h * w * c; int64_t n_o = Ceil(n, static_cast(c0)); int64_t c_o = c0; @@ -309,6 +308,5 @@ Status FormatTransferNchwToFZC04::TransShape(Format src_format, const std::vecto } REGISTER_FORMAT_TRANSFER(FormatTransferNchwToFZC04, FORMAT_NCHW, FORMAT_FRACTAL_Z_C04) - } // namespace formats } // namespace ge diff --git a/src/ge/common/formats/utils/formats_definitions.h b/src/ge/common/formats/utils/formats_definitions.h index d889c33c..2faa60e1 100644 --- a/src/ge/common/formats/utils/formats_definitions.h +++ b/src/ge/common/formats/utils/formats_definitions.h @@ -19,7 +19,6 @@ namespace ge { namespace formats { - static const int kCubeSize = 16; static const int kNiSize = 16; static const int64_t kShapeItemNumMAX = 1024UL * 1024UL * 1024UL * 1024UL; @@ -47,7 +46,6 @@ enum FracZDimIndex { kFracZHWC1, kFracZN0, kFracZNi, kFracZC0, kFracZDimsNum }; enum DhwcnDimIndex { kDhwcnD, kDhwcnH, kDhwcnW, kDhwcnC, kDhwcnN, kDhwcnDimsNum }; enum DhwncDimIndex { kDhwncD, kDhwncH, kDhwncW, kDhwncN, kDhwncC, kDhwncDimsNum }; - } // namespace formats } // namespace ge #endif // GE_COMMON_FORMATS_UTILS_FORMATS_DEFINITIONS_H_ diff --git a/src/ge/common/formats/utils/formats_trans_utils.h b/src/ge/common/formats/utils/formats_trans_utils.h index a8fbd09b..7b902c3e 100644 --- a/src/ge/common/formats/utils/formats_trans_utils.h +++ b/src/ge/common/formats/utils/formats_trans_utils.h @@ -69,7 +69,6 @@ T Ceil(T n1, T n2) { } return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0; } - } // namespace formats } // namespace ge #endif // GE_COMMON_FORMATS_UTILS_FORMATS_TRANS_UTILS_H_ diff --git a/src/ge/common/fp16_t.h b/src/ge/common/fp16_t.h index 34908b95..0fda2cd2 100644 --- a/src/ge/common/fp16_t.h +++ b/src/ge/common/fp16_t.h @@ -600,5 +600,5 @@ int16_t GetManBitLength(T man) { } return len; } -}; // namespace ge +} // namespace ge #endif // GE_COMMON_FP16_T_H_ diff --git a/src/ge/common/ge/op_tiling_manager.cc b/src/ge/common/ge/op_tiling_manager.cc new file mode 100644 index 00000000..7fb7a8fc --- /dev/null +++ b/src/ge/common/ge/op_tiling_manager.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/ge/op_tiling_manager.h" +#include "framework/common/debug/log.h" +#include + +namespace { +const char *const kEnvName = "ASCEND_OPP_PATH"; +const std::string kDefaultPath = "/usr/local/Ascend/opp"; +const std::string kDefaultBuiltInTilingPath = "/op_impl/built-in/liboptiling.so"; +const std::string kDefaultCustomTilingPath = "/op_impl/custom/liboptiling.so"; +const uint8_t kPrefixIndex = 9; +} // namespace + +namespace ge { +void OpTilingManager::ClearHandles() noexcept { + for (const auto &handle : handles_) { + if (dlclose(handle.second) != 0) { + GELOGE(FAILED, "Failed to close handle of %s: %s", handle.first.c_str(), dlerror()); + } + } + handles_.clear(); +} + +OpTilingManager::~OpTilingManager() { ClearHandles(); } + +std::string OpTilingManager::GetPath() { + const char *opp_path_env = std::getenv(kEnvName); + std::string opp_path = kDefaultPath; + if (opp_path_env != nullptr) { + char resolved_path[PATH_MAX]; + if (realpath(opp_path_env, resolved_path) == NULL) { + GELOGE(PARAM_INVALID, "Failed load tiling lib as env 'ASCEND_OPP_PATH'(%s) is invalid path.", opp_path_env); + return std::string(); + } + opp_path = resolved_path; + } + return opp_path; +} + +void OpTilingManager::LoadSo() { + std::string opp_path = GetPath(); + if (opp_path.empty()) { + GELOGW("Skip load tiling lib."); + return; + } + std::string built_in_tiling_lib = opp_path + kDefaultBuiltInTilingPath; + std::string custom_tiling_lib = opp_path + kDefaultCustomTilingPath; + std::string built_in_name = kDefaultBuiltInTilingPath.substr(kPrefixIndex); + std::string custom_name = kDefaultCustomTilingPath.substr(kPrefixIndex); + + void *handle_bi = dlopen(built_in_tiling_lib.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (handle_bi == nullptr) { + GELOGW("Failed to dlopen %s!", dlerror()); + } else { + handles_[built_in_name] = handle_bi; + } + + void *handle_ct = dlopen(custom_tiling_lib.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (handle_ct == nullptr) { + GELOGW("Failed to dlopen %s!", dlerror()); + } else { + handles_[custom_name] = handle_ct; + } +} + +} // namespace ge diff --git a/src/ge/graph/passes/identify_reference_pass.h b/src/ge/common/ge/op_tiling_manager.h similarity index 62% rename from src/ge/graph/passes/identify_reference_pass.h rename to src/ge/common/ge/op_tiling_manager.h index 5f284b4c..320e1411 100644 --- a/src/ge/graph/passes/identify_reference_pass.h +++ b/src/ge/common/ge/op_tiling_manager.h @@ -14,16 +14,25 @@ * limitations under the License. */ -#ifndef GE_GRAPH_PASSES_IDENTIFY_REFERENCE_PASS_H_ -#define GE_GRAPH_PASSES_IDENTIFY_REFERENCE_PASS_H_ +#ifndef GE_COMMON_GE_OP_TILING_MANAGER_H_ +#define GE_COMMON_GE_OP_TILING_MANAGER_H_ -#include "graph/passes/base_pass.h" +#include namespace ge { -class IdentifyReferencePass : public BaseNodePass { +using SoToHandleMap = std::map; + +class OpTilingManager { public: - Status Run(NodePtr &node) override; + OpTilingManager() = default; + ~OpTilingManager(); + void LoadSo(); + + private: + static std::string GetPath(); + void ClearHandles() noexcept; + SoToHandleMap handles_; }; } // namespace ge -#endif // GE_GRAPH_PASSES_IDENTIFY_REFERENCE_PASS_H_ +#endif // GE_COMMON_GE_OP_TILING_MANAGER_H_ diff --git a/src/ge/common/helper/model_helper.cc b/src/ge/common/helper/model_helper.cc index 2f95cbb1..e1f7c75f 100644 --- a/src/ge/common/helper/model_helper.cc +++ b/src/ge/common/helper/model_helper.cc @@ -17,6 +17,7 @@ #include "framework/common/helper/model_helper.h" #include "common/ge/ge_util.h" +#include "common/util/error_manager/error_manager.h" #include "framework/common/debug/log.h" #include "framework/common/util.h" #include "framework/common/debug/ge_log.h" @@ -267,6 +268,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadModel(c } auto partition_table = reinterpret_cast(model_addr_tmp_); if (partition_table->num == kOriginalOmPartitionNum) { + model_addr_tmp_ = nullptr; GELOGE(FAILED, "om model is error,please use executable om model"); return FAILED; } @@ -390,107 +392,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeMo return out_model; } -// Transit func for model to ge_model. It will be removed when load and build support ge_model in future -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::TransModelToGeModel(const ModelPtr &model, - GeModelPtr &ge_model) { - if (model == nullptr) { - GELOGE(FAILED, "Model is null"); - return FAILED; - } - ge_model = ge::MakeShared(); - GE_CHECK_NOTNULL(ge_model); - ge_model->SetGraph(model->GetGraph()); - ge_model->SetName(model->GetName()); - ge_model->SetVersion(model->GetVersion()); - ge_model->SetPlatformVersion(model->GetPlatformVersion()); - ge_model->SetAttr(model->MutableAttrMap()); - - // Copy weight info - auto compute_graph = ge::GraphUtils::GetComputeGraph(model->GetGraph()); - // ge::Buffer weight; - ge::Buffer weight; - (void)ge::AttrUtils::GetZeroCopyBytes(compute_graph, ge::ATTR_NAME_WEIGHTS_DATA, weight); - ge_model->SetWeight(weight); - // Copy task info - if (model->HasAttr(MODEL_ATTR_TASKS)) { - ge::Buffer task_buffer; - GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetZeroCopyBytes(model, MODEL_ATTR_TASKS, task_buffer), FAILED, - "Get bytes failed."); - - std::shared_ptr task = ge::MakeShared(); - GE_CHECK_NOTNULL(task); - GE_IF_BOOL_EXEC(task_buffer.GetData() == nullptr, GELOGE(FAILED, "Get data fail"); return FAILED); - GE_IF_BOOL_EXEC(task_buffer.GetSize() == 0, GELOGE(FAILED, "Get size fail"); return FAILED); - - GE_CHK_BOOL_EXEC(ReadProtoFromArray(task_buffer.GetData(), static_cast(task_buffer.GetSize()), task.get()), - return INTERNAL_ERROR, "ReadProtoFromArray failed."); - - ge_model->SetModelTaskDef(task); - } - // Copy tbe kernel info - // TBEKernelStore kernel_store; - TBEKernelStore kernel_store; - if (compute_graph != nullptr && compute_graph->GetDirectNodesSize() != 0) { - for (const ge::NodePtr &n : compute_graph->GetDirectNode()) { - auto node_op_desc = n->GetOpDesc(); - GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); - TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); - GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); - kernel_store.AddTBEKernel(tbe_kernel); - GELOGI("Add tbe kernel bin %s", tbe_kernel->GetName().c_str()); - } - } - if (!kernel_store.Build()) { - GELOGE(FAILED, "TBE Kernels store build failed!"); - return FAILED; - } - ge_model->SetTBEKernelStore(kernel_store); - - return SUCCESS; -} - -// trasit func for ge_model to Model. will be removed when load and build support ge_model in future -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::TransGeModelToModel(const GeModelPtr &ge_model, - ModelPtr &model) { - if (ge_model == nullptr) { - GELOGE(FAILED, "Ge_model is null"); - return FAILED; - } - model = ge::MakeShared(); - GE_CHECK_NOTNULL(model); - model->SetGraph(ge_model->GetGraph()); - model->SetName(ge_model->GetName()); - model->SetVersion(ge_model->GetVersion()); - model->SetPlatformVersion(ge_model->GetPlatformVersion()); - model->SetAttr(ge_model->MutableAttrMap()); - // Copy weight info - auto compute_graph = ge::GraphUtils::GetComputeGraph(model->GetGraph()); - bool ret = ge::AttrUtils::SetZeroCopyBytes(compute_graph, ge::ATTR_NAME_WEIGHTS_DATA, ge_model->GetWeight()); - if (!ret) { - GELOGE(FAILED, "Copy weight buffer failed!"); - return FAILED; - } - // Copy task info - std::shared_ptr model_task = ge_model->GetModelTaskDefPtr(); - - if (model_task != nullptr) { - int size = model_task->ByteSize(); - ge::Buffer buffer(static_cast(size)); - if (buffer.GetSize() == 0) { - GELOGE(MEMALLOC_FAILED, "alloc model attr task buffer failed!"); - return MEMALLOC_FAILED; - } - // no need to check value - (void)model_task->SerializePartialToArray(buffer.GetData(), size); - ret = ge::AttrUtils::SetZeroCopyBytes(model, MODEL_ATTR_TASKS, std::move(buffer)); - if (!ret) { - GELOGE(FAILED, "Copy task buffer failed!"); - return FAILED; - } - } - return SUCCESS; -} - Status ModelHelper::ReleaseLocalModelData() noexcept { Status result = SUCCESS; if (model_addr_tmp_ != nullptr) { diff --git a/src/ge/common/math/fp16_math.h b/src/ge/common/math/fp16_math.h index 5bc9ac6d..c3a4eb28 100644 --- a/src/ge/common/math/fp16_math.h +++ b/src/ge/common/math/fp16_math.h @@ -92,5 +92,5 @@ fp16_t max(fp16_t fp1, fp16_t fp2); /// @brief Calculate the minimum fp16_t of fp1 and fp2 /// @return Returns minimum fp16_t of fp1 and fp2 fp16_t min(fp16_t fp1, fp16_t fp2); -}; // namespace ge +} // namespace ge #endif // GE_COMMON_MATH_FP16_MATH_H_ \ No newline at end of file diff --git a/src/ge/common/math_util.h b/src/ge/common/math_util.h index 5e783e81..a12be9e0 100644 --- a/src/ge/common/math_util.h +++ b/src/ge/common/math_util.h @@ -27,7 +27,6 @@ #include "mmpa/mmpa_api.h" namespace ge { - /** * @ingroup domi_calibration * @brief Initializes an input array to a specified value @@ -67,7 +66,6 @@ Status NnSet(const int32_t n, const Dtype alpha, Dtype *output) { } return SUCCESS; } - } // end namespace ge #endif // GE_COMMON_MATH_UTIL_H_ diff --git a/src/ge/common/model_saver.cc b/src/ge/common/model_saver.cc index 11d9e804..821fde60 100644 --- a/src/ge/common/model_saver.cc +++ b/src/ge/common/model_saver.cc @@ -60,8 +60,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFi mode_t mode = S_IRUSR | S_IWUSR; int32_t fd = mmOpen2(real_path, O_RDWR | O_CREAT | O_TRUNC, mode); if (fd == EN_ERROR || fd == EN_INVALID_PARAM) { - ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"filepath", "errMsg"}, {file_path, strerror(errno)}); - GELOGE(FAILED, "Open file failed. file path : %s, %s", file_path, strerror(errno)); + ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file_path, strerror(errno)}); + GELOGE(FAILED, "Open file[%s] failed. %s", file_path, strerror(errno)); return FAILED; } const char *model_char = model_str.c_str(); @@ -69,8 +69,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFi // Write data to file mmSsize_t mmpa_ret = mmWrite(fd, const_cast((const void *)model_char), len); if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { - ErrorManager::GetInstance().ATCReportErrMessage("E19003", {"mmpa_ret", "errMsg"}, - {std::to_string(mmpa_ret), strerror(errno)}); + ErrorManager::GetInstance().ATCReportErrMessage("E19004", {"file", "errmsg"}, {file_path, strerror(errno)}); // Need to both print the error info of mmWrite and mmClose, so return ret after mmClose GELOGE(FAILED, "Write to file failed. errno = %d, %s", mmpa_ret, strerror(errno)); ret = FAILED; diff --git a/src/ge/common/profiling/profiling_manager.cc b/src/ge/common/profiling/profiling_manager.cc index 0944b5e0..04d23546 100644 --- a/src/ge/common/profiling/profiling_manager.cc +++ b/src/ge/common/profiling/profiling_manager.cc @@ -336,16 +336,17 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin std::string data; for (const auto &task : task_desc_info) { + std::string model_name = task.model_name; std::string op_name = task.op_name; uint32_t block_dim = task.block_dim; uint32_t task_id = task.task_id; uint32_t stream_id = task.stream_id; - data = op_name.append(" ").append(std::to_string(block_dim) - .append(" ") - .append(std::to_string(task_id)) - .append(" ") - .append(std::to_string(stream_id)) - .append("\n")); + data = model_name.append(" ").append(op_name).append(" ").append(std::to_string(block_dim) + .append(" ") + .append(std::to_string(task_id)) + .append(" ") + .append(std::to_string(stream_id)) + .append("\n")); Msprof::Engine::ReporterData reporter_data{}; reporter_data.deviceId = device_id; @@ -376,7 +377,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin std::string data; for (const auto &graph : compute_graph_desc_info) { - data.append("op_name:").append(graph.op_name).append(" op_type:").append(graph.op_type); + data.append("model_name:") + .append(graph.model_name) + .append(" op_name:") + .append(graph.op_name) + .append(" op_type:") + .append(graph.op_type); for (size_t i = 0; i < graph.input_format.size(); ++i) { data.append(" input_id:") .append(std::to_string(i)) diff --git a/src/ge/common/properties_manager.cc b/src/ge/common/properties_manager.cc index cf1ada05..0c2b1db6 100644 --- a/src/ge/common/properties_manager.cc +++ b/src/ge/common/properties_manager.cc @@ -20,15 +20,204 @@ #include #include +#include "common/ge/ge_util.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" #include "framework/common/ge_types.h" #include "framework/common/types.h" #include "graph/debug/ge_attr_define.h" +#include "graph/ge_context.h" #include "graph/utils/attr_utils.h" namespace ge { +namespace { +const string kEnableFlag = "1"; + +const uint32_t kAicoreOverflow = (0x1 << 0); +const uint32_t kAtomicOverflow = (0x1 << 1); +const uint32_t kAllOverflow = (kAicoreOverflow | kAtomicOverflow); +} // namespace + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties::DumpProperties(const DumpProperties &other) { + CopyFrom(other); +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties &DumpProperties::operator=( + const DumpProperties &other) { + CopyFrom(other); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::InitByOptions() { + enable_dump_.clear(); + enable_dump_debug_.clear(); + dump_path_.clear(); + dump_step_.clear(); + dump_mode_.clear(); + is_op_debug_ = false; + op_debug_mode_ = 0; + + string enable_dump; + (void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP, enable_dump); + enable_dump_ = enable_dump; + + string enable_dump_debug; + (void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP_DEBUG, enable_dump_debug); + enable_dump_debug_ = enable_dump_debug; + + if ((enable_dump_ == kEnableFlag) || (enable_dump_debug_ == kEnableFlag)) { + string dump_path; + if (GetContext().GetOption(OPTION_EXEC_DUMP_PATH, dump_path) == GRAPH_SUCCESS) { + if (!dump_path.empty() && dump_path[dump_path.size() - 1] != '/') { + dump_path = dump_path + "/"; + } + dump_path = dump_path + CurrentTimeInStr() + "/"; + GELOGI("Get dump path %s successfully", dump_path.c_str()); + SetDumpPath(dump_path); + } else { + GELOGW("DUMP_PATH is not set"); + } + } + + if (enable_dump_ == kEnableFlag) { + string dump_step; + if (GetContext().GetOption(OPTION_EXEC_DUMP_STEP, dump_step) == GRAPH_SUCCESS) { + GELOGD("Get dump step %s successfully", dump_step.c_str()); + SetDumpStep(dump_step); + } + string dump_mode; + if (GetContext().GetOption(OPTION_EXEC_DUMP_MODE, dump_mode) == GRAPH_SUCCESS) { + GELOGD("Get dump mode %s successfully", dump_mode.c_str()); + SetDumpMode(dump_mode); + } + AddPropertyValue(DUMP_ALL_MODEL, {}); + } + + SetDumpDebugOptions(); +} + +// The following is the new dump scenario of the fusion operator +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::AddPropertyValue( + const std::string &model, const std::set &layers) { + for (const std::string &layer : layers) { + GELOGI("This model %s config to dump layer %s", model.c_str(), layer.c_str()); + } + + model_dump_properties_map_[model] = layers; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::DeletePropertyValue(const std::string &model) { + auto iter = model_dump_properties_map_.find(model); + if (iter != model_dump_properties_map_.end()) { + model_dump_properties_map_.erase(iter); + } +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set DumpProperties::GetAllDumpModel() const { + std::set model_list; + for (auto &iter : model_dump_properties_map_) { + model_list.insert(iter.first); + } + + return model_list; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set DumpProperties::GetPropertyValue( + const std::string &model) const { + auto iter = model_dump_properties_map_.find(model); + if (iter != model_dump_properties_map_.end()) { + return iter->second; + } + return {}; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool DumpProperties::IsLayerNeedDump( + const std::string &model, const std::string &om_name, const std::string &op_name) const { + // if dump all + if (model_dump_properties_map_.find(DUMP_ALL_MODEL) != model_dump_properties_map_.end()) { + return true; + } + + // if this model need dump + auto om_name_iter = model_dump_properties_map_.find(om_name); + auto model_name_iter = model_dump_properties_map_.find(model); + if (om_name_iter != model_dump_properties_map_.end() || model_name_iter != model_dump_properties_map_.end()) { + // if no dump layer info, dump all layer in this model + auto model_iter = om_name_iter != model_dump_properties_map_.end() ? om_name_iter : model_name_iter; + if (model_iter->second.empty()) { + return true; + } + + return model_iter->second.find(op_name) != model_iter->second.end(); + } + + GELOGD("Model %s is not seated to be dump.", model.c_str()); + return false; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpPath(const std::string &path) { + dump_path_ = path; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string DumpProperties::GetDumpPath() const { return dump_path_; } + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpStep(const std::string &step) { + dump_step_ = step; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string DumpProperties::GetDumpStep() const { return dump_step_; } + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpMode(const std::string &mode) { + dump_mode_ = mode; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string DumpProperties::GetDumpMode() const { return dump_mode_; } + +void DumpProperties::CopyFrom(const DumpProperties &other) { + if (&other != this) { + enable_dump_ = other.enable_dump_; + enable_dump_debug_ = other.enable_dump_debug_; + dump_path_ = other.dump_path_; + dump_step_ = other.dump_step_; + dump_mode_ = other.dump_mode_; + + model_dump_properties_map_ = other.model_dump_properties_map_; + is_op_debug_ = other.is_op_debug_; + op_debug_mode_ = other.op_debug_mode_; + } +} + +void DumpProperties::SetDumpDebugOptions() { + if (enable_dump_debug_ == kEnableFlag) { + string dump_debug_mode; + if (GetContext().GetOption(OPTION_EXEC_DUMP_DEBUG_MODE, dump_debug_mode) == GRAPH_SUCCESS) { + GELOGD("Get dump debug mode %s successfully", dump_debug_mode.c_str()); + } else { + GELOGW("Dump debug mode is not set."); + return; + } + + if (dump_debug_mode == OP_DEBUG_AICORE) { + GELOGD("ge.exec.dumpDebugMode=aicore_overflow, op debug is open."); + is_op_debug_ = true; + op_debug_mode_ = kAicoreOverflow; + } else if (dump_debug_mode == OP_DEBUG_ATOMIC) { + GELOGD("ge.exec.dumpDebugMode=atomic_overflow, op debug is open."); + is_op_debug_ = true; + op_debug_mode_ = kAtomicOverflow; + } else if (dump_debug_mode == OP_DEBUG_ALL) { + GELOGD("ge.exec.dumpDebugMode=all, op debug is open."); + is_op_debug_ = true; + op_debug_mode_ = kAllOverflow; + } else { + GELOGW("ge.exec.dumpDebugMode is invalid."); + } + } else { + GELOGI("ge.exec.enableDumpDebug is false or is not set."); + } +} + PropertiesManager::PropertiesManager() : is_inited_(false), delimiter("=") {} PropertiesManager::~PropertiesManager() {} @@ -159,131 +348,22 @@ PropertiesManager::GetPropertyMap() { // Set separator FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetPropertyDelimiter(const std::string &de) { + std::lock_guard lock(mutex_); delimiter = de; } -// The following is the new dump scenario of the fusion operator -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::AddDumpPropertyValue( - const std::string &model, const std::set &layers) { - for (const std::string &layer : layers) { - GELOGI("This model %s config to dump layer %s", model.c_str(), layer.c_str()); - } - - std::lock_guard lock(dump_mutex_); - model_dump_properties_map_[model] = layers; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::DeleteDumpPropertyValue( - const std::string &model) { - std::lock_guard lock(dump_mutex_); - auto iter = model_dump_properties_map_.find(model); - if (iter != model_dump_properties_map_.end()) { - model_dump_properties_map_.erase(iter); - } -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::ClearDumpPropertyValue() { - std::lock_guard lock(dump_mutex_); - model_dump_properties_map_.clear(); -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set PropertiesManager::GetAllDumpModel() { - std::set model_list; - std::lock_guard lock(dump_mutex_); - for (auto &iter : model_dump_properties_map_) { - model_list.insert(iter.first); - } - - return model_list; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set PropertiesManager::GetDumpPropertyValue( - const std::string &model) { - std::lock_guard lock(dump_mutex_); - auto iter = model_dump_properties_map_.find(model); - if (iter != model_dump_properties_map_.end()) { - return iter->second; - } - return {}; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool PropertiesManager::IsLayerNeedDump(const std::string &model, - const std::string &om_name, - const std::string &op_name) { - std::lock_guard lock(dump_mutex_); - // if dump all - if (model_dump_properties_map_.find(ge::DUMP_ALL_MODEL) != model_dump_properties_map_.end()) { - return true; - } - - // if this model need dump - auto om_name_iter = model_dump_properties_map_.find(om_name); - auto model_name_iter = model_dump_properties_map_.find(model); - if (om_name_iter != model_dump_properties_map_.end() || model_name_iter != model_dump_properties_map_.end()) { - // if no dump layer info, dump all layer in this model - auto model_iter = om_name_iter != model_dump_properties_map_.end() ? om_name_iter : model_name_iter; - if (model_iter->second.empty()) { - return true; - } - - return model_iter->second.find(op_name) != model_iter->second.end(); - } - - GELOGD("Model %s is not seated to be dump.", model.c_str()); - return false; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties &PropertiesManager::GetDumpProperties( + uint64_t session_id) { + std::lock_guard lock(mutex_); + // If session_id is not found in dump_properties_map_, operator[] will insert one. + return dump_properties_map_[session_id]; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool PropertiesManager::QueryModelDumpStatus( - const std::string &model) { - std::lock_guard lock(dump_mutex_); - auto iter = model_dump_properties_map_.find(model); - if (iter != model_dump_properties_map_.end()) { - return true; - } else if (model_dump_properties_map_.find(ge::DUMP_ALL_MODEL) != model_dump_properties_map_.end()) { - return true; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::RemoveDumpProperties(uint64_t session_id) { + std::lock_guard lock(mutex_); + auto iter = dump_properties_map_.find(session_id); + if (iter != dump_properties_map_.end()) { + dump_properties_map_.erase(iter); } - return false; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetDumpOutputModel( - const std::string &output_mode) { - std::lock_guard lock(dump_mutex_); - this->output_mode_ = output_mode; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager::GetDumpOutputModel() { - std::lock_guard lock(dump_mutex_); - return this->output_mode_; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetDumpOutputPath( - const std::string &output_path) { - std::lock_guard lock(dump_mutex_); - this->output_path_ = output_path; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager::GetDumpOutputPath() { - std::lock_guard lock(dump_mutex_); - return this->output_path_; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetDumpStep(const std::string &dump_step) { - std::lock_guard lock(dump_mutex_); - this->dump_step_ = dump_step; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager::GetDumpStep() { - std::lock_guard lock(dump_mutex_); - return this->dump_step_; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetDumpMode(const std::string &dump_mode) { - std::lock_guard lock(dump_mutex_); - this->dump_mode_ = dump_mode; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager::GetDumpMode() { - std::lock_guard lock(dump_mutex_); - return this->dump_mode_; } } // namespace ge diff --git a/src/ge/common/properties_manager.h b/src/ge/common/properties_manager.h index 7cbb5949..3b1547f5 100644 --- a/src/ge/common/properties_manager.h +++ b/src/ge/common/properties_manager.h @@ -32,6 +32,50 @@ static const char *USE_FUSION __attribute__((unused)) = "FMK_USE_FUSION"; static const char *TIMESTAT_ENABLE __attribute__((unused)) = "DAVINCI_TIMESTAT_ENABLE"; static const char *ANNDROID_DEBUG __attribute__((unused)) = "ANNDROID_DEBUG"; +class DumpProperties { + public: + DumpProperties() = default; + ~DumpProperties() = default; + DumpProperties(const DumpProperties &dump); + DumpProperties &operator=(const DumpProperties &dump); + + void InitByOptions(); + + void AddPropertyValue(const std::string &model, const std::set &layers); + void DeletePropertyValue(const std::string &model); + + std::set GetAllDumpModel() const; + std::set GetPropertyValue(const std::string &model) const; + bool IsLayerNeedDump(const std::string &model, const std::string &om_name, const std::string &op_name) const; + + void SetDumpPath(const std::string &path); + std::string GetDumpPath() const; + + void SetDumpStep(const std::string &step); + std::string GetDumpStep() const; + + void SetDumpMode(const std::string &mode); + std::string GetDumpMode() const; + + bool IsOpDebugOpen() const { return is_op_debug_; } + uint32_t GetOpDebugMode() const { return op_debug_mode_; } + + private: + void CopyFrom(const DumpProperties &other); + void SetDumpDebugOptions(); + + string enable_dump_; + string enable_dump_debug_; + + std::string dump_path_; + std::string dump_step_; + std::string dump_mode_; + std::map> model_dump_properties_map_; + + bool is_op_debug_ = false; + uint32_t op_debug_mode_ = 0; +}; + class PropertiesManager { public: // Singleton @@ -81,21 +125,8 @@ class PropertiesManager { */ void SetPropertyDelimiter(const std::string &de); - void AddDumpPropertyValue(const std::string &model, const std::set &layers); - std::set GetAllDumpModel(); - std::set GetDumpPropertyValue(const std::string &model); - bool IsLayerNeedDump(const std::string &model, const std::string &om_name, const std::string &op_name); - void DeleteDumpPropertyValue(const std::string &model); - void ClearDumpPropertyValue(); - bool QueryModelDumpStatus(const std::string &model); - void SetDumpOutputModel(const std::string &output_model); - std::string GetDumpOutputModel(); - void SetDumpOutputPath(const std::string &output_path); - std::string GetDumpOutputPath(); - void SetDumpStep(const std::string &dump_step); - std::string GetDumpStep(); - void SetDumpMode(const std::string &dump_mode); - std::string GetDumpMode(); + DumpProperties &GetDumpProperties(uint64_t session_id); + void RemoveDumpProperties(uint64_t session_id); private: // Private construct, destructor @@ -119,12 +150,7 @@ class PropertiesManager { std::map properties_map_; std::mutex mutex_; - std::string output_mode_; - std::string output_path_; - std::string dump_step_; - std::string dump_mode_; - std::map> model_dump_properties_map_; // model_dump_layers_map_ - std::mutex dump_mutex_; + std::map dump_properties_map_; }; } // namespace ge diff --git a/src/ge/common/tbe_kernel_store.h b/src/ge/common/tbe_kernel_store.h index da231358..51d69af2 100644 --- a/src/ge/common/tbe_kernel_store.h +++ b/src/ge/common/tbe_kernel_store.h @@ -28,7 +28,6 @@ #include "graph/op_kernel_bin.h" namespace ge { - using TBEKernel = ge::OpKernelBin; using TBEKernelPtr = std::shared_ptr; diff --git a/src/ge/common/types.cc b/src/ge/common/types.cc index 97761dea..80dea8a0 100644 --- a/src/ge/common/types.cc +++ b/src/ge/common/types.cc @@ -26,6 +26,11 @@ const std::string DUMP_LAYER = "layer"; const std::string DUMP_FILE_PATH = "path"; const std::string DUMP_MODE = "dump_mode"; +// op debug mode +const std::string OP_DEBUG_AICORE = "aicore_overflow"; +const std::string OP_DEBUG_ATOMIC = "atomic_overflow"; +const std::string OP_DEBUG_ALL = "all"; + const int DEFAULT_FORMAT = static_cast(ge::FORMAT_NCHW); // Supported public property names const std::string PROP_OME_START_TIME = "ome_start_time"; // start time @@ -277,8 +282,8 @@ REGISTER_OPTYPE_DEFINE(GETSPAN, "GetSpan"); REGISTER_OPTYPE_DEFINE(STOPGRADIENT, "StopGradient"); REGISTER_OPTYPE_DEFINE(PREVENTGRADIENT, "PreventGradient"); REGISTER_OPTYPE_DEFINE(GUARANTEECONST, "GuaranteeConst"); -REGISTER_OPTYPE_DEFINE(BROADCASTGRADIENTARGS, "BroadcastGradientArgs") -REGISTER_OPTYPE_DEFINE(BROADCASTARGS, "BroadcastArgs") +REGISTER_OPTYPE_DEFINE(BROADCASTGRADIENTARGS, "BroadcastGradientArgs"); +REGISTER_OPTYPE_DEFINE(BROADCASTARGS, "BroadcastArgs"); REGISTER_OPTYPE_DEFINE(CONFUSIONMATRIX, "ConfusionMatrix"); REGISTER_OPTYPE_DEFINE(RANK, "Rank"); REGISTER_OPTYPE_DEFINE(PLACEHOLDER, "PlaceHolder"); @@ -286,6 +291,7 @@ REGISTER_OPTYPE_DEFINE(END, "End"); REGISTER_OPTYPE_DEFINE(BASICLSTMCELL, "BasicLSTMCell"); REGISTER_OPTYPE_DEFINE(GETNEXT, "GetNext"); REGISTER_OPTYPE_DEFINE(INITDATA, "InitData"); +REGISTER_OPTYPE_DEFINE(REFIDENTITY, "RefIdentity"); /***************Ann special operator*************************/ REGISTER_OPTYPE_DEFINE(ANN_MEAN, "AnnMean"); @@ -479,72 +485,72 @@ const uint64_t ALLOC_MEMORY_MAX_SIZE = 536870912; // Max size of 512M. #endif /// -///@brief Magic number of model file +/// @brief Magic number of model file /// const uint32_t MODEL_FILE_MAGIC_NUM = 0x444F4D49; // magic number /// -///@brief Model head length +/// @brief Model head length /// const uint32_t MODEL_FILE_HEAD_LEN = 256; /// -///@ingroup domi_omg -///@brief Input node type +/// @ingroup domi_omg +/// @brief Input node type /// const std::string INPUT_TYPE = "Input"; /// -///@ingroup domi_omg -///@brief AIPP label, label AIPP conv operator +/// @ingroup domi_omg +/// @brief AIPP label, label AIPP conv operator /// const std::string AIPP_CONV_FLAG = "Aipp_Conv_Flag"; /// -///@ingroup domi_omg -///@brief AIPP label, label aipp data operator +/// @ingroup domi_omg +/// @brief AIPP label, label aipp data operator /// const std::string AIPP_DATA_FLAG = "Aipp_Data_Flag"; /// -///@ingroup domi_omg -///@brief Record the w dimension of model input corresponding to dynamic AIPP +/// @ingroup domi_omg +/// @brief Record the w dimension of model input corresponding to dynamic AIPP /// const std::string AIPP_RELATED_DATA_DIM_W = "aipp_related_data_dim_w"; /// -///@ingroup domi_omg -///@brief Record the H dimension of model input corresponding to dynamic AIPP +/// @ingroup domi_omg +/// @brief Record the H dimension of model input corresponding to dynamic AIPP /// const std::string AIPP_RELATED_DATA_DIM_H = "aipp_related_data_dim_h"; /// -///@ingroup domi_omg -///@brief The tag of the data operator. Mark this input to the dynamic AIPP operator +/// @ingroup domi_omg +/// @brief The tag of the data operator. Mark this input to the dynamic AIPP operator /// const std::string INPUT_TO_DYNAMIC_AIPP = "input_to_dynamic_aipp"; /// -///@ingroup domi_omg -///@brief DATA node type +/// @ingroup domi_omg +/// @brief DATA node type /// const std::string DATA_TYPE = "Data"; /// -///@ingroup domi_omg -///@brief DATA node type +/// @ingroup domi_omg +/// @brief DATA node type /// const std::string AIPP_DATA_TYPE = "AippData"; /// -///@ingroup domi_omg -///@brief Frame operator type +/// @ingroup domi_omg +/// @brief Frame operator type /// const std::string FRAMEWORK_OP_TYPE = "FrameworkOp"; /// -///@ingroup domi_omg -///@brief Data node type +/// @ingroup domi_omg +/// @brief Data node type /// const std::string ANN_DATA_TYPE = "AnnData"; const std::string ANN_NETOUTPUT_TYPE = "AnnNetOutput"; @@ -552,136 +558,139 @@ const std::string ANN_DEPTHCONV_TYPE = "AnnDepthConv"; const std::string ANN_CONV_TYPE = "AnnConvolution"; const std::string ANN_FC_TYPE = "AnnFullConnection"; /// -///@ingroup domi_omg -///@brief Convolution node type +/// @ingroup domi_omg +/// @brief Convolution node type /// const std::string NODE_NAME_NET_OUTPUT = "Node_Output"; const std::string NODE_NAME_END_GRAPH = "Node_EndGraph"; +const std::string NODE_NAME_OP_DEBUG = "Node_OpDebug"; +const std::string OP_TYPE_OP_DEBUG = "Opdebug"; + /// -///@ingroup domi_omg -///@brief Convolution node type +/// @ingroup domi_omg +/// @brief Convolution node type /// const std::string OP_TYPE_CONVOLUTION = "Convolution"; /// -///@ingroup domi_omg -///@brief Add convolution node name to AIPP +/// @ingroup domi_omg +/// @brief Add convolution node name to AIPP /// const std::string AIPP_CONV_OP_NAME = "aipp_conv_op"; /// -///@ingroup domi_omg -///@brief Operator configuration item separator +/// @ingroup domi_omg +/// @brief Operator configuration item separator /// const std::string OP_CONF_DELIMITER = ":"; /// -///@ingroup domi_omg -///@brief attr value name +/// @ingroup domi_omg +/// @brief attr value name /// const std::string ATTR_NAME_VALUE1 = "value1"; /// -///@ingroup domi_omg -///@brief attr value name, 6d_2_4d C +/// @ingroup domi_omg +/// @brief attr value name, 6d_2_4d C /// const std::string ATTR_NAME_INPUT_CVALUE = "input_cvalue"; /// -///@ingroup domi_omg -///@brief alpha default value +/// @ingroup domi_omg +/// @brief alpha default value /// const float ALPHA_DEFAULT_VALUE = 1.0; /// -///@ingroup domi_omg -///@brief beta default value +/// @ingroup domi_omg +/// @brief beta default value /// const float BETA_DEFAULT_VALUE = 0.0; /// -///@ingroup domi_omg -///@brief coef default value +/// @ingroup domi_omg +/// @brief coef default value /// const float COEF_DEFAULT_VALUE = 0.0; /// -///@ingroup domi_omg -///@brief Relu6 coef value +/// @ingroup domi_omg +/// @brief Relu6 coef value /// const float RELU6_COEF = 6.0; /// -///@ingroup domi_omg -///@brief stride default value +/// @ingroup domi_omg +/// @brief stride default value /// const uint32_t STRIDE_DEFAULT_VALUE = 1; /// -///@ingroup domi_omg -///@brief pad default value +/// @ingroup domi_omg +/// @brief pad default value /// const uint32_t PAD_DEFAULT_VALUE = 0; /// -///@ingroup domi_omg -///@brief dilation default value +/// @ingroup domi_omg +/// @brief dilation default value /// const int DILATION_DEFAULT_VALUE = 1; /// -///@ingroup domi_omg -///@brief kernel default value +/// @ingroup domi_omg +/// @brief kernel default value /// const uint32_t KERNEL_DEFAULT_VALUE = 0; /// -///@ingroup domi_omg -///@brief defaule convolution group size +/// @ingroup domi_omg +/// @brief defaule convolution group size /// const uint32_t DEFAULT_CONV_GROUP = 1; /// -///@ingroup domi_omg -///@brief Default deconvolution adj +/// @ingroup domi_omg +/// @brief Default deconvolution adj /// const uint32_t DEFAULT_DECONV_ADJ = 0; /// -///@ingroup domi_omg -///@brief Represents value 1 +/// @ingroup domi_omg +/// @brief Represents value 1 /// const uint32_t NUM_ONE = 1; /// -///@ingroup domi_omg -///@brief spatial dim size default value +/// @ingroup domi_omg +/// @brief spatial dim size default value /// const int32_t SPATIAL_DIM_DEFAULT_SIZE = 2; /// -///@ingroup domi_omg -///@brief dim extended default value +/// @ingroup domi_omg +/// @brief dim extended default value /// const int32_t DIM_DEFAULT_VALUE = 1; /// -///@ingroup domi_omg -///@brief The first weight list in opdef is filter +/// @ingroup domi_omg +/// @brief The first weight list in opdef is filter /// const int32_t WEIGHT_FILTER_INDEX = 0; /// -///@ingroup domi_omg -///@brief The second weight list in opdef is bias +/// @ingroup domi_omg +/// @brief The second weight list in opdef is bias /// const int32_t WEIGHT_BIAS_INDEX = 1; const int32_t TENSOR_ND_SUPPORT_SIZE = 8; /// -///@ingroup domi_omg -///@brief NCHW index default value +/// @ingroup domi_omg +/// @brief NCHW index default value /// const uint32_t NCHW_DIM_N = 0; const uint32_t NCHW_DIM_C = 1; @@ -689,8 +698,8 @@ const uint32_t NCHW_DIM_H = 2; const uint32_t NCHW_DIM_W = 3; /// -///@ingroup domi_omg -///@brief KCHW index default value +/// @ingroup domi_omg +/// @brief KCHW index default value /// const uint32_t KCHW_DIM_K = 0; const uint32_t KCHW_DIM_C = 1; @@ -698,8 +707,8 @@ const uint32_t KCHW_DIM_H = 2; const uint32_t KCHW_DIM_W = 3; /// -///@ingroup domi_omg -///@brief HWCK index default value +/// @ingroup domi_omg +/// @brief HWCK index default value /// const uint32_t HWCK_DIM_H = 0; const uint32_t HWCK_DIM_W = 1; @@ -707,8 +716,8 @@ const uint32_t HWCK_DIM_C = 2; const uint32_t HWCK_DIM_K = 3; /// -///@ingroup domi_omg -///@brief NHWC index default value +/// @ingroup domi_omg +/// @brief NHWC index default value /// const uint32_t NHWC_DIM_N = 0; const uint32_t NHWC_DIM_H = 1; @@ -716,8 +725,8 @@ const uint32_t NHWC_DIM_W = 2; const uint32_t NHWC_DIM_C = 3; /// -///@ingroup domi_omg -///@brief CHWN index default value +/// @ingroup domi_omg +/// @brief CHWN index default value /// const uint32_t CHWN_DIM_N = 3; const uint32_t CHWN_DIM_C = 0; @@ -725,23 +734,23 @@ const uint32_t CHWN_DIM_H = 1; const uint32_t CHWN_DIM_W = 2; /// -///@ingroup domi_omg -///@brief CHW index default value +/// @ingroup domi_omg +/// @brief CHW index default value /// const uint32_t CHW_DIM_C = 0; const uint32_t CHW_DIM_H = 1; const uint32_t CHW_DIM_W = 2; /// -///@ingroup domi_omg -///@brief HWC index default value +/// @ingroup domi_omg +/// @brief HWC index default value /// const uint32_t HWC_DIM_H = 0; const uint32_t HWC_DIM_W = 1; const uint32_t HWC_DIM_C = 2; /// -///@ingroup domi_omg -///@brief Pad index default value +/// @ingroup domi_omg +/// @brief Pad index default value /// const uint32_t PAD_H_HEAD = 0; const uint32_t PAD_H_TAIL = 1; @@ -749,35 +758,35 @@ const uint32_t PAD_W_HEAD = 2; const uint32_t PAD_W_TAIL = 3; /// -///@ingroup domi_omg -///@brief window index default value +/// @ingroup domi_omg +/// @brief window index default value /// const uint32_t WINDOW_H = 0; const uint32_t WINDOW_W = 1; /// -///@ingroup domi_omg -///@brief stride index default value +/// @ingroup domi_omg +/// @brief stride index default value /// const uint32_t STRIDE_H = 0; const uint32_t STRIDE_W = 1; /// -///@ingroup domi_omg -///@brief dilation index default value +/// @ingroup domi_omg +/// @brief dilation index default value /// const uint32_t DILATION_H = 0; const uint32_t DILATION_W = 1; /// -///@ingroup domi_omg -///@brief the num of XRBG channel +/// @ingroup domi_omg +/// @brief the num of XRBG channel /// const uint32_t XRGB_CHN_NUM = 4; /// -///@ingroup domi_omg -///@brief global pooling default value +/// @ingroup domi_omg +/// @brief global pooling default value /// const bool DEFAULT_GLOBAL_POOLING = false; @@ -801,4 +810,4 @@ const uint32_t STREAM_SWITCH_INPUT_NUM = 2; const std::string NODE_NAME_GLOBAL_STEP = "ge_global_step"; const std::string NODE_NAME_GLOBAL_STEP_ASSIGNADD = "global_step_assignadd"; -}; // namespace ge +} // namespace ge diff --git a/src/ge/common/util.cc b/src/ge/common/util.cc index 50ed2f33..69dc7442 100644 --- a/src/ge/common/util.cc +++ b/src/ge/common/util.cc @@ -56,6 +56,7 @@ const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M /// The maximum length of the file. /// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1 const int kMaxFileSizeLimit = INT_MAX; +const char *const kPathValidReason = "The path can only contains 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese character"; } // namespace namespace ge { @@ -77,7 +78,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary); if (!fs.is_open()) { - ErrorManager::GetInstance().ATCReportErrMessage("E19004", {"realpath"}, {file}); + ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file, "ifstream is_open failed"}); GELOGE(ge::FAILED, "Open real path[%s] failed.", file); return false; } @@ -90,7 +91,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co fs.close(); if (!ret) { - ErrorManager::GetInstance().ATCReportErrMessage("E19005", {"filepath"}, {file}); + ErrorManager::GetInstance().ATCReportErrMessage("E19005", {"file"}, {file}); GELOGE(ge::FAILED, "Parse file[%s] failed.", file); return ret; } @@ -114,17 +115,18 @@ long GetFileLength(const std::string &input_file) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str()); unsigned long long file_length = 0; - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, - ErrorManager::GetInstance().ATCReportErrMessage("E10037", {"filepath"}, {input_file}); - return -1, "Open file[%s] failed", input_file.c_str()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, + ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {input_file, strerror(errno)}); + return -1, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno)); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0), - ErrorManager::GetInstance().ATCReportErrMessage("E10038", {"filepath"}, {input_file}); + ErrorManager::GetInstance().ATCReportErrMessage("E19015", {"filepath"}, {input_file}); return -1, "File[%s] size is 0, not valid.", input_file.c_str()); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( file_length > kMaxFileSizeLimit, ErrorManager::GetInstance().ATCReportErrMessage( - "E10039", {"filepath", "filesize", "maxlen"}, + "E19016", {"filepath", "filesize", "maxlen"}, {input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)}); return -1, "File[%s] size %lld is out of limit: %d.", input_file.c_str(), file_length, kMaxFileSizeLimit); return static_cast(file_length); @@ -219,7 +221,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: if (ret != 0) { if (errno != EEXIST) { ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); - GELOGW("Cannot create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); + GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); return ret; } } @@ -230,7 +232,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: if (ret != 0) { if (errno != EEXIST) { ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); - GELOGW("Cannot create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); + GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); return ret; } } @@ -258,16 +260,16 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch "incorrect parameter. nullptr == file || nullptr == message"); std::string real_path = RealPath(file); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), - ErrorManager::GetInstance().ATCReportErrMessage("E10036", {"filepath"}, {file}); - return false, "Get path[%s]'s real path failed", file); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), ErrorManager::GetInstance().ATCReportErrMessage( + "E19000", {"path", "errmsg"}, {file, strerror(errno)}); + return false, "Path[%s]'s realpath is empty, errmsg[%s]", file, strerror(errno)); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); std::ifstream fs(real_path.c_str(), std::ifstream::in); if (!fs.is_open()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10040", {"realpth", "protofile"}, {real_path, file}); + ErrorManager::GetInstance().ATCReportErrMessage("E19017", {"realpth", "protofile"}, {real_path, file}); GELOGE(ge::FAILED, "Fail to open proto file real path is '%s' when orginal file path is '%s'.", real_path.c_str(), file); return false; @@ -275,7 +277,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch google::protobuf::io::IstreamInputStream input(&fs); bool ret = google::protobuf::TextFormat::Parse(&input, message); - GE_IF_BOOL_EXEC(!ret, ErrorManager::GetInstance().ATCReportErrMessage("E10041", {"protofile"}, {file}); + GE_IF_BOOL_EXEC(!ret, ErrorManager::GetInstance().ATCReportErrMessage("E19018", {"protofile"}, {file}); GELOGE(ret, "Parse file[%s] through [google::protobuf::TextFormat::Parse] failed, " "please check whether the file is a valid protobuf format file.", @@ -360,14 +362,14 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const // The specified path is empty std::map args_map; if (file_path.empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {atc_param}); + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param}); GELOGW("Input parameter's value is empty."); return false; } std::string real_path = RealPath(file_path.c_str()); // Unable to get absolute path (does not exist or does not have permission to access) if (real_path.empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"path", "errmsg"}, {file_path.c_str(), strerror(errno)}); + ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, {file_path, strerror(errno)}); GELOGW("Path[%s]'s realpath is empty, errmsg[%s]", file_path.c_str(), strerror(errno)); return false; } @@ -380,16 +382,14 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( !ValidateStr(real_path, mode), - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "path"}, {atc_param, real_path}); - return false, - "Input parameter[--%s]'s value[%s] is illegal. The path can only contains 'a-z' 'A-Z' '0-9' '-' '.' '_' " - "and chinese character.", - atc_param.c_str(), real_path.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {atc_param, real_path, kPathValidReason}); + return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason); // The absolute path points to a file that is not readable if (access(real_path.c_str(), R_OK) != 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"path", "errmsg"}, {file_path.c_str(), strerror(errno)}); - GELOGW("Read path[%s] failed, errmsg[%s]", file_path.c_str(), strerror(errno)); + ErrorManager::GetInstance().ATCReportErrMessage("E19003", {"file", "errmsg"}, {file_path.c_str(), strerror(errno)}); + GELOGW("Read file[%s] failed, errmsg[%s]", file_path.c_str(), strerror(errno)); return false; } @@ -400,7 +400,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const const std::string &atc_param) { // The specified path is empty if (file_path.empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {atc_param}); + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param}); GELOGW("Input parameter's value is empty."); return false; } @@ -416,18 +416,14 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( !ValidateStr(real_path, mode), - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "path"}, {atc_param, real_path}); - return false, - "Input parameter[--%s]'s value[%s] is illegal. The path can only contains 'a-z' 'A-Z' '0-9' '-' '.' '_' " - "and chinese character.", - atc_param.c_str(), real_path.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {atc_param, real_path, kPathValidReason}); + return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason); // File is not readable or writable if (access(real_path.c_str(), W_OK | F_OK) != 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"realpath", "path", "errmsg"}, - {real_path, file_path, strerror(errno)}); - GELOGW("Write file[%s] failed, input path is %s, errmsg[%s]", real_path.c_str(), file_path.c_str(), - strerror(errno)); + ErrorManager::GetInstance().ATCReportErrMessage("E19004", {"file", "errmsg"}, {real_path, strerror(errno)}); + GELOGW("Write file[%s] failed, errmsg[%s]", real_path.c_str(), strerror(errno)); return false; } } else { @@ -445,8 +441,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const std::string prefix_path = std::string(file_path).substr(0, static_cast(path_split_pos)); // Determine whether the specified path is valid by creating the path if (CreateDirectory(prefix_path) != 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"path"}, {file_path}); - GELOGW("Can not create prefix path for path[%s].", file_path.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {file_path}); + GELOGW("Can not create directory[%s].", file_path.c_str()); return false; } } diff --git a/src/ge/engine_manager/dnnengine_manager.cc b/src/ge/engine_manager/dnnengine_manager.cc index c8843c09..9afb207f 100644 --- a/src/ge/engine_manager/dnnengine_manager.cc +++ b/src/ge/engine_manager/dnnengine_manager.cc @@ -24,6 +24,7 @@ #include "common/debug/log.h" #include "common/ge/ge_util.h" +#include "common/util/error_manager/error_manager.h" #include "framework/common/debug/ge_log.h" #include "graph/ge_context.h" #include "init/gelib.h" @@ -161,6 +162,10 @@ bool DNNEngineManager::IsEngineRegistered(const std::string &name) { return false; } +void DNNEngineManager::InitPerformanceStaistic() { checksupport_cost_.clear(); } + +const map &DNNEngineManager::GetCheckSupportCost() const { return checksupport_cost_; } + std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GE_CLI_GE_NOT_INITIALIZED, "DNNEngineManager: op_desc is nullptr"); return ""); @@ -194,15 +199,20 @@ std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { if (kernel_info_store != kernel_map.end()) { std::string unsupported_reason; // It will be replaced by engine' checksupport + uint64_t start_time = GetCurrentTimestap(); if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) { + checksupport_cost_[kernel_name] += GetCurrentTimestap() - start_time; op_desc->SetOpEngineName(it.engine); op_desc->SetOpKernelLibName(kernel_name); GELOGD("DNNEngineManager:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), it.engine.c_str(), op_desc->GetName().c_str()); return it.engine; } else { + checksupport_cost_[kernel_name] += GetCurrentTimestap() - start_time; bool is_custom_op = false; if ((ge::AttrUtils::GetBool(op_desc, kCustomOpFlag, is_custom_op)) && is_custom_op) { + ErrorManager::GetInstance().ATCReportErrMessage("E13001", {"kernelname", "optype", "opname"}, + {kernel_name, op_desc->GetType(), op_desc->GetName()}); GELOGE(FAILED, "The custom operator registered by the user does not support the logic function delivered by this " "network. Check support failed, kernel_name is %s, op type is %s, op name is %s", @@ -221,9 +231,13 @@ std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { } } for (const auto &it : unsupported_reasons) { + ErrorManager::GetInstance().ATCReportErrMessage("E13002", {"optype", "opskernel", "reason"}, + {op_desc->GetType(), it.first, it.second}); GELOGE(GE_GRAPH_ASSIGN_ENGINE_FAILED, "GetDNNEngineName:Op type %s of ops kernel %s is unsupported, reason:%s", op_desc->GetType().c_str(), it.first.c_str(), it.second.c_str()); } + ErrorManager::GetInstance().ATCReportErrMessage("E13003", {"opname", "optype"}, + {op_desc->GetName(), op_desc->GetType()}); GELOGE(GE_GRAPH_ASSIGN_ENGINE_FAILED, "Can't find any supported ops kernel and engine of %s, type is %s", op_desc->GetName().c_str(), op_desc->GetType().c_str()); return ""; @@ -384,7 +398,13 @@ Status DNNEngineManager::ReadJsonFile(const std::string &file_path, JsonHandle h return FAILED; } - ifs >> *json_file; + try { + ifs >> *json_file; + } catch (const json::exception &e) { + GELOGE(FAILED, "Read json file failed"); + ifs.close(); + return FAILED; + } ifs.close(); GELOGI("Read json file success"); return SUCCESS; diff --git a/src/ge/engine_manager/dnnengine_manager.h b/src/ge/engine_manager/dnnengine_manager.h index ab813398..15628ecf 100644 --- a/src/ge/engine_manager/dnnengine_manager.h +++ b/src/ge/engine_manager/dnnengine_manager.h @@ -63,6 +63,8 @@ class DNNEngineManager { // If can't find appropriate engine name, return "", report error string GetDNNEngineName(const OpDescPtr &op_desc); const map &GetSchedulers() const; + const map &GetCheckSupportCost() const; + void InitPerformanceStaistic(); private: DNNEngineManager(); @@ -78,6 +80,7 @@ class DNNEngineManager { std::map engines_map_; std::map engines_attrs_map_; std::map schedulers_; + std::map checksupport_cost_; bool init_flag_; }; } // namespace ge diff --git a/src/ge/executor/CMakeLists.txt b/src/ge/executor/CMakeLists.txt index cddf25b7..0cdb00e2 100755 --- a/src/ge/executor/CMakeLists.txt +++ b/src/ge/executor/CMakeLists.txt @@ -26,6 +26,7 @@ file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "ge_executor.cc" + "../common/ge/op_tiling_manager.cc" "../common/ge/plugin_manager.cc" "../common/profiling/profiling_manager.cc" "../graph/execute/graph_execute.cc" @@ -59,7 +60,6 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "../graph/load/new_model_manager/task_info/task_info.cc" "../graph/load/new_model_manager/tbe_handle_store.cc" "../graph/load/new_model_manager/zero_copy_task.cc" - "../graph/load/output/output.cc" "../graph/manager/graph_caching_allocator.cc" "../graph/manager/graph_manager_utils.cc" "../graph/manager/graph_mem_allocator.cc" diff --git a/src/ge/executor/ge_executor.cc b/src/ge/executor/ge_executor.cc index ad7ef1fe..098c57b6 100644 --- a/src/ge/executor/ge_executor.cc +++ b/src/ge/executor/ge_executor.cc @@ -452,7 +452,7 @@ Status GeExecutor::RunModel(const ge::RunModelData &input_data, ge::RunModelData // Get input and output descriptor Status GeExecutor::GetModelDescInfo(uint32_t model_id, std::vector &input_desc, - std::vector &output_desc) { + std::vector &output_desc, bool new_model_desc) { GELOGI("get model desc info begin."); if (!isInit_) { GELOGE(GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); @@ -464,8 +464,8 @@ Status GeExecutor::GetModelDescInfo(uint32_t model_id, std::vector input_formats; std::vector output_formats; - Status ret = - GraphExecutor::GetInputOutputDescInfo(model_id, input_desc_infos, output_desc_infos, input_formats, output_formats); + Status ret = GraphExecutor::GetInputOutputDescInfo(model_id, input_desc_infos, output_desc_infos, input_formats, + output_formats, new_model_desc); if (ret != domi::SUCCESS) { GELOGE(ret, "GetInputOutputDescInfo failed. ret = %u", ret); return TransferDomiErrorCode(ret); @@ -854,5 +854,4 @@ Status GeExecutor::GetAllAippInputOutputDims(uint32_t model_id, uint32_t index, GELOGI("GetAllAippInputOutputDims succ."); return SUCCESS; } - } // namespace ge diff --git a/src/ge/executor/module.mk b/src/ge/executor/module.mk index efed8854..0eb87822 100644 --- a/src/ge/executor/module.mk +++ b/src/ge/executor/module.mk @@ -4,6 +4,7 @@ local_ge_executor_src_files := \ ge_executor.cc \ ../common/profiling/profiling_manager.cc \ ../common/ge/plugin_manager.cc \ + ../common/ge/op_tiling_manager.cc \ ../graph/load/graph_loader.cc \ ../graph/execute/graph_execute.cc \ ../omm/csa_interact.cc \ @@ -44,7 +45,6 @@ local_ge_executor_src_files := \ ../graph/load/new_model_manager/task_info/end_graph_task_info.cc \ ../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc \ ../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc \ - ../graph/load/output/output.cc \ ../single_op/single_op_manager.cc \ ../single_op/single_op_model.cc \ ../single_op/single_op.cc \ @@ -53,6 +53,7 @@ local_ge_executor_src_files := \ ../single_op/task/build_task_utils.cc \ ../single_op/task/tbe_task_builder.cc \ ../single_op/task/aicpu_task_builder.cc \ + ../single_op/task/aicpu_kernel_task_builder.cc \ ../hybrid/hybrid_davinci_model_stub.cc\ local_ge_executor_c_include := \ diff --git a/src/ge/ge_inference.mk b/src/ge/ge_inference.mk index e12989c0..f18f733a 100644 --- a/src/ge/ge_inference.mk +++ b/src/ge/ge_inference.mk @@ -1,5 +1,5 @@ LOCAL_PATH := $(call my-dir) - +include $(LOCAL_PATH)/stub/Makefile COMMON_LOCAL_SRC_FILES := \ proto/fusion_model.proto \ proto/optimizer_priority.proto \ @@ -32,6 +32,7 @@ COMMON_LOCAL_SRC_FILES := \ GRAPH_MANAGER_LOCAL_SRC_FILES := \ common/ge/plugin_manager.cc\ + common/ge/op_tiling_manager.cc\ init/gelib.cc \ session/inner_session.cc \ session/session_manager.cc \ @@ -91,6 +92,7 @@ OMG_HOST_SRC_FILES := \ graph/passes/no_use_reshape_remove_pass.cc \ graph/passes/iterator_op_pass.cc \ graph/passes/atomic_addr_clean_pass.cc \ + graph/passes/mark_same_addr_pass.cc \ graph/common/omg_util.cc \ graph/common/bcast.cc \ graph/passes/dimension_compute_pass.cc \ @@ -145,6 +147,7 @@ OMG_HOST_SRC_FILES := \ graph/passes/stop_gradient_pass.cc \ graph/passes/prevent_gradient_pass.cc \ graph/passes/identity_pass.cc \ + graph/passes/ref_identity_delete_op_pass.cc \ graph/passes/placeholder_with_default_pass.cc \ graph/passes/snapshot_pass.cc \ graph/passes/guarantee_const_pass.cc \ @@ -153,7 +156,9 @@ OMG_HOST_SRC_FILES := \ graph/passes/folding_pass.cc \ graph/passes/cast_translate_pass.cc \ graph/passes/prune_pass.cc \ - graph/passes/switch_op_pass.cc \ + graph/passes/merge_to_stream_merge_pass.cc \ + graph/passes/switch_to_stream_switch_pass.cc \ + graph/passes/attach_stream_label_pass.cc \ graph/passes/multi_batch_pass.cc \ graph/passes/next_iteration_pass.cc \ graph/passes/control_trigger_pass.cc \ @@ -173,7 +178,6 @@ OMG_HOST_SRC_FILES := \ graph/passes/variable_op_pass.cc \ graph/passes/cast_remove_pass.cc \ graph/passes/transpose_transdata_pass.cc \ - graph/passes/identify_reference_pass.cc \ graph/passes/hccl_memcpy_pass.cc \ graph/passes/flow_ctrl_pass.cc \ graph/passes/link_gen_mask_nodes_pass.cc \ @@ -199,7 +203,6 @@ OME_HOST_SRC_FILES := \ graph/load/new_model_manager/tbe_handle_store.cc \ graph/load/new_model_manager/cpu_queue_schedule.cc \ graph/load/new_model_manager/zero_copy_task.cc \ - graph/load/output/output.cc \ graph/load/new_model_manager/data_dumper.cc \ graph/load/new_model_manager/task_info/task_info.cc \ graph/load/new_model_manager/task_info/event_record_task_info.cc \ @@ -224,6 +227,7 @@ OME_HOST_SRC_FILES := \ single_op/task/build_task_utils.cc \ single_op/task/tbe_task_builder.cc \ single_op/task/aicpu_task_builder.cc \ + single_op/task/aicpu_kernel_task_builder.cc \ single_op/single_op.cc \ single_op/single_op_model.cc \ single_op/stream_resource.cc \ @@ -355,6 +359,28 @@ LOCAL_LDFLAGS := -lrt -ldl include $(BUILD_HOST_SHARED_LIBRARY) +#compiler for host infer +include $(CLEAR_VARS) + +LOCAL_MODULE := stub/libge_compiler + +LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DREUSE_MEMORY=1 -O2 +LOCAL_CFLAGS += -DFMK_HOST_INFER -DFMK_SUPPORT_DUMP +ifeq ($(DEBUG), 1) +LOCAL_CFLAGS += -g -O0 +endif + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) + +LOCAL_SRC_FILES := ../../out/ge/lib64/stub/ge_ir_build.cc + + +LOCAL_SHARED_LIBRARIES := + +LOCAL_LDFLAGS := -lrt -ldl + +include $(BUILD_HOST_SHARED_LIBRARY) + #compiler for device include $(CLEAR_VARS) diff --git a/src/ge/ge_runner.mk b/src/ge/ge_runner.mk index a9cfdd82..fe19de02 100644 --- a/src/ge/ge_runner.mk +++ b/src/ge/ge_runner.mk @@ -23,6 +23,7 @@ LIBGE_LOCAL_SRC_FILES := \ common/formats/utils/formats_trans_utils.cc \ common/fp16_t.cc \ common/ge/plugin_manager.cc\ + common/ge/op_tiling_manager.cc\ common/helper/model_cache_helper.cc \ common/profiling/profiling_manager.cc \ engine_manager/dnnengine_manager.cc \ @@ -77,7 +78,6 @@ LIBGE_LOCAL_SRC_FILES := \ graph/load/new_model_manager/task_info/task_info.cc \ graph/load/new_model_manager/tbe_handle_store.cc \ graph/load/new_model_manager/zero_copy_task.cc \ - graph/load/output/output.cc \ graph/manager/graph_context.cc \ graph/manager/graph_manager.cc \ graph/manager/graph_manager_utils.cc \ @@ -99,6 +99,7 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/aicpu_constant_folding_pass.cc \ graph/passes/assert_pass.cc \ graph/passes/atomic_addr_clean_pass.cc \ + graph/passes/mark_same_addr_pass.cc \ graph/partition/dynamic_shape_partition.cc \ graph/passes/base_pass.cc \ graph/passes/cast_remove_pass.cc \ @@ -158,8 +159,8 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/get_original_format_pass.cc \ graph/passes/guarantee_const_pass.cc \ graph/passes/hccl_memcpy_pass.cc \ - graph/passes/identify_reference_pass.cc \ graph/passes/identity_pass.cc \ + graph/passes/ref_identity_delete_op_pass.cc \ graph/passes/infershape_pass.cc \ graph/passes/isolated_op_remove_pass.cc \ graph/passes/iterator_op_pass.cc \ @@ -191,7 +192,9 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/data_pass.cc \ graph/passes/switch_data_edges_bypass.cc \ graph/passes/switch_logic_remove_pass.cc \ - graph/passes/switch_op_pass.cc \ + graph/passes/merge_to_stream_merge_pass.cc \ + graph/passes/switch_to_stream_switch_pass.cc \ + graph/passes/attach_stream_label_pass.cc \ graph/passes/switch_dead_branch_elimination.cc \ graph/passes/replace_transshape_pass.cc \ graph/passes/transop_breadth_fusion_pass.cc \ @@ -230,6 +233,7 @@ LIBGE_LOCAL_SRC_FILES := \ single_op/task/op_task.cc \ single_op/task/tbe_task_builder.cc \ single_op/task/aicpu_task_builder.cc \ + single_op/task/aicpu_kernel_task_builder.cc \ hybrid/common/tensor_value.cc \ hybrid/common/npu_memory_allocator.cc \ hybrid/executor/rt_callback_manager.cc \ @@ -239,12 +243,15 @@ LIBGE_LOCAL_SRC_FILES := \ hybrid/executor/hybrid_model_executor.cc \ hybrid/executor/hybrid_model_async_executor.cc \ hybrid/executor/hybrid_execution_context.cc \ + hybrid/executor/subgraph_context.cc \ + hybrid/executor/subgraph_executor.cc \ hybrid/executor/worker/task_compile_engine.cc \ hybrid/executor/worker/shape_inference_engine.cc \ hybrid/executor/worker/execution_engine.cc \ hybrid/model/hybrid_model.cc \ hybrid/model/hybrid_model_builder.cc \ hybrid/model/node_item.cc \ + hybrid/model/graph_item.cc \ hybrid/node_executor/aicore/aicore_node_executor.cc \ hybrid/node_executor/aicore/aicore_op_task.cc \ hybrid/node_executor/aicore/aicore_task_builder.cc \ @@ -253,6 +260,9 @@ LIBGE_LOCAL_SRC_FILES := \ hybrid/node_executor/aicpu/aicpu_node_executor.cc \ hybrid/node_executor/compiledsubgraph/known_node_executor.cc \ hybrid/node_executor/hostcpu/ge_local_node_executor.cc \ + hybrid/node_executor/controlop/control_op_executor.cc \ + hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc \ + hybrid/node_executor/hccl/hccl_node_executor.cc \ hybrid/node_executor/node_executor.cc \ hybrid/node_executor/task_context.cc \ hybrid/hybrid_davinci_model.cc \ @@ -338,6 +348,28 @@ LOCAL_SHARED_LIBRARIES += \ include $(BUILD_HOST_SHARED_LIBRARY) +#compiler for GeRunner +include $(CLEAR_VARS) + +LOCAL_MODULE := stub/libge_runner + +LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DREUSE_MEMORY=1 -O2 +LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -DDAVINCI_SUPPORT_PROFILING -DDAVINCI_CLOUD +ifeq ($(DEBUG), 1) +LOCAL_CFLAGS += -g -O0 +endif + + +LOCAL_C_INCLUDES := $(RUNNER_LOCAL_C_INCLUDES) + +LOCAL_SRC_FILES := ../../out/ge/lib64/stub/ge_api.cc + + +LOCAL_SHARED_LIBRARIES := + +LOCAL_LDFLAGS := -lrt -ldl + +include $(BUILD_HOST_SHARED_LIBRARY) # add engine_conf.json to host include $(CLEAR_VARS) @@ -407,6 +439,7 @@ LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -DDAVINCI_SUPPORT_PROFILING -DDAVINCI_CLOUD LOCAL_CFLAGS += -g -O0 LOCAL_C_INCLUDES := $(RUNNER_LOCAL_C_INCLUDES) + LOCAL_SRC_FILES := $(LIBGE_LOCAL_SRC_FILES) LOCAL_SRC_FILES += $(LIBCLIENT_LOCAL_SRC_FILES) diff --git a/src/ge/ge_runtime/model_runner.cc b/src/ge/ge_runtime/model_runner.cc index 59952e39..b6e43dd5 100644 --- a/src/ge/ge_runtime/model_runner.cc +++ b/src/ge/ge_runtime/model_runner.cc @@ -49,6 +49,15 @@ bool ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint return true; } +bool ModelRunner::LoadModelComplete(uint32_t model_id) { + auto model_iter = runtime_models_.find(model_id); + if (model_iter == runtime_models_.end()) { + GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); + return false; + } + return model_iter->second->LoadComplete(); +} + const std::vector &ModelRunner::GetTaskIdList(uint32_t model_id) const { auto model_iter = runtime_models_.find(model_id); if (model_iter == runtime_models_.end()) { @@ -60,6 +69,28 @@ const std::vector &ModelRunner::GetTaskIdList(uint32_t model_id) const return model_iter->second->GetTaskIdList(); } +const std::vector &ModelRunner::GetStreamIdList(uint32_t model_id) const { + auto model_iter = runtime_models_.find(model_id); + if (model_iter == runtime_models_.end()) { + GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); + static const std::vector empty_ret; + return empty_ret; + } + + return model_iter->second->GetStreamIdList(); +} + +const std::map> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const { + auto model_iter = runtime_models_.find(model_id); + if (model_iter == runtime_models_.end()) { + GELOGW("Model id %u not found.", model_id); + static const std::map> empty_ret; + return empty_ret; + } + + return model_iter->second->GetRuntimeInfoMap(); +} + bool ModelRunner::UnloadModel(uint32_t model_id) { auto iter = runtime_models_.find(model_id); if (iter != runtime_models_.end()) { diff --git a/src/ge/ge_runtime/output.cc b/src/ge/ge_runtime/output.cc index 90c33bb4..5153f688 100644 --- a/src/ge/ge_runtime/output.cc +++ b/src/ge/ge_runtime/output.cc @@ -76,7 +76,7 @@ bool Output::CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_inde DataBuffer data_buf = rslt->blobs[data_begin + data_count]; bool ret = SetDataBuf(data_buf, data_begin, data_count, i, support_mem_share); if (!ret) { - GELOGE(FAILED, "Copy data to host failed. index: %lu, addr: %p", i, v_input_data_addr_[i]); + GELOGE(FAILED, "Copy data to host error. index: %lu, addr: %p", i, v_input_data_addr_[i]); return ret; } data_index = data_begin + data_count; diff --git a/src/ge/ge_runtime/runtime_model.cc b/src/ge/ge_runtime/runtime_model.cc index c89ced91..bdf8f2a6 100644 --- a/src/ge/ge_runtime/runtime_model.cc +++ b/src/ge/ge_runtime/runtime_model.cc @@ -28,7 +28,6 @@ namespace ge { namespace model_runner { - RuntimeModel::~RuntimeModel() { GELOGI("RuntimeModel destructor start"); @@ -116,23 +115,34 @@ bool RuntimeModel::InitEvent(uint32_t event_num) { return true; } -bool RuntimeModel::InitLabel(uint32_t batch_num) { - GELOGI("batch number:%u.", batch_num); - for (uint32_t i = 0; (batch_num != 0 && i <= batch_num); ++i) { - rtLabel_t rt_lLabel = nullptr; - rtError_t rt_ret = rtLabelCreate(&rt_lLabel); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, i; %u; ret: 0x%X", i, rt_ret); - return false; +bool RuntimeModel::InitLabel(std::shared_ptr &davinci_model) { + GELOGI("batch number:%u.", davinci_model->GetBatchNum()); + label_list_.resize(davinci_model->GetBatchNum()); + for (auto &task_info : davinci_model->GetTaskInfoList()) { + if (task_info == nullptr) { + GELOGE(PARAM_INVALID, "task_info is null."); + continue; + } + + if (task_info->type() != TaskInfoType::LABEL_SET) { + continue; } + auto label_set_task_info = std::static_pointer_cast(task_info); - if (rt_lLabel == nullptr) { - GELOGE(RT_FAILED, "rtLabel is nullptr!"); + if (label_set_task_info->stream_id() >= stream_list_.size()) { + GELOGE(PARAM_INVALID, "Invalid stream id."); return false; } - label_list_.emplace_back(rt_lLabel); + rtLabel_t rt_label = nullptr; + rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret); + return false; + } + label_list_[label_set_task_info->label_id()] = rt_label; } + return true; } @@ -164,7 +174,7 @@ bool RuntimeModel::InitResource(std::shared_ptr &davinci_model) { return false; } - if (!InitLabel(davinci_model->GetBatchNum())) { + if (!InitLabel(davinci_model)) { return false; } @@ -209,20 +219,41 @@ bool RuntimeModel::LoadTask() { return false; } task_id_list_.push_back(task_id); + stream_id_list_.push_back(stream_id); + if (task->Args() != nullptr) { + std::shared_ptr runtime_tuple = nullptr; + GE_MAKE_SHARED(runtime_tuple = std::make_shared(task_id, stream_id, task->Args()), return false); + auto emplace_ret = runtime_info_map_.emplace(task->task_name(), runtime_tuple); + if (!emplace_ret.second) { + GELOGW("Task name exist:%s", task->task_name().c_str()); + } + } } if (task_list_.empty()) { GELOGE(FAILED, "Task list is empty"); return false; } - GELOGI("Distribute task succ."); - auto rt_ret = rtModelLoadComplete(rt_model_handle_); + GELOGI("LoadTask succ."); + return true; +} + +bool RuntimeModel::LoadComplete() { + uint32_t task_id = 0; + uint32_t stream_id = 0; + auto rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rtModelGetTaskId failed, ret:0x%X", rt_ret); + return RT_FAILED; + } + task_id_list_.push_back(task_id); + stream_id_list_.push_back(stream_id); + + rt_ret = rtModelLoadComplete(rt_model_handle_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api rtModelLoadComplete failed, ret: 0x%X.", rt_ret); return false; } - - GELOGI("LoadTask succ."); return true; } @@ -270,10 +301,14 @@ bool RuntimeModel::Run() { return false; } - GELOGI("Run rtModelExecute success"); + GELOGI("Run rtModelExecute success, ret = 0x%X", ret); ret = rtStreamSynchronize(rt_model_stream_); if (ret != RT_ERROR_NONE) { + if (ret == RT_ERROR_END_OF_SEQUENCE) { + GELOGI("Model stream RT_ERROR_END_OF_SEQUENCE signal received, ret = 0x%X", ret); + return true; + } GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret); return false; } @@ -433,7 +468,7 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr &davinci_model } if (constant->output_tensors[0].size < constant->weight_data.size()) { - GELOGE(PARAM_INVALID, "Output size:%u is less than weight data size:%zu", constant->output_tensors[0].size, + GELOGE(PARAM_INVALID, "Output size:%u less than weight data size:%zu", constant->output_tensors[0].size, constant->weight_data.size()); return false; } @@ -448,11 +483,8 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr &davinci_model /// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero /// and that of unknown shape is zero too. /// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. - int64_t elem_num = constant->weight_tensors[0].GetShapeSize(); - if (elem_num == 0 && constant->weight_tensors[0].size == 0) { - elem_num = 1; - } - + int64_t elem_num = + (constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize(); if (constant->weight_data.size() < sizeof(uint64_t)) { GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); return false; @@ -495,5 +527,6 @@ void RuntimeModel::CreateOutput(uint32_t index, const OpInfo &op_info, InputOutp const std::vector &RuntimeModel::GetTaskIdList() const { return task_id_list_; } +const std::vector &RuntimeModel::GetStreamIdList() const { return stream_id_list_; } } // namespace model_runner } // namespace ge diff --git a/src/ge/ge_runtime/runtime_model.h b/src/ge/ge_runtime/runtime_model.h index e8ff4057..67535296 100644 --- a/src/ge/ge_runtime/runtime_model.h +++ b/src/ge/ge_runtime/runtime_model.h @@ -27,7 +27,7 @@ namespace ge { namespace model_runner { - +using RuntimeInfo = std::tuple; class Task; class RuntimeModel { public: @@ -35,7 +35,10 @@ class RuntimeModel { ~RuntimeModel(); bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr &davinci_model); + bool LoadComplete(); const std::vector &GetTaskIdList() const; + const std::vector &GetStreamIdList() const; + const std::map> &GetRuntimeInfoMap() const { return runtime_info_map_; } bool Run(); bool CopyInputData(const InputData &input_data); bool GetInputOutputDescInfo(bool zero_copy, std::vector *input_desc, @@ -48,7 +51,7 @@ class RuntimeModel { bool LoadTask(); bool InitStream(std::shared_ptr &davinci_model); bool InitEvent(uint32_t event_num); - bool InitLabel(uint32_t batch_num); + bool InitLabel(std::shared_ptr &davinci_model); bool InitDataInfo(std::shared_ptr &davinci_model); bool InitOutputInfo(std::shared_ptr &davinci_model); bool InitConstantInfo(std::shared_ptr &davinci_model); @@ -77,6 +80,8 @@ class RuntimeModel { std::vector> constant_info_list_{}; std::vector task_id_list_{}; + std::vector stream_id_list_{}; + std::map> runtime_info_map_; }; } // namespace model_runner diff --git a/src/ge/ge_runtime/task/aicpu_task.cc b/src/ge/ge_runtime/task/aicpu_task.cc index 4cb71866..9b126ec0 100644 --- a/src/ge/ge_runtime/task/aicpu_task.cc +++ b/src/ge/ge_runtime/task/aicpu_task.cc @@ -85,11 +85,15 @@ bool AicpuTask::Distribute() { return false; } - GELOGI("Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s.", args_size, - io_addrs_num, task_info_->so_name().data(), task_info_->kernel_name().data()); - rt_ret = rtCpuKernelLaunch(reinterpret_cast(task_info_->so_name().data()), - reinterpret_cast(task_info_->kernel_name().data()), 1, args_, args_size, - nullptr, stream_); + input_output_addr_ = reinterpret_cast(reinterpret_cast(args_) + io_addr_offset); + + auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; + GELOGI( + "Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s, dump_flag = %d.", + args_size, io_addrs_num, task_info_->so_name().data(), task_info_->kernel_name().data(), dump_flag); + rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast(task_info_->so_name().data()), + reinterpret_cast(task_info_->kernel_name().data()), 1, args_, + args_size, nullptr, stream_, dump_flag); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return false; diff --git a/src/ge/ge_runtime/task/aicpu_task.h b/src/ge/ge_runtime/task/aicpu_task.h index f5cdc617..cc21af8a 100644 --- a/src/ge/ge_runtime/task/aicpu_task.h +++ b/src/ge/ge_runtime/task/aicpu_task.h @@ -18,6 +18,7 @@ #define GE_GE_RUNTIME_TASK_AICPU_TASK_H_ #include +#include #include "ge_runtime/task/task.h" namespace ge { @@ -30,12 +31,17 @@ class AicpuTask : public TaskRepeater { bool Distribute() override; + void *Args() override { return input_output_addr_; } + + std::string task_name() const override { return task_info_->op_name(); } + private: static void ReleaseRtMem(void **ptr) noexcept; std::shared_ptr task_info_; void *stream_; void *args_; + void *input_output_addr_; }; } // namespace model_runner } // namespace ge diff --git a/src/ge/ge_runtime/task/hccl_task.cc b/src/ge/ge_runtime/task/hccl_task.cc index 54ae3bf3..3d5f8504 100644 --- a/src/ge/ge_runtime/task/hccl_task.cc +++ b/src/ge/ge_runtime/task/hccl_task.cc @@ -115,7 +115,6 @@ bool HcclTask::Distribute() { rt_ret = rtModelBindStream(rt_model_handle_, stream, RT_HEAD_STREAM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - (void)rtStreamDestroy(stream); return false; } diff --git a/src/ge/ge_runtime/task/label_goto_task.cc b/src/ge/ge_runtime/task/label_goto_task.cc new file mode 100644 index 00000000..d357accb --- /dev/null +++ b/src/ge/ge_runtime/task/label_goto_task.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_runtime/task/label_goto_task.h" +#include "ge_runtime/task/task_factory.h" + +namespace ge { +namespace model_runner { +LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + label_(nullptr) { + if (task_info_ == nullptr) { + GELOGW("task_info_ is null!"); + return; + } + auto stream_list = model_context.stream_list(); + auto label_list = model_context.label_list(); + uint32_t stream_id = task_info->stream_id(); + uint32_t label_id = task_info->label_id(); + GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); + GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); + if (stream_id >= stream_list.size() || label_id >= label_list.size()) { + GELOGW("Stream/Label id invalid."); + return; + } + stream_ = stream_list[stream_id]; + label_ = label_list[label_id]; +} + +LabelGotoTask::~LabelGotoTask() {} + +bool LabelGotoTask::Distribute() { + GELOGI("LabelGotoTask Distribute start."); + if (stream_ == nullptr) { + GELOGE(PARAM_INVALID, "stream is null!"); + return false; + } + if (label_ == nullptr) { + GELOGE(PARAM_INVALID, "label is null!"); + return false; + } + rtError_t rt_ret = rtLabelGotoEx(label_, stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } + + GELOGI("DistributeTask end."); + return true; +} + +REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); + +} // namespace model_runner +} // namespace ge diff --git a/src/ge/ge_runtime/task/label_goto_task.h b/src/ge/ge_runtime/task/label_goto_task.h new file mode 100644 index 00000000..4fd6d1bc --- /dev/null +++ b/src/ge/ge_runtime/task/label_goto_task.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ +#define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ + +#include +#include "ge_runtime/task/task.h" + +namespace ge { +namespace model_runner { +class LabelGotoTask : public TaskRepeater { + public: + LabelGotoTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~LabelGotoTask() override; + + bool Distribute() override; + + private: + std::shared_ptr task_info_; + void *stream_; + void *label_; +}; +} // namespace model_runner +} // namespace ge + +#endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ diff --git a/src/ge/ge_runtime/task/label_set_task.cc b/src/ge/ge_runtime/task/label_set_task.cc new file mode 100644 index 00000000..3ab5802c --- /dev/null +++ b/src/ge/ge_runtime/task/label_set_task.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_runtime/task/label_set_task.h" +#include "ge_runtime/task/task_factory.h" + +namespace ge { +namespace model_runner { +LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + label_(nullptr) { + if (task_info_ == nullptr) { + GELOGW("task_info_ is null!"); + return; + } + auto stream_list = model_context.stream_list(); + auto label_list = model_context.label_list(); + uint32_t stream_id = task_info->stream_id(); + uint32_t label_id = task_info->label_id(); + GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); + GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); + if (stream_id >= stream_list.size() || label_id >= label_list.size()) { + GELOGW("Stream/Label id invalid."); + return; + } + stream_ = stream_list[stream_id]; + label_ = label_list[label_id]; +} + +LabelSetTask::~LabelSetTask() {} + +bool LabelSetTask::Distribute() { + GELOGI("LabelSetTask Distribute start."); + if (stream_ == nullptr) { + GELOGE(PARAM_INVALID, "stream is null!"); + return false; + } + if (label_ == nullptr) { + GELOGE(PARAM_INVALID, "label is null!"); + return false; + } + rtError_t rt_ret = rtLabelSet(label_, stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } + + GELOGI("DistributeTask end."); + return true; +} + +REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo); + +} // namespace model_runner +} // namespace ge diff --git a/src/ge/ge_runtime/task/label_set_task.h b/src/ge/ge_runtime/task/label_set_task.h new file mode 100644 index 00000000..70bf1584 --- /dev/null +++ b/src/ge/ge_runtime/task/label_set_task.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ +#define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ + +#include +#include "ge_runtime/task/task.h" + +namespace ge { +namespace model_runner { +class LabelSetTask : public TaskRepeater { + public: + LabelSetTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~LabelSetTask() override; + + bool Distribute() override; + + private: + std::shared_ptr task_info_; + void *stream_; + void *label_; +}; +} // namespace model_runner +} // namespace ge + +#endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ diff --git a/src/ge/ge_runtime/task/label_switch_task.cc b/src/ge/ge_runtime/task/label_switch_task.cc new file mode 100644 index 00000000..a3c2d41a --- /dev/null +++ b/src/ge/ge_runtime/task/label_switch_task.cc @@ -0,0 +1,131 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_runtime/task/label_switch_task.h" +#include "ge_runtime/task/task_factory.h" + +namespace ge { +namespace model_runner { +LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, + const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + all_label_resource_(), + label_info_(nullptr) { + if (task_info_ == nullptr) { + GELOGW("task_info_ is null!"); + return; + } + + all_label_resource_ = model_context.label_list(); + auto stream_list = model_context.stream_list(); + uint32_t stream_id = task_info->stream_id(); + GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); + if (stream_id >= stream_list.size()) { + GELOGW("Stream id invalid."); + return; + } + stream_ = stream_list[stream_id]; +} + +LabelSwitchTask::~LabelSwitchTask() { + if (label_info_ != nullptr) { + rtError_t rt_ret = rtFree(label_info_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); + } + label_info_ = nullptr; + } +} + +bool LabelSwitchTask::Distribute() { + GELOGI("LabelSwitchTask Distribute start."); + if (!CheckParamValid()) { + return false; + } + + const std::vector &label_index_list = task_info_->label_list(); + std::vector label_list(task_info_->label_size(), nullptr); + + for (size_t i = 0; i < task_info_->label_size(); ++i) { + uint32_t label_index = label_index_list[i]; + if (label_index >= all_label_resource_.size()) { + GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, + all_label_resource_.size()); + return false; + } + label_list[i] = all_label_resource_[label_index]; + GELOGI("Case %zu: label id %zu.", i, label_index); + } + + uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); + rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } + + rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } + + rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } + + GELOGI("DistributeTask end."); + return true; +} + +bool LabelSwitchTask::CheckParamValid() { + if (stream_ == nullptr) { + GELOGE(PARAM_INVALID, "stream is null!"); + return false; + } + + if (task_info_->label_list().empty()) { + GELOGE(PARAM_INVALID, "label_list is empty."); + return false; + } + + if (task_info_->label_size() != task_info_->label_list().size()) { + GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(), + task_info_->label_size()); + return false; + } + + if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) { + GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size()); + return false; + } + + if (label_info_ != nullptr) { + GELOGE(PARAM_INVALID, "label_info_ has dirty data."); + return false; + } + + return true; +} + +REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); + +} // namespace model_runner +} // namespace ge diff --git a/src/ge/ge_runtime/task/label_switch_task.h b/src/ge/ge_runtime/task/label_switch_task.h new file mode 100644 index 00000000..463faa31 --- /dev/null +++ b/src/ge/ge_runtime/task/label_switch_task.h @@ -0,0 +1,44 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ +#define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ + +#include +#include "ge_runtime/task/task.h" + +namespace ge { +namespace model_runner { +class LabelSwitchTask : public TaskRepeater { + public: + LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~LabelSwitchTask() override; + + bool Distribute() override; + + private: + bool CheckParamValid(); + + std::shared_ptr task_info_; + void *stream_; + std::vector all_label_resource_; + void *label_info_; +}; +} // namespace model_runner +} // namespace ge + +#endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ diff --git a/src/ge/ge_runtime/task/stream_switch_task.cc b/src/ge/ge_runtime/task/stream_switch_task.cc index 91141139..2adcb4bd 100644 --- a/src/ge/ge_runtime/task/stream_switch_task.cc +++ b/src/ge/ge_runtime/task/stream_switch_task.cc @@ -51,7 +51,7 @@ bool StreamSwitchTask::Distribute() { } if (static_cast(task_info_->true_stream_id()) >= stream_list_.size()) { - GELOGE(PARAM_INVALID, "true_stream_id %ld must be less than stream_list_ size %zu!", task_info_->true_stream_id(), + GELOGE(PARAM_INVALID, "true_stream_id %ld must less than stream_list_ size %zu!", task_info_->true_stream_id(), stream_list_.size()); return false; } diff --git a/src/ge/ge_runtime/task/task.h b/src/ge/ge_runtime/task/task.h index 7c748a7d..6c4df248 100644 --- a/src/ge/ge_runtime/task/task.h +++ b/src/ge/ge_runtime/task/task.h @@ -18,7 +18,9 @@ #define GE_GE_RUNTIME_TASK_TASK_H_ #include +#include #include +#include #include "runtime/rt_model.h" #include "ge_runtime/model_context.h" #include "ge_runtime/task_info.h" @@ -32,6 +34,10 @@ class Task { virtual ~Task() {} virtual bool Distribute() = 0; + + virtual void *Args() { return nullptr; } + + virtual std::string task_name() const { return ""; } }; template diff --git a/src/ge/ge_runtime/task/tbe_task.cc b/src/ge/ge_runtime/task/tbe_task.cc index 8a3c36a4..e7025ae8 100644 --- a/src/ge/ge_runtime/task/tbe_task.cc +++ b/src/ge/ge_runtime/task/tbe_task.cc @@ -95,15 +95,14 @@ bool TbeTask::Distribute() { return false; } - GELOGI("InitTbeTask end."); GELOGI("DistributeTbeTask start."); - rt_ret = rtKernelLaunch(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_); + auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; + rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_, dump_flag); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api rtKernelLaunch failed, ret: 0x%X", rt_ret); return false; } - - GELOGI("DistributeTbeTask end."); + GELOGI("[DataDump] task name:%s, dump_flag:%d", task_info_->op_name().c_str(), dump_flag); return true; } diff --git a/src/ge/ge_runtime/task/tbe_task.h b/src/ge/ge_runtime/task/tbe_task.h index 994ba5e2..a8ce6268 100644 --- a/src/ge/ge_runtime/task/tbe_task.h +++ b/src/ge/ge_runtime/task/tbe_task.h @@ -30,6 +30,10 @@ class TbeTask : public TaskRepeater { bool Distribute() override; + void *Args() override { return args_; } + + std::string task_name() const override { return task_info_->op_name(); } + private: std::shared_ptr task_info_; void *stream_; diff --git a/src/ge/ge_train.mk b/src/ge/ge_train.mk deleted file mode 100644 index 767ce86b..00000000 --- a/src/ge/ge_train.mk +++ /dev/null @@ -1,333 +0,0 @@ -LOCAL_PATH := $(call my-dir) - -COMMON_LOCAL_SRC_FILES := \ - proto/fusion_model.proto \ - proto/optimizer_priority.proto \ - session/inner_session.cc \ - session/session_manager.cc \ - common/ge/plugin_manager.cc\ - common/fp16_t.cc \ - common/formats/utils/formats_trans_utils.cc \ - common/formats/format_transfers/datatype_transfer.cc \ - common/formats/format_transfers/format_transfer_transpose.cc \ - common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc \ - common/formats/format_transfers/format_transfer_fractal_z.cc \ - common/formats/format_transfers/format_transfer_fractal_nz.cc \ - common/formats/format_transfers/format_transfer_fractal_zz.cc \ - common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc \ - common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc \ - common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc \ - common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc \ - common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc \ - common/formats/format_transfers/format_transfer_fracz_nchw.cc \ - common/formats/format_transfers/format_transfer_fracz_nhwc.cc \ - common/formats/format_transfers/format_transfer_fracz_hwcn.cc \ - common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc \ - common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc \ - common/formats/formats.cc \ - init/gelib.cc \ - engine_manager/dnnengine_manager.cc \ - opskernel_manager/ops_kernel_manager.cc \ - graph/manager/graph_manager.cc \ - graph/manager/graph_manager_utils.cc \ - graph/manager/graph_context.cc \ - graph/preprocess/graph_preprocess.cc \ - graph/preprocess/multi_batch_copy_graph.cc \ - graph/execute/graph_execute.cc \ - graph/load/graph_loader.cc \ - graph/optimize/graph_optimize.cc \ - graph/passes/folding_pass.cc \ - graph/optimize/summary_optimize.cc \ - graph/build/graph_builder.cc \ - graph/partition/engine_place.cc \ - graph/partition/graph_partition.cc \ - graph/partition/dynamic_shape_partition.cc \ - generator/ge_generator.cc \ - generator/generator_api.cc \ - common/profiling/profiling_manager.cc \ - ge_local_engine/engine/host_cpu_engine.cc \ - common/helper/model_cache_helper.cc \ - -OMG_HOST_SRC_FILES := \ - model/ge_model.cc \ - model/ge_root_model.cc \ - graph/common/transop_util.cc \ - graph/manager/graph_var_manager.cc \ - graph/manager/trans_var_data_utils.cc \ - omm/csa_interact.cc \ - graph/passes/pass_manager.cc \ - graph/passes/pass_utils.cc \ - graph/passes/base_pass.cc \ - graph/passes/resource_pair_add_control_pass.cc \ - graph/passes/resource_pair_remove_control_pass.cc \ - graph/passes/constant_folding_pass.cc \ - graph/passes/aicpu_constant_folding_pass.cc \ - graph/passes/reshape_remove_pass.cc \ - graph/passes/reshape_recovery_pass.cc \ - graph/passes/transop_breadth_fusion_pass.cc \ - graph/passes/transop_depth_fusion_pass.cc \ - graph/passes/same_transdata_breadth_fusion_pass.cc \ - graph/passes/transop_without_reshape_fusion_pass.cc \ - graph/passes/compile_nodes_pass.cc \ - graph/passes/transop_nearby_allreduce_fusion_pass.cc \ - graph/passes/variable_prepare_op_pass.cc \ - graph/passes/variable_ref_delete_op_pass.cc \ - graph/passes/variable_ref_useless_control_out_delete_pass.cc \ - graph/passes/variable_op_pass.cc \ - graph/passes/cast_remove_pass.cc \ - graph/passes/replace_transshape_pass.cc \ - graph/passes/transpose_transdata_pass.cc \ - graph/passes/identify_reference_pass.cc \ - graph/passes/variable_format_pass.cc \ - graph/passes/subgraph_pass.cc \ - graph/passes/data_pass.cc \ - graph/passes/net_output_pass.cc \ - graph/passes/constant_fuse_same_pass.cc \ - graph/passes/print_op_pass.cc \ - graph/passes/no_use_reshape_remove_pass.cc \ - graph/passes/iterator_op_pass.cc \ - graph/passes/atomic_addr_clean_pass.cc \ - graph/optimize/optimizer/allreduce_fusion_pass.cc \ - graph/common/omg_util.cc \ - graph/common/bcast.cc \ - graph/passes/dimension_compute_pass.cc \ - graph/passes/dimension_adjust_pass.cc \ - graph/passes/get_original_format_pass.cc \ - graph/passes/shape_operate_op_remove_pass.cc \ - graph/passes/unused_op_remove_pass.cc \ - graph/passes/assert_pass.cc \ - graph/passes/dropout_pass.cc \ - graph/passes/infershape_pass.cc \ - graph/passes/unused_const_pass.cc \ - graph/passes/isolated_op_remove_pass.cc \ - graph/passes/permute_pass.cc \ - graph/passes/ctrl_edge_transfer_pass.cc \ - host_kernels/broadcast_gradient_args_kernel.cc \ - host_kernels/greater_kernel.cc \ - host_kernels/gather_v2_kernel.cc \ - host_kernels/maximum_kernel.cc \ - host_kernels/floormod_kernel.cc \ - host_kernels/floordiv_kernel.cc \ - host_kernels/range_kernel.cc \ - host_kernels/shape_kernel.cc \ - host_kernels/size_kernel.cc \ - host_kernels/shape_n_kernel.cc \ - host_kernels/rank_kernel.cc \ - host_kernels/broadcast_args_kernel.cc \ - host_kernels/fill_kernel.cc \ - host_kernels/empty_kernel.cc \ - host_kernels/expanddims_kernel.cc \ - host_kernels/reshape_kernel.cc \ - host_kernels/squeeze_kernel.cc \ - host_kernels/kernel_utils.cc \ - host_kernels/cast_kernel.cc \ - host_kernels/transdata_kernel.cc \ - host_kernels/transpose_kernel.cc \ - host_kernels/permute_kernel.cc \ - host_kernels/pack_kernel.cc \ - host_kernels/concat_v2_kernel.cc \ - host_kernels/concat_offset_kernel.cc \ - host_kernels/strided_slice_kernel.cc \ - host_kernels/ssd_prior_box_kernel.cc \ - host_kernels/add_kernel.cc \ - host_kernels/unpack_kernel.cc \ - host_kernels/sub_kernel.cc \ - host_kernels/mul_kernel.cc \ - host_kernels/reduce_prod_kernel.cc \ - host_kernels/rsqrt_kernel.cc \ - host_kernels/slice_kernel.cc \ - host_kernels/slice_d_kernel.cc \ - host_kernels/dynamic_stitch_kernel.cc \ - graph/passes/stop_gradient_pass.cc \ - graph/passes/prevent_gradient_pass.cc \ - graph/passes/identity_pass.cc \ - graph/passes/placeholder_with_default_pass.cc \ - graph/passes/snapshot_pass.cc \ - graph/passes/guarantee_const_pass.cc \ - graph/passes/var_is_initialized_op_pass.cc \ - graph/passes/parallel_concat_start_op_pass.cc \ - graph/passes/cast_translate_pass.cc \ - graph/passes/addn_pass.cc \ - graph/passes/common_subexpression_elimination_pass.cc \ - graph/passes/transop_symmetry_elimination_pass.cc \ - graph/passes/save_pass.cc \ - graph/passes/switch_dead_branch_elimination.cc \ - graph/passes/merge_pass.cc \ - graph/passes/prune_pass.cc \ - graph/passes/flow_ctrl_pass.cc \ - graph/passes/control_trigger_pass.cc \ - graph/passes/switch_data_edges_bypass.cc \ - graph/passes/switch_op_pass.cc \ - graph/passes/multi_batch_pass.cc \ - graph/passes/switch_logic_remove_pass.cc \ - graph/passes/next_iteration_pass.cc \ - graph/passes/cond_pass.cc \ - graph/passes/cond_remove_pass.cc \ - graph/passes/for_pass.cc \ - graph/passes/enter_pass.cc \ - graph/passes/hccl_memcpy_pass.cc \ - graph/passes/link_gen_mask_nodes_pass.cc \ - graph/passes/replace_with_empty_const_pass.cc \ - graph/passes/hccl_group_pass.cc \ - -OME_SRC_FILES := \ - graph/manager/graph_mem_allocator.cc \ - graph/manager/graph_caching_allocator.cc \ - graph/manager/model_manager/event_manager.cc \ - graph/manager/util/debug.cc \ - graph/manager/util/rt_context_util.cc \ - graph/manager/util/variable_accelerate_ctrl.cc \ - graph/manager/util/hcom_util.cc \ - graph/load/new_model_manager/model_manager.cc \ - graph/load/new_model_manager/data_inputer.cc \ - graph/load/new_model_manager/davinci_model.cc \ - graph/load/new_model_manager/davinci_model_parser.cc \ - graph/load/new_model_manager/model_utils.cc \ - graph/load/new_model_manager/tbe_handle_store.cc \ - graph/load/new_model_manager/cpu_queue_schedule.cc \ - graph/load/new_model_manager/zero_copy_task.cc \ - graph/load/output/output.cc \ - graph/load/new_model_manager/data_dumper.cc \ - graph/load/new_model_manager/task_info/task_info.cc \ - graph/load/new_model_manager/task_info/event_record_task_info.cc \ - graph/load/new_model_manager/task_info/event_wait_task_info.cc \ - graph/load/new_model_manager/task_info/fusion_start_task_info.cc \ - graph/load/new_model_manager/task_info/fusion_stop_task_info.cc \ - graph/load/new_model_manager/task_info/hccl_task_info.cc \ - graph/load/new_model_manager/task_info/kernel_ex_task_info.cc \ - graph/load/new_model_manager/task_info/kernel_task_info.cc \ - graph/load/new_model_manager/task_info/label_set_task_info.cc \ - graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc \ - graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc \ - graph/load/new_model_manager/task_info/memcpy_async_task_info.cc \ - graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc \ - graph/load/new_model_manager/task_info/profiler_trace_task_info.cc \ - graph/load/new_model_manager/task_info/stream_active_task_info.cc \ - graph/load/new_model_manager/task_info/stream_switch_task_info.cc \ - graph/load/new_model_manager/task_info/stream_switchn_task_info.cc \ - graph/load/new_model_manager/task_info/end_graph_task_info.cc \ - graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc \ - graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc \ - single_op/task/op_task.cc \ - single_op/task/build_task_utils.cc \ - single_op/task/tbe_task_builder.cc \ - single_op/task/aicpu_task_builder.cc \ - single_op/single_op.cc \ - single_op/single_op_model.cc \ - single_op/stream_resource.cc \ - single_op/single_op_manager.cc \ - hybrid/hybrid_davinci_model_stub.cc \ - - -COMMON_LOCAL_C_INCLUDES := \ - proto/om.proto \ - proto/task.proto \ - proto/insert_op.proto \ - proto/ge_ir.proto \ - proto/fwk_adapter.proto \ - proto/op_mapping_info.proto \ - proto/tensorflow/attr_value.proto \ - proto/tensorflow/function.proto \ - proto/tensorflow/graph.proto \ - proto/tensorflow/node_def.proto \ - proto/tensorflow/op_def.proto \ - proto/tensorflow/resource_handle.proto \ - proto/tensorflow/tensor.proto \ - proto/tensorflow/tensor_shape.proto \ - proto/tensorflow/types.proto \ - proto/tensorflow/versions.proto \ - $(LOCAL_PATH) ./ \ - $(TOPDIR)inc \ - $(TOPDIR)inc/external \ - $(TOPDIR)inc/external/graph \ - $(TOPDIR)inc/framework \ - $(TOPDIR)inc/framework/common \ - $(TOPDIR)inc/runtime \ - $(TOPDIR)libc_sec/include \ - $(TOPDIR)ops/built-in/op_proto/inc \ - third_party/json/include \ - third_party/protobuf/include \ - third_party/opencv/include \ - -NEW_OMG_HOST_SRC_FILES := \ - graph/preprocess/insert_op/util_insert_aipp_op.cc \ - graph/preprocess/insert_op/ge_aipp_op.cc \ - graph/build/model_builder.cc \ - graph/build/task_generator.cc \ - graph/build/stream_allocator.cc \ - graph/build/logical_stream_allocator.cc \ - graph/build/stream_graph_optimizer.cc \ - graph/build/run_context.cc \ - graph/build/label_allocator.cc \ - graph/label/label_maker.cc \ - graph/label/if_label_maker.cc \ - graph/label/case_label_maker.cc \ - graph/label/while_label_maker.cc \ - graph/label/partitioned_call_label_maker.cc \ - - - -#compiler for host train -include $(CLEAR_VARS) - -LOCAL_MODULE := libge_train - -LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DREUSE_MEMORY=1 -O2 -LOCAL_CFLAGS += -DDAVINCI_CLOUD -DDAVINCI_TRAIN -DFMK_SUPPORT_DUMP -DDAVINCI_SUPPORT_PROFILING -LOCAL_CFLAGS += -DFMK_SUPPORT_DEBUG -ifeq ($(DEBUG), 1) -LOCAL_CFLAGS += -g -O0 -endif - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) - -LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) -LOCAL_SRC_FILES += $(OMG_HOST_SRC_FILES) -LOCAL_SRC_FILES += $(OME_SRC_FILES) -LOCAL_SRC_FILES += $(NEW_OMG_HOST_SRC_FILES) - -LOCAL_STATIC_LIBRARIES := libge_memory \ - -LOCAL_SHARED_LIBRARIES := \ - libc_sec \ - libprotobuf \ - libslog \ - libmmpa \ - libgraph \ - libregister \ - libge_common \ - libhccl \ - libmsprof \ - - -LOCAL_LDFLAGS := -lrt -ldl - -LOCAL_SHARED_LIBRARIES += \ - libruntime \ - libresource \ - -include $(BUILD_HOST_SHARED_LIBRARY) - -# add engine_conf.json to host -include $(CLEAR_VARS) - -LOCAL_MODULE := engine_conf.json - -LOCAL_SRC_FILES := engine_manager/engine_conf.json - -LOCAL_MODULE_CLASS := ETC - -LOCAL_INSTALLED_PATH := $(HOST_OUT_ROOT)/engine_conf.json -include $(BUILD_HOST_PREBUILT) - -# add optimizer_priority.pbtxt to host -include $(CLEAR_VARS) - -LOCAL_MODULE := optimizer_priority.pbtxt - -LOCAL_SRC_FILES := opskernel_manager/optimizer_priority.pbtxt - -LOCAL_MODULE_CLASS := ETC - -LOCAL_INSTALLED_PATH := $(HOST_OUT_ROOT)/optimizer_priority.pbtxt -include $(BUILD_HOST_PREBUILT) diff --git a/src/ge/generator/ge_generator.cc b/src/ge/generator/ge_generator.cc index b01f7591..4869eb40 100644 --- a/src/ge/generator/ge_generator.cc +++ b/src/ge/generator/ge_generator.cc @@ -207,6 +207,13 @@ class GeGenerator::Impl { GraphManager graph_manager_; SaveParam save_param_; bool is_offline_ = true; + + private: + static std::string Trim(const std::string &str); + bool ParseVersion(const std::string &line, std::string &version); + bool GetVersionFromPath(const std::string &file_path, std::string &version); + bool SetAtcVersionInfo(AttrHolder &obj); + bool SetOppVersionInfo(AttrHolder &obj); }; Status GeGenerator::Initialize(const map &options) { @@ -288,6 +295,124 @@ Status GeGenerator::GenerateInfershapeGraph(const Graph &graph) { return SUCCESS; } +// Remove the space and tab before and after the string +std::string GeGenerator::Impl::Trim(const std::string &str) { + if (str.empty()) { + return str; + } + + std::string::size_type start = str.find_first_not_of(" \t\r\n"); + if (start == std::string::npos) { + return str; + } + + std::string::size_type end = str.find_last_not_of(" \t\r\n") + 1; + return str.substr(start, end); +} + +// Parsing the command line +bool GeGenerator::Impl::ParseVersion(const std::string &line, std::string &version) { + std::string flag = "Version="; + std::string temp = Trim(line); + + if (temp.empty()) { + GELOGW("line is empty."); + return false; + } + + std::string::size_type pos = temp.find(flag); + if (pos == std::string::npos) { + GELOGW("Incorrect line [%s], it must include [%s].", line.c_str(), flag.c_str()); + return false; + } + + if (temp.size() == flag.size()) { + GELOGW("version information is empty. %s", line.c_str()); + return false; + } + + version = temp.substr(pos + flag.size()); + GELOGI("Version=%s", version.c_str()); + + return true; +} + +bool GeGenerator::Impl::GetVersionFromPath(const std::string &file_path, std::string &version) { + // Normalize the path + string resolved_file_path = RealPath(file_path.c_str()); + if (resolved_file_path.empty()) { + GELOGW("Invalid input file path [%s], make sure that the file path is correct.", file_path.c_str()); + return false; + } + std::ifstream fs(resolved_file_path, std::ifstream::in); + if (!fs.is_open()) { + GELOGW("Open %s failed.", file_path.c_str()); + return false; + } + + std::string line; + if (getline(fs, line)) { + if (!ParseVersion(line, version)) { + GELOGW("Parse version failed. content is [%s].", line.c_str()); + fs.close(); + return false; + } + } else { + GELOGW("No version information found in the file path:%s", file_path.c_str()); + fs.close(); + return false; + } + + fs.close(); // close the file + return true; +} + +// Set package version information in the model +bool GeGenerator::Impl::SetAtcVersionInfo(AttrHolder &obj) { + std::string path_base = ge::GELib::GetPath(); + path_base = path_base.substr(0, path_base.rfind('/')); + path_base = path_base.substr(0, path_base.rfind('/') + 1); + + std::string version_path = path_base + "version.info"; + GELOGI("version_path is %s", version_path.c_str()); + std::string version; + if (!GetVersionFromPath(version_path, version)) { + GELOGW("Get atc version information failed!"); + return false; + } + // set version info + if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_ATC_VERSION, version)) { + GELOGW("Ge model set atc version failed!"); + return false; + } + GELOGI("Ge model set atc version information success."); + return true; +} + +// Set package version information in the model +bool GeGenerator::Impl::SetOppVersionInfo(AttrHolder &obj) { + const char *path_env = std::getenv("ASCEND_OPP_PATH"); + if (path_env == nullptr) { + GELOGW("Get environment variable ASCEND_OPP_PATH failed!"); + return false; + } + std::string version_path = path_env; + version_path += "/version.info"; + GELOGI("version_path is %s", version_path.c_str()); + std::string version; + if (!GetVersionFromPath(version_path, version)) { + GELOGW("Get opp version information failed!"); + return false; + } + // set version info + if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_OPP_VERSION, version)) { + GELOGW("Ge model set opp version failed!"); + return false; + } + GELOGI("Ge Model set opp version information success."); + return true; +} + Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector &inputs, ModelBufferData &model, bool is_offline) { rtContext_t ctx = nullptr; @@ -315,6 +440,7 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr string model_name = ""; Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(), model_name); if (name_ret != SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"}); GELOGE(FAILED, "Get model_name failed. Param --output is invalid"); return PARAM_INVALID; } @@ -464,6 +590,14 @@ Status GeGenerator::Impl::SaveParams(GeModelPtr &ge_model, const string &type, c } Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr &model, ModelBufferData &model_buff) { + // set atc version + if (!SetAtcVersionInfo(*(model.get()))) { + GELOGW("SetPackageVersionInfo of atc failed!"); + } + // set opp version + if (!SetOppVersionInfo(*(model.get()))) { + GELOGW("SetPackageVersionInfo of ops failed!"); + } ModelHelper model_helper; model_helper.SetSaveMode(is_offline_); Status ret = model_helper.SaveToOmModel(model, save_param_, file_name_prefix, model_buff); @@ -526,5 +660,4 @@ Status GeGenerator::Impl::GenerateInfershapeGraph(const Graph &graph, GraphId &g return SUCCESS; } - } // namespace ge diff --git a/src/ge/graph/build/graph_builder.cc b/src/ge/graph/build/graph_builder.cc index f2fa4ada..abcc253e 100644 --- a/src/ge/graph/build/graph_builder.cc +++ b/src/ge/graph/build/graph_builder.cc @@ -18,11 +18,14 @@ #include "common/ge/ge_util.h" #include "common/helper/model_helper.h" #include "common/opskernel/ops_kernel_info_types.h" +#include "graph/build/logical_stream_allocator.h" #include "graph/build/run_context.h" #include "graph/build/stream_graph_optimizer.h" #include "graph/manager/graph_var_manager.h" +#include "graph/passes/mark_same_addr_pass.h" #include "graph/utils/node_utils.h" #include "graph/utils/type_utils.h" +#include "graph/common/ge_call_wrapper.h" #include "init/gelib.h" #include "model/ge_model.h" @@ -34,6 +37,21 @@ const int32_t kInvalidPerfLevel = -1; namespace ge { GraphBuilder::GraphBuilder() : build_mode_(BuildMode::GEN_TASK_WITH_FUSION), hcom_parallel_(false) {} +Status GraphBuilder::MarkGraph(ComputeGraphPtr &graph) { + GE_CHECK_NOTNULL(graph); + bool is_unknown_shape = false; + for (const auto &node : graph->GetDirectNode()) { + GE_CHK_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape), + "Get node[%s] shape status failed!", node->GetName().c_str()); + if (is_unknown_shape) { + break; + } + } + graph->SetGraphUnknownFlag(is_unknown_shape); + GELOGD("mark graph [%s] unknown status success! value is %d", graph->GetName().c_str(), is_unknown_shape); + return SUCCESS; +} + void GraphBuilder::SetOptions(const ge::GraphManagerOptions &options) { stream_max_parallel_num_ = options.stream_max_parallel_num; hcom_parallel_ = options.hcom_parallel; @@ -54,7 +72,7 @@ Status GraphBuilder::CalcOpParam(const ge::ComputeGraphPtr &graph) { return GE_CLI_GE_NOT_INITIALIZED; } - for (const auto &node_ptr : graph->GetAllNodes()) { + for (const auto &node_ptr : graph->GetNodes(graph->GetGraphUnknownFlag())) { GE_CHECK_NOTNULL(node_ptr->GetOpDesc()); std::string kernel_lib_name = node_ptr->GetOpDesc()->GetOpKernelLibName(); if (kernel_lib_name.empty()) { @@ -102,11 +120,7 @@ Status GraphBuilder::UpdateParentNodeOutputSize(const ge::ComputeGraphPtr &graph graph->GetName().c_str()); auto parent_op_desc = parent_node_ptr->GetOpDesc(); GE_CHECK_NOTNULL(parent_op_desc); - bool is_unknown_shape = false; - if (!AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape)) { - GELOGE(PARAM_INVALID, "Get op %s unknown shape attr failed.", parent_op_desc->GetName().c_str()); - return PARAM_INVALID; - } + bool is_unknown_shape = graph->GetGraphUnknownFlag(); if (is_unknown_shape) { GELOGI("Current graph[%s] is unknown, no need to update parent node[%s] output size.", graph->GetName().c_str(), parent_node_ptr->GetName().c_str()); @@ -121,14 +135,14 @@ Status GraphBuilder::UpdateParentNodeOutputSize(const ge::ComputeGraphPtr &graph for (const auto &in_data_anchor : node_ptr->GetAllInDataAnchors()) { auto index = in_data_anchor->GetIdx(); ge::GeTensorDesc desc_temp = op_desc->GetInputDesc(index); - int64_t size = 0; - GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc_temp, size) != SUCCESS, GELOGI("Get size failed!")); uint32_t parent_index = 0; if (!AttrUtils::GetInt(desc_temp, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - GELOGE(INTERNAL_ERROR, "NetOutput input tensor %d, attr %s not found.", index, - ATTR_NAME_PARENT_NODE_INDEX.c_str()); - return INTERNAL_ERROR; + GELOGI("NetOutput input tensor %d, attr %s not found.", index, ATTR_NAME_PARENT_NODE_INDEX.c_str()); + continue; } + + int64_t size = 0; + GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc_temp, size) != SUCCESS, GELOGI("Get size failed!")); ge::GeTensorDesc parent_desc_temp = parent_op_desc->GetOutputDesc(parent_index); ge::TensorUtils::SetSize(parent_desc_temp, size); GE_CHK_STATUS_RET(parent_op_desc->UpdateOutputDesc(parent_index, parent_desc_temp)); @@ -176,7 +190,7 @@ Status GraphBuilder::BuildForKnownShapeGraph(ComputeGraphPtr &comp_graph, auto subgraph_map = graph_partitioner_.GetSubGraphMap(); GE_TIMESTAMP_START(BuildSubgraph); - ge::ModelBuilder builder(comp_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_); + ge::ModelBuilder builder(session_id, comp_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_); GE_DUMP(comp_graph, "BeforePreBuildModel"); GE_TIMESTAMP_START(PreBuildModel); GE_CHK_STATUS_RET(builder.PreBuildModel(), "Graph[%s] builder PreBuildModel() return fail.", @@ -229,7 +243,7 @@ Status GraphBuilder::BuildForUnknownShapeGraph(ComputeGraphPtr &comp_graph, GeMo GE_TIMESTAMP_END(CalcOpParam, "GraphBuilder::CalcOpParam"); GE_DUMP(comp_graph, "AfterCalcOpParam"); Graph2SubGraphInfoList subgraph_map; - ge::ModelBuilder builder(comp_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_); + ge::ModelBuilder builder(session_id, comp_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_); ModelPtr model_ptr = MakeShared(); if (model_ptr == nullptr) { return MEMALLOC_FAILED; @@ -263,51 +277,38 @@ Status GraphBuilder::BuildForDynamicShapeGraph(ComputeGraphPtr &comp_graph, GeRootModelPtr &ge_root_model_ptr, GeModelPtr &ge_model_ptr, uint64_t session_id) { GELOGI("Start to build BuildForDynamicShape for dynamic shape."); - for (const auto &node : comp_graph->GetDirectNode()) { + // mark unknown shape attr + for (auto &sub_graph : comp_graph->GetAllSubgraphs()) { + auto status = MarkGraph(sub_graph); + if (status != SUCCESS) { + GELOGE(FAILED, "mark graph failed!"); + return status; + } + } + // Update Root Graph Data size + for (auto &node : comp_graph->GetDirectNode()) { auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); + op_desc->SetStreamId(kInvalidStream); if (node->GetType() == DATA) { GE_CHK_STATUS_RET(CalcDynShapeRootGraphDataSize(op_desc), "Calc dynamic shape root graph data[%s] size failed.", op_desc->GetName().c_str()); } - - // ATTR_NAME_IS_UNKNOWN_SHAPE is set on "graph partion" stage, but afer fusion , the graph may - // be changed so here need to renew. For example , the scene followed: - // (known)partioncall(known) (known)partioncall(known) - // After fusion - // | --> - // (known)Unique(unknown)--->(unknow)Shape(unknown) (known)FuncDef(known) - // if scene like this , it should be process as known shape graph - bool is_unknown_shape = false; - GE_CHK_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape), - "Get node[%s] shape status failed!", node->GetName().c_str()); - if (!is_unknown_shape) { - GE_CHK_BOOL_EXEC(ge::AttrUtils::SetBool(op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape), return FAILED, - "Renew node [%s] attr[%s] failed!", node->GetName().c_str(), ATTR_NAME_IS_UNKNOWN_SHAPE.c_str()); - GELOGD("renew node [%s] attr[%s] success! value is %d", node->GetName().c_str(), - ATTR_NAME_IS_UNKNOWN_SHAPE.c_str(), is_unknown_shape); - } - - vector subgraph_names = op_desc->GetSubgraphInstanceNames(); - for (auto subgraph_name : subgraph_names) { - ComputeGraphPtr subgraph = comp_graph->GetSubgraph(subgraph_name); - bool is_unknown_shape = false; - if (!AttrUtils::GetBool(op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape)) { - GELOGE(PARAM_INVALID, "Get op %s unknown shape attr failed.", op_desc->GetName().c_str()); - return PARAM_INVALID; - } - if (is_unknown_shape) { - // unknown shape build flow - GE_CHK_STATUS_RET(BuildForUnknownShapeGraph(subgraph, ge_model_ptr, session_id), - "Build for unknown shape graph failed."); - } else { - // known shape build flow - GE_CHK_STATUS_RET(BuildForKnownShapeGraph(subgraph, subgraph_ptr_list, ge_model_ptr, session_id), - "Build for known shape graph failed."); - } - ge_root_model_ptr->SetSubgraphInstanceNameToModel(subgraph_name, ge_model_ptr); + } + // + for (auto &sub_graph : comp_graph->GetAllSubgraphs()) { + if (sub_graph->GetGraphUnknownFlag()) { + // unknown shape build flow + GE_CHK_STATUS_RET(BuildForUnknownShapeGraph(sub_graph, ge_model_ptr, session_id), + "Build for unknown shape graph failed."); + } else { + // known shape build flow + GE_CHK_STATUS_RET(BuildForKnownShapeGraph(sub_graph, subgraph_ptr_list, ge_model_ptr, session_id), + "Build for known shape graph failed."); } + ge_root_model_ptr->SetSubgraphInstanceNameToModel(sub_graph->GetName(), ge_model_ptr); } + return SUCCESS; } @@ -327,8 +328,9 @@ Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr GELOGE(INTERNAL_ERROR, "Get weight memory size fail."); return INTERNAL_ERROR; } - auto *get_mem_base = - reinterpret_cast(reinterpret_cast(ge::VarManager::Instance(0)->GetVarMemMaxSize())); + + auto var_manager = VarManager::Instance(session_id); + auto *get_mem_base = reinterpret_cast(reinterpret_cast(var_manager->GetVarMemMaxSize())); uint8_t *get_weight_mem_base = get_mem_base; if (weight_size > 0) { get_weight_mem_base = get_mem_base + memory_size; @@ -354,11 +356,8 @@ Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr return ret; } GE_DUMP(comp_graph, "AfterOptimizeStreamedSubGraph"); - auto *get_var_mem_base = - reinterpret_cast(reinterpret_cast(ge::VarManager::Instance(0)->GetVarMemLogicBase())); - uint64_t var_size = (ge::VarManager::Instance(session_id)->GetVarMemSize(RT_MEMORY_HBM) > 0) - ? ge::VarManager::Instance(0)->GetVarMemMaxSize() - : 0; + auto *get_var_mem_base = reinterpret_cast(reinterpret_cast(var_manager->GetVarMemLogicBase())); + uint64_t var_size = (var_manager->GetVarMemSize(RT_MEMORY_HBM) > 0) ? var_manager->GetVarMemMaxSize() : 0; TaskGenerator task_generator(get_var_mem_base, var_size); ret = task_generator.GetTaskInfo(*model_ptr, comp_graph, session_id, run_context.GetRunContext()); @@ -368,6 +367,13 @@ Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr Status GraphBuilder::SetInputSize(const ge::NodePtr &node_ptr) { // set input_desc.size = src_node.output_desc.size if (node_ptr->GetType() == DATA) { + bool is_unknown_shape = false; + GE_CHK_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node_ptr, is_unknown_shape), + "Get data node[%s] shape status failed!", node_ptr->GetName().c_str()); + if (is_unknown_shape) { + GELOGD("data node: %s is unknown shape, do not set input size!", node_ptr->GetName().c_str()); + return SUCCESS; + } if (UpdateDataInputSize(node_ptr) != SUCCESS) { GELOGE(FAILED, "Update data input size failed."); return FAILED; @@ -398,7 +404,7 @@ Status GraphBuilder::SetInputSize(const ge::NodePtr &node_ptr) { GE_CHECK_NOTNULL(input_desc); ge::TensorUtils::SetSize(const_cast(*input_desc), size); GE_CHK_STATUS_RET(node_op_desc->UpdateInputDesc(in_data_anchor->GetIdx(), *input_desc)); - GELOGD("%s input desc, dim_size: %zu, mem_size: %u, format: %s, type: %s.", node_ptr->GetName().c_str(), + GELOGD("%s input desc, dim_size: %zu, mem_size: %ld, format: %s, type: %s.", node_ptr->GetName().c_str(), input_desc->GetShape().GetDimNum(), size, TypeUtils::FormatToSerialString(input_desc->GetFormat()).c_str(), TypeUtils::DataTypeToSerialString(input_desc->GetDataType()).c_str()); } diff --git a/src/ge/graph/build/graph_builder.h b/src/ge/graph/build/graph_builder.h index def3a28b..2597aa2a 100644 --- a/src/ge/graph/build/graph_builder.h +++ b/src/ge/graph/build/graph_builder.h @@ -67,6 +67,7 @@ class GraphBuilder { GeModelPtr &ge_model_ptr, uint64_t session_id = INVALID_SESSION_ID); Status BuildForUnknownShapeGraph(ComputeGraphPtr &comp_graph, GeModelPtr &ge_model_ptr, uint64_t session_id = INVALID_SESSION_ID); + Status MarkGraph(ComputeGraphPtr &graph); int build_mode_; std::map stream_max_parallel_num_; diff --git a/src/ge/graph/build/label_allocator.cc b/src/ge/graph/build/label_allocator.cc index 46c092f5..f8fbe28b 100644 --- a/src/ge/graph/build/label_allocator.cc +++ b/src/ge/graph/build/label_allocator.cc @@ -24,7 +24,6 @@ #include "graph/label/label_maker.h" namespace ge { - LabelAllocator::LabelAllocator(const ComputeGraphPtr &graph) : compute_graph_(graph) {} Status LabelAllocator::AssignFunctionalLabels(uint32_t &label_index) { @@ -76,5 +75,4 @@ bool LabelAllocator::CollectFunctionalNode(ComputeGraphPtr &graph, std::setGetOpDesc(), kAttrNameParentOpType, parent_op_type)) { - if ((parent_op_type != CONSTANT) && (parent_op_type != CONSTANTOP)) { - return true; - } - } - } - } - - return false; -} - Status AssignByLabelPass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { bool changed = false; int64_t &next_stream = context.next_stream; @@ -133,21 +110,6 @@ Status IndependentStreamPass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { bool changed = false; - if (IsHeadNodeExceeded(subgraphs)) { - int64_t &next_stream = context.next_stream; - for (const SubgraphPtr &subgraph : subgraphs) { - if (!HasAssignedStream(*subgraph)) { - subgraph->stream_id = next_stream; - changed = true; - } - } - if (changed) { - ++next_stream; - return SUCCESS; - } - return NOT_CHANGED; - } - map end_subgraph_map; map pld_subgraph_map; InitEndSubgraphMap(subgraphs, end_subgraph_map); @@ -190,24 +152,6 @@ Status AssignByDependencyPass::Run(ComputeGraphPtr graph, const vector &subgraphs) const { - size_t aicpu_node_num = 0; - for (const SubgraphPtr &subgraph : subgraphs) { - if (subgraph->engine_conf.id == kAICPUEngineName && !HasNonConstInputNode(*subgraph)) { - const SubGraphInfo &subgraph_info = subgraph->subgraph_info; - auto compute_graph = subgraph_info.GetSubGraph(); - aicpu_node_num += compute_graph->GetDirectNode().size() - subgraph_info.GetPld2EndMap().size() - - subgraph_info.GetEnd2PldMap().size(); - if (aicpu_node_num > kHeadNodeMaxNum) { - GELOGI("aicpu_node_num, %zu", aicpu_node_num); - return true; - } - } - } - - return false; -} - void AssignByDependencyPass::InitEndSubgraphMap(const vector &subgraphs, map &end_subgraph_map) { for (const auto &subgraph : subgraphs) { @@ -727,7 +671,7 @@ void LogicalStreamAllocator::RefreshContinuousStreams(const ComputeGraphPtr &gra int64_t stream_num = context_.next_stream; vector stream_has_node(stream_num); - for (const NodePtr &node : graph->GetAllNodes()) { + for (const NodePtr &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { if (node != nullptr) { auto op_desc = node->GetOpDesc(); if (op_desc != nullptr) { @@ -748,7 +692,7 @@ void LogicalStreamAllocator::RefreshContinuousStreams(const ComputeGraphPtr &gra } } - for (const NodePtr &node : graph->GetAllNodes()) { + for (const NodePtr &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { auto op_desc = node->GetOpDesc(); if (op_desc != nullptr) { int64_t stream_id = op_desc->GetStreamId(); diff --git a/src/ge/graph/build/logical_stream_allocator.h b/src/ge/graph/build/logical_stream_allocator.h index 71946630..280a4104 100644 --- a/src/ge/graph/build/logical_stream_allocator.h +++ b/src/ge/graph/build/logical_stream_allocator.h @@ -81,9 +81,6 @@ class LogicalStreamPass { bool HasStreamLabel(const Subgraph &subgraph) const; bool HasAssignedStream(const Subgraph &subgraph) const; - // Determine if the input of the subgraph is a constant. - bool HasNonConstInputNode(const Subgraph &subgraph) const; - private: std::string name_; }; @@ -121,7 +118,6 @@ class AssignByDependencyPass : public LogicalStreamPass { void UpdateAssignedSubgraphs(Context &context); void UpdateReusedSubgraphs(); - bool IsHeadNodeExceeded(const std::vector &subgraphs) const; bool CouldReuse(const SubgraphPtr &subgraph, const SubgraphPtr &pred_subgraph, const std::map &pld_subgraph_map); diff --git a/src/ge/graph/build/memory/block_mem_assigner.cc b/src/ge/graph/build/memory/block_mem_assigner.cc index df7912fa..1910618d 100644 --- a/src/ge/graph/build/memory/block_mem_assigner.cc +++ b/src/ge/graph/build/memory/block_mem_assigner.cc @@ -18,6 +18,7 @@ #include #include +#include "external/ge/ge_api_types.h" #include "framework/common/debug/ge_log.h" #include "graph/anchor.h" #include "graph/buffer.h" @@ -39,7 +40,6 @@ namespace { const char *const kAttrNameWorkspaceReuseFlag = "workspace_reuse_flag"; const char *const kL2FusionDynamicConvergeOp = "l2fusion_dynamic_converge_op"; const char *const kOpNoReuseMem = "no_reuse_mem_flag"; -const char *const kDisableReuseMemory = "ge.exec.disableReuseMemory"; const char *const OP_NO_REUSE_MEM = "OP_NO_REUSE_MEM"; const int kReuseMaxCount = 10; const int kReuseMaxOpNum = 10; @@ -133,21 +133,20 @@ bool MemoryBlock::IsSameLabel(std::string &first_batch_label) { } bool CanNotLifeReuse(MemoryBlock *block) { - if (block == nullptr || !block->reuse_mem_ || block->deleted_block_ || block->continuous_block_ || - block->GetLifeEnd() == kMaxLifeTime) { + if ((block == nullptr) || !block->reuse_mem_ || block->deleted_block_ || block->continuous_block_) { return true; } return false; } -void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block) { +void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life) { if (CanNotLifeReuse(this) || CanNotLifeReuse(block)) { return; } MemoryBlock *parent = nullptr; MemoryBlock *child = nullptr; // merge small block to large block - if ((block->GetLifeBegin() > GetLifeEnd()) && (block->stream_id_ == stream_id_)) { + if (block->GetDependLifeBegin(stream_id_, total_node_depend_stream_life) > GetLifeEnd()) { if ((child_offset_ + block->block_size_) <= block_size_) { parent = this; child = block; @@ -181,6 +180,87 @@ size_t MemoryBlock::GetLifeBegin() { return life_time; } +/// |-stream 1-| |-stream 2-| +/// |--block1--| |--block---| +/// |--block2--| |--block---| +/// |--block3--|\ |--block---| +/// |--block---| \ |--block---| +/// |--block---| \|--block---| +/// |--block---| |--block7--| +/// |--block---| |--block---| +/// block7's first node's input node's life begin > block2's life end, block7 can reuse block1~block2 +size_t MemoryBlock::GetDependLifeBegin(int64_t stream_id, DependStreamLife &total_node_depend_stream_life) { + AddDependLifeBegin(total_node_depend_stream_life); + auto it = depend_stream_life_.find(stream_id); + if (it == depend_stream_life_.end()) { + return 0; + } + return it->second; +} + +void AddDependLife(const ge::NodePtr &org_node, const ge::NodePtr &node, int64_t stream_id, + std::map &depend_stream_life, DependStreamLife &total_node_depend_stream_life) { + GE_CHECK_NOTNULL_EXEC(node, return ); + auto node_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(node_desc, return ); + auto node_id = node_desc->GetId(); + auto stream_life = total_node_depend_stream_life.find(node_id); + if (stream_life != total_node_depend_stream_life.end()) { + for (auto &it : stream_life->second) { + if (depend_stream_life.find(it.first) == depend_stream_life.end()) { + depend_stream_life[it.first] = it.second; + } + } + return; + } + + for (const auto &in_anchor : node->GetAllInAnchors()) { + GE_CHECK_NOTNULL_EXEC(in_anchor, continue); + for (auto peer_out_anchor : in_anchor->GetPeerAnchors()) { + GE_CHECK_NOTNULL_EXEC(peer_out_anchor, continue); + auto peer_node = peer_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL_EXEC(peer_node, continue); + auto peer_node_desc = peer_node->GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(peer_node_desc, continue); + auto peer_node_stream_id = peer_node_desc->GetStreamId(); + if (peer_node_stream_id < 0) { + continue; + } + size_t peer_node_life_time = peer_node_desc->GetId(); + auto it = depend_stream_life.find(peer_node_stream_id); + if (it == depend_stream_life.end() || peer_node_life_time > it->second) { + depend_stream_life[peer_node_stream_id] = peer_node_life_time; + if (peer_node_stream_id != stream_id) { + GELOGI("Node:%s stream id:%ld depend node:%s stream id:%ld index[%d] life time[%zu].", + org_node->GetName().c_str(), stream_id, peer_node_desc->GetName().c_str(), peer_node_stream_id, + peer_out_anchor->GetIdx(), peer_node_life_time); + } + AddDependLife(org_node, peer_node, stream_id, depend_stream_life, total_node_depend_stream_life); + } + } + } + + // save on node to save next calculation + for (auto &it : depend_stream_life) { + if (total_node_depend_stream_life[node_id].find(it.first) == total_node_depend_stream_life[node_id].end()) { + total_node_depend_stream_life[node_id][it.first] = it.second; + } + } +} + +void MemoryBlock::AddDependLifeBegin(DependStreamLife &total_node_depend_stream_life) { + if (!depend_stream_life_.empty()) { + return; + } + if (!node_type_index_list_.empty()) { + auto node = node_type_index_list_.front().node; + if (node != nullptr) { + AddDependLife(node, node, stream_id_, depend_stream_life_, total_node_depend_stream_life); + } + } + depend_stream_life_[stream_id_] = GetLifeBegin(); +} + size_t MemoryBlock::GetLifeEnd() { if (!node_type_index_list_.empty()) { return node_type_index_list_.back().life_time_end; @@ -302,7 +382,7 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector &all_memory_size) { if (iter1 == anchor_to_symbol_.end()) { continue; } - std::string symbol = iter1->second; + const std::string &symbol = iter1->second; auto iter2 = symbol_size_.find(symbol); if (iter2 == symbol_size_.end()) { symbol_size_[symbol] = size; @@ -317,7 +397,7 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector &all_memory_size) { all_memory_size.insert(all_memory_size.end(), temp.begin(), temp.end()); } GELOGI("The last atomic_addr_clean node id: %ld", atomic_addr_clean_id_); - for (auto &pair : symbol_size_) { + for (const auto &pair : symbol_size_) { all_memory_size.emplace_back(pair.second); } sort(all_memory_size.begin(), all_memory_size.end()); @@ -427,14 +507,6 @@ bool CanReuseBySize(const map &reusable_block_counts, const Me return can_reuse; } -bool CanReuseByStream(const std::unordered_set &reuse_stream, MemoryBlock &reusable_block) { - bool can_reuse = false; - if (reuse_stream.find(reusable_block.stream_id_) != reuse_stream.cend()) { - can_reuse = true; - } - return can_reuse; -} - bool BlockMemAssigner::IsOutNodeSetContinuousInput(const NodePtr &n, uint32_t out_index, std::string &peer_name, uint32_t &peer_input_index) { if (n == nullptr || n->GetAllOutDataAnchors().size() <= 0) { @@ -495,11 +567,11 @@ void BlockMemAssigner::InitReuseFlag() { ge::CONSTANT, ge::CONSTANTOP}; static const std::set kPostReuseTypes = {ge::DATA_TYPE, ge::AIPP_DATA_TYPE, ge::ENTER, ge::REFENTER, ge::NEXTITERATION, ge::REFNEXTITERATION}; - for (auto &pair : symbol_to_anchors_) { + for (const auto &pair : symbol_to_anchors_) { std::string symbol = pair.first; bool pre_reuse_flag = true; bool post_reuse_flag = true; - for (auto &node_index_io : pair.second) { + for (const auto &node_index_io : pair.second) { if (node_index_io.io_type_ == kIn) { continue; } @@ -513,13 +585,13 @@ void BlockMemAssigner::InitReuseFlag() { if (node_index_io.node_->GetOutDataNodes().empty()) { out_flg = true; } - for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { + for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { if (IsDirectOutputNode(in_anchor->GetOwnerNode(), in_anchor->GetIdx())) { out_flg = true; break; } } - std::string type = out_anchor->GetOwnerNode()->GetType(); + const std::string &type = out_anchor->GetOwnerNode()->GetType(); pre_reuse_flag = pre_reuse_flag && !out_flg && (kPreReuseTypes.count(type) == 0); post_reuse_flag = post_reuse_flag && (kPostReuseTypes.count(type) == 0); if (!pre_reuse_flag && !post_reuse_flag) { @@ -552,7 +624,7 @@ bool BlockMemAssigner::IsPreReuse(const NodePtr &node, uint32_t out_index) const return false; } - std::string symbol = iter1->second; + const std::string &symbol = iter1->second; auto iter2 = pre_reuse_flag_.find(symbol); if (iter2 == pre_reuse_flag_.end()) { return false; @@ -570,7 +642,7 @@ bool BlockMemAssigner::IsPostReuse(const MemoryBlock *mem_block) const { if (mem_block == nullptr) { return false; } - for (auto &symbol : mem_block->SymbolList()) { + for (const auto &symbol : mem_block->SymbolList()) { auto iter = post_reuse_flag_.find(symbol); if (iter == post_reuse_flag_.end()) { continue; @@ -593,8 +665,7 @@ bool BlockMemAssigner::IsSymbolExist(const NodeIndexIO &node_index_io) { if (iter == anchor_to_symbol_.end()) { return false; } - std::string symbol = iter->second; - return symbol_blocks_.find(symbol) != symbol_blocks_.end(); + return symbol_blocks_.find(iter->second) != symbol_blocks_.end(); } /// @@ -603,10 +674,10 @@ bool BlockMemAssigner::IsSymbolExist(const NodeIndexIO &node_index_io) { /// @return void /// void BlockMemAssigner::PrintSymbolMap() { - for (auto &pair : symbol_to_anchors_) { + for (const auto &pair : symbol_to_anchors_) { GELOGD("symbol=%s, max_size=%zu, pre_reuse=%s, post_reuse=%s", pair.first.c_str(), symbol_size_[pair.first], pre_reuse_flag_[pair.first] ? "true" : "false", post_reuse_flag_[pair.first] ? "true" : "false"); - for (auto &node_index_io : pair.second) { + for (const auto &node_index_io : pair.second) { GELOGD("anchor:%s", node_index_io.ToString().c_str()); } } @@ -622,15 +693,14 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, bool is_reuse_memory = false; string ge_disable_reuse_mem_env = "0"; - (void)ge::GetContext().GetOption(kDisableReuseMemory, ge_disable_reuse_mem_env); + (void)ge::GetContext().GetOption(OPTION_EXEC_DISABLE_REUSED_MEMORY, ge_disable_reuse_mem_env); if (ge_disable_reuse_mem_env != "1") { bool reuse_mem_flag = !((workspace_reuse_flag.size() > out_index) && !workspace_reuse_flag[out_index]); is_reuse_memory = !node_op_desc->HasAttr(kL2FusionDynamicConvergeOp) && !node_op_desc->HasAttr(kOpNoReuseMem) && reuse_mem_flag && is_op_reuse_mem && (IsPreReuse(n, out_index)); auto stream_id = node_op_desc->GetStreamId(); - auto map_iter = reusable_streams_map_.find(stream_id); - if (is_reuse_memory && map_iter != reusable_streams_map_.end()) { - for (auto it = reusable_blocks_.begin(); it != reusable_blocks_.end(); ++it) { + if (is_reuse_memory) { + for (auto it = reusable_blocks_[stream_id].begin(); it != reusable_blocks_[stream_id].end(); ++it) { MemoryBlock *reusable_block = *it; if (!IsPostReuse(reusable_block)) { reusable_block->reuse_mem_ = false; @@ -640,10 +710,7 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, // A node can reuse blocks of the same stream and preorder streams auto id = GetAtomicAddrCleanId(); - if (CanReuseBySize(reusable_block_counts_, *reusable_block, block_size, real_size, continuous, id) && - CanReuseByStream(map_iter->second, *reusable_block)) { - GELOGD("Cross stream mem reuse, target stream:%ld, current stream:%ld", reusable_block->stream_id_, - stream_id); + if (CanReuseBySize(reusable_block_counts_, *reusable_block, block_size, real_size, continuous, id)) { reusable_block->AddNodeTypeIndex({n, mem_type, out_index, false}, real_size, no_align_size); if (mem_type == kOutput) { auto iter = anchor_to_symbol_.find(NodeIndexIO(n, out_index, kOut).ToString()); @@ -654,7 +721,7 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, reusable_block->continuous_block_ = continuous; reusable_block->ref_count_++; ReduceReusableBlockCount(*reusable_block, reusable_block_counts_); - reusable_blocks_.erase(it); + reusable_blocks_[stream_id].erase(it); return reusable_block; } } @@ -700,7 +767,7 @@ MemoryBlock *BlockMemAssigner::ApplyOutMemory(const NodePtr &n, uint32_t index, "Get no align size failed"); if (IsSymbolExist(node_index_io)) { - std::string symbol = anchor_to_symbol_[node_index_io.ToString()]; + const std::string &symbol = anchor_to_symbol_[node_index_io.ToString()]; block = symbol_blocks_[symbol]; block->AddNodeTypeIndex({n, kOutput, index, true}, size, no_align_size); block->ref_count_++; @@ -923,7 +990,7 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector (void)ge::AttrUtils::GetBool(op_desc, ATOMIC_ATTR_IS_ATOMIC_NODE, is_atomic); // Allocate memory for the current node and release node memory of the same size in the workspace GE_IF_BOOL_EXEC(ge_disable_reuse_mem_env_ != "1", - ReleaseMemorys(stream_workspace_blocks_[stream_id], reusable_blocks_);) + ReleaseMemorys(stream_workspace_blocks_[stream_id], reusable_blocks_[stream_id]);) for (uint32_t i = 0; i < static_cast(op_desc->GetOutputsSize()); i++) { int64_t size = 0; auto output_op_desc = op_desc->GetOutputDescPtr(i); @@ -977,10 +1044,7 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector /// @return Status result /// void BlockMemAssigner::AssignMemoryWithReuse(vector &ranges) { - // Init reusable streams map - InitReusableStreamMap(); - - (void)ge::GetContext().GetOption(kDisableReuseMemory, ge_disable_reuse_mem_env_); + (void)ge::GetContext().GetOption(OPTION_EXEC_DISABLE_REUSED_MEMORY, ge_disable_reuse_mem_env_); GEEVENT("Reuse memory %s", ge_disable_reuse_mem_env_ == "1" ? "close" : "open"); string op_no_reuse_mem_str; const char *op_no_reuse_mem = std::getenv(OP_NO_REUSE_MEM); @@ -1033,7 +1097,7 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector &ranges) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mem_block == nullptr, continue, "failed to apply memory block."); CheckWorkspaceReuse(workspace_reuse_flag, i, stream_id, mem_block); } - ReleaseInputNodeOutMemory(node_out_blocks_, reusable_blocks_, n); + ReleaseInputNodeOutMemory(node_out_blocks_, reusable_blocks_[stream_id], n); } GELOGD("Assigned memory blocks:"); @@ -1044,7 +1108,7 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector &ranges) { bool merge_dynamic_batch = false; GE_IF_BOOL_EXEC(!(ge_disable_reuse_mem_env_ == "1"), merge_dynamic_batch = MergeDynamicBatchBlocks();) - GE_IF_BOOL_EXEC(!merge_dynamic_batch, ReuseBlocksByLifeTime();) + GE_IF_BOOL_EXEC((!(ge_disable_reuse_mem_env_ == "1") && !merge_dynamic_batch), ReuseBlocksByLifeTime(ranges.size());) AssignContinuousBlocks(); ResizeMemoryBlocks(); @@ -1221,7 +1285,11 @@ void BlockMemAssigner::AssignContinuousBlocks() { } } -void BlockMemAssigner::ReuseBlocksByLifeTime() { +void BlockMemAssigner::ReuseBlocksByLifeTime(size_t range_size) { + // 1 means block size is same so no need to do this + if (range_size <= 1) { + return; + } for (size_t i = 0; i < memory_blocks_.size(); ++i) { auto parent = memory_blocks_[i]; if (parent == nullptr || parent->deleted_block_) { @@ -1231,7 +1299,7 @@ void BlockMemAssigner::ReuseBlocksByLifeTime() { parent->reuse_mem_ = false; } for (size_t j = i + 1; j < memory_blocks_.size(); ++j) { - parent->AddLifeReuseBlock(memory_blocks_[j]); + parent->AddLifeReuseBlock(memory_blocks_[j], total_node_depend_stream_life_); } } } @@ -1318,10 +1386,10 @@ void SetOffsetSize(const NodeTypeIndex &node_type, const MemoryBlock *block, siz } GELOGI( "[IMAS]Set %s name[%s] %s[%u] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu]" - " noalignsize[%zu] life time begin[%zu] life time end[%zu] child[%d] isref[%d].", + " noalignsize[%zu] life time begin[%zu] life time end[%zu] child[%d:%d:%d:%d] isref[%d].", graph_name.c_str(), op_desc->GetName().c_str(), node_type.GetMemType().c_str(), node_type.index, offset, op_desc->GetStreamId(), block->Size(), real_size, no_align_size, op_desc->GetId(), end, child_block, - node_type.ref_input); + block->reuse_mem_, block->continuous_block_, block->deleted_block_, node_type.ref_input); } void SetBlockOpMemOffset(MemoryBlock *block, bool child_block) { @@ -1380,139 +1448,6 @@ Status BlockMemAssigner::Assign() { return SUCCESS; } -void BlockMemAssigner::InitReusableStreamMap() { - // save a stream's id and its first Node and last node. - map> stream_head_tail_node_map; - // save a stream's id and its directly child stream. - map> stream_dependency_map; - // save a stream's id and its occupied memory. - unordered_map stream_mem_map; - - // Find streams's first and last node. - FindHeadAndTailNodesForStream(stream_head_tail_node_map, stream_mem_map); - - // If streamB's first node is the output of streamA's last node, then B depends on A. - FindDependentStream(stream_head_tail_node_map, stream_dependency_map); - - // If a stream has more than one child stream, select the one that occupies the closest memory - for (const auto &iter : stream_dependency_map) { - if (iter.second.empty()) { - continue; - } - int64_t target_size = stream_mem_map[iter.first]; - int64_t min_size_gap = LONG_MAX; - int64_t target_reuse_stream_id = 0; - for (auto id : iter.second) { - if (labs(stream_mem_map[id] - target_size) < min_size_gap) { - target_reuse_stream_id = id; - min_size_gap = labs(stream_mem_map[id] - target_size); - } - } - // If b can reuse a, then b should also be able to reuse all blocks that a can reuse. - reusable_streams_map_[target_reuse_stream_id].insert(reusable_streams_map_[iter.first].begin(), - reusable_streams_map_[iter.first].end()); - } -} - -void BlockMemAssigner::FindHeadAndTailNodesForStream(map> &stream_head_tail_node_map, - unordered_map &stream_mem_map) { - for (const auto &n : compute_graph_->GetAllNodes()) { - GE_IF_BOOL_EXEC(n->GetOpDesc() == nullptr, GELOGW("Op desc is nullptr"); continue); - auto stream_id = n->GetOpDesc()->GetStreamId(); - // traverse to find streams's first and last node. - if (stream_head_tail_node_map.find(stream_id) == stream_head_tail_node_map.end()) { - stream_head_tail_node_map[stream_id] = std::make_pair(n, n); - reusable_streams_map_[stream_id].insert(stream_id); // a node can reuse blocks from same stream. - } else { - stream_head_tail_node_map[stream_id].second = n; - } - - // Accumulate the output size of the node in the stream. - for (size_t i = 0; i < n->GetOpDesc()->GetOutputsSize(); i++) { - int64_t size = 0; - if (ge::TensorUtils::GetSize(*n->GetOpDesc()->GetOutputDescPtr(static_cast(i)), size) != SUCCESS) { - GELOGW("Get output size failed!"); - continue; - } - stream_mem_map[stream_id] += size; - } - // Accumulate the workspace size of the node in the stream. - for (auto size : n->GetOpDesc()->GetWorkspaceBytes()) { - stream_mem_map[stream_id] += size; - } - } -} - -void BlockMemAssigner::FindDependentStream(map> &stream_head_tail_node_map, - map> &stream_dependency_map) { - for (const auto &it1 : stream_head_tail_node_map) { - for (const auto &it2 : stream_head_tail_node_map) { - if (it1 == it2) { - continue; - } - NodePtr pre_node = it1.second.second; - NodePtr post_node = it2.second.first; - std::vector out_nodes; - // Direct link out_node - for (const auto &out_node : pre_node->GetOutNodes()) { - if ((out_node->GetOpDesc() == nullptr) || (post_node->GetOpDesc() == nullptr) || - (pre_node->GetOpDesc() == nullptr)) { - continue; - } - out_nodes.emplace_back(out_node); - } - - FindDependentStreamBetweenGraphs(pre_node, out_nodes); - - for (auto &out_node : out_nodes) { - if (out_node->GetOpDesc()->GetId() == post_node->GetOpDesc()->GetId()) { - stream_dependency_map[pre_node->GetOpDesc()->GetStreamId()].insert(post_node->GetOpDesc()->GetStreamId()); - } - } - } - } -} - -/// -/// @ingroup GE -/// @brief Find dependent link between parent_graph and sub_graph -/// @param [in] pre_node -/// @param [out] out_nodes -/// @return void -/// @author -/// -void BlockMemAssigner::FindDependentStreamBetweenGraphs(const NodePtr &pre_node, std::vector &out_nodes) { - if ((pre_node == nullptr) || (pre_node->GetOpDesc() == nullptr)) { - return; - } - - // FunctionOp & subgraph input - std::vector subgraph_names = pre_node->GetOpDesc()->GetSubgraphInstanceNames(); - for (auto &subgraph_name : subgraph_names) { - ComputeGraphPtr subgraph = compute_graph_->GetSubgraph(subgraph_name); - if (subgraph == nullptr) { - continue; - } - for (auto &node : subgraph->GetDirectNode()) { - OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - continue; - } - if (op_desc->HasAttr(ATTR_NAME_PARENT_NODE_INDEX)) { - out_nodes.emplace_back(node); - } - } - } - - // subgraph output & parent_node output - if (NodeUtils::IsSubgraphOutput(pre_node)) { - NodePtr parent_node = pre_node->GetOwnerComputeGraph()->GetParentNode(); - for (const auto &out_node : parent_node->GetOutNodes()) { - out_nodes.emplace_back(out_node); - } - } -} - bool BlockMemAssigner::CheckIsZeroMemNodeType(const string &node_type) const { return (node_type == VARIABLE) || (node_type == CONSTANT) || (node_type == MULTISHAPE) || (node_type == HCOMBROADCAST) || (node_type == HCOMALLREDUCE) || (node_type == CONSTANTOP) || diff --git a/src/ge/graph/build/memory/block_mem_assigner.h b/src/ge/graph/build/memory/block_mem_assigner.h index 8ee4506e..4e9c3b05 100644 --- a/src/ge/graph/build/memory/block_mem_assigner.h +++ b/src/ge/graph/build/memory/block_mem_assigner.h @@ -34,6 +34,8 @@ namespace ge { const size_t kMaxLifeTime = 0xffffffff; +using DependStreamLife = std::map>; + enum MemoryType { kOutput, kWorkspace }; struct NodeTypeIndex { @@ -116,7 +118,7 @@ class MemoryBlock { bool IsSameLabel(std::string &first_batch_label); - void AddLifeReuseBlock(MemoryBlock *block); + void AddLifeReuseBlock(MemoryBlock *block, DependStreamLife &node_depend_stream_life); void SetLifeTimeEnd(size_t time); @@ -124,6 +126,10 @@ class MemoryBlock { size_t GetLifeEnd(); + void AddDependLifeBegin(DependStreamLife &node_depend_stream_life); + + size_t GetDependLifeBegin(int64_t stream_id, DependStreamLife &node_depend_stream_life); + int ref_count_; int64_t stream_id_; bool deleted_block_; @@ -196,47 +202,6 @@ class BlockMemAssigner : public MemAssigner { /// /// @ingroup GE - /// @brief Traversing the compute_graph_ to find the reuse relationship between streams - /// @param [in] reusable_stream_map map to save stream_id and its reusable stream_ids - /// @return void - /// @author - /// - void InitReusableStreamMap(); - - /// - /// @ingroup GE - /// @brief Traversing the compute_graph_ to find the first and last nodeptr of a stream. - /// @param [in] stream_head_tail_node_map map to save stream_id and its first and last nodeptr. - /// @param [in] stream_mem_map map to save stream_id and its memory capacity. - /// @return void - /// @author - /// - void FindHeadAndTailNodesForStream(std::map> &stream_head_tail_node_map, - std::unordered_map &stream_mem_map); - - /// - /// @ingroup GE - /// @brief Traversing the compute_graph_ to find the reuse relationship between streams. - /// @param [in] stream_head_tail_node_map map to save stream_id and its first and last nodeptr. - /// @param [in] stream_dependency_map map to save stream_id and stream_ids depends on it. - /// @return void - /// @author - /// - void FindDependentStream(std::map> &stream_head_tail_node_map, - std::map> &stream_dependency_map); - - /// - /// @ingroup GE - /// @brief Find dependent link between parent_graph and sub_graph - /// @param [in] pre_node - /// @param [out] out_nodes - /// @return void - /// @author - /// - void FindDependentStreamBetweenGraphs(const NodePtr &pre_node, std::vector &out_nodes); - - /// - /// @ingroup GE /// @brief Determine whether it is the type of zero memory node. /// @param [in] node type. /// @return bool true: is zero memory node; false: is not zero memory node @@ -395,9 +360,9 @@ class BlockMemAssigner : public MemAssigner { /// @return void /// @author /// - void ReuseBlocksByLifeTime(); + void ReuseBlocksByLifeTime(size_t range_size); - std::vector reusable_blocks_; + std::unordered_map> reusable_blocks_; std::map reusable_block_counts_; @@ -411,9 +376,6 @@ class BlockMemAssigner : public MemAssigner { std::unordered_map node_continuous_input_counts_; - // save stream_id and reusable stream_ids - std::unordered_map> reusable_streams_map_; - // reuse memory vector op_no_reuse_mem_vec_; @@ -426,6 +388,8 @@ class BlockMemAssigner : public MemAssigner { size_t life_time_; int64_t atomic_addr_clean_id_ = 0; + + DependStreamLife total_node_depend_stream_life_; }; } // namespace ge #endif // GE_GRAPH_BUILD_MEMORY_BLOCK_MEM_ASSIGNER_H_ diff --git a/src/ge/graph/build/memory/graph_mem_assigner.cc b/src/ge/graph/build/memory/graph_mem_assigner.cc index c4aca639..8393c474 100644 --- a/src/ge/graph/build/memory/graph_mem_assigner.cc +++ b/src/ge/graph/build/memory/graph_mem_assigner.cc @@ -222,9 +222,10 @@ Status GraphMemoryAssigner::ReAssignMemory(bool is_loop_graph, size_t &mem_offse mem_offset = memory_offset_[0].mem_offset_; - if (mem_offset > VarManager::Instance(0)->GetGraphMemoryMaxSize()) { + auto session_id = compute_graph_->GetSessionID(); + if (mem_offset > VarManager::Instance(session_id)->GetGraphMemoryMaxSize()) { GELOGE(ge::FAILED, "Current memoffset %zu is greater than memory manager malloc max size %zu", mem_offset, - VarManager::Instance(0)->GetGraphMemoryMaxSize()); + VarManager::Instance(session_id)->GetGraphMemoryMaxSize()); return ge::FAILED; } return SUCCESS; @@ -1222,10 +1223,16 @@ ge::Status GraphMemoryAssigner::UpdateOpInputOffset(const NodePtr &node, vector< peer_out_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_out_anchor->GetIdx(), input_list.back()); } else { + int64_t output_offset = output_list.at(peer_out_anchor->GetIdx()); + if (peer_out_anchor->GetOwnerNode()->GetType() == CONSTANT) { + GeTensorDesc tensor_desc = tmp_op_desc->GetInputDesc(input_index); + GE_CHK_STATUS(TensorUtils::GetDataOffset(tensor_desc, output_offset)); + } + GELOGI("node[%s] input[%d] is set from node[%s] out index[%d] offset[%ld]", tmp_op_desc->GetName().c_str(), input_index, peer_out_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_out_anchor->GetIdx(), - output_list.at(peer_out_anchor->GetIdx())); - input_list.emplace_back(output_list.at(peer_out_anchor->GetIdx())); + output_offset); + input_list.emplace_back(output_offset); } } } diff --git a/src/ge/graph/build/memory/var_mem_assign_util.cc b/src/ge/graph/build/memory/var_mem_assign_util.cc index 111adc7a..a352cf65 100644 --- a/src/ge/graph/build/memory/var_mem_assign_util.cc +++ b/src/ge/graph/build/memory/var_mem_assign_util.cc @@ -299,21 +299,33 @@ Status VarMemAssignUtil::SetOutTransNodeToAssign(const ge::NodePtr &node, const Status VarMemAssignUtil::AssignMemory2HasRefAttrNode(ge::ComputeGraphPtr &compute_graph) { for (const ge::NodePtr &n : compute_graph->GetAllNodes()) { string ref_var_src_var_name; - GE_CHECK_NOTNULL(n->GetOpDesc()); - bool is_ref = ge::AttrUtils::GetStr(n->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); - GE_IF_BOOL_EXEC(is_ref, - GE_CHK_STATUS_RET(AssignData2VarRef(n, ref_var_src_var_name, compute_graph->GetSessionID()))); + auto op_desc = n->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (uint32_t idx = 0; idx < op_desc->GetOutputsSize(); idx += 1) { + const auto out_desc = op_desc->MutableOutputDesc(idx); + if (ge::AttrUtils::GetStr(out_desc, REF_VAR_SRC_VAR_NAME, ref_var_src_var_name)) { + GE_CHK_STATUS_RET(AssignData2VarRef(n, ref_var_src_var_name, compute_graph->GetSessionID(), idx)); + } + } } return SUCCESS; } Status VarMemAssignUtil::AssignData2VarRef(const ge::NodePtr &has_ref_attr_node, const string &src_var_name, - uint64_t session_id) { - if (!TransOpUtil::IsTransOp(has_ref_attr_node)) { - return SUCCESS; - } + uint64_t session_id, uint32_t out_index) { // Get ref_var_src_var address - ge::NodePtr var_ref_src_var = has_ref_attr_node->GetOwnerComputeGraph()->FindNode(src_var_name); + auto root_graph = GraphUtils::FindRootGraph(has_ref_attr_node->GetOwnerComputeGraph()); + GE_CHECK_NOTNULL(root_graph); + ge::NodePtr var_ref_src_var = root_graph->FindNode(src_var_name); + if (var_ref_src_var == nullptr) { + for (auto sub_graph : root_graph->GetAllSubgraphs()) { + auto node_ptr = sub_graph->FindNode(src_var_name); + if (node_ptr != nullptr) { + var_ref_src_var = node_ptr; + break; + } + } + } GE_IF_BOOL_EXEC(var_ref_src_var == nullptr || var_ref_src_var->GetOpDesc() == nullptr, return FAILED); GeTensorDesc src_tensor_desc = var_ref_src_var->GetOpDesc()->GetOutputDesc(0); uint8_t *dev_ptr = nullptr; @@ -322,14 +334,8 @@ Status VarMemAssignUtil::AssignData2VarRef(const ge::NodePtr &has_ref_attr_node, vector ref_attr_node_output_list = has_ref_attr_node->GetOpDesc()->GetOutputOffset(); GE_CHECK_SIZE(ref_attr_node_output_list.size()); - int out_index = 0; - bool is_get = ge::AttrUtils::GetInt(var_ref_src_var->GetOpDesc(), REF_VAR_PRE_PEER_OUT_INDEX, out_index); - if (!is_get) { - GELOGI("%s failed to get attr [REF_VAR_PRE_PEER_OUT_INDEX]", var_ref_src_var->GetName().c_str()); - } - - GE_CHK_BOOL_RET_STATUS(static_cast(out_index) < ref_attr_node_output_list.size(), FAILED, - "out_index %d >= ref_attr_node_output_list.size() %zu", out_index, + GE_CHK_BOOL_RET_STATUS(out_index < ref_attr_node_output_list.size(), FAILED, + "out_index %u >= ref_attr_node_output_list.size() %zu", out_index, ref_attr_node_output_list.size()); ref_attr_node_output_list[out_index] = static_cast(reinterpret_cast(dev_ptr)); diff --git a/src/ge/graph/build/memory/var_mem_assign_util.h b/src/ge/graph/build/memory/var_mem_assign_util.h index 036fed06..cb38af29 100644 --- a/src/ge/graph/build/memory/var_mem_assign_util.h +++ b/src/ge/graph/build/memory/var_mem_assign_util.h @@ -46,8 +46,8 @@ class VarMemAssignUtil { static Status DealTransNode(const ge::NodePtr &final_trans_node); static Status DealExportTransNode(const ge::NodePtr &node, const ge::NodePtr &final_trans_node); - static Status AssignData2VarRef(const ge::NodePtr &variable_ref, const std::string &src_var_name, - uint64_t session_id); + static Status AssignData2VarRef(const ge::NodePtr &variable_ref, const std::string &src_var_name, uint64_t session_id, + uint32_t out_index); static Status SetOutTransNodeToAssign(const ge::NodePtr &node, const ge::NodePtr &final_trans_node, size_t index); }; diff --git a/src/ge/graph/build/model_builder.cc b/src/ge/graph/build/model_builder.cc index 62abd4ab..a765d8e7 100644 --- a/src/ge/graph/build/model_builder.cc +++ b/src/ge/graph/build/model_builder.cc @@ -15,10 +15,10 @@ */ #include "graph/build/model_builder.h" +#include #include #include #include -#include #include "common/ge/ge_util.h" #include "framework/common/debug/ge_log.h" #include "graph/anchor.h" @@ -27,6 +27,7 @@ #include "graph/build/label_allocator.h" #include "graph/build/stream_allocator.h" #include "graph/common/omg_util.h" +#include "graph/common/ge_call_wrapper.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_attr_value.h" #include "graph/ge_context.h" @@ -85,9 +86,11 @@ bool IsGeLocalOp(const ge::ConstOpDescPtr &op_desc) { } // namespace namespace ge { -ModelBuilder::ModelBuilder(ge::ComputeGraphPtr compute_graph, const Graph2SubGraphInfoList &subgraphs, - const map &stream_max_parallel_num, bool hcom_parallel, int mode) - : mem_offset_(0), +ModelBuilder::ModelBuilder(uint64_t session_id, ge::ComputeGraphPtr compute_graph, + const Graph2SubGraphInfoList &subgraphs, const map &stream_max_parallel_num, + bool hcom_parallel, int mode) + : session_id_(session_id), + mem_offset_(0), weight_offset_(kWeightsStartOffset), compute_graph_(std::move(compute_graph)), subgraphs_(subgraphs), @@ -242,7 +245,7 @@ Status ModelBuilder::SetInputOutputDesc() { Status ret; GELOGI("Start to SetInputOutputDesc."); - for (const ge::NodePtr &n : compute_graph_->GetAllNodes()) { + for (const ge::NodePtr &n : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { auto node_op_desc = n->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); @@ -291,7 +294,7 @@ Status ModelBuilder::SetInputOutputDesc() { } void ModelBuilder::AddNodeInputProperty() { - for (const ge::NodePtr &node : compute_graph_->GetAllNodes()) { + for (const ge::NodePtr &node : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { auto node_op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, GELOGW("node_op_desc is nullptr!"); return ); vector src_name_list; @@ -318,7 +321,7 @@ void ModelBuilder::AddNodeInputProperty() { node_op_desc->SetSrcIndex(src_index_list); } - for (const ge::NodePtr &node : compute_graph_->GetAllNodes()) { + for (const ge::NodePtr &node : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { auto node_op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, GELOGW("node_op_desc is nullptr!"); return ); GE_IF_BOOL_EXEC(node_op_desc->GetType() == NETOUTPUT, continue); @@ -356,7 +359,7 @@ void ModelBuilder::AddNodeInputProperty() { Status ModelBuilder::AdjustInputTensorFlag() { GELOGI("Start to AdjustInputTensorFlag."); - for (const ge::NodePtr &n : compute_graph_->GetAllNodes()) { + for (const ge::NodePtr &n : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { if ((n->GetType() == DATA_TYPE) || (n->GetType() == AIPP_DATA_TYPE)) { GELOGD("Data node: %s.", n->GetName().c_str()); for (const auto &anchor : n->GetAllOutDataAnchors()) { @@ -432,6 +435,21 @@ Status ModelBuilder::BuildModelDef(ge::Model &model) { GE_CHK_BOOL_EXEC(ge::AttrUtils::SetBool(&model, ATTR_NAME_SWITCH_FOR_L1_FUSION, is_l1_fusion_enable_), GELOGE(FAILED, "SetBool of ATTR_NAME_SWITCH_FOR_L1_FUSION failed."); return FAILED); + const DumpProperties &dump_properties = PropertiesManager::Instance().GetDumpProperties(session_id_); + bool is_op_debug = dump_properties.IsOpDebugOpen(); + GELOGI("Get op debug:%d", is_op_debug); + if (is_op_debug) { + if (!ge::AttrUtils::SetBool(&model, ATTR_OP_DEBUG_FLAG, is_op_debug)) { + GELOGE(FAILED, "SetBool of ATTR_OP_DEBUG_FLAG failed."); + return FAILED; + } + uint32_t op_debug_mode = dump_properties.GetOpDebugMode(); + GELOGI("Get op debug mode:%d", op_debug_mode); + if (!ge::AttrUtils::SetInt(&model, ATTR_OP_DEBUG_MODE, op_debug_mode)) { + GELOGE(FAILED, "SetBool of ATTR_OP_DEBUG_MODE failed."); + return FAILED; + } + } model.SetName(compute_graph_->GetName()); model.SetGraph(ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph_)); @@ -448,7 +466,7 @@ Status ModelBuilder::BuildModelDef(ge::Model &model) { } void ModelBuilder::ClearOriginalFormat() { - for (const ge::NodePtr &n : compute_graph_->GetAllNodes()) { + for (const ge::NodePtr &n : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { auto node_op_desc = n->GetOpDesc(); if (node_op_desc != nullptr) { if (node_op_desc->HasAttr(ATTR_NAME_FORMAT)) { @@ -487,7 +505,7 @@ Status ModelBuilder::MergeWeights() { weight_buffer_ = buffer; auto base_addr = weight_buffer_.GetData(); - for (const ge::NodePtr &node : compute_graph_->GetAllNodes()) { + for (const ge::NodePtr &node : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { auto op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(op_desc == nullptr, continue); if (node->GetType() != CONSTANT) { @@ -527,8 +545,8 @@ Status ModelBuilder::MergeWeights() { weight_data.size()); return FAILED; } - uintptr_t dst_ptr = (uintptr_t)base_addr + offset; - uintptr_t src_ptr = (uintptr_t)weight_data.data(); + uintptr_t dst_ptr = reinterpret_cast(base_addr) + offset; + uintptr_t src_ptr = reinterpret_cast(weight_data.data()); size_t left_size = weight_data.size(); while (left_size > SECUREC_MEM_MAX_LEN) { auto err = memcpy_s(reinterpret_cast(dst_ptr), SECUREC_MEM_MAX_LEN, reinterpret_cast(src_ptr), @@ -565,7 +583,7 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) { // Add TBE Kernels std::set name_set; - for (const ge::NodePtr &n : compute_graph_->GetAllNodes()) { + for (const ge::NodePtr &n : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { auto node_op_desc = n->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); @@ -659,7 +677,7 @@ Status ModelBuilder::BuildModelForGetTask(ge::Model &model) { // Compile single op in graph build stage GE_TIMESTAMP_START(CompileSingleOp); GE_CHK_STATUS_RET(CompileSingleOp(), "ATC builder CompileSingleOp() return fail."); - GE_TIMESTAMP_END(CompileSingleOp, "GraphBuilder::CompileSingleOp"); + GE_TIMESTAMP_EVENT_END(CompileSingleOp, "GraphBuilder::CompileSingleOp"); // Refresh real streams and insert event nodes. GE_TIMESTAMP_START(RefreshRealStream); @@ -700,7 +718,7 @@ Status ModelBuilder::CompileSingleOp() { GE_TIMESTAMP_CALLNUM_START(BatchCompileOp); std::unordered_map> node_vector_map; - for (auto &node : compute_graph_->GetAllNodes()) { + for (auto &node : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { auto op_desc = node->GetOpDesc(); if (op_desc == nullptr) { continue; @@ -737,7 +755,7 @@ Status ModelBuilder::CompileSingleOp() { GE_CHECK_NOTNULL(kernel_info); GE_TIMESTAMP_RESTART(BatchCompileOp); auto ret = kernel_info->CompileOp(node_vector); - GEEVENT("[GEPERFTRACE] The node size of compile op of %s is %zu", kernel_lib_name.c_str(), node_vector.size()); + GELOGI("[GEPERFTRACE] The node size of compile op of %s is %zu", kernel_lib_name.c_str(), node_vector.size()); GE_TIMESTAMP_ADD(BatchCompileOp); if (ret != ge::SUCCESS) { GELOGE(ret, "Compile op failed, kernel lib name is %s", kernel_lib_name.c_str()); diff --git a/src/ge/graph/build/model_builder.h b/src/ge/graph/build/model_builder.h index 21e611ee..86b34c6d 100644 --- a/src/ge/graph/build/model_builder.h +++ b/src/ge/graph/build/model_builder.h @@ -37,7 +37,7 @@ namespace ge { class ModelBuilder { public: - ModelBuilder(ge::ComputeGraphPtr whole_graph, const Graph2SubGraphInfoList &subgraphs, + ModelBuilder(uint64_t session_id, ge::ComputeGraphPtr whole_graph, const Graph2SubGraphInfoList &subgraphs, const std::map &stream_max_parallel_num, bool hcom_parallel, int mode = static_cast(domi::BuildMode::GEN_TASK_WITHOUT_FUSION)); @@ -82,6 +82,8 @@ class ModelBuilder { Status CompileSingleOp(); + uint64_t session_id_; + size_t mem_offset_; size_t weight_offset_; diff --git a/src/ge/graph/build/run_context.cc b/src/ge/graph/build/run_context.cc index f2a41271..cece31ea 100644 --- a/src/ge/graph/build/run_context.cc +++ b/src/ge/graph/build/run_context.cc @@ -173,5 +173,4 @@ Status RunContextUtil::CreateRunContext(Model &model, const ComputeGraphPtr &gra } RunContext &RunContextUtil::GetRunContext() { return run_context_; } - } // namespace ge diff --git a/src/ge/graph/build/stream_allocator.cc b/src/ge/graph/build/stream_allocator.cc index f6323434..d49bb61b 100644 --- a/src/ge/graph/build/stream_allocator.cc +++ b/src/ge/graph/build/stream_allocator.cc @@ -146,12 +146,6 @@ Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_nu return status; } - status = AddActiveEntryStream(); - if (status != SUCCESS) { - GELOGE(status, "AddActiveEntryStream failed!"); - return status; - } - status = RefreshContinuousEvents(); if (status != SUCCESS) { GELOGE(status, "RefreshContinuousEvents failed!"); @@ -167,7 +161,7 @@ Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_nu DumpEvents(); GE_DUMP(whole_graph_, "AfterRefreshRealStream"); - for (const NodePtr &node : whole_graph_->GetAllNodes()) { + for (const NodePtr &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { GE_CHECK_NOTNULL(node->GetOpDesc()); auto stream_id = node->GetOpDesc()->GetStreamId(); if (stream_id == kInvalidStream) { @@ -199,7 +193,7 @@ Status StreamAllocator::AssignSingleStream() { } int64_t task_count = 0; - for (const NodePtr &node : whole_graph_->GetAllNodes()) { + for (const NodePtr &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { string op_type = node->GetType(); if (IsHcclOp(op_type)) { task_count += kTaskNumPerHcclNode; @@ -236,7 +230,7 @@ Status StreamAllocator::AssignSingleStream() { } Status StreamAllocator::SetActiveStreamsByLabel() { - for (const auto &node : whole_graph_->GetAllNodes()) { + for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); string stream_label; @@ -248,7 +242,7 @@ Status StreamAllocator::SetActiveStreamsByLabel() { } } - for (const auto &node : whole_graph_->GetAllNodes()) { + for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { GE_CHECK_NOTNULL(node->GetOpDesc()); vector activated_label_list; if (!AttrUtils::GetListStr(node->GetOpDesc(), ATTR_NAME_ACTIVE_LABEL_LIST, activated_label_list) || @@ -326,7 +320,7 @@ Status StreamAllocator::SetActiveStreamsForSubgraphs() { // Insert the send/recv event id to the graph Status StreamAllocator::InsertSyncEvents() { - for (const auto &cur_node : whole_graph_->GetAllNodes()) { + for (const auto &cur_node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { // Take the adjacent points, then judge whether need to insert the event for (const OutDataAnchorPtr &anchor : cur_node->GetAllOutDataAnchors()) { for (const InDataAnchorPtr &peer_in_anchor : anchor->GetPeerInDataAnchors()) { @@ -380,6 +374,11 @@ Status StreamAllocator::InsertOneEventInTwoNodes(const NodePtr &cur_node, const return SUCCESS; } + if ((cur_node->GetType() == ENTER) || (cur_node->GetType() == REFENTER)) { + GELOGD("No need to insert event after enter_node %s.", cur_node->GetName().c_str()); + return SUCCESS; + } + if (next_stream_id == kInvalidStream) { GELOGE(FAILED, "Stream id of next_node %s should not be %ld", next_node->GetName().c_str(), kInvalidStream); return FAILED; @@ -446,7 +445,7 @@ Status StreamAllocator::InsertEventsForSubgraph() { Status StreamAllocator::OptimizeSyncEvents() { map> stream_nodes; - for (const auto &node : whole_graph_->GetAllNodes()) { + for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { GE_CHECK_NOTNULL(node->GetOpDesc()); int64_t stream_id = node->GetOpDesc()->GetStreamId(); stream_nodes[stream_id].emplace_back(node); @@ -671,7 +670,7 @@ Status StreamAllocator::SplitStreams(vector> &split_streams) { GE_CHK_STATUS_RET(GetMaxStreamAndTask(false, max_stream_count, max_task_count), "Get max stream and task count failed."); - for (const auto &cur_node : whole_graph_->GetAllNodes()) { + for (const auto &cur_node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { GE_CHECK_NOTNULL(cur_node); auto op_desc = cur_node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -774,42 +773,23 @@ bool StreamAllocator::NeedSpiltNewStream(int64_t stream_node_num, int64_t max_no Status StreamAllocator::UpdateActiveStreams(const vector> &split_streams) { UpdateLabelStreams(split_streams); - for (auto &node : whole_graph_->GetAllNodes()) { + for (auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { if ((node->GetType() == STREAMSWITCH) || (node->GetType() == STREAMSWITCHN)) { - if (InsertActiveNodesAfterSwitch(node) != SUCCESS) { - GELOGE(FAILED, "Insert active nodes after switch node failed."); + if (UpdateActiveStreamsForSwitchNode(node) != SUCCESS) { + GELOGE(FAILED, "Update active streams for switch node: %s failed.", node->GetName().c_str()); return FAILED; } } else { - vector active_streams; - GE_CHECK_NOTNULL(node->GetOpDesc()); - if (AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { - vector new_active_streams = active_streams; - for (const uint32_t logical_stream : active_streams) { - if (static_cast(logical_stream) >= split_streams.size()) { - GELOGE(FAILED, "logical stream is out of range."); - return FAILED; - } - const set &new_split_streams = split_streams[logical_stream]; - if (!new_split_streams.empty()) { - for (int64_t split_stream : new_split_streams) { - new_active_streams.emplace_back(static_cast(split_stream)); - GELOGI("Add stream %ld to active_stream_list of node %s of graph %s", split_stream, - node->GetName().c_str(), node->GetOwnerComputeGraph()->GetName().c_str()); - } - } - } - if (!AttrUtils::SetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, new_active_streams)) { - GELOGE(FAILED, "Set active streams for node %s failed.", node->GetName().c_str()); - return FAILED; - } + if (UpdateActiveStreamsForActiveNode(split_streams, node) != SUCCESS) { + GELOGE(FAILED, "Update active streams for active node: %s failed.", node->GetName().c_str()); + return FAILED; } } } Status status = UpdateActiveStreamsForSubgraphs(); if (status != SUCCESS) { - GELOGE(status, "SetActiveStreamsForSubgraph failed!"); + GELOGE(status, "Update active streams for subgraphs failed!"); return status; } @@ -840,7 +820,7 @@ void StreamAllocator::UpdateLabelStreams(const vector> &split_strea } } -Status StreamAllocator::InsertActiveNodesAfterSwitch(NodePtr &switch_node) { +Status StreamAllocator::UpdateActiveStreamsForSwitchNode(NodePtr &switch_node) { vector active_nodes; if (InsertActiveNodesAfterSwitch(switch_node, active_nodes) != SUCCESS) { GELOGE(FAILED, "Insert active nodes after node %s failed.", switch_node->GetName().c_str()); @@ -906,6 +886,38 @@ Status StreamAllocator::InsertActiveNodesAfterSwitch(NodePtr &switch_node, vecto return SUCCESS; } +Status StreamAllocator::UpdateActiveStreamsForActiveNode(const vector> &split_streams, NodePtr &node) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + vector active_streams; + if (AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { + vector new_active_streams = active_streams; + for (uint32_t logical_stream : active_streams) { + if (static_cast(logical_stream) >= split_streams.size()) { + GELOGE(FAILED, "logical stream is out of range."); + return FAILED; + } + const set &new_split_streams = split_streams[logical_stream]; + for (int64_t split_stream : new_split_streams) { + for (const auto &node_stream : node_split_stream_map_) { + if (split_stream == node_stream.second) { + if (node_stream.first->GetOwnerComputeGraph() == node->GetOwnerComputeGraph()) { + new_active_streams.emplace_back(static_cast(split_stream)); + GELOGI("Add stream %ld to active_stream_list of node %s of graph %s", split_stream, + node->GetName().c_str(), node->GetOwnerComputeGraph()->GetName().c_str()); + } + break; + } + } + } + } + if (!AttrUtils::SetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, new_active_streams)) { + GELOGE(FAILED, "Set active streams for node %s failed.", node->GetName().c_str()); + return FAILED; + } + } + return SUCCESS; +} + Status StreamAllocator::UpdateActiveStreamsForSubgraphs() const { // Update active stream list for active nodes for (auto &node_stream_pair : node_split_stream_map_) { @@ -926,14 +938,19 @@ Status StreamAllocator::UpdateActiveStreamsForSubgraphs() const { } const auto &active_node = it->second; GE_CHECK_NOTNULL(active_node); - auto op_desc = active_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); + auto active_op = active_node->GetOpDesc(); + GE_CHECK_NOTNULL(active_op); vector active_streams; - (void)AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams); + (void)AttrUtils::GetListInt(active_op, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams); set new_active_streams(active_streams.begin(), active_streams.end()); - new_active_streams.emplace(static_cast(node_stream_pair.second)); + // specific_activated_streams_ has already contained new split activated stream + int64_t new_split_stream = node_stream_pair.second; + if (IsActivated(new_split_stream)) { + continue; + } + new_active_streams.emplace(static_cast(new_split_stream)); active_streams.assign(new_active_streams.begin(), new_active_streams.end()); - if (!AttrUtils::SetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { + if (!AttrUtils::SetListInt(active_op, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { GELOGE(FAILED, "Set active streams for node %s failed.", active_node->GetName().c_str()); return FAILED; } @@ -942,6 +959,20 @@ Status StreamAllocator::UpdateActiveStreamsForSubgraphs() const { return SUCCESS; } +bool StreamAllocator::IsActivated(int64_t stream_id) const { + for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { + auto op_desc = node->GetOpDesc(); + vector active_streams; + if (op_desc == nullptr || !AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { + continue; + } + if (std::find(active_streams.begin(), active_streams.end(), stream_id) != active_streams.end()) { + return true; + } + } + return false; +} + Status StreamAllocator::SetActiveStreamsForLoop() { vector loop_active_streams; for (int64_t stream_id = 0; stream_id < stream_num_; stream_id++) { @@ -950,7 +981,7 @@ Status StreamAllocator::SetActiveStreamsForLoop() { } } // Set the stream that needs to be activated - for (const auto &node : whole_graph_->GetAllNodes()) { + for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { GE_CHECK_NOTNULL(node->GetOpDesc()); bool is_loop_active = false; if (AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, is_loop_active) && is_loop_active) { @@ -973,7 +1004,7 @@ Status StreamAllocator::SetActiveStreamsForLoop() { } Status StreamAllocator::CheckStreamActived() const { - for (const auto &node : whole_graph_->GetAllNodes()) { + for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { GE_CHECK_NOTNULL(node->GetOpDesc()); vector active_streams; if (AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { @@ -989,108 +1020,6 @@ Status StreamAllocator::CheckStreamActived() const { return SUCCESS; } -// Add active entry stream for special env. -Status StreamAllocator::AddActiveEntryStream() { - auto gelib = GELib::GetInstance(); - bool head_stream = (gelib == nullptr) ? false : gelib->HeadStream(); - GELOGI("Configured head stream: %u", head_stream); - if (!head_stream) { - return SUCCESS; - } - - // Collect streams active by StreamSwitch/StreamActive node. - std::set deactive_stream; - for (ge::NodePtr &node : whole_graph_->GetAllNodes()) { - GE_CHECK_NOTNULL(node->GetOpDesc()); - Status ret = CollectDeactiveStream(node->GetOpDesc(), deactive_stream); - if (ret != SUCCESS) { - return ret; - } - } - - // Collect default active stream, Add to active entry stream. - std::vector active_stream_list; - for (int64_t stream_id = 0; stream_id < stream_num_; ++stream_id) { - if (deactive_stream.count(stream_id) == 0) { - active_stream_list.push_back(stream_id); - } - } - - int64_t new_stream_id = stream_num_; - stream_num_++; - return InsertActiveEntryStream(active_stream_list, new_stream_id); -} - -// Collect deactive stream from flowctrl op. -Status StreamAllocator::CollectDeactiveStream(const OpDescPtr &op_desc, std::set &deactive_streams) const { - GE_CHECK_NOTNULL(op_desc); - std::string op_type = op_desc->GetType(); - if (op_type == STREAMSWITCH) { - std::vector active_stream_list; - // If GetListInt fail, active_stream_list is empty. - (void)ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list); - if (active_stream_list.size() != kMaxSwitchStreamNum) { - GELOGE(INTERNAL_ERROR, "Stream num of switch true branch must be %u.", kMaxSwitchStreamNum); - return INTERNAL_ERROR; - } - - deactive_streams.insert(active_stream_list[0]); - GELOGI("Flowctrl_op node:%s, flowctrl stream id:%u.", op_desc->GetName().c_str(), active_stream_list[0]); - } else if (op_type == STREAMACTIVE) { - if (op_desc->HasAttr(ATTR_NAME_SWITCH_BRANCH_NODE_LABEL)) { - std::vector active_stream_list; - if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list)) { - GELOGE(INTERNAL_ERROR, "StreamActiveOp get attr ACTIVE_STREAM fail."); - return INTERNAL_ERROR; - } - - for (uint32_t deactive_stream : active_stream_list) { - deactive_streams.insert(deactive_stream); - GELOGI("Flowctrl_op node:%s, flowctrl stream id:%u.", op_desc->GetName().c_str(), deactive_stream); - } - } - } - - return SUCCESS; -} - -// Insert StreamActive Op for Entry Stream. -Status StreamAllocator::InsertActiveEntryStream(const std::vector &active_streams, int64_t stream_id) { - string node_name = whole_graph_->GetName() + "_ActiveEntryStream_" + string(STREAMACTIVE); - OpDescPtr op_desc = ge::MakeShared(node_name, STREAMACTIVE); - if (op_desc == nullptr) { - GELOGE(FAILED, "Failed to new opdesc."); - return FAILED; - } - GELOGI("Create StreamActive op:%s.", op_desc->GetName().c_str()); - - GE_CHK_BOOL_EXEC( - AttrUtils::SetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, std::move(std::vector())), - GELOGE(FAILED, "SetListStr failed."); - return FAILED); - - NodePtr active_node = whole_graph_->AddNodeFront(op_desc); - GE_IF_BOOL_EXEC(active_node == nullptr, - GELOGE(FAILED, "Create StreamActive op: %s failed.", op_desc->GetName().c_str()); - return INTERNAL_ERROR); - GE_CHECK_NOTNULL(active_node->GetOpDesc()); - // Add one stream for ActiveEntryStream Task. - active_node->GetOpDesc()->SetStreamId(stream_id); - - GE_CHK_BOOL_EXEC(AttrUtils::SetBool(op_desc, "is_aicpu_stream", true), GELOGE(FAILED, "SetBool failed."); - return FAILED); - GE_CHK_BOOL_EXEC(AttrUtils::SetListInt(active_node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams), - GELOGE(FAILED, "SetListInt failed."); - return FAILED); - - std::vector group_names; - GE_CHK_BOOL_EXEC(AttrUtils::SetListStr(active_node->GetOpDesc(), ATTR_NAME_SWITCH_BRANCH_NODE_LABEL, group_names), - GELOGE(FAILED, "SetLisStr failed."); - return FAILED); - - return SUCCESS; -} - // Refresh events to continuous events Status StreamAllocator::RefreshContinuousEvents() { // Establish a mapping relationship from old to new event id @@ -1136,7 +1065,7 @@ Status StreamAllocator::RefreshContinuousEvents() { // Insert the real send/recv node in the graph Status StreamAllocator::InsertSyncEventNodes() { - for (const auto &node : whole_graph_->GetAllNodes()) { + for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { // Add the node corresponding to the recv event vector recv_event_id_list; GetRecvEventIdList(node, recv_event_id_list); @@ -1223,7 +1152,7 @@ Status StreamAllocator::ReorderEventNodes() const { void StreamAllocator::DumpEvents() { map> after_refresh_stream_nodes; - for (const auto &node : whole_graph_->GetAllNodes()) { + for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); int64_t stream_id = node->GetOpDesc()->GetStreamId(); after_refresh_stream_nodes[stream_id].emplace_back(node); diff --git a/src/ge/graph/build/stream_allocator.h b/src/ge/graph/build/stream_allocator.h index ae79430a..a201a138 100644 --- a/src/ge/graph/build/stream_allocator.h +++ b/src/ge/graph/build/stream_allocator.h @@ -59,18 +59,16 @@ class StreamAllocator { Status SplitStreams(std::vector> &split_streams); bool NeedSpiltNewStream(int64_t stream_node_num, int64_t max_node_num_one_stream, const OpDescPtr &op_desc) const; - Status UpdateActiveStreams(const std::vector> &splited_streams); + Status UpdateActiveStreams(const std::vector> &split_streams); void UpdateLabelStreams(const std::vector> &split_streams); - Status InsertActiveNodesAfterSwitch(NodePtr &switch_node); + Status UpdateActiveStreamsForSwitchNode(NodePtr &switch_node); Status InsertActiveNodesAfterSwitch(NodePtr &switch_nodes, std::vector &switch_active_nodes); + Status UpdateActiveStreamsForActiveNode(const std::vector> &split_streams, NodePtr &node); Status UpdateActiveStreamsForSubgraphs() const; + bool IsActivated(int64_t stream_id) const; Status SetActiveStreamsForLoop(); Status CheckStreamActived() const; - Status AddActiveEntryStream(); - Status CollectDeactiveStream(const OpDescPtr &op_desc, std::set &deactive_streams) const; - Status InsertActiveEntryStream(const std::vector &active_streams, int64_t stream_id); - Status RefreshContinuousEvents(); Status InsertSyncEventNodes(); diff --git a/src/ge/graph/build/task_generator.cc b/src/ge/graph/build/task_generator.cc index 2ce4e89d..41a845a2 100644 --- a/src/ge/graph/build/task_generator.cc +++ b/src/ge/graph/build/task_generator.cc @@ -29,6 +29,7 @@ #include "graph/utils/node_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" +#include "graph/common/ge_call_wrapper.h" #include "init/gelib.h" using domi::LogTimeStampDef; @@ -47,7 +48,6 @@ const char *const kIsOutputVar = "OUTPUT_IS_VAR"; const char *const kProfilingMode = "PROFILING_MODE"; const char *const kProfilingFpPoint = "FP_POINT"; const char *const kProfilingBpPoint = "BP_POINT"; -const char *const kOffOptimize = "off_optimize"; const uint32_t kProfilingArStep = 2; const uint64_t kProfilingFpStartLogid = 1; const uint64_t kProfilingBpEndLogid = 2; @@ -75,21 +75,7 @@ Status TaskGenerator::GetTaskInfo(Model &model, ComputeGraphPtr &graph, uint64_t std::vector task_def_list; std::map op_name_map; GE_DUMP(graph, "GenerateTaskBefore"); - bool is_unknown_shape = false; - NodePtr parent_node = graph->GetParentNode(); - if (parent_node != nullptr) { - auto op_desc = parent_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - (void)AttrUtils::GetBool(op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape); - } - Status ret = SUCCESS; - if (is_unknown_shape) { - GELOGI("Beign to generate unknown shape task. Graph name is %s.", graph->GetName().c_str()); - ret = GenerateUnknownShapeTask(run_context, graph, task_def_list, op_name_map); - } else { - GELOGI("Beign to generate known shape task. Graph name is %s.", graph->GetName().c_str()); - ret = GenerateTask(run_context, graph, task_def_list, op_name_map); - } + Status ret = GenerateTask(run_context, graph, task_def_list, op_name_map); GE_DUMP(graph, "GenerateTaskAfter"); if (ret != SUCCESS) { @@ -109,7 +95,7 @@ Status TaskGenerator::GetTaskInfo(Model &model, ComputeGraphPtr &graph, uint64_t GELOGE(FAILED, "SetListStr failed."); return FAILED); - GELOGI("Generate task success, task_def_list.size:%zu, op_name_map.size:%zu", task_def_list.size(), + GELOGI("Call GenerateTask Success, task_def_list.size:%zu, op_name_map.size:%zu", task_def_list.size(), op_name_map.size()); // Init and serialize model_task_def @@ -131,7 +117,7 @@ Status TaskGenerator::GetTaskInfo(Model &model, ComputeGraphPtr &graph, uint64_t return ret; } - GELOGI("Get TaskInfo success. session id is %lu", session_id); + GELOGI("Get TaskInfo success. session_id=%lu", session_id); return SUCCESS; } @@ -198,7 +184,7 @@ Status TaskGenerator::UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t sessi Status TaskGenerator::SaveFusionNodes(map> &fusion_nodes, ComputeGraphPtr &graph) { std::map nodes_with_group_attr; - for (auto &node : graph->GetAllNodes()) { + for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); int64_t group_id = kInvalidGroupId; @@ -249,12 +235,13 @@ Status TaskGenerator::SaveFusionNodes(map> &fusion Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &graph, vector &task_def_list, map &op_name_map) { + GELOGD("Beign to generate task, graph name is %s.", graph->GetName().c_str()); std::shared_ptr ge_lib = GELib::GetInstance(); if ((ge_lib == nullptr) || !ge_lib->InitFlag()) { GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GenerateTask failed."); return GE_CLI_GE_NOT_INITIALIZED; } - GE_CHK_STATUS_RET(MarkNodeAndSetIndex(graph), "Mark node and set index failed."); + GE_CHK_STATUS_RET(MarkNodeAndSetIndex(graph), "MarkNodeAndSetIndex failed."); ProfilingPoint profiling_point; vector all_reduce_nodes; GE_CHK_STATUS_RET(FindProfilingTaskIndex(graph, profiling_point, all_reduce_nodes)); @@ -264,15 +251,21 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra GE_TIMESTAMP_CALLNUM_START(GenerateTask); // map store fusion nodes map> fusion_nodes; - string buffer_optimize = kOffOptimize; + string buffer_optimize = "off_optimize"; (void)ge::GetContext().GetOption(BUFFER_OPTIMIZE, buffer_optimize); - if (buffer_optimize != kOffOptimize) { + if (buffer_optimize != "off_optimize") { GE_CHK_STATUS_RET(SaveFusionNodes(fusion_nodes, graph)); } std::unordered_set fusion_nodes_seen; int64_t group_key; uint32_t node_index = 0; - for (auto &node : graph->GetAllNodes()) { + rtStream_t stream = nullptr; + bool is_unknown_shape = graph->GetGraphUnknownFlag(); + if (is_unknown_shape) { + GE_CHK_STATUS_RET(SetUnknownShapeStream(run_context, stream), "Set unknown shape stream failed."); + } + + for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); node_index++; @@ -302,7 +295,6 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra GELOGI("Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); continue; } - OpsKernelInfoStorePtr kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); if (kernel_info_store == nullptr) { GELOGE(INTERNAL_ERROR, "No ops kernel store found. node:%s(%s), op_kernel_lib_name=%s.", name.c_str(), @@ -311,18 +303,17 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra } GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "Call UpdateAnchorStatus node:%s(%s) failed", name.c_str(), type.c_str()); - int64_t op_id = op_desc->GetId(); - int64_t stream_id = op_desc->GetStreamId(); - if (stream_id < 0 || stream_id >= static_cast(run_context.graphStreamList.size())) { - GELOGE(INTERNAL_ERROR, "node[name:%s(%s), id:%ld] stream id is invalid, stream list size=%zu", name.c_str(), - type.c_str(), op_id, run_context.graphStreamList.size()); - return INTERNAL_ERROR; - } - // Profiling task size_t task_list_size_before = task_def_list.size(); GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); - run_context.stream = run_context.graphStreamList[stream_id]; + int64_t op_id = op_desc->GetId(); + // Compatible with dynamic shape scenes, the default is 0 + int64_t stream_id = 0; + if (!is_unknown_shape) { + stream_id = op_desc->GetStreamId(); + GE_CHK_STATUS_RET(SetKnownShapeStream(run_context, stream_id), "node[name:%s(%s), id:%ld] stream id is invalid.", + name.c_str(), type.c_str(), op_id); + } GELOGD("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task.", op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id); GE_TIMESTAMP_RESTART(GenerateTask); @@ -355,131 +346,14 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra GE_CHECK_NOTNULL(task_def_ptr); task_def_ptr->set_ops_kernel_store_ptr(reinterpret_cast(ops_kernel_info_store_ptr)); } - GELOGD("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task finished, generate %zu task(s).", op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id, task_list_size_after - task_list_size_before); } - GE_TIMESTAMP_CALLNUM_END(GenerateTask, "GraphBuild::GenerateTask"); - return SUCCESS; -} - -Status TaskGenerator::GenerateUnknownShapeTask(RunContext &run_context, ComputeGraphPtr &graph, - vector &task_def_list, - map &op_name_map) { - std::shared_ptr ge_lib = GELib::GetInstance(); - if ((ge_lib == nullptr) || !ge_lib->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GenerateTask failed."); - return GE_CLI_GE_NOT_INITIALIZED; - } - GE_CHK_STATUS_RET(MarkNodeAndSetIndex(graph), "Mark node and set index failed."); - ProfilingPoint profiling_point; - vector all_reduce_nodes; - GE_CHK_STATUS_RET(FindProfilingTaskIndex(graph, profiling_point, all_reduce_nodes)); - - const OpsKernelManager &ops_kernel_manager = ge_lib->OpsKernelManagerObj(); - - GE_TIMESTAMP_CALLNUM_START(GenerateTask); - // map store fusion nodes - map> fusion_nodes; - string buffer_optimize = kOffOptimize; - (void)ge::GetContext().GetOption(BUFFER_OPTIMIZE, buffer_optimize); - if (buffer_optimize != kOffOptimize) { - GE_CHK_STATUS_RET(SaveFusionNodes(fusion_nodes, graph)); - } - std::unordered_set fusion_nodes_seen; - int64_t group_key; - uint32_t node_index = 0; - rtStream_t stream = nullptr; - GE_CHK_RT_RET(rtStreamCreate(&stream, 0)); - run_context.stream = stream; - if (rtModelBindStream(run_context.model, stream, 0) != RT_ERROR_NONE) { - GELOGE(FAILED, "Call rt api failed."); - GE_CHK_RT(rtStreamDestroy(stream)); - return FAILED; - } - for (auto &node : graph->GetAllNodes()) { - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - node_index++; - string name = node->GetName(); - string type = node->GetType(); - bool attr_notask = false; - bool get_attr_notask_flag = ge::AttrUtils::GetBool(op_desc, ATTR_NAME_NOTASK, attr_notask); - GE_IF_BOOL_EXEC(get_attr_notask_flag && attr_notask, - GELOGI("Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); - continue); - - GE_CHK_STATUS_RET(UpdateOpIsVarAttr(op_desc, graph->GetSessionID())); - string op_kernel_lib_name = op_desc->GetOpKernelLibName(); - // For fusion ddb pass, task def must be continuous. - // Part2: Call - auto fusion_task_info = - FusionTaskInfo{run_context, graph, node, op_desc, node_index, ge_lib, - ops_kernel_manager, task_def_list, op_name_map, profiling_point, all_reduce_nodes}; - GE_CHK_STATUS_RET(GenerateTaskForFusionNode(fusion_task_info, fusion_nodes, fusion_nodes_seen), - "Call GenerateTaskForFusionNode node:%s(%s) failed", name.c_str(), type.c_str()); - // continue directly - if (ge::AttrUtils::GetInt(op_desc, ATTR_NAME_FUSION_GROUP_KEY, group_key)) { - GELOGI("Fusion node[name:%s, type:%s] do not need generate task again.", name.c_str(), type.c_str()); - continue; - } - if (op_kernel_lib_name.empty()) { - GELOGI("Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); - continue; - } - OpsKernelInfoStorePtr kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); - if (kernel_info_store == nullptr) { - GELOGE(INTERNAL_ERROR, "No ops kernel store found. node:%s(%s), op_kernel_lib_name=%s.", name.c_str(), - type.c_str(), op_kernel_lib_name.c_str()); - return INTERNAL_ERROR; - } - GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "Call UpdateAnchorStatus node:%s(%s) failed", name.c_str(), - type.c_str()); - int64_t op_id = op_desc->GetId(); - int64_t stream_id = op_desc->GetStreamId(); - // Profiling task - size_t task_list_size_before = task_def_list.size(); - GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); - - GELOGD("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task.", op_kernel_lib_name.c_str(), - name.c_str(), type.c_str(), op_id, stream_id); - GE_TIMESTAMP_RESTART(GenerateTask); - auto ret = kernel_info_store->GenerateTask(*node, run_context, task_def_list); - GE_TIMESTAMP_ADD(GenerateTask); - if (ret != SUCCESS) { - GELOGE(ret, "Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task failed.", - op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id); - return ret; - } - // Profiling task - GE_CHK_STATUS_RET(InsertProfilingTaskAfter(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); - size_t task_list_size_after = task_def_list.size(); - // If tasks is reduced - if (task_list_size_after < task_list_size_before) { - GELOGE(FAILED, "Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task. but task num from %zu to %zu.", - op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id, task_list_size_before, - task_list_size_after); - return FAILED; - } - - // Reset stream id to ge stream id, as graph load must use ge stream to reassign stream - void *ops_kernel_info_store_ptr = kernel_info_store.get(); - for (size_t idx = task_list_size_before; idx < task_list_size_after; ++idx) { - op_name_map[idx] = name; - // Set opsKernelInfoStorePtr and op_index, the two fields be use in DistributeTask and InitTaskInfo - TaskDef *task_def_ptr = &task_def_list[idx]; - GE_CHECK_NOTNULL(task_def_ptr); - task_def_ptr->set_ops_kernel_store_ptr(reinterpret_cast(ops_kernel_info_store_ptr)); - } - - GELOGD("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task finished, generate %zu task(s).", - op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id, - task_list_size_after - task_list_size_before); + if (is_unknown_shape) { + GE_CHK_STATUS_RET(DestroyUnknownShapeStream(run_context, stream), "Destory unknown shape stream failed."); } - GE_CHK_RT(rtModelUnbindStream(run_context.model, stream)); - GE_CHK_RT(rtStreamDestroy(stream)); - GE_TIMESTAMP_CALLNUM_END(GenerateTask, "GraphBuild::GenerateTask"); + GE_TIMESTAMP_CALLNUM_EVENT_END(GenerateTask, "GraphBuild::GenerateTask"); return SUCCESS; } @@ -628,7 +502,11 @@ Status TaskGenerator::MarkNodeAndSetIndex(ComputeGraphPtr &graph) { return GE_CLI_GE_NOT_INITIALIZED; } - const auto all_nodes = graph->GetAllNodes(); + const auto all_nodes = graph->GetNodes(graph->GetGraphUnknownFlag()); + if (all_nodes.empty()) { + GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "Graph's node is empty"); + return GE_GRAPH_GRAPH_NODE_NULL; + } int64_t node_index = 0; for (auto &node : all_nodes) { @@ -715,7 +593,7 @@ Status TaskGenerator::AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingP OpDescPtr fp_op_desc = nullptr; uint32_t current_idx = 0; uint32_t first_fp = 0; - for (auto &node : graph->GetAllNodes()) { + for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); string op_kernel_lib_name = op_desc->GetOpKernelLibName(); @@ -742,7 +620,7 @@ Status TaskGenerator::AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingP return SUCCESS; } GELOGI("Find fp_op_desc is %s, id is %ld", fp_op_desc->GetName().c_str(), fp_op_desc->GetId()); - for (auto &node : graph->GetAllNodes()) { + for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); current_idx++; @@ -763,7 +641,7 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP uint32_t last_bp = 0; uint32_t iter_end = 0; uint32_t current_idx = 0; - for (auto &node : graph->GetAllNodes()) { + for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); current_idx++; @@ -807,7 +685,7 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP GE_CHECK_NOTNULL(bp_op_desc); current_idx = 0; - for (auto &node : graph->GetAllNodes()) { + for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); current_idx++; @@ -826,7 +704,7 @@ Status TaskGenerator::FindFpOfEnv(const ComputeGraphPtr &graph, const std::strin GELOGI("Start FindFpOfEnv"); uint32_t current_idx = 0; uint32_t first_fp = 0; - for (auto &node : graph->GetAllNodes()) { + for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(node->GetOpDesc()); current_idx++; @@ -851,7 +729,7 @@ Status TaskGenerator::FindBpOfEnv(const ComputeGraphPtr &graph, const std::strin uint32_t current_idx = 0; uint32_t iter_end = 0; uint32_t last_bp = 0; - for (auto &node : graph->GetAllNodes()) { + for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(node->GetOpDesc()); current_idx++; @@ -927,10 +805,10 @@ Status TaskGenerator::FindProfilingTaskIndex(const ComputeGraphPtr &graph, Profi bool train_graph = graph->GetNeedIteration(); if (profiling_point.fp_index == 0 && train_graph) { - GELOGE(FAILED, "First forward op name can't be found in graph for training trace."); + GELOGW("First forward op name can't be found in graph for training trace."); } if (profiling_point.bp_index == 0 && train_graph) { - GELOGE(FAILED, "Last backward op name can't be found in graph for training trace."); + GELOGW("Last backward op name can't be found in graph for training trace."); } return SUCCESS; } @@ -1068,4 +946,31 @@ bool TaskGenerator::IsProfPoint(const OpDescPtr &op, const std::string &name) { return false; } +Status TaskGenerator::SetUnknownShapeStream(RunContext &run_context, rtStream_t &stream) { + GE_CHK_RT_RET(rtStreamCreate(&stream, 0)); + run_context.stream = stream; + rtError_t rt_ret = rtModelBindStream(run_context.model, stream, 0); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + GE_CHK_RT_RET(rtStreamDestroy(stream)); + return FAILED; + } + return SUCCESS; +} + +Status TaskGenerator::DestroyUnknownShapeStream(RunContext &run_context, rtStream_t &stream) { + GE_CHK_RT(rtModelUnbindStream(run_context.model, stream)); + GE_CHK_RT_RET(rtStreamDestroy(stream)); + return SUCCESS; +} + +Status TaskGenerator::SetKnownShapeStream(RunContext &run_context, int64_t stream_id) { + if (stream_id < 0 || stream_id >= static_cast(run_context.graphStreamList.size())) { + GELOGE(INTERNAL_ERROR, "Stream id[%ld] is invalid, stream list size=%zu", stream_id, + run_context.graphStreamList.size()); + return INTERNAL_ERROR; + } + run_context.stream = run_context.graphStreamList[stream_id]; + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/build/task_generator.h b/src/ge/graph/build/task_generator.h index 02721e00..b2ca4470 100644 --- a/src/ge/graph/build/task_generator.h +++ b/src/ge/graph/build/task_generator.h @@ -94,18 +94,6 @@ class TaskGenerator { std::map &op_name_map); /// - /// call engine to generate unknown shape task. - /// @param run_context run context - /// @param graph compute graph - /// @param task_def_list task def list generate by engine - /// @param op_name_map relation of task index and op - /// @return SUCCESS:seccess - /// Other: failed - /// - Status GenerateUnknownShapeTask(RunContext &run_context, ComputeGraphPtr &graph, - std::vector &task_def_list, std::map &op_name_map); - - /// /// AddModelTaskToModel /// @param model_task_def model task /// @param model_def model @@ -154,6 +142,12 @@ class TaskGenerator { Status SaveFusionNodes(map> &fusion_nodes, ComputeGraphPtr &graph); + Status SetUnknownShapeStream(RunContext &run_context, rtStream_t &stream); + + Status DestroyUnknownShapeStream(RunContext &run_context, rtStream_t &stream); + + Status SetKnownShapeStream(RunContext &run_context, int64_t stream_id); + uint8_t *var_mem_base_ = nullptr; uint64_t var_mem_size_ = 0; }; diff --git a/src/ge/graph/common/ge_call_wrapper.h b/src/ge/graph/common/ge_call_wrapper.h index a21d642e..a2bb6b88 100644 --- a/src/ge/graph/common/ge_call_wrapper.h +++ b/src/ge/graph/common/ge_call_wrapper.h @@ -18,6 +18,41 @@ #define GE_GE_CALL_WRAPPER_H_ #include "framework/common/debug/ge_log.h" +#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() + +#define GE_TIMESTAMP_END(stage, stage_name) \ + do { \ + uint64_t endUsec_##stage = ge::GetCurrentTimestap(); \ + GELOGI("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ + (endUsec_##stage - startUsec_##stage)); \ + } while (0); + +#define GE_TIMESTAMP_EVENT_END(stage, stage_name) \ + do { \ + uint64_t endUsec_##stage = ge::GetCurrentTimestap(); \ + GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ + (endUsec_##stage - startUsec_##stage)); \ + } while (0); + +#define GE_TIMESTAMP_CALLNUM_START(stage) \ + uint64_t startUsec_##stage = ge::GetCurrentTimestap(); \ + uint64_t call_num_of##stage = 0; \ + uint64_t time_of##stage = 0 + +#define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = ge::GetCurrentTimestap()) + +#define GE_TIMESTAMP_ADD(stage) \ + time_of##stage += ge::GetCurrentTimestap() - startUsec_##stage; \ + call_num_of##stage++ + +#define GE_TIMESTAMP_CALLNUM_END(stage, stage_name) \ + GELOGI("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \ + call_num_of##stage) + +#define GE_TIMESTAMP_CALLNUM_EVENT_END(stage, stage_name) \ + GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \ + call_num_of##stage) + #define RUN_WITH_TIMESTAMP_NAME(var_name, prefix, func, ...) \ do { \ GE_TIMESTAMP_START(var_name); \ @@ -29,10 +64,23 @@ } \ } while (0) +#define RUN_WITH_PERF_TIMESTAMP_NAME(var_name, prefix, func, ...) \ + do { \ + GE_TIMESTAMP_START(var_name); \ + auto ret_inner_macro = func(__VA_ARGS__); \ + GE_TIMESTAMP_EVENT_END(var_name, #prefix "::" #func) \ + if (ret_inner_macro != ge::SUCCESS) { \ + GELOGE(ret_inner_macro, "Failed to process " #prefix "_" #func); \ + return ret_inner_macro; \ + } \ + } while (0) + #define JOIN_NAME_INNER(a, b) a##b #define JOIN_NAME(a, b) JOIN_NAME_INNER(a, b) #define COUNTER_NAME(a) JOIN_NAME(a, __COUNTER__) #define GE_RUN(prefix, func, ...) \ RUN_WITH_TIMESTAMP_NAME(COUNTER_NAME(ge_timestamp_##prefix), prefix, func, __VA_ARGS__) +#define GE_RUN_PERF(prefix, func, ...) \ + RUN_WITH_PERF_TIMESTAMP_NAME(COUNTER_NAME(ge_timestamp_##prefix), prefix, func, __VA_ARGS__) #endif // GE_GE_CALL_WRAPPER_H_ diff --git a/src/ge/graph/execute/graph_execute.cc b/src/ge/graph/execute/graph_execute.cc index 9293b9af..5ff89c07 100644 --- a/src/ge/graph/execute/graph_execute.cc +++ b/src/ge/graph/execute/graph_execute.cc @@ -120,7 +120,7 @@ Status GraphExecutor::FreeInOutBuffer() { } } -Status GraphExecutor::MallocInOutBuffer(const std::vector &buffer_size, std::vector &data_addr) { +Status GraphExecutor::MallocInOutBuffer(const std::vector &buffer_size, std::vector &data_addr) { if (malloc_flag_) { auto all_size_same = true; if (buffer_size.size() == buffer_size_.size()) { @@ -169,7 +169,7 @@ Status GraphExecutor::PrepareInputData(const std::vector &input_tensor graph_input_data.timestamp = 0; std::size_t inputSize = input_tensor.size(); std::size_t output_size = output_desc.size(); - std::vector bufferSizeVec; + std::vector bufferSizeVec; std::vector addrVec; for (std::size_t i = 0; i < inputSize; ++i) { @@ -211,7 +211,7 @@ Status GraphExecutor::PrepareInputData(const std::vector &input_tensor for (std::size_t j = 0; j < output_size; j++) { auto desc = output_desc[j]; - uint32_t buffer_size = desc.size; + uint64_t buffer_size = desc.size; DataBuffer out_data_buf; out_data_buf.data = reinterpret_cast(addrVec[inputSize + j]); @@ -225,6 +225,13 @@ Status GraphExecutor::PrepareInputData(const std::vector &input_tensor Status GraphExecutor::SyncExecuteModel(uint32_t model_id, const std::vector &input_tensor, std::vector &output_tensor) { + auto model_manager = ge::ModelManager::GetInstance(); + GE_CHECK_NOTNULL(model_manager); + if (model_manager->IsDynamicShape(model_id)) { + GELOGI("[ExecuteGraph] GetInputOutputDescInfo via dynamic shape model executor, modelId=%u", model_id); + return model_manager->SyncExecuteModel(model_id, input_tensor, output_tensor); + } + // Prepare input and output std::vector inputs_desc; std::vector output_desc; @@ -450,11 +457,13 @@ Status GraphExecutor::GetInputOutputDescInfo(const uint32_t model_id, vector &input_desc, vector &output_desc, - std::vector &input_formats, std::vector &out_formats) { + std::vector &input_formats, std::vector &out_formats, + bool new_model_desc) { try { auto model_manager = ge::ModelManager::GetInstance(); GE_CHECK_NOTNULL(model_manager); - Status ret = model_manager->GetInputOutputDescInfo(model_id, input_desc, output_desc, input_formats, out_formats); + Status ret = model_manager->GetInputOutputDescInfo(model_id, input_desc, output_desc, input_formats, out_formats, + new_model_desc); if (ret != SUCCESS) { GELOGE(ret, "GetInputOutputDescInfo failed."); CsaInteract::GetInstance().WriteErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC); @@ -573,5 +582,4 @@ Status GraphExecutor::GetAllAippInputOutputDims(uint32_t model_id, uint32_t inde return SUCCESS; } - } // namespace ge diff --git a/src/ge/graph/execute/graph_execute.h b/src/ge/graph/execute/graph_execute.h index ae467515..6919a439 100644 --- a/src/ge/graph/execute/graph_execute.h +++ b/src/ge/graph/execute/graph_execute.h @@ -71,7 +71,7 @@ class GraphExecutor { static Status GetInputOutputDescInfo(const uint32_t model_id, vector &input_desc, vector &output_desc, std::vector &input_formats, - std::vector &output_formats); + std::vector &output_formats, bool new_model_desc = false); static Status GetAIPPInfo(uint32_t model_id, uint32_t index, AippConfigInfo &aipp_info); @@ -110,7 +110,7 @@ class GraphExecutor { Status FreeInOutBuffer(); - Status MallocInOutBuffer(const std::vector &buffer_size, std::vector &data_addr); + Status MallocInOutBuffer(const std::vector &buffer_size, std::vector &data_addr); bool init_flag_; @@ -129,7 +129,7 @@ class GraphExecutor { bool malloc_flag_; std::vector buffer_addr_; - std::vector buffer_size_; + std::vector buffer_size_; }; } // namespace ge diff --git a/src/ge/graph/load/graph_loader.cc b/src/ge/graph/load/graph_loader.cc index 1f4cbcf9..4a986308 100644 --- a/src/ge/graph/load/graph_loader.cc +++ b/src/ge/graph/load/graph_loader.cc @@ -350,7 +350,8 @@ Status GraphLoader::GetMemoryInfo(int64_t &free) { return RT_FAILED; } // Add small page memory size - free = static_cast(free_mem + VarManager::Instance(0)->GetUseMaxMemorySize() - total_mem); + free = + static_cast(free_mem + VarManager::Instance(GetContext().SessionId())->GetUseMaxMemorySize() - total_mem); GELOGI("GetMemoryInfo free[%zu], total[%zu], return free[%ld]", free_mem, total_mem, free); return SUCCESS; } diff --git a/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc b/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc index 06111015..a0011b34 100644 --- a/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc +++ b/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc @@ -339,7 +339,7 @@ Status CpuTaskActiveEntry::Distribute() { return RT_FAILED; } - GELOGI("Cpu kernel launch wait end task success."); + GELOGI("Cpu kernel launch active entry task success."); return SUCCESS; } diff --git a/src/ge/graph/load/new_model_manager/data_dumper.cc b/src/ge/graph/load/new_model_manager/data_dumper.cc index 653a3fa1..a4fe8898 100644 --- a/src/ge/graph/load/new_model_manager/data_dumper.cc +++ b/src/ge/graph/load/new_model_manager/data_dumper.cc @@ -21,7 +21,6 @@ #include #include -#include "common/debug/log.h" #include "common/properties_manager.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" @@ -37,9 +36,36 @@ namespace { const uint32_t kAicpuLoadFlag = 1; const uint32_t kAicpuUnloadFlag = 0; +const int64_t kOpDebugSize = 2048; +const int64_t kOpDebugShape = 2048; +const int8_t kDecimal = 10; +const uint32_t kAddrLen = sizeof(void *); const char *const kDumpOutput = "output"; const char *const kDumpInput = "input"; const char *const kDumpAll = "all"; + +// parse for format like nodename:input:index +static bool ParseNameIndex(const std::string &node_name_index, std::string &node_name, std::string &input_or_output, + size_t &index) { + auto sep = node_name_index.rfind(':'); + if (sep == std::string::npos) { + return false; + } + auto index_str = node_name_index.substr(sep + 1); + index = static_cast(std::strtol(index_str.c_str(), nullptr, kDecimal)); + auto node_name_without_index = node_name_index.substr(0, sep); + sep = node_name_without_index.rfind(':'); + if (sep == std::string::npos) { + return false; + } + node_name = node_name_without_index.substr(0, sep); + input_or_output = node_name_without_index.substr(sep + 1); + return !(input_or_output != kDumpInput && input_or_output != kDumpOutput); +} + +static bool IsTensorDescWithSkipDumpAddrType(bool has_mem_type_attr, vector v_memory_type, size_t i) { + return has_mem_type_attr && (v_memory_type[i] == RT_MEMORY_L1); +} } // namespace static int32_t GetIrDataType(ge::DataType data_type) { @@ -138,6 +164,13 @@ void DataDumper::SaveEndGraphId(uint32_t task_id, uint32_t stream_id) { end_graph_stream_id_ = stream_id; } +void DataDumper::SaveOpDebugId(uint32_t task_id, uint32_t stream_id, void *op_debug_addr, bool is_op_debug) { + op_debug_task_id_ = task_id; + op_debug_stream_id_ = stream_id; + op_debug_addr_ = op_debug_addr; + is_op_debug_ = is_op_debug; +} + void DataDumper::SaveDumpTask(uint32_t task_id, uint32_t stream_id, const std::shared_ptr &op_desc, uintptr_t args) { if (op_desc == nullptr) { @@ -202,56 +235,121 @@ static void SetOpMappingLoopAddr(uintptr_t step_id, uintptr_t loop_per_iter, uin } } -Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task) { - GELOGI("Start dump output"); - if (inner_dump_info.is_task) { - // tbe or aicpu op - const auto &output_descs = inner_dump_info.op->GetAllOutputsDesc(); - const auto input_size = inner_dump_info.op->GetAllInputsDesc().size(); - const std::vector output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, inner_dump_info.op, false); - if (output_descs.size() != output_addrs.size()) { - GELOGE(PARAM_INVALID, "Invalid output desc addrs size %zu, op %s has %zu output desc.", output_addrs.size(), - inner_dump_info.op->GetName().c_str(), output_descs.size()); - return PARAM_INVALID; - } +Status DataDumper::GenerateOutput(aicpu::dump::Output &output, const OpDesc::Vistor &tensor_descs, + const uintptr_t &addr, size_t index) { + output.set_data_type(static_cast(GetIrDataType(tensor_descs.at(index).GetDataType()))); + output.set_format(static_cast(tensor_descs.at(index).GetFormat())); - for (size_t i = 0; i < output_descs.size(); ++i) { - aicpu::dump::Output output; - output.set_data_type(static_cast(GetIrDataType(output_descs.at(i).GetDataType()))); - output.set_format(static_cast(output_descs.at(i).GetFormat())); + for (auto dim : tensor_descs.at(index).GetShape().GetDims()) { + output.mutable_shape()->add_dim(dim); + } + int64_t output_size = 0; + if (TensorUtils::GetTensorSizeInBytes(tensor_descs.at(index), output_size) != SUCCESS) { + GELOGE(PARAM_INVALID, "Get output size filed"); + return PARAM_INVALID; + } + GELOGD("Get output size in dump is %ld", output_size); + std::string origin_name; + int32_t origin_output_index = -1; + (void)AttrUtils::GetStr(&tensor_descs.at(index), ATTR_NAME_DATA_DUMP_ORIGIN_NAME, origin_name); + (void)AttrUtils::GetInt(&tensor_descs.at(index), ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, origin_output_index); + output.set_size(output_size); + output.set_original_name(origin_name); + output.set_original_output_index(origin_output_index); + output.set_original_output_format(static_cast(tensor_descs.at(index).GetOriginFormat())); + output.set_original_output_data_type(static_cast(tensor_descs.at(index).GetOriginDataType())); + output.set_address(static_cast(addr)); + return SUCCESS; +} - for (auto dim : output_descs.at(i).GetShape().GetDims()) { - output.mutable_shape()->add_dim(dim); - } +Status DataDumper::DumpRefOutput(const DataDumper::InnerDumpInfo &inner_dump_info, aicpu::dump::Output &output, + size_t i, const std::string &node_name_index) { + std::string dump_op_name; + std::string input_or_output; + size_t index; + // parser and find which node's input or output tensor desc is chosen for dump info + if (!ParseNameIndex(node_name_index, dump_op_name, input_or_output, index)) { + GELOGE(PARAM_INVALID, "Op [%s] output desc[%zu] with invalid ATTR_DATA_DUMP_REF attr[%s].", + inner_dump_info.op->GetName().c_str(), i, node_name_index.c_str()); + return PARAM_INVALID; + } + GE_CHECK_NOTNULL(compute_graph_); + auto replace_node = compute_graph_->FindNode(dump_op_name); + GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(replace_node == nullptr, + "Op [%s] output desc[%zu] with invalid ATTR_DATA_DUMP_REF attr[%s]," + " cannot find redirect node[%s].", + inner_dump_info.op->GetName().c_str(), i, node_name_index.c_str(), + dump_op_name.c_str()); + auto replace_opdesc = replace_node->GetOpDesc(); + GE_CHECK_NOTNULL(replace_opdesc); + auto iter = ref_info_.find(replace_opdesc); + GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(iter == ref_info_.end(), + "Op [%s] output desc[%zu] cannot find any saved redirect node[%s]'s info.", + inner_dump_info.op->GetName().c_str(), i, replace_opdesc->GetName().c_str()); + GE_CHECK_NOTNULL(iter->second); + auto addr = reinterpret_cast(iter->second); + if (input_or_output == kDumpInput) { + const auto &replace_input_descs = replace_opdesc->GetAllInputsDesc(); + addr += kAddrLen * index; + GE_CHK_STATUS_RET(GenerateOutput(output, replace_input_descs, addr, index), "Generate output failed"); + } else if (input_or_output == kDumpOutput) { + const auto &replace_output_descs = replace_opdesc->GetAllOutputsDesc(); + const auto replace_input_size = replace_opdesc->GetAllInputsDesc().size(); + addr += (index + replace_input_size) * kAddrLen; + GE_CHK_STATUS_RET(GenerateOutput(output, replace_output_descs, addr, index), "Generate output failed"); + } + GELOGD("Op [%s] output desc[%zu] dump info is replaced by node[%s] [%s] tensor_desc [%zu]", + inner_dump_info.op->GetName().c_str(), i, dump_op_name.c_str(), input_or_output.c_str(), index); + return SUCCESS; +} - int64_t output_size = 0; - if (TensorUtils::GetTensorSizeInBytes(output_descs.at(i), output_size) != SUCCESS) { - GELOGE(PARAM_INVALID, "Get output size filed"); - return PARAM_INVALID; - } - GELOGI("Get output size in dump is %ld", output_size); - std::string origin_name; - int32_t origin_output_index = -1; - (void)AttrUtils::GetStr(&output_descs.at(i), ATTR_NAME_DATA_DUMP_ORIGIN_NAME, origin_name); - (void)AttrUtils::GetInt(&output_descs.at(i), ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, origin_output_index); - GE_IF_BOOL_EXEC(output_size <= 0, GELOGE(PARAM_INVALID, "Output size %ld is less than zero", output_size); - return PARAM_INVALID) - output.set_size(output_size); - output.set_original_name(origin_name); - output.set_original_output_index(origin_output_index); - output.set_original_output_format(static_cast(output_descs.at(i).GetOriginFormat())); - output.set_original_output_data_type(static_cast(output_descs.at(i).GetOriginDataType())); - output.set_address(static_cast(inner_dump_info.args + (i + input_size) * sizeof(void *))); - - task.mutable_output()->Add(std::move(output)); +Status DataDumper::DumpOutputWithTask(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task) { + const auto &output_descs = inner_dump_info.op->GetAllOutputsDesc(); + const std::vector output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, inner_dump_info.op); + if (output_descs.size() != output_addrs.size()) { + GELOGE(PARAM_INVALID, "Invalid output desc addrs size %zu, op %s has %zu output desc.", output_addrs.size(), + inner_dump_info.op->GetName().c_str(), output_descs.size()); + return PARAM_INVALID; + } + std::vector v_memory_type; + bool has_mem_type_attr = ge::AttrUtils::GetListInt(inner_dump_info.op, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, v_memory_type); + GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(has_mem_type_attr && (v_memory_type.size() != output_descs.size()), + "DumpOutputWithTask[%s], output size[%zu], output memory type size[%zu]", + inner_dump_info.op->GetName().c_str(), output_descs.size(), + v_memory_type.size()); + + for (size_t i = 0; i < output_descs.size(); ++i) { + aicpu::dump::Output output; + std::string node_name_index; + const auto &output_desc = output_descs.at(i); + // check dump output tensor desc is redirected by attr ATTR_DATA_DUMP_REF + if (AttrUtils::GetStr(&output_desc, ATTR_DATA_DUMP_REF, node_name_index)) { + GE_CHK_STATUS_RET(DumpRefOutput(inner_dump_info, output, i, node_name_index), "DumpRefOutput failed"); + } else { + GE_IF_BOOL_EXEC( + IsTensorDescWithSkipDumpAddrType(has_mem_type_attr, v_memory_type, i), + GELOGD("DumpOutputWithTask[%s] output[%zu] is l1 addr, skip it", inner_dump_info.op->GetName().c_str(), i); + continue;); + + const auto input_size = inner_dump_info.op->GetInputsSize(); + auto addr = inner_dump_info.args + (i + input_size) * kAddrLen; + GE_CHK_STATUS_RET(GenerateOutput(output, output_descs, addr, i), "Generate output failed"); } - return SUCCESS; + task.mutable_output()->Add(std::move(output)); } + return SUCCESS; +} +Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task) { + GELOGI("Start dump output"); + if (inner_dump_info.is_task) { + // tbe or aicpu op, these ops are with task + return DumpOutputWithTask(inner_dump_info, task); + } // else data, const or variable op aicpu::dump::Output output; auto output_tensor = inner_dump_info.op->GetOutputDescPtr(inner_dump_info.output_anchor_index); - const std::vector output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, inner_dump_info.op, false); + const std::vector output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, inner_dump_info.op); if (output_tensor == nullptr) { GELOGE(PARAM_INVALID, "output_tensor is null, index: %d, size: %zu.", inner_dump_info.output_anchor_index, inner_dump_info.op->GetOutputsSize()); @@ -269,9 +367,6 @@ Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump: int32_t origin_output_index = -1; (void)AttrUtils::GetStr(output_tensor, ATTR_NAME_DATA_DUMP_ORIGIN_NAME, origin_name); (void)AttrUtils::GetInt(output_tensor, ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, origin_output_index); - GE_IF_BOOL_EXEC(inner_dump_info.data_size <= 0, - GELOGE(PARAM_INVALID, "The size of data %ld is less than zero", inner_dump_info.data_size); - return PARAM_INVALID) output.set_size(inner_dump_info.data_size); output.set_original_name(origin_name); output.set_original_output_index(origin_output_index); @@ -282,7 +377,7 @@ Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump: GELOGE(FAILED, "Index is out of range."); return FAILED; } - auto data_addr = inner_dump_info.args + sizeof(void *) * static_cast(inner_dump_info.input_anchor_index); + auto data_addr = inner_dump_info.args + kAddrLen * static_cast(inner_dump_info.input_anchor_index); output.set_address(static_cast(data_addr)); task.mutable_output()->Add(std::move(output)); @@ -290,37 +385,98 @@ Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump: return SUCCESS; } +Status DataDumper::GenerateInput(aicpu::dump::Input &input, const OpDesc::Vistor &tensor_descs, + const uintptr_t &addr, size_t index) { + input.set_data_type(static_cast(GetIrDataType(tensor_descs.at(index).GetDataType()))); + input.set_format(static_cast(tensor_descs.at(index).GetFormat())); + + for (auto dim : tensor_descs.at(index).GetShape().GetDims()) { + input.mutable_shape()->add_dim(dim); + } + int64_t input_size = 0; + if (AttrUtils::GetInt(tensor_descs.at(index), ATTR_NAME_INPUT_ORIGIN_SIZE, input_size)) { + GELOGI("Get aipp input size according to attr is %ld", input_size); + } else if (TensorUtils::GetTensorSizeInBytes(tensor_descs.at(index), input_size) != SUCCESS) { + GELOGE(PARAM_INVALID, "Get input size filed"); + return PARAM_INVALID; + } + GELOGD("Get input size in dump is %ld", input_size); + input.set_size(input_size); + input.set_address(static_cast(addr)); + return SUCCESS; +} + +Status DataDumper::DumpRefInput(const DataDumper::InnerDumpInfo &inner_dump_info, aicpu::dump::Input &input, size_t i, + const std::string &node_name_index) { + std::string dump_op_name; + std::string input_or_output; + size_t index; + // parser and find which node's input or output tensor desc is chosen for dump info + if (!ParseNameIndex(node_name_index, dump_op_name, input_or_output, index)) { + GELOGE(PARAM_INVALID, "Op [%s] input desc[%zu] with invalid ATTR_DATA_DUMP_REF attr[%s].", + inner_dump_info.op->GetName().c_str(), i, node_name_index.c_str()); + return PARAM_INVALID; + } + GE_CHECK_NOTNULL(compute_graph_); + auto replace_node = compute_graph_->FindNode(dump_op_name); + GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(replace_node == nullptr, + "Op [%s] input desc[%zu] with invalid ATTR_DATA_DUMP_REF attr[%s]," + " cannot find redirect node[%s].", + inner_dump_info.op->GetName().c_str(), i, node_name_index.c_str(), + dump_op_name.c_str()); + auto replace_opdesc = replace_node->GetOpDesc(); + GE_CHECK_NOTNULL(replace_opdesc); + auto iter = ref_info_.find(replace_opdesc); + GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(iter == ref_info_.end(), + "Op [%s] input desc[%zu] cannot find any saved redirect node[%s]'s info.", + inner_dump_info.op->GetName().c_str(), i, replace_opdesc->GetName().c_str()); + GE_CHECK_NOTNULL(iter->second); + auto addr = reinterpret_cast(iter->second); + if (input_or_output == kDumpInput) { + const auto &replace_input_descs = replace_opdesc->GetAllInputsDesc(); + addr += kAddrLen * index; + GE_CHK_STATUS_RET(GenerateInput(input, replace_input_descs, addr, index), "Generate input failed"); + } else if (input_or_output == kDumpOutput) { + const auto &replace_output_descs = replace_opdesc->GetAllOutputsDesc(); + const auto replace_input_size = replace_opdesc->GetAllInputsDesc().size(); + addr += (index + replace_input_size) * kAddrLen; + GE_CHK_STATUS_RET(GenerateInput(input, replace_output_descs, addr, index), "Generate input failed"); + } + GELOGD("Op [%s] input desc[%zu] dump info is replaced by node[%s] [%s] tensor_desc [%zu]", + inner_dump_info.op->GetName().c_str(), i, dump_op_name.c_str(), input_or_output.c_str(), index); + return SUCCESS; +} + Status DataDumper::DumpInput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task) { GELOGI("Start dump input"); const auto &input_descs = inner_dump_info.op->GetAllInputsDesc(); - const std::vector input_addrs = ModelUtils::GetInputDataAddrs(runtime_param_, inner_dump_info.op, false); + const std::vector input_addrs = ModelUtils::GetInputDataAddrs(runtime_param_, inner_dump_info.op); if (input_descs.size() != input_addrs.size()) { GELOGE(PARAM_INVALID, "Invalid input desc addrs size %zu, op %s has %zu input desc.", input_addrs.size(), inner_dump_info.op->GetName().c_str(), input_descs.size()); return PARAM_INVALID; } + std::vector v_memory_type; + bool has_mem_type_attr = ge::AttrUtils::GetListInt(inner_dump_info.op, ATTR_NAME_INPUT_MEM_TYPE_LIST, v_memory_type); + GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(has_mem_type_attr && (v_memory_type.size() != input_descs.size()), + "DumpInput[%s], input size[%zu], input memory type size[%zu]", + inner_dump_info.op->GetName().c_str(), input_descs.size(), v_memory_type.size()); for (size_t i = 0; i < input_descs.size(); ++i) { aicpu::dump::Input input; - input.set_data_type(static_cast(GetIrDataType(input_descs.at(i).GetDataType()))); - input.set_format(static_cast(input_descs.at(i).GetFormat())); - - for (auto dim : input_descs.at(i).GetShape().GetDims()) { - input.mutable_shape()->add_dim(dim); + std::string node_name_index; + // check dump input tensor desc is redirected by attr ATTR_DATA_DUMP_REF + if (AttrUtils::GetStr(&input_descs.at(i), ATTR_DATA_DUMP_REF, node_name_index)) { + GE_CHK_STATUS_RET(DumpRefInput(inner_dump_info, input, i, node_name_index), "DumpRefInput failed"); + // normal dump without attr + } else { + GE_IF_BOOL_EXEC(IsTensorDescWithSkipDumpAddrType(has_mem_type_attr, v_memory_type, i), + GELOGD("DumpInput[%s] input[%zu] is l1 addr, skip it", inner_dump_info.op->GetName().c_str(), i); + continue;); + + auto addr = inner_dump_info.args + kAddrLen * i; + GE_CHK_STATUS_RET(GenerateInput(input, input_descs, addr, i), "Generate input failed"); } - - int64_t input_size = 0; - if (AttrUtils::GetInt(&input_descs.at(i), ATTR_NAME_INPUT_ORIGIN_SIZE, input_size)) { - GELOGI("Get aipp input size according to attr is %ld", input_size); - } else if (TensorUtils::GetTensorSizeInBytes(input_descs.at(i), input_size) != SUCCESS) { - GELOGE(PARAM_INVALID, "Get input size filed"); - return PARAM_INVALID; - } - GELOGI("Get input size in dump is %ld", input_size); - GE_IF_BOOL_EXEC(input_size <= 0, GELOGE(PARAM_INVALID, "Input size %ld is less than zero", input_size); - return PARAM_INVALID;) - input.set_size(input_size); - input.set_address(static_cast(inner_dump_info.args + sizeof(void *) * i)); task.mutable_input()->Add(std::move(input)); } return SUCCESS; @@ -400,36 +556,38 @@ Status DataDumper::ExecuteUnLoadDumpInfo(aicpu::dump::OpMappingInfo &op_mapping_ GELOGI("UnloadDumpInfo success, proto size is: %zu.", proto_size); return SUCCESS; } + Status DataDumper::LoadDumpInfo() { std::string dump_list_key; PrintCheckLog(dump_list_key); if (op_list_.empty()) { - return SUCCESS; + GELOGW("op_list_ is empty"); } aicpu::dump::OpMappingInfo op_mapping_info; - auto dump_path = PropertiesManager::Instance().GetDumpOutputPath(); - op_mapping_info.set_dump_path(PropertiesManager::Instance().GetDumpOutputPath() + std::to_string(device_id_) + "/"); + auto dump_path = dump_properties_.GetDumpPath() + std::to_string(device_id_) + "/"; + op_mapping_info.set_dump_path(dump_path); op_mapping_info.set_model_name(dump_list_key); op_mapping_info.set_model_id(model_id_); op_mapping_info.set_flag(kAicpuLoadFlag); - op_mapping_info.set_dump_step(PropertiesManager::Instance().GetDumpStep()); + op_mapping_info.set_dump_step(dump_properties_.GetDumpStep()); SetOpMappingLoopAddr(global_step_, loop_per_iter_, loop_cond_, op_mapping_info); - GELOGI("Dump step is %s and dump path is %s in load dump info", PropertiesManager::Instance().GetDumpStep().c_str(), + GELOGI("Dump step is %s and dump path is %s in load dump info", dump_properties_.GetDumpStep().c_str(), dump_path.c_str()); for (const auto &op_iter : op_list_) { - aicpu::dump::Task task; auto op_desc = op_iter.op; + GELOGD("Op %s in model %s begin to add task in op_mapping_info", op_desc->GetName().c_str(), dump_list_key.c_str()); + aicpu::dump::Task task; task.set_end_graph(false); task.set_task_id(op_iter.task_id); task.set_stream_id(op_iter.stream_id); task.mutable_op()->set_op_name(op_desc->GetName()); task.mutable_op()->set_op_type(op_desc->GetType()); - if (PropertiesManager::Instance().GetDumpMode() == kDumpOutput) { + if (dump_properties_.GetDumpMode() == kDumpOutput) { if (DumpOutput(op_iter, task) != SUCCESS) { GELOGE(FAILED, "Dump output failed"); return FAILED; @@ -437,7 +595,7 @@ Status DataDumper::LoadDumpInfo() { op_mapping_info.mutable_task()->Add(std::move(task)); continue; } - if (PropertiesManager::Instance().GetDumpMode() == kDumpInput) { + if (dump_properties_.GetDumpMode() == kDumpInput) { if (op_iter.is_task) { if (DumpInput(op_iter, task) != SUCCESS) { GELOGE(FAILED, "Dump input failed"); @@ -447,7 +605,7 @@ Status DataDumper::LoadDumpInfo() { op_mapping_info.mutable_task()->Add(std::move(task)); continue; } - if (PropertiesManager::Instance().GetDumpMode() == kDumpAll) { + if (dump_properties_.GetDumpMode() == kDumpAll) { auto ret = DumpOutput(op_iter, task); if (ret != SUCCESS) { GELOGE(FAILED, "Dump output failed when in dumping all"); @@ -467,19 +625,22 @@ Status DataDumper::LoadDumpInfo() { SetEndGraphIdToAicpu(end_graph_task_id_, end_graph_stream_id_, op_mapping_info); - auto ret = ExecuteLoadDumpInfo(op_mapping_info); - if (ret != SUCCESS) { - GELOGE(FAILED, "Execute load dump info failed"); - return FAILED; + SetOpDebugIdToAicpu(op_debug_task_id_, op_debug_stream_id_, op_debug_addr_, op_mapping_info); + + if (!op_list_.empty() || is_op_debug_) { + auto ret = ExecuteLoadDumpInfo(op_mapping_info); + if (ret != SUCCESS) { + GELOGE(FAILED, "Execute load dump info failed"); + return FAILED; + } } return SUCCESS; } void DataDumper::SetEndGraphIdToAicpu(uint32_t task_id, uint32_t stream_id, aicpu::dump::OpMappingInfo &op_mapping_info) { - if (PropertiesManager::Instance().GetDumpMode() == kDumpOutput || - PropertiesManager::Instance().GetDumpMode() == kDumpInput || - PropertiesManager::Instance().GetDumpMode() == kDumpAll) { + if (dump_properties_.GetDumpMode() == kDumpOutput || dump_properties_.GetDumpMode() == kDumpInput || + dump_properties_.GetDumpMode() == kDumpAll) { GELOGI("Add end_graph_info to aicpu, task_id is %u, stream_id is %u", end_graph_task_id_, end_graph_stream_id_); aicpu::dump::Task task; task.set_end_graph(true); @@ -491,6 +652,37 @@ void DataDumper::SetEndGraphIdToAicpu(uint32_t task_id, uint32_t stream_id, } } +void DataDumper::SetOpDebugIdToAicpu(uint32_t task_id, uint32_t stream_id, void *op_debug_addr, + aicpu::dump::OpMappingInfo &op_mapping_info) { + if (is_op_debug_) { + GELOGI("add op_debug_info to aicpu, task_id is %u, stream_id is %u", task_id, stream_id); + aicpu::dump::Task task; + task.set_end_graph(false); + task.set_task_id(task_id); + task.set_stream_id(stream_id); + task.mutable_op()->set_op_name(NODE_NAME_OP_DEBUG); + task.mutable_op()->set_op_type(OP_TYPE_OP_DEBUG); + + // set output + aicpu::dump::Output output; + output.set_data_type(DT_UINT8); + output.set_format(FORMAT_ND); + + output.mutable_shape()->add_dim(kOpDebugShape); + + output.set_original_name(NODE_NAME_OP_DEBUG); + output.set_original_output_index(0); + output.set_original_output_format(FORMAT_ND); + output.set_original_output_data_type(DT_UINT8); + // due to lhisi virtual addr bug, cannot use args now + output.set_address(static_cast(reinterpret_cast(op_debug_addr))); + output.set_size(kOpDebugSize); + + task.mutable_output()->Add(std::move(output)); + op_mapping_info.mutable_task()->Add(std::move(task)); + } +} + Status DataDumper::UnloadDumpInfo() { if (!load_flag_) { GELOGI("No need to UnloadDumpInfo."); @@ -517,15 +709,17 @@ Status DataDumper::UnloadDumpInfo() { } void DataDumper::PrintCheckLog(string &dump_list_key) { - std::set model_list = PropertiesManager::Instance().GetAllDumpModel(); + std::set model_list = dump_properties_.GetAllDumpModel(); if (model_list.empty()) { GELOGI("No model need dump."); return; } - GELOGI("%zu op need dump in %s.", op_list_.size(), model_name_.c_str()); bool not_find_by_omname = model_list.find(om_name_) == model_list.end(); bool not_find_by_modelname = model_list.find(model_name_) == model_list.end(); + dump_list_key = not_find_by_omname ? model_name_ : om_name_; + GELOGI("%zu op need dump in %s.", op_list_.size(), dump_list_key.c_str()); + if (model_list.find(DUMP_ALL_MODEL) == model_list.end()) { if (not_find_by_omname && not_find_by_modelname) { std::string model_list_str; @@ -533,12 +727,12 @@ void DataDumper::PrintCheckLog(string &dump_list_key) { model_list_str += "[" + model + "]."; } - GELOGW("Model %s will not be set to dump, dump list: %s", model_name_.c_str(), model_list_str.c_str()); + GELOGW("Model %s will not be set to dump, dump list: %s", dump_list_key.c_str(), model_list_str.c_str()); return; } } - dump_list_key = not_find_by_omname ? model_name_ : om_name_; - std::set config_dump_op_list = PropertiesManager::Instance().GetDumpPropertyValue(dump_list_key); + + std::set config_dump_op_list = dump_properties_.GetPropertyValue(dump_list_key); std::set dump_op_list; for (auto &inner_dump_info : op_list_) { // oplist value OpDescPtr is not nullptr diff --git a/src/ge/graph/load/new_model_manager/data_dumper.h b/src/ge/graph/load/new_model_manager/data_dumper.h index ee5b3241..0648a8ce 100644 --- a/src/ge/graph/load/new_model_manager/data_dumper.h +++ b/src/ge/graph/load/new_model_manager/data_dumper.h @@ -23,7 +23,9 @@ #include #include "framework/common/ge_inner_error_codes.h" +#include "common/properties_manager.h" #include "graph/node.h" +#include "graph/compute_graph.h" #include "proto/ge_ir.pb.h" #include "proto/op_mapping_info.pb.h" #include "runtime/mem.h" @@ -44,7 +46,9 @@ class DataDumper { device_id_(0), global_step_(0), loop_per_iter_(0), - loop_cond_(0) {} + loop_cond_(0), + compute_graph_(nullptr), + ref_info_() {} ~DataDumper(); @@ -56,6 +60,10 @@ class DataDumper { void SetDeviceId(uint32_t device_id) { device_id_ = device_id; } + void SetComputeGraph(const ComputeGraphPtr &compute_graph) { compute_graph_ = compute_graph; }; + + void SetRefInfo(const std::map &ref_info) { ref_info_ = ref_info; }; + void SetLoopAddr(void *global_step, void *loop_per_iter, void *loop_cond); void SaveDumpInput(const std::shared_ptr &node); @@ -65,11 +73,15 @@ class DataDumper { void SaveEndGraphId(uint32_t task_id, uint32_t stream_id); void SetOmName(const std::string &om_name) { om_name_ = om_name; } + void SaveOpDebugId(uint32_t task_id, uint32_t stream_id, void *op_debug_addr, bool is_op_debug); Status LoadDumpInfo(); Status UnloadDumpInfo(); + void SetDumpProperties(const DumpProperties &dump_properties) { dump_properties_ = dump_properties; } + const DumpProperties &GetDumpProperties() const { return dump_properties_; } + private: void ReleaseDevMem(void **ptr) noexcept; @@ -97,12 +109,32 @@ class DataDumper { uintptr_t global_step_; uintptr_t loop_per_iter_; uintptr_t loop_cond_; + ComputeGraphPtr compute_graph_; + std::map ref_info_; + + uint32_t op_debug_task_id_ = 0; + uint32_t op_debug_stream_id_ = 0; + void *op_debug_addr_ = nullptr; + bool is_op_debug_ = false; + + DumpProperties dump_properties_; Status DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task); + Status DumpRefOutput(const DataDumper::InnerDumpInfo &inner_dump_info, aicpu::dump::Output &output, size_t i, + const std::string &node_name_index); + Status DumpOutputWithTask(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task); Status DumpInput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task); + Status DumpRefInput(const DataDumper::InnerDumpInfo &inner_dump_info, aicpu::dump::Input &input, size_t i, + const std::string &node_name_index); Status ExecuteLoadDumpInfo(aicpu::dump::OpMappingInfo &op_mapping_info); void SetEndGraphIdToAicpu(uint32_t task_id, uint32_t stream_id, aicpu::dump::OpMappingInfo &op_mapping_info); + void SetOpDebugIdToAicpu(uint32_t task_id, uint32_t stream_id, void *op_debug_addr, + aicpu::dump::OpMappingInfo &op_mapping_info); Status ExecuteUnLoadDumpInfo(aicpu::dump::OpMappingInfo &op_mapping_info); + Status GenerateInput(aicpu::dump::Input &input, const OpDesc::Vistor &tensor_descs, + const uintptr_t &addr, size_t index); + Status GenerateOutput(aicpu::dump::Output &output, const OpDesc::Vistor &tensor_descs, + const uintptr_t &addr, size_t index); }; struct DataDumper::InnerDumpInfo { uint32_t task_id; diff --git a/src/ge/graph/load/new_model_manager/davinci_model.cc b/src/ge/graph/load/new_model_manager/davinci_model.cc index 45acee07..c43c37eb 100644 --- a/src/ge/graph/load/new_model_manager/davinci_model.cc +++ b/src/ge/graph/load/new_model_manager/davinci_model.cc @@ -42,11 +42,11 @@ #include "graph/graph.h" #include "graph/load/new_model_manager/cpu_queue_schedule.h" #include "graph/load/new_model_manager/tbe_handle_store.h" -#include "graph/load/output/output.h" #include "graph/manager/graph_mem_allocator.h" #include "graph/manager/graph_var_manager.h" #include "graph/manager/trans_var_data_utils.h" #include "graph/manager/util/debug.h" +#include "graph/common/ge_call_wrapper.h" #include "graph/model_serialize.h" #include "graph/node.h" #include "graph/utils/graph_utils.h" @@ -59,6 +59,7 @@ #include "runtime/event.h" #include "runtime/mem.h" #include "runtime/stream.h" +#include "runtime/rt_model.h" #include "securec.h" // create std::thread, catch exceptions using try/catch @@ -80,7 +81,6 @@ const uint32_t kOutputNum = 1; const uint32_t kTrueBranchStreamNum = 1; const uint32_t kThreadNum = 16; const uint32_t kAddrLen = sizeof(void *); -const char *const kNeedDestroySpecifiedAicpuKernel = "need_destroy_specified_aicpu_kernel"; const int kDecimal = 10; const int kBytes = 8; const uint32_t kDataMemAlignSizeCompare = 64; @@ -89,42 +89,10 @@ const char *const kDefaultBatchLable = "Batch_default"; inline bool IsDataOp(const std::string &node_type) { return node_type == DATA_TYPE || node_type == AIPP_DATA_TYPE || node_type == ANN_DATA_TYPE; } -inline bool IsCallDumpInputOp(const OpDescPtr &op_desc) { - bool skip_task_generate = false; - (void)ge::AttrUtils::GetBool(op_desc, ATTR_NO_TASK_AND_DUMP_NEEDED, skip_task_generate); - return skip_task_generate; -} - -void CreateInputDimsInfo(const OpDescPtr &op_desc, Format format, InputOutputDescInfo &input) { - uint32_t n, c, h, w; - n = format == FORMAT_NHWC ? NHWC_DIM_N : NCHW_DIM_N; - c = format == FORMAT_NHWC ? NHWC_DIM_C : NCHW_DIM_C; - h = format == FORMAT_NHWC ? NHWC_DIM_H : NCHW_DIM_H; - w = format == FORMAT_NHWC ? NHWC_DIM_W : NCHW_DIM_W; - - if (!op_desc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) { - if (op_desc->GetInputDescPtr(0)->GetShape().GetDimNum() == static_cast(NORMAL_TENSOR_SIZE)) { - input.shape_info.num = op_desc->GetInputDescPtr(0)->GetShape().GetDim(n); - input.shape_info.height = op_desc->GetInputDescPtr(0)->GetShape().GetDim(h); - input.shape_info.width = op_desc->GetInputDescPtr(0)->GetShape().GetDim(w); - input.shape_info.channel = op_desc->GetInputDescPtr(0)->GetShape().GetDim(c); - } - for (size_t k = 0; k < op_desc->GetInputDescPtr(0)->GetShape().GetDimNum(); k++) { - input.shape_info.dims.push_back(op_desc->GetInputDescPtr(0)->GetShape().GetDim(k)); - } - } else { - vector origin_input_dims; - (void)AttrUtils::GetListInt(op_desc, ATTR_MBATCH_ORIGIN_INPUT_DIMS, origin_input_dims); - if (origin_input_dims.size() == static_cast(NORMAL_TENSOR_SIZE)) { - input.shape_info.num = origin_input_dims[n]; - input.shape_info.height = origin_input_dims[h]; - input.shape_info.width = origin_input_dims[w]; - input.shape_info.channel = origin_input_dims[c]; - } - for (size_t k = 0; k < origin_input_dims.size(); ++k) { - input.shape_info.dims.push_back(origin_input_dims[k]); - } - } +inline bool IsNoTaskAndDumpNeeded(const OpDescPtr &op_desc) { + bool save_dump_info = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NO_TASK_AND_DUMP_NEEDED, save_dump_info); + return save_dump_info; } } // namespace @@ -157,10 +125,10 @@ DavinciModel::DavinciModel(int32_t priority, const std::shared_ptrGetModelTaskDefPtr(); return SUCCESS; } +/// +/// @ingroup ge +/// @brief Reduce memory usage after task sink. +/// @return: void +/// +void DavinciModel::Shrink() { + ge_model_.reset(); // delete object. + + // Old dump need op list, clear when closed. + char *ge_dump_env = std::getenv("DUMP_OP"); + int dump_op_switch = (ge_dump_env != nullptr) ? std::strtol(ge_dump_env, nullptr, kDecimal) : 0; + if (dump_op_switch == 0) { + op_list_.clear(); + } +} + Status DavinciModel::InitModelMem(void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) { if (is_model_has_inited_) { GELOGI("call InitModelMem more than once ."); return FAILED; } is_model_has_inited_ = true; - std::size_t data_size = TotalMemSize(); - ge::Buffer weights = ge_model_->GetWeight(); - uint8_t *weights_addr = weights.GetData(); + std::size_t data_size = TotalMemSize(); + const Buffer &weights = ge_model_->GetWeight(); std::size_t weights_size = weights.GetSize(); - GE_CHECK_LE(weights_size, ALLOC_MEMORY_MAX_SIZE); if ((dev_ptr != nullptr) && (mem_size < TotalMemSize())) { @@ -312,7 +308,7 @@ Status DavinciModel::InitModelMem(void *dev_ptr, size_t mem_size, void *weight_p } GELOGI("[IMAS]InitModelMem graph_%u MallocMemory type[W] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, weights_mem_base_, weights_size); - GE_CHK_RT_RET(rtMemcpy(weights_mem_base_, weights_size, weights_addr, weights_size, RT_MEMCPY_HOST_TO_DEVICE)) + GE_CHK_RT_RET(rtMemcpy(weights_mem_base_, weights_size, weights.GetData(), weights_size, RT_MEMCPY_HOST_TO_DEVICE)); GELOGI("copy weights data to device"); } @@ -367,19 +363,15 @@ void DavinciModel::InitRuntimeParams() { session_id_ = runtime_param_.session_id; GELOGI( - "InitRuntimeParams(), memory_size:%lu, weight_size:%lu, session_id:%u, var_size:%lu, logic_var_base:%lu, " - "logic_mem_base:%lu.", - runtime_param_.mem_size, runtime_param_.weight_size, runtime_param_.session_id, runtime_param_.var_size, - runtime_param_.logic_var_base, runtime_param_.logic_mem_base); - - GELOGI("InitRuntimeParams(), stream_num:%lu, event_num:%u, label_num:%u", runtime_param_.stream_num, - runtime_param_.event_num, runtime_param_.label_num); + "InitRuntimeParams(), session_id:%u, stream_num:%lu, event_num:%u, label_num:%u, " + "logic_mem_base:0x%lx, logic_weight_base:0x%lx, logic_var_base:0x%lx, " + "memory_size:%lu, weight_size:%lu, var_size:%lu", + runtime_param_.session_id, runtime_param_.stream_num, runtime_param_.event_num, runtime_param_.label_num, + runtime_param_.logic_mem_base, runtime_param_.logic_weight_base, runtime_param_.logic_var_base, + runtime_param_.mem_size, runtime_param_.weight_size, runtime_param_.var_size); } void DavinciModel::CheckHasHcomOp() { - // definiteness queue schedule, all stream by TS. - GE_IF_BOOL_EXEC(!input_queue_ids_.empty() || !output_queue_ids_.empty(), return ); - Graph graph = ge_model_->GetGraph(); auto compute_graph = GraphUtils::GetComputeGraph(graph); if (compute_graph == nullptr) { @@ -395,11 +387,6 @@ void DavinciModel::CheckHasHcomOp() { (op_desc->GetType() == HVDCALLBACKBROADCAST) || (op_desc->GetType() == HVDWAIT)), uint32_t stream_id = static_cast(op_desc->GetStreamId()); (void)hcom_streams_.emplace(stream_id); GELOGD("hcom stream: %u.", stream_id); continue); - - bool is_aicpu_stream = false; - GE_IF_BOOL_EXEC(AttrUtils::GetBool(op_desc, "is_aicpu_stream", is_aicpu_stream) && is_aicpu_stream, - uint32_t stream_id = static_cast(op_desc->GetStreamId()); - (void)aicpu_streams_.emplace(stream_id); GELOGD("aicpu stream: %u.", stream_id); continue); } } @@ -410,20 +397,13 @@ void DavinciModel::CheckHasHcomOp() { /// Status DavinciModel::BindModelStream() { // Stream not in active_stream_indication_ is active stream. - if (!input_queue_ids_.empty() || !output_queue_ids_.empty()) { - // Asynchronous Queue, need add S0, deactive all model stream. + if ((!input_queue_ids_.empty() || !output_queue_ids_.empty()) || (deploy_type_ == AICPU_DEPLOY_CROSS_THREAD)) { for (size_t i = 0; i < stream_list_.size(); ++i) { if (active_stream_indication_.count(i) == 0) { active_stream_list_.push_back(stream_list_[i]); active_stream_indication_.insert(i); // deactive all model stream. } } - } else { - for (size_t i = 0; i < stream_list_.size(); ++i) { - if (active_stream_indication_.count(i) == 0) { - active_stream_list_.push_back(stream_list_[i]); - } - } } for (size_t i = 0; i < stream_list_.size(); ++i) { @@ -441,23 +421,29 @@ Status DavinciModel::BindModelStream() { Status DavinciModel::DoTaskSink() { // task sink is supported as model_task_def is set - if (model_task_def_) { - GELOGI("do task_sink."); - GE_CHK_STATUS_RET(BindModelStream(), "Bind model stream failed."); + const auto &model_task_def = ge_model_->GetModelTaskDefPtr(); + if (model_task_def == nullptr) { + return SUCCESS; + } - if (known_node_) { - GE_CHK_STATUS_RET(MallocKnownArgs(), "Mallloc known node args failed."); - } + GE_CHK_RT_RET(rtGetAicpuDeploy(&deploy_type_)); + GELOGI("do task_sink. AiCpu deploy type is: %x.", deploy_type_); - GE_CHK_STATUS_RET(InitTaskInfo(*model_task_def_.get()), "InitTaskInfo failed."); + GE_CHK_STATUS_RET(BindModelStream(), "Bind model stream failed."); - GE_CHK_STATUS_RET(LoadWithQueue(), "LoadWithQueue failed."); + if (known_node_) { + GE_CHK_STATUS_RET(MallocKnownArgs(), "Mallloc known node args failed."); + } - GE_CHK_STATUS_RET(DistributeTask(), "Distribute failed."); + GE_CHK_STATUS_RET(InitTaskInfo(*model_task_def.get()), "InitTaskInfo failed."); - GE_CHK_RT_RET(rtModelLoadComplete(rt_model_handle_)); - } + GE_CHK_STATUS_RET(InitEntryTask(), "InitEntryTask failed."); + + GE_CHK_STATUS_RET(DistributeTask(), "Distribute failed."); + + GE_CHK_RT_RET(rtModelLoadComplete(rt_model_handle_)); + SetCopyOnlyOutput(); return SUCCESS; } @@ -475,12 +461,96 @@ Status DavinciModel::SetTSDevice() { return SUCCESS; } +Status DavinciModel::OpDebugRegister() { + bool is_op_debug = false; + (void)ge::AttrUtils::GetBool(ge_model_, ATTR_OP_DEBUG_FLAG, is_op_debug); + GELOGD("The value of op_debug in ge_model_ is %d.", is_op_debug); + if (is_op_debug) { + debug_reg_mutex_.lock(); + rtError_t rt_ret = rtMalloc(&op_debug_addr_, kOpDebugMemorySize, RT_MEMORY_DDR); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtMalloc error, ret: 0x%X", rt_ret); + return RT_FAILED; + } + + uint64_t debug_addrs_tmp = static_cast(reinterpret_cast(op_debug_addr_)); + + // For data dump, aicpu needs the pointer to pointer that save the real debug address. + rt_ret = rtMalloc(&p2p_debug_addr_, kDebugP2pSize, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtMalloc error, ret: 0x%X", rt_ret); + return RT_FAILED; + } + rt_ret = rtMemcpy(p2p_debug_addr_, sizeof(uint64_t), &debug_addrs_tmp, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(rt_ret, "rtMemcpy to p2p_addr error: 0x%X", rt_ret); + return FAILED; + } + + uint32_t op_debug_mode = 0; + (void)ge::AttrUtils::GetInt(ge_model_, ATTR_OP_DEBUG_MODE, op_debug_mode); + GELOGD("The value of op_debug_mode in ge_model_ is %u.", op_debug_mode); + uint32_t debug_task_id = 0; + uint32_t debug_stream_id = 0; + rt_ret = rtDebugRegister(rt_model_handle_, op_debug_mode, op_debug_addr_, &debug_stream_id, &debug_task_id); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtDebugRegister error, ret: 0x%X", rt_ret); + return RT_FAILED; + } + GELOGI("debug_task_id:%d, debug_stream_id:%u", debug_task_id, debug_stream_id); + is_op_debug_reg_ = true; + + data_dumper_.SaveOpDebugId(debug_task_id, debug_stream_id, p2p_debug_addr_, is_op_debug); + } + + return SUCCESS; +} + +void DavinciModel::OpDebugUnRegister() { + GELOGI("OpDebugUnRegister, is_op_debug_reg_ = %d", is_op_debug_reg_); + if (is_op_debug_reg_) { + debug_reg_mutex_.unlock(); + + rtError_t rt_ret = RT_ERROR_NONE; + if (rt_model_handle_ != nullptr) { + rt_ret = rtDebugUnRegister(rt_model_handle_); + if (rt_ret != RT_ERROR_NONE) { + GELOGW("rtDebugUnRegister failed, ret: 0x%X", rt_ret); + } + } + + if (op_debug_addr_ != nullptr) { + rt_ret = rtFree(op_debug_addr_); + if (rt_ret != RT_ERROR_NONE) { + GELOGW("rtFree failed, ret: 0x%X", rt_ret); + } + op_debug_addr_ = nullptr; + } + + if (p2p_debug_addr_ != nullptr) { + rt_ret = rtFree(p2p_debug_addr_); + if (rt_ret != RT_ERROR_NONE) { + GELOGW("rtFree failed, ret: 0x%X", rt_ret); + } + p2p_debug_addr_ = nullptr; + } + + is_op_debug_reg_ = false; + } + + return; +} + // initialize op sequence and call initialization function of each op respectively Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) { // validating params GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(priority_ < 0 || priority_ > 7, return PARAM_INVALID, "Priority must between 0-7, now is %d", priority_); GE_CHK_BOOL_RET_STATUS(ge_model_ != nullptr, PARAM_INVALID, "GeModel is null."); + Graph graph = ge_model_->GetGraph(); + ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); + GE_CHK_BOOL_RET_STATUS(compute_graph != nullptr, INTERNAL_ERROR, "Get compute graph is nullptr."); + // Initializing runtime_param_ InitRuntimeParams(); @@ -509,8 +579,6 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size if (hcom_streams_.find(i) != hcom_streams_.end()) { GE_CHK_RT_RET(rtStreamCreateWithFlags(&stream, priority_, stream_flags | RT_STREAM_FORCE_COPY)); - } else if (aicpu_streams_.find(i) != aicpu_streams_.end()) { - GE_CHK_RT_RET(rtStreamCreateWithFlags(&stream, priority_, stream_flags | RT_STREAM_AICPU)); } else { GE_CHK_RT_RET(rtStreamCreateWithFlags(&stream, priority_, stream_flags)); } @@ -531,20 +599,19 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size // create model_handle to load model GE_CHK_RT_RET(rtModelCreate(&rt_model_handle_, 0)); GE_CHK_RT_RET(rtModelGetId(rt_model_handle_, &runtime_model_id_)); + // inference will use default graph_id 0; + runtime_param_.graph_id = compute_graph->GetGraphID(); - Graph graph = ge_model_->GetGraph(); - compute_graph_ = GraphUtils::GetComputeGraph(graph); - GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, INTERNAL_ERROR, "Get compute graph is nullptr."); - - runtime_param_.graph_id = compute_graph_->GetGraphID(); + // op debug register + GE_CHK_STATUS_RET(OpDebugRegister(), "OpDebugRegister failed"); GE_TIMESTAMP_START(TransAllVarData); - GE_CHK_STATUS_RET(TransAllVarData(compute_graph_, runtime_param_.graph_id), "TransAllVarData failed."); + GE_CHK_STATUS_RET(TransAllVarData(compute_graph, runtime_param_.graph_id), "TransAllVarData failed."); GE_TIMESTAMP_END(TransAllVarData, "GraphLoader::TransAllVarData"); - GE_CHK_STATUS_RET(CopyVarData(compute_graph_), "copy var data failed."); + GE_CHK_STATUS_RET(TransVarDataUtils::CopyVarData(compute_graph, session_id_, device_id_), "copy var data failed."); GE_TIMESTAMP_START(InitModelMem); - GELOGI("known_node is %d", known_node_); + GELOGI("Known node is %d", known_node_); if (!known_node_) { GE_CHK_STATUS_RET_NOLOG(InitModelMem(dev_ptr, mem_size, weight_ptr, weight_size)); data_inputer_ = new (std::nothrow) DataInputer(); @@ -552,14 +619,16 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size } GE_TIMESTAMP_END(InitModelMem, "GraphLoader::InitModelMem"); - for (const ge::NodePtr &node : compute_graph_->GetDirectNode()) { - GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); - GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != VARIABLE, continue); + for (const ge::NodePtr &node : compute_graph->GetDirectNode()) { + auto op_desc = node->GetOpDesc(); + GE_IF_BOOL_EXEC(op_desc == nullptr, continue); + GetFixedAddrAttr(op_desc); + GE_IF_BOOL_EXEC(op_desc->GetType() != VARIABLE, continue); GE_IF_BOOL_EXEC(IsBroadCastOpData(node), - (void)ge::AttrUtils::SetStr(node->GetOpDesc(), VAR_ATTR_VAR_IS_BROADCAST, "var_is_restore");); + (void)ge::AttrUtils::SetStr(op_desc, VAR_ATTR_VAR_IS_BROADCAST, "var_is_restore");); } // for profiling - op_name_map_ = compute_graph_->GetGraphOpName(); + op_name_map_ = compute_graph->GetGraphOpName(); vector op_name; GE_IF_BOOL_EXEC(ge::AttrUtils::GetListStr(ge_model_, ATTR_MODEL_TASK_INDEX_OP_NAME, op_name), @@ -568,14 +637,14 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size for (size_t idx = 0; idx < op_name.size(); idx++) { op_name_map_[idx] = op_name[idx]; } - GELOGI("infer profiling: op_name_size(%zu)", op_name.size()); + GELOGI("Infer profiling: op_name_size(%zu)", op_name.size()); } - if (InitNodes(compute_graph_) != SUCCESS) { + if (InitNodes(compute_graph) != SUCCESS) { return FAILED; } - SetDataDumperArgs(); + SetDataDumperArgs(compute_graph); GE_TIMESTAMP_START(DoTaskSink); auto ret = DoTaskSink(); GE_TIMESTAMP_END(DoTaskSink, "GraphLoader::DoTaskSink"); @@ -583,22 +652,23 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size /// In zero copy model, if a aicpu operator is connected to the first or last layer, before model execution, /// the aicpu opertor needs to destroy history record, and update operator memory address. /// The model with specified aicpu operators is only marked here, and destruction is in ModelManager::ExecuteModel(). - if (MarkSpecifiedAicpuKernel() != SUCCESS) { - GELOGE(FAILED, "Mark model with specified aicpu operators failed."); - return FAILED; - } + need_destroy_aicpu_kernel_ = IsAicpuKernelConnectSpecifiedLayer(); + (void)ge::AttrUtils::GetListStr(ge_model_, ATTR_MODEL_OUT_NODES_NAME, out_node_name_); // collect profiling for ge if (ProfilingManager::Instance().ProfilingOn()) { std::vector compute_graph_desc_info; - Status ret1 = GetComputeGraphInfo(compute_graph_desc_info); + Status ret1 = GetComputeGraphInfo(compute_graph, compute_graph_desc_info); if (ret1 != SUCCESS) { GELOGE(ret1, "GetComputeGraphInfo failed."); return ret1; } ProfilingManager::Instance().ReportProfilingData(GetTaskDescInfo(), compute_graph_desc_info); + GE_CHK_STATUS(SinkModelProfile(), "Sink model profile failed."); } - GELOGI("davinci model init success."); + + Shrink(); + GELOGI("Davinci model init success."); return ret; } @@ -655,26 +725,14 @@ bool DavinciModel::IsAicpuKernelConnectSpecifiedLayer() { return false; } -/// -/// @ingroup ge -/// @brief mark ge model with specified aicpu operators . -/// @return Status -/// -Status DavinciModel::MarkSpecifiedAicpuKernel() { - bool result = IsAicpuKernelConnectSpecifiedLayer(); - if (!result) { - // No aicpu operator needing destroy. - GELOGD("No specified aicpu operator that connects to data or netoutput."); - return SUCCESS; - } - bool ret = ge::AttrUtils::SetBool(ge_model_, kNeedDestroySpecifiedAicpuKernel, result); - if (!ret) { - GELOGW("Add attr[%s] in ge model failed, and may lead to specified aicpu operators destruction failure.", - kNeedDestroySpecifiedAicpuKernel); +Status DavinciModel::UpdateSessionId(uint64_t session_id) { + GE_CHECK_NOTNULL(ge_model_); + if (!AttrUtils::SetInt(ge_model_, MODEL_ATTR_SESSION_ID, static_cast(session_id))) { + GELOGW("Set attr[%s] failed in updating session_id.", MODEL_ATTR_SESSION_ID.c_str()); } - GELOGI("Mark ge model success, the model has specified aicpu operators, ge model name: %s.", - ge_model_->GetName().c_str()); + + GELOGD("Update session id: %lu.", session_id); return SUCCESS; } @@ -721,12 +779,6 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { continue; } - if (IsCallDumpInputOp(op_desc)) { - GELOGI("node[%s] is no task op , call SaveDumpInput to save it's output node info", op_desc->GetName().c_str()); - data_dumper_.SaveDumpInput(node); - continue; - } - if (op_desc->GetType() == NETOUTPUT) { if (InitNetOutput(node) != SUCCESS) { GELOGE(PARAM_INVALID, "NetOutput init failed, Name: %s", op_desc->GetName().c_str()); @@ -744,6 +796,29 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { continue; } + if (IsNoTaskAndDumpNeeded(op_desc)) { + GELOGD("node[%s] without task, and save op_desc and addr for dump", op_desc->GetName().c_str()); + const RuntimeParam &rts_param = GetRuntimeParam(); + const vector input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); + const vector output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); + const vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc); + vector tensor_device_addrs; + tensor_device_addrs.insert(tensor_device_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); + tensor_device_addrs.insert(tensor_device_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); + tensor_device_addrs.insert(tensor_device_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); + void *addr = nullptr; + auto size = kAddrLen * tensor_device_addrs.size(); + GE_CHK_RT_RET(rtMalloc(&addr, size, RT_MEMORY_HBM)); + + rtError_t rt_ret = rtMemcpy(addr, size, tensor_device_addrs.data(), size, RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(rt_ret, "rtMemcpy error"); + GE_CHK_RT(rtFree(addr)); + return FAILED; + } + saved_task_addrs_.emplace(op_desc, addr); + } + GE_TIMESTAMP_RESTART(InitTbeHandle); uint32_t run_mode = static_cast(domi::ImplyType::INVALID); if (AttrUtils::GetInt(op_desc, ATTR_NAME_IMPLY_TYPE, run_mode) && @@ -773,7 +848,6 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { /// @brief Data Op Initialize. /// @param [in] NodePtr: Data Op. /// @param [in/out] data_op_index: NetOutput addr size info. -/// @param [in/out] input_data_info: Data index and addr info {index, {size, addr}}. /// @return Status Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index) { // op_desc Checked by Init: Data, valid. @@ -801,7 +875,7 @@ Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index) { // Make information for copy input data. const vector output_size_list = ModelUtils::GetOutputSize(op_desc); - const vector virtual_addr_list = ModelUtils::GetOutputDataAddrs(runtime_param_, op_desc, false); + const vector virtual_addr_list = ModelUtils::GetOutputDataAddrs(runtime_param_, op_desc); if (output_size_list.empty() || virtual_addr_list.empty() || (output_size_list.size() != virtual_addr_list.size())) { GELOGE(PARAM_INVALID, "Data[%s] init failed: Output size is %zu, Output addr is %zu", op_desc->GetName().c_str(), output_size_list.size(), virtual_addr_list.size()); @@ -877,7 +951,7 @@ Status DavinciModel::InitNetOutput(const NodePtr &node) { output_op_list_.push_back(op_desc); // Make information for copy output data. const vector input_size_list = ModelUtils::GetInputSize(op_desc); - const vector virtual_addr_list = ModelUtils::GetInputDataAddrs(runtime_param_, op_desc, false); + const vector virtual_addr_list = ModelUtils::GetInputDataAddrs(runtime_param_, op_desc); if (input_size_list.empty() && virtual_addr_list.empty()) { GELOGI("NetOutput[%s] is empty.", op_desc->GetName().c_str()); return SUCCESS; @@ -890,7 +964,15 @@ Status DavinciModel::InitNetOutput(const NodePtr &node) { size_t num = output_data_info_.size(); for (size_t idx = 0; idx < input_size_list.size(); ++idx) { - output_data_info_[num + idx] = {input_size_list[idx], virtual_addr_list[idx]}; + int64_t size = input_size_list[idx]; + auto tensor_desc = op_desc->GetInputDescPtr(idx); + if ((tensor_desc == nullptr) || (TensorUtils::GetTensorSizeInBytes(*tensor_desc, size) != GRAPH_SUCCESS)) { + GELOGE(FAILED, "GetTensorSizeInBytes failed!"); + return FAILED; + } + + GELOGI("Tensor data size: GetSize=%ld, GetTensorSizeInBytes=%ld", input_size_list[idx], size); + output_data_info_[num + idx] = {size, virtual_addr_list[idx]}; } SetOutputOutsideAddr(virtual_addr_list); @@ -1000,7 +1082,7 @@ Status DavinciModel::InitVariable(const OpDescPtr &op_desc) { Status DavinciModel::SetQueIds(const std::vector &input_queue_ids, const std::vector &output_queue_ids) { if (input_queue_ids.empty() && output_queue_ids.empty()) { - GELOGE(PARAM_INVALID, "Para is empty"); + GELOGE(PARAM_INVALID, "Param is empty"); return PARAM_INVALID; } @@ -1033,11 +1115,7 @@ Status DavinciModel::LoadWithQueue() { return PARAM_INVALID; } - // create stream instance which rt_model_handel is running on, this is S0. - GE_CHK_RT_RET(rtStreamCreateWithFlags(&rt_model_stream_, priority_, RT_STREAM_AICPU)); - is_inner_model_stream_ = true; - GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, rt_model_stream_, RT_HEAD_STREAM)); - + GE_CHK_STATUS_RET(AddHeadStream(), "Add head stream failed."); // Binding input_queue and Data Op. GE_CHK_STATUS_RET(BindInputQueue(), "Launch bind input queue failed."); GE_CHK_STATUS_RET(CpuTaskModelZeroCopy(input_mbuf_list_, input_outside_addrs_), "Launch zero copy failed."); @@ -1046,7 +1124,7 @@ Status DavinciModel::LoadWithQueue() { GE_CHK_STATUS_RET(BindOutputQueue(), "Launch bind output queue failed."); GE_CHK_STATUS_RET(CpuTaskModelZeroCopy(output_mbuf_list_, output_outside_addrs_), "Launch zero copy failed."); - GE_CHK_STATUS_RET(CpuActiveStream(active_stream_list_), "Launch active entry stream failed."); + GE_CHK_STATUS_RET(CpuActiveStream(), "Launch active entry stream failed."); GE_CHK_STATUS_RET(CpuWaitEndGraph(), "Launch wait end graph failed."); GE_CHK_STATUS_RET(BindEnqueue(), "Launch enqueue failed."); GE_CHK_STATUS_RET(CpuModelRepeat(), "Launch model repeat failed."); @@ -1090,7 +1168,7 @@ Status DavinciModel::BindInputQueue() { /// @return: 0 for success / others for failed Status DavinciModel::CpuModelDequeue(uint32_t queue_id) { GELOGI("Set CpuKernel model dequeue task enter."); - std::shared_ptr dequeue_task = MakeShared(rt_model_stream_); + std::shared_ptr dequeue_task = MakeShared(rt_entry_stream_); if (dequeue_task == nullptr) { GELOGE(FAILED, "Make CpuTaskModelDequeue task failed."); return FAILED; @@ -1111,7 +1189,7 @@ Status DavinciModel::CpuModelDequeue(uint32_t queue_id) { Status DavinciModel::CpuTaskModelZeroCopy(std::vector &mbuf_list, std::map> &outside_addrs) { GELOGI("Set CpuKernel model zero_copy task enter."); - std::shared_ptr zero_copy = MakeShared(rt_model_stream_); + std::shared_ptr zero_copy = MakeShared(rt_entry_stream_); if (zero_copy == nullptr) { GELOGE(FAILED, "Make CpuTaskZeroCopy task failed."); return FAILED; @@ -1156,7 +1234,6 @@ Status DavinciModel::BindOutputQueue() { /// @ingroup ge /// @brief definiteness queue schedule, bind output queue to task. -/// @param [in] queue_id: output queue id from user. /// @param [in] addr: NetOutput Op input tensor address. /// @param [in] size: NetOutput Op input tensor size. /// @return: 0 for success / others for failed @@ -1167,7 +1244,7 @@ Status DavinciModel::CpuModelPrepareOutput(uintptr_t addr, uint32_t size) { return FAILED; } - std::shared_ptr prepare_output = MakeShared(rt_model_stream_); + std::shared_ptr prepare_output = MakeShared(rt_entry_stream_); if (prepare_output == nullptr) { GELOGE(FAILED, "Make CpuTaskPrepareOutput task failed."); return FAILED; @@ -1187,25 +1264,21 @@ Status DavinciModel::CpuModelPrepareOutput(uintptr_t addr, uint32_t size) { /// /// @ingroup ge /// @brief definiteness queue schedule, active original model stream. -/// @param [in] streams: streams will active by S0. /// @return: 0 for success / others for failed /// -Status DavinciModel::CpuActiveStream(const std::vector &stream_list) { - GELOGI("Set CpuKernel active stream task:%zu enter.", stream_list.size()); - for (auto s : stream_list) { - std::shared_ptr active_entry = MakeShared(rt_model_stream_); - if (active_entry == nullptr) { - GELOGE(FAILED, "Make CpuTaskActiveEntry task failed."); - return FAILED; - } - - if (active_entry->Init(s) != SUCCESS) { - return FAILED; - } +Status DavinciModel::CpuActiveStream() { + GELOGI("Set CpuKernel active stream task enter."); + std::shared_ptr active_entry = MakeShared(rt_entry_stream_); + if (active_entry == nullptr) { + GELOGE(FAILED, "Make CpuTaskActiveEntry task failed."); + return FAILED; + } - cpu_task_list_.push_back(active_entry); + if (active_entry->Init(rt_head_stream_) != SUCCESS) { + return FAILED; } + cpu_task_list_.push_back(active_entry); GELOGI("Set CpuKernel active stream task success."); return SUCCESS; } @@ -1215,7 +1288,7 @@ Status DavinciModel::CpuActiveStream(const std::vector &stream_list) /// @return: 0 for success / others for failed Status DavinciModel::CpuWaitEndGraph() { GELOGI("Set CpuKernel wait end graph task enter."); - std::shared_ptr wait_endgraph = MakeShared(rt_model_stream_); + std::shared_ptr wait_endgraph = MakeShared(rt_entry_stream_); if (wait_endgraph == nullptr) { GELOGE(FAILED, "Make CpuTaskWaitEndGraph task failed."); return FAILED; @@ -1248,7 +1321,7 @@ Status DavinciModel::BindEnqueue() { Status DavinciModel::CpuModelEnqueue(uint32_t queue_id, uintptr_t out_mbuf) { GELOGI("Set CpuKernel model enqueue task enter."); - std::shared_ptr model_enqueue = MakeShared(rt_model_stream_); + std::shared_ptr model_enqueue = MakeShared(rt_entry_stream_); if (model_enqueue == nullptr) { GELOGE(FAILED, "Make CpuTaskModelEnqueue task failed."); return FAILED; @@ -1267,7 +1340,7 @@ Status DavinciModel::CpuModelEnqueue(uint32_t queue_id, uintptr_t out_mbuf) { /// @return: 0 for success / others for failed Status DavinciModel::CpuModelRepeat() { GELOGI("Set CpuKernel repeat task enter."); - std::shared_ptr model_repeat = MakeShared(rt_model_stream_); + std::shared_ptr model_repeat = MakeShared(rt_entry_stream_); if (model_repeat == nullptr) { GELOGE(FAILED, "Make CpuTaskModelRepeat task failed."); return FAILED; @@ -1319,36 +1392,8 @@ Status DavinciModel::GetInputOutputDescInfo(vector &input_d /// @param [out] batch_info /// @return execute result /// -Status DavinciModel::GetDynamicBatchInfo(std::vector> &batch_info) { - for (auto &iter : op_list_) { - OpDescPtr op_desc = iter.second; - if (op_desc == nullptr) { - GELOGE(FAILED, "op_desc is null, index=%u.", iter.first); - return FAILED; - } - - if (op_desc->GetType() != STREAMSWITCHN) { - continue; - } - - batch_info.clear(); - uint32_t batch_num = 0; - if (!AttrUtils::GetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) { - GELOGE(FAILED, "Failed to get attr ATTR_NAME_BATCH_NUM, StreamSwitchN: %s.", op_desc->GetName().c_str()); - return FAILED; - } - std::vector batch_shape; - for (uint32_t i = 0; i < batch_num; i++) { - batch_shape.clear(); - const std::string attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i); - if (!AttrUtils::GetListInt(op_desc, attr_name, batch_shape)) { - GELOGE(FAILED, "Failed to get attr ATTR_NAME_PRED_VALUE, StreamSwitchN: %s.", op_desc->GetName().c_str()); - return FAILED; - } - batch_info.emplace_back(batch_shape); - } - break; - } +Status DavinciModel::GetDynamicBatchInfo(std::vector> &batch_info) const { + batch_info = batch_info_; return SUCCESS; } @@ -1447,6 +1492,55 @@ Status DavinciModel::GetInputOutputDescInfoForZeroCopy(vectorHasAttr(ATTR_NAME_INPUT_DIMS)) { + // When static aipp is set, need to get the model input dims which processed by aipp + vector model_input_dims; + (void)AttrUtils::GetListInt(op_desc, ATTR_NAME_INPUT_DIMS, model_input_dims); + if (model_input_dims.size() == static_cast(NORMAL_TENSOR_SIZE)) { + input.shape_info.num = model_input_dims[n]; + input.shape_info.height = model_input_dims[h]; + input.shape_info.width = model_input_dims[w]; + input.shape_info.channel = model_input_dims[c]; + } + for (size_t k = 0; k < model_input_dims.size(); ++k) { + input.shape_info.dims.push_back(model_input_dims[k]); + } + is_new_model_desc_ = false; + return; + } + + if (!op_desc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) { + if (op_desc->GetInputDescPtr(0)->GetShape().GetDimNum() == static_cast(NORMAL_TENSOR_SIZE)) { + input.shape_info.num = op_desc->GetInputDescPtr(0)->GetShape().GetDim(n); + input.shape_info.height = op_desc->GetInputDescPtr(0)->GetShape().GetDim(h); + input.shape_info.width = op_desc->GetInputDescPtr(0)->GetShape().GetDim(w); + input.shape_info.channel = op_desc->GetInputDescPtr(0)->GetShape().GetDim(c); + } + for (size_t k = 0; k < op_desc->GetInputDescPtr(0)->GetShape().GetDimNum(); k++) { + input.shape_info.dims.push_back(op_desc->GetInputDescPtr(0)->GetShape().GetDim(k)); + } + } else { + vector origin_input_dims; + (void)AttrUtils::GetListInt(op_desc, ATTR_MBATCH_ORIGIN_INPUT_DIMS, origin_input_dims); + if (origin_input_dims.size() == static_cast(NORMAL_TENSOR_SIZE)) { + input.shape_info.num = origin_input_dims[n]; + input.shape_info.height = origin_input_dims[h]; + input.shape_info.width = origin_input_dims[w]; + input.shape_info.channel = origin_input_dims[c]; + } + for (size_t k = 0; k < origin_input_dims.size(); ++k) { + input.shape_info.dims.push_back(origin_input_dims[k]); + } + } +} + Status DavinciModel::GetInputDescInfo(vector &input_desc, std::vector &formats) { for (size_t index = 0; index < data_op_list_.size(); ++index) { InputOutputDescInfo input; @@ -1455,6 +1549,7 @@ Status DavinciModel::GetInputDescInfo(vector &input_desc, s Format format = data_op_list_[index]->GetInputDescPtr(0)->GetFormat(); CreateInputDimsInfo(data_op_list_[index], format, input); + input.data_type = data_op_list_[index]->GetInputDescPtr(0)->GetDataType(); input.name = data_op_list_[index]->GetName(); int64_t input_size = 0; @@ -1511,7 +1606,7 @@ void DavinciModel::CreateOutput(uint32_t index, OpDescPtr &op_desc, InputOutputD int64_t tensor_size = 0; (void)TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size); // no need to check value - output.size = static_cast(tensor_size); + output.size = static_cast(tensor_size); output.data_type = op_desc->GetInputDescPtr(index)->GetDataType(); } @@ -1520,9 +1615,6 @@ Status DavinciModel::GetOutputDescInfo(vector &output_desc, for (size_t i = 0; i < output_op_list_.size(); i++) { auto &op_desc = output_op_list_[i]; uint32_t out_size = static_cast(op_desc->GetInputsSize()); - // get real out nodes from model - vector out_node_name; - (void)ge::AttrUtils::GetListStr(ge_model_, ATTR_MODEL_OUT_NODES_NAME, out_node_name); for (uint32_t index = 0; index < out_size; index++) { string output_name; InputOutputDescInfo output; @@ -1534,11 +1626,11 @@ Status DavinciModel::GetOutputDescInfo(vector &output_desc, GE_CHK_BOOL_RET_STATUS(src_name.size() > index && src_index.size() > index, INTERNAL_ERROR, "construct output_name failed."); // forward compatbility, if old om has no out_node_name, need to return output follow origin way - if (out_size == out_node_name.size()) { + if (out_size == out_node_name_.size()) { // neweast plan, the index will add to name during generate model. - bool contains_colon = out_node_name[index].find(":") != std::string::npos; + bool contains_colon = out_node_name_[index].find(":") != std::string::npos; output_name = - contains_colon ? out_node_name[index] : out_node_name[index] + ":" + std::to_string(src_index[index]); + contains_colon ? out_node_name_[index] : out_node_name_[index] + ":" + std::to_string(src_index[index]); } else { output_name = std::string("output_") + std::to_string(index) + "_" + src_name[index] + "_" + std::to_string(src_index[index]); @@ -1572,12 +1664,12 @@ Status DavinciModel::CopyInputData(const InputData &input_data, bool device_data const DataBuffer &data_buf = blobs[data.first]; void *mem_addr = data.second.second; - uint32_t mem_size = static_cast(data.second.first); + uint64_t mem_size = static_cast(data.second.first); GE_CHK_BOOL_RET_STATUS(mem_size >= data_buf.length, PARAM_INVALID, - "input data size(%u) does not match model required size(%u), ret failed.", data_buf.length, + "input data size(%lu) does not match model required size(%lu), ret failed.", data_buf.length, mem_size); - GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] input[%u] dst[%p] src[%p] mem_size[%u] datasize[%u]", + GELOGI("[IMAS]CopyPlainData memcpy graph_%lu type[F] input[%lu] dst[%p] src[%p] mem_size[%lu] datasize[%lu]", runtime_param_.graph_id, data.first, mem_addr, data_buf.data, mem_size, data_buf.length); if (data_buf.length == 0) { GELOGW("No data need to memcpy!"); @@ -1625,15 +1717,9 @@ inline int64_t SumSize(const vector &size_list) { } Status DavinciModel::SinkModelProfile() { - // not support non-sink model - GE_CHK_BOOL_EXEC(this->model_task_def_ != nullptr, return SUCCESS); - // profiling plugin must be registered Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); - if (reporter == nullptr) { - GELOGI("Profiling report is nullptr!"); - return SUCCESS; - } + GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return SUCCESS); GELOGI("Start collect model load profiling data."); @@ -1645,15 +1731,19 @@ Status DavinciModel::SinkModelProfile() { return FAILED, "Sink model tag memcpy error."); // Model Header - string name = this->Name(); - int32_t name_len = name.size(); + string name; + if (!om_name_.empty()) { + name = om_name_; + } else { + name = name_; + } + size_t name_len = name.size(); // phy device id uint32_t phy_device_id = 0; rtError_t rt_ret = rtGetDevicePhyIdByIndex(device_id_, &phy_device_id); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "runtime get phy_device_id failed, current phy_device_id:%d", phy_device_id); - return FAILED; - } + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, + GELOGE(rt_ret, "runtime get phy_device_id failed, current phy_device_id:%u", phy_device_id); + return FAILED); reporter_data.deviceId = phy_device_id; reporter_data.data = (unsigned char *)&name_len; reporter_data.dataLen = sizeof(int32_t); @@ -1690,7 +1780,6 @@ Status DavinciModel::SinkModelProfile() { for (int32_t i = 0; i < task_num; i++) { auto task = task_list_[i]; auto fusion_op_info = task->GetFusionOpInfo(); - // when type is RT_MODEL_TASK_KERNEL, ctx is not null if (fusion_op_info != nullptr) { uint32_t op_num = fusion_op_info->original_op_names.size(); @@ -1809,15 +1898,9 @@ Status DavinciModel::SinkModelProfile() { } Status DavinciModel::SinkTimeProfile(const InputData ¤t_data) { - // not support non-sink model - GE_CHK_BOOL_EXEC(this->model_task_def_ != nullptr, return SUCCESS); - // profiling plugin must be registered Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); - if (reporter == nullptr) { - GELOGI("Profiling report is nullptr!"); - return SUCCESS; - } + GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return SUCCESS); Msprof::Engine::ReporterData reporter_data{}; // report model data tag name @@ -1832,15 +1915,19 @@ Status DavinciModel::SinkTimeProfile(const InputData ¤t_data) { // device id uint32_t phy_device_id = 0; rtError_t rt_ret = rtGetDevicePhyIdByIndex(device_id_, &phy_device_id); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "runtime get phy_device_id failed, current phy_device_id:%d", phy_device_id); - return FAILED; - } + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, + GELOGE(rt_ret, "runtime get phy_device_id failed, current phy_device_id:%u", phy_device_id); + return FAILED); reporter_data.deviceId = phy_device_id; // Model Header - string name = this->Name(); - int32_t name_len = name.size(); + string name; + if (!om_name_.empty()) { + name = om_name_; + } else { + name = name_; + } + size_t name_len = name.size(); reporter_data.data = (unsigned char *)&name_len; reporter_data.dataLen = sizeof(int32_t); GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", @@ -1918,77 +2005,62 @@ void DavinciModel::SetProfileTime(ModelProcStage stage, int64_t endTime) { /// @ingroup ge /// @brief send Output Op result to upper layer /// @already malloced in ModelLoad, no need to malloc again -/// @param [in] sink_op Sink Op +/// @param [in] data_id: the index of output_data +/// @param [in/out] output_data: real user output_data +/// @param [in] kind: the kind of rtMemcpy /// @return Status result /// @author /// -Status DavinciModel::CopyOutputData(uint32_t data_id, OutputData &output_data) { - Status ret = SUCCESS; +Status DavinciModel::CopyOutputData(uint32_t data_id, OutputData &output_data, rtMemcpyKind_t kind) { if (output_op_list_.empty()) { - ret = SyncVarData(); - } else { - output_data.index = data_id; - output_data.model_id = model_id_; - GE_CHK_BOOL_RET_STATUS(output_data.blobs.size() == output_data_info_.size(), INTERNAL_ERROR, - "output buffer size[%zu] not equal output_size_list[%zu] size!", output_data.blobs.size(), - output_data_info_.size()); - - // index of data in output_data - uint32_t output_data_index = 0; - for (auto &op_desc : output_op_list_) { - ret = CopyOutputDataToUser(op_desc, output_data.blobs, output_data_index); - GE_CHK_BOOL_EXEC(ret == SUCCESS, break, "Copy output data to model ret failed, index:%u, model id:%u", - output_data.index, output_data.model_id); - } + Status ret = SyncVarData(); + DumpOpInputOutput(); + return ret; } - (void)DumpOpInputOutput(); // dump, not care result. - return ret; -} - -Status DavinciModel::CopyOutputDataToUser(OpDescPtr &op_desc, std::vector &blobs, uint32_t &data_index) { - Output model_output(op_desc, this); - - GE_CHK_BOOL_RET_STATUS(model_output.Init() == SUCCESS, PARAM_INVALID, "make shared model_output failed"); + output_data.index = data_id; + output_data.model_id = model_id_; + if (output_data.blobs.size() != output_data_info_.size()) { + GELOGE(FAILED, "Output data buffer num=%zu not equal model data num=%zu", output_data.blobs.size(), + output_data_info_.size()); + return FAILED; + } - vector v_output_size; - vector v_output_data_addr; - model_output.GetOutputData(v_output_data_addr, v_output_size); + std::vector &blobs = output_data.blobs; + for (const auto &output : output_data_info_) { + if (output.first >= blobs.size()) { + GELOGE(FAILED, "Blobs not match: blobs=%zu, tensor=%zu, index=%u, size=%ld", blobs.size(), + input_data_info_.size(), output.first, output.second.first); + return FAILED; + } - // for all output tensor, copy output data from op to designated position - for (size_t i = 0; i < v_output_size.size(); ++i) { - GE_CHK_BOOL_RET_STATUS(data_index < blobs.size(), PARAM_INVALID, - "The blobs size:%zu, data_op size:%zu, curr output size:%zu", blobs.size(), - data_op_list_.size(), v_output_size.size()); + if ((kind == RT_MEMCPY_DEVICE_TO_DEVICE) && (copy_only_addrs_.count(output.second.second) == 0)) { + continue; // Skip: Feed by zero copy. + } - DataBuffer &data_buf = blobs[data_index]; - data_index++; + DataBuffer &buffer = blobs[output.first]; + uint64_t mem_size = static_cast(output.second.first); + if ((buffer.length == 0) || (mem_size == 0)) { + GELOGI("Length of data is zero, No need copy. output tensor index=%u", output.first); + continue; + } - uint32_t size = data_buf.length; - GE_CHK_BOOL_RET_STATUS(size <= v_output_size[i], PARAM_INVALID, - "Model output data size(%u) does not match required size(%u).", v_output_size[i], - data_buf.length); + if (buffer.length < mem_size) { + GELOGE(FAILED, "Tensor data size=%lu, buffer size=%u", mem_size, buffer.length); + return FAILED; + } else if (buffer.length > mem_size) { + GELOGW("Tensor data size=%lu, buffer size=%u", mem_size, buffer.length); + } - GELOGI( - "CopyOutputDataToUser memcpy graph_%u type[F] name[%s] output[%lu] dst[%p] src[%p] mem_size[%u] datasize[%u]", - runtime_param_.graph_id, op_desc->GetName().c_str(), i, data_buf.data, v_output_data_addr[i], data_buf.length, - v_output_size[i]); - GE_CHK_RT_RET(rtMemcpy(data_buf.data, size, v_output_data_addr[i], size, RT_MEMCPY_DEVICE_TO_DEVICE)); + GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] output[%u] memaddr[%p] mem_size[%lu] datasize[%u]", + runtime_param_.graph_id, output.first, output.second.second, mem_size, buffer.length); + GE_CHK_RT_RET(rtMemcpy(buffer.data, buffer.length, output.second.second, mem_size, kind)); } + DumpOpInputOutput(); return SUCCESS; } -Status DavinciModel::SyncDataAndDump() { - Status ret = SUCCESS; - if (output_op_list_.empty()) { - ret = SyncVarData(); - } - - (void)DumpOpInputOutput(); // dump, not care result. - return ret; -} - Status DavinciModel::GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data_index, OutputData *output_data, std::vector &outputs) { GE_CHECK_NOTNULL(op_desc); @@ -2020,13 +2092,13 @@ Status DavinciModel::GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data GELOGE(GE_GRAPH_MALLOC_FAILED, "Malloc buffer failed."); return GE_GRAPH_MALLOC_FAILED; } - output_data->blobs.push_back({data_buf.get(), static_cast(out_buffer_size_vec[i]), false}); + output_data->blobs.push_back({data_buf.get(), static_cast(out_buffer_size_vec[i]), false}); ge::OutputTensorInfo output; output.dims = shape_info_vec[i]; output.data = std::move(data_buf); output.length = out_buffer_size_vec[i]; outputs.emplace_back(std::move(output)); - GELOGI("Output index:%zu, data_length:%u.", i, output.length); + GELOGI("Output index:%zu, data_length:%lu.", i, output.length); } return SUCCESS; } @@ -2035,7 +2107,10 @@ Status DavinciModel::GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data /// @ingroup ge /// @brief send Output Op result to upper layer /// @already malloced in ModelLoad, no need to malloc again -/// @param [in] sink_op Sink Op +/// @param [in] data_id: the index of output_data +/// @param [in] rslt_flg: result flag +/// @param [in] seq_end_flag: sequence end flag +/// @param [out] output_data: real user output_data /// @return Status result /// @author /// @@ -2066,20 +2141,17 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b // copy output data from op to designated position for (auto &op_desc : output_op_list_) { - Output model_output(op_desc, this); - if (model_output.Init() != SUCCESS || GenOutputTensorInfo(op_desc, data_index, output_data, outputs) != SUCCESS) { + if (GenOutputTensorInfo(op_desc, data_index, output_data, outputs) != SUCCESS) { return INTERNAL_ERROR; } + data_index += op_desc->GetInputsSize(); + } - Status ret = model_output.CopyResult(*output_data, data_index, data_index, false); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "CopyResult failed, op name: %s", op_desc->GetName().c_str()); - GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, INTERNAL_ERROR, outputs), "OnComputeDone failed"); - return INTERNAL_ERROR; - } + if (CopyOutputData(data_id, *output_data, RT_MEMCPY_DEVICE_TO_HOST) != SUCCESS) { + GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, INTERNAL_ERROR, outputs), "OnComputeDone failed"); + return INTERNAL_ERROR; } - GE_IF_BOOL_EXEC((DumpOpInputOutput() != SUCCESS), GELOGW("dump op failed, model_id: %u", model_id_);); if (seq_end_flag) { GELOGW("End of sequence, model id: %u", model_id_); GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, END_OF_SEQUENCE, outputs), "OnCompute Done failed."); @@ -2092,6 +2164,7 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b /// /// @ingroup ge /// @brief return not output to upper layer for cloud case +/// @param [in] data_id /// @return Status result /// Status DavinciModel::ReturnNoOutput(uint32_t data_id) { @@ -2103,7 +2176,7 @@ Status DavinciModel::ReturnNoOutput(uint32_t data_id) { op_desc->GetName().c_str()); } - GE_IF_BOOL_EXEC((DumpOpInputOutput() != SUCCESS), GELOGW("dump op failed, model_id: %u", model_id_);); + DumpOpInputOutput(); GE_CHK_BOOL_EXEC(listener_ != nullptr, return PARAM_INVALID, "listener_ is null!"); std::vector outputs; GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, SUCCESS, outputs), "OnComputeDone failed."); @@ -2113,41 +2186,40 @@ Status DavinciModel::ReturnNoOutput(uint32_t data_id) { /// /// @ingroup ge /// @brief dump all op input and output information -/// @param [in] op_list model_id -/// @return Status result +/// @return void /// -Status DavinciModel::DumpOpInputOutput() { +void DavinciModel::DumpOpInputOutput() { + char *ge_dump_env = std::getenv("DUMP_OP"); + int dump_op_switch = (ge_dump_env != nullptr) ? std::strtol(ge_dump_env, nullptr, kDecimal) : 0; + if (dump_op_switch == 0) { + GELOGI("need to set DUMP_OP for dump op input and output"); + return; + } + if (op_list_.empty()) { - GELOGW("op_list is empty."); - return FAILED; + GELOGW("op list is empty"); + return; } - char *ge_dump_env = getenv("DUMP_OP"); - int dump_op_switch = - (ge_dump_env != nullptr) ? std::strtol(ge_dump_env, nullptr, kDecimal) : 0; // 10 for decimal number - if (dump_op_switch != 0) { - int64_t cnt = 1; - for (auto it : op_list_) { - if (maxDumpOpNum_ != 0 && cnt > maxDumpOpNum_) { - GELOGW("dump op cnt > maxDumpOpNum, maxDumpOpNum: %ld.", maxDumpOpNum_); - return SUCCESS; - } - Status ret = DumpSingleOpInputOutput(it.second); - cnt++; - if (ret != SUCCESS) { - GELOGE(FAILED, "dump single op failed, model_id: %u", model_id_); - return FAILED; - } + + int64_t cnt = 1; + for (auto it : op_list_) { + if (maxDumpOpNum_ != 0 && cnt > maxDumpOpNum_) { + GELOGW("dump op cnt > maxDumpOpNum, maxDumpOpNum: %ld", maxDumpOpNum_); + return; + } + + cnt++; + if (DumpSingleOpInputOutput(it.second) != SUCCESS) { + GELOGW("dump single op failed, model_id: %u", model_id_); + return; } - } else { - GELOGW("need to set DUMP_OP for dump op input and output."); } - return SUCCESS; } /// /// @ingroup ge /// @brief dump single op input and output information -/// @param [in] dump_op model_id +/// @param [in] op_def: the op_desc which will be dump /// @return Status result /// Status DavinciModel::DumpSingleOpInputOutput(const OpDescPtr &op_def) { @@ -2163,7 +2235,7 @@ Status DavinciModel::DumpSingleOpInputOutput(const OpDescPtr &op_def) { } } const vector input_size_vec = ModelUtils::GetInputSize(op_def); - const vector input_addr_vec = ModelUtils::GetInputDataAddrs(runtime_param_, op_def, false); + const vector input_addr_vec = ModelUtils::GetInputDataAddrs(runtime_param_, op_def); vector v_memory_type; bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_def, ATTR_NAME_INPUT_MEM_TYPE_LIST, v_memory_type); GELOGD("DumpSingleOp[%s], input size[%zu], input memory type size[%zu]", op_def->GetName().c_str(), @@ -2186,7 +2258,7 @@ Status DavinciModel::DumpSingleOpInputOutput(const OpDescPtr &op_def) { } const vector output_size_vec = ModelUtils::GetOutputSize(op_def); - const vector output_addr_vec = ModelUtils::GetOutputDataAddrs(runtime_param_, op_def, false); + const vector output_addr_vec = ModelUtils::GetOutputDataAddrs(runtime_param_, op_def); v_memory_type.clear(); has_mem_type_attr = ge::AttrUtils::GetListInt(op_def, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, v_memory_type); GELOGD("DumpSingleOp[%s], output size[%zu], output memory type size[%zu]", op_def->GetName().c_str(), @@ -2256,7 +2328,7 @@ void *DavinciModel::Run(DavinciModel *model) { ret != SUCCESS, (void)model->ReturnResult(current_data.index, false, false, data_wrapper->GetOutput()); CsaInteract::GetInstance().StoreInternalErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC); continue, "Copy input data to model failed."); // [No need to check value] - GE_TIMESTAMP_END(Model_SyncVarData, "Model Run SyncVarData"); + GE_IF_BOOL_EXEC(model->is_first_execute_, GE_TIMESTAMP_EVENT_END(Model_SyncVarData, "Model Run SyncVarData")); GELOGI("Copy input data, model id:%u", model_id); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), model->SetProfileTime(MODEL_PRE_PROC_START)); @@ -2302,7 +2374,7 @@ void *DavinciModel::Run(DavinciModel *model) { CsaInteract::GetInstance().WriteErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); continue); GELOGI("rtModelExecute end"); - GE_TIMESTAMP_END(rtModelExecute, "GraphExcute::rtModelExecute"); + GE_IF_BOOL_EXEC(model->is_first_execute_, GE_TIMESTAMP_EVENT_END(rtModelExecute, "GraphExcute::rtModelExecute")); GE_TIMESTAMP_START(rtStreamSynchronize); GELOGI("rtStreamSynchronize start."); @@ -2317,7 +2389,8 @@ void *DavinciModel::Run(DavinciModel *model) { CsaInteract::GetInstance().StoreInternalErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); continue); GELOGI("rtStreamSynchronize end."); - GE_TIMESTAMP_END(rtStreamSynchronize, "GraphExcute::Wait for rtStreamSynchronize"); + GE_IF_BOOL_EXEC(model->is_first_execute_, + GE_TIMESTAMP_EVENT_END(rtStreamSynchronize, "GraphExcute::Wait for rtStreamSynchronize")); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), model->SetProfileTime(MODEL_INFER_END)); } @@ -2328,11 +2401,13 @@ void *DavinciModel::Run(DavinciModel *model) { (void)model->ReturnResult(current_data.index, rslt_flg, false, data_wrapper->GetOutput())) // copy output data from device to host for variable graph GE_IF_BOOL_EXEC(model->output_op_list_.empty(), (void)model->ReturnNoOutput(current_data.index)); - GE_TIMESTAMP_END(ReturnResult3, "GraphExcute::CopyDataFromDeviceToHost"); + GE_IF_BOOL_EXEC(model->is_first_execute_, + GE_TIMESTAMP_EVENT_END(ReturnResult3, "GraphExcute::CopyDataFromDeviceToHost")); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), model->SetProfileTime(MODEL_AFTER_PROC_END)); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), (void)model->SinkTimeProfile(current_data)); model->iterator_count_++; + model->is_first_execute_ = false; GELOGI("run iterator count is %lu", model->iterator_count_); } @@ -2385,7 +2460,7 @@ Status DavinciModel::ModelRunStart() { is_inner_model_stream_ = true; string opt = "0"; - (void)ge::GetContext().GetOption("ge.maxDumpOpNum", opt); // option may not be set up, no need to check value + (void)ge::GetContext().GetOption(OPTION_GE_MAX_DUMP_OP_NUM, opt); // option may not be set up, no need to check value int64_t maxDumpOpNum = std::strtol(opt.c_str(), nullptr, kDecimal); maxDumpOpNum_ = maxDumpOpNum; @@ -2428,7 +2503,18 @@ void DavinciModel::UnbindTaskSinkStream() { // destroy stream that is bound with rt_model GE_LOGW_IF(rtStreamDestroy(rt_model_stream_) != RT_ERROR_NONE, "Destroy stream for rt_model failed.") } - return; + + if (is_pure_head_stream_ && rt_head_stream_ != nullptr) { + GE_LOGW_IF(rtModelUnbindStream(rt_model_handle_, rt_head_stream_) != RT_ERROR_NONE, "Unbind stream failed!"); + GE_LOGW_IF(rtStreamDestroy(rt_head_stream_) != RT_ERROR_NONE, "Destroy stream for rt_model failed."); + rt_head_stream_ = nullptr; + } + + if (rt_entry_stream_ != nullptr) { + GE_LOGW_IF(rtModelUnbindStream(rt_model_handle_, rt_entry_stream_) != RT_ERROR_NONE, "Unbind stream failed!"); + GE_LOGW_IF(rtStreamDestroy(rt_entry_stream_) != RT_ERROR_NONE, "Destroy stream for rt_model failed."); + rt_entry_stream_ = nullptr; + } } Status DavinciModel::CreateKnownZeroCopyMap(const vector &inputs, const vector &outputs) { @@ -2437,6 +2523,9 @@ Status DavinciModel::CreateKnownZeroCopyMap(const vector &inputs, const GELOGE(FAILED, "input data addr %u is not equal to input op number %u.", inputs.size(), data_op_list_.size()); return FAILED; } + // remove zero copy addr in last iteration + knonw_input_data_info_.clear(); + knonw_output_data_info_.clear(); for (size_t i = 0; i < data_op_list_.size(); ++i) { const vector addr_list = ModelUtils::GetOutputDataAddrs(runtime_param_, data_op_list_[i]); knonw_input_data_info_[addr_list[kDataIndex]] = inputs[i]; @@ -2518,7 +2607,9 @@ Status DavinciModel::InitTaskInfo(domi::ModelTaskDef &model_task_def) { for (int i = 0; i < model_task_def.task_size(); ++i) { // dynamic shape will create task_list_ before const domi::TaskDef &task = model_task_def.task(i); - task_list_[i] = TaskInfoFactory::Instance().Create(static_cast(task.type())); + if (this->task_list_[i] == nullptr) { + task_list_[i] = TaskInfoFactory::Instance().Create(static_cast(task.type())); + } GE_CHECK_NOTNULL(task_list_[i]); Status ret = task_list_[i]->Init(task, this); if (ret != SUCCESS) { @@ -2532,13 +2623,14 @@ Status DavinciModel::InitTaskInfo(domi::ModelTaskDef &model_task_def) { Status DavinciModel::MallocKnownArgs() { GELOGI("DavinciModel::MallocKnownArgs in"); - if (model_task_def_->task_size() == 0) { + const auto &model_task_def = ge_model_->GetModelTaskDefPtr(); + if (model_task_def->task_size() == 0) { GELOGW("DavinciModel::MallocKnownArgs davincimodel has no task info."); return SUCCESS; } - task_list_.resize(model_task_def_->task_size()); - for (int32_t i = 0; i < model_task_def_->task_size(); ++i) { - const domi::TaskDef &taskdef = model_task_def_->task(i); + task_list_.resize(model_task_def->task_size()); + for (int32_t i = 0; i < model_task_def->task_size(); ++i) { + const domi::TaskDef &taskdef = model_task_def->task(i); task_list_[i] = TaskInfoFactory::Instance().Create(static_cast(taskdef.type())); GE_CHECK_NOTNULL(task_list_[i]); Status ret = task_list_[i]->CalculateArgs(taskdef, this); @@ -2559,7 +2651,19 @@ Status DavinciModel::MallocKnownArgs() { GELOGE(RT_FAILED, "Call rtMallocHost failed, ret: 0x%X", rt_ret); return RT_FAILED; } - GELOGI("DavinciModel::MallocKnownArgs success, total args size %u.", total_args_size_); + + // malloc fixed addr memory, eg: rts op + if (total_fixed_addr_size_ != 0) { + GELOGI("Begin to allocate fixed addr."); + rt_ret = rtMalloc(&fixed_addrs_, total_fixed_addr_size_, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret); + return RT_FAILED; + } + } + + GELOGI("DavinciModel::MallocKnownArgs success, total args size %u. total fixed addr size %ld", total_args_size_, + total_fixed_addr_size_); return SUCCESS; } @@ -2575,26 +2679,28 @@ Status DavinciModel::DistributeTask() { task_desc_info_.clear(); bool flag = GetL1FusionEnableOption(); - char *skt_enable_env = getenv("SKT_ENABLE"); - int64_t env_flag = (skt_enable_env != nullptr) ? strtol(skt_enable_env, nullptr, 10) : 0; + char *skt_enable_env = std::getenv("SKT_ENABLE"); + int64_t env_flag = (skt_enable_env != nullptr) ? std::strtol(skt_enable_env, nullptr, kDecimal) : 0; if (env_flag != 0) { flag = true; } + const auto &model_task_def = ge_model_->GetModelTaskDefPtr(); for (size_t task_index = 0; task_index < task_list_.size(); ++task_index) { auto &task = task_list_.at(task_index); GE_CHK_STATUS_RET(task->Distribute(), "Task[%zu] distribute fail", task_index); // for data dump if (reinterpret_cast(task->GetDumpArgs()) != nullptr) { - auto op_index = std::max(model_task_def_->task(task_index).kernel().context().op_index(), - model_task_def_->task(task_index).kernel_ex().op_index()); + auto op_index = std::max(model_task_def->task(task_index).kernel().context().op_index(), + model_task_def->task(task_index).kernel_ex().op_index()); OpDescPtr op = GetOpByIndex(op_index); if (op == nullptr) { GELOGE(PARAM_INVALID, "Op index %u is null, op list size %zu.", op_index, op_list_.size()); return PARAM_INVALID; } - if (PropertiesManager::Instance().IsLayerNeedDump(name_, om_name_, op->GetName())) { + bool call_dump = GetDumpProperties().IsLayerNeedDump(name_, om_name_, op->GetName()) && task->CallSaveDumpInfo(); + if (call_dump) { SaveDumpTask(task->GetTaskID(), task->GetStreamId(), op, task->GetDumpArgs()); } } @@ -2609,8 +2715,13 @@ Status DavinciModel::DistributeTask() { // else task index is found in op_name_map_ TaskDescInfo task_desc_info; string op_name = op_name_map_[task_index]; + if (!om_name_.empty()) { + task_desc_info.model_name = om_name_; + } else { + task_desc_info.model_name = name_; + } task_desc_info.op_name = op_name; - task_desc_info.block_dim = model_task_def_->task(task_index).kernel().block_dim(); + task_desc_info.block_dim = model_task_def->task(task_index).kernel().block_dim(); task_desc_info.task_id = task->GetTaskID(); task_desc_info.stream_id = task->GetStreamId(); task_desc_info_.emplace_back(task_desc_info); @@ -2631,7 +2742,7 @@ Status DavinciModel::DistributeTask() { } void DavinciModel::SetEndGraphId(uint32_t task_id, uint32_t stream_id) { - auto all_dump_model = PropertiesManager::Instance().GetAllDumpModel(); + auto all_dump_model = GetDumpProperties().GetAllDumpModel(); bool findByOmName = all_dump_model.find(om_name_) != all_dump_model.end(); bool findByModelName = all_dump_model.find(name_) != all_dump_model.end(); if (all_dump_model.find(ge::DUMP_ALL_MODEL) != all_dump_model.end() || findByOmName || findByModelName) { @@ -2669,20 +2780,35 @@ void DavinciModel::SetOutputOutsideAddr(const std::vector &outside_addrs continue; } - (void)output_outside_addrs_.emplace(std::pair>(addr, {})); + DisableZeroCopy(addr); // Data to NetOutput directly. + output_outside_addrs_.emplace(std::pair>(addr, {})); GELOGI("SetOutputOutsideAddr success."); } } /// /// @ingroup ge +/// @brief Set copy only for No task feed NetOutput address. +/// @return None. +/// +void DavinciModel::SetCopyOnlyOutput() { + for (const auto &addrs : output_outside_addrs_) { + const auto &used_list = addrs.second; + if (used_list.empty()) { // No task feed Output addr, Need copy directly. + copy_only_addrs_.insert(addrs.first); + } + } +} + +/// +/// @ingroup ge /// @brief Set disabled input zero copy addr. /// @param [in] const void *addr: address of task /// @return None. /// void DavinciModel::DisableZeroCopy(const void *addr) { - auto it = input_outside_addrs_.find(addr); - if (it == input_outside_addrs_.end()) { + if ((input_outside_addrs_.find(addr) == input_outside_addrs_.end()) && + (output_outside_addrs_.find(addr) == output_outside_addrs_.end())) { return; } @@ -2696,7 +2822,10 @@ void DavinciModel::DisableZeroCopy(const void *addr) { /// @brief Save outside address used info for ZeroCopy. /// @param [in] const OpDescPtr &op_desc: current op desc /// @param [in] const std::vector &outside_addrs: address of task -/// @param [in] const char *args_offset: arguments address save the address. +/// @param [in] const void *info: task args +/// @param [in] const char *args: task args +/// @param [in] size_t size: size of task args +/// @param [in] size_t offset: offset of task args /// @return None. /// void DavinciModel::SetZeroCopyAddr(const OpDescPtr &op_desc, const std::vector &outside_addrs, const void *info, @@ -2772,7 +2901,7 @@ bool DavinciModel::CheckInputAndModelSize(const int64_t &input_size, const int64 if (input_size > op_size) { GELOGW( - "Input size [%u] is bigger than om size need [%u]," + "Input size [%u] is bigger than om size need [%u], " "MAY cause inference result ERROR, please check model input", input_size, op_size); } @@ -2866,7 +2995,7 @@ Status DavinciModel::UpdateIoTaskArgs(const map> return FAILED; } - GELOGI("[ZCPY] Copy Blobs: %u, addr: %p, size: %ld, data: %p, length: %u.", data.first, data.second.second, + GELOGI("[ZCPY] Copy Blobs: %u, addr: %p, size: %ld, data: %p, length: %lu.", data.first, data.second.second, data.second.first, buffer.data, buffer.length); if (!CheckInputAndModelSize(buffer.length, size, is_dynamic)) { GELOGE(FAILED, "Check input size and model size failed"); @@ -3134,6 +3263,24 @@ Status DavinciModel::InitStreamSwitchN(const OpDescPtr &op_desc) { GELOGI("StreamSwitchNOp node:%s, active_stream_id=%u.", op_desc->GetName().c_str(), active_stream_list[j]); } + batch_info_.clear(); + uint32_t batch_num = 0; + if (!AttrUtils::GetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) { + GELOGE(FAILED, "Failed to get attr ATTR_NAME_BATCH_NUM, StreamSwitchN: %s.", op_desc->GetName().c_str()); + return FAILED; + } + + for (uint32_t i = 0; i < batch_num; i++) { + std::vector batch_shape; + const std::string attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i); + if (!AttrUtils::GetListInt(op_desc, attr_name, batch_shape)) { + GELOGE(FAILED, "Failed to get attr ATTR_NAME_PRED_VALUE, StreamSwitchN: %s.", op_desc->GetName().c_str()); + batch_info_.clear(); + return FAILED; + } + batch_info_.emplace_back(batch_shape); + } + return SUCCESS; } @@ -3152,20 +3299,6 @@ bool DavinciModel::IsBroadCastOpData(const ge::NodePtr &var_node) { return false; } -void DavinciModel::InitZeroCopyUtil(bool is_dynamic_batch, bool &input_zero_copy, bool &output_zero_copy) { - if (!is_dynamic_batch) { - zero_copy_batch_label_addrs_.clear(); - } - - for (const auto &addrs : output_outside_addrs_) { - const auto &used_list = addrs.second; - if (used_list.empty()) { - output_zero_copy = false; - break; - } - } -} - /// /// @ingroup ge /// @brief Init model stream for NN model. @@ -3213,20 +3346,12 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa GELOGI("Model Run begin, model id:%u, data index:%u, flag:%d.", model_id_, input_data.index, is_async_mode_); GE_CHK_STATUS_RET(InitModelStream(stream), "Init model stream failed."); - bool input_use_zero_copy = true; - bool output_use_zero_copy = true; - bool is_dynamic_batch = input_data.is_dynamic_batch; - InitZeroCopyUtil(is_dynamic_batch, input_use_zero_copy, output_use_zero_copy); - - // Empty task, Just copy input to output, need direct copy. - if (task_list_.empty() && (input_use_zero_copy || output_use_zero_copy)) { - GELOGE(FAILED, "Empty task, Just copy input to output, need direct copy."); - return FAILED; + if (!input_data.is_dynamic_batch) { + zero_copy_batch_label_addrs_.clear(); } GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_PRE_PROC_START)); - Status ret = - input_use_zero_copy ? CopyModelData(input_data, output_data, is_dynamic_batch) : CopyInputData(input_data, true); + Status ret = CopyModelData(input_data, output_data, input_data.is_dynamic_batch); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return INTERNAL_ERROR, "Copy input data to model failed."); GELOGI("current_data.index=%u", input_data.index); @@ -3243,7 +3368,7 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa if (!is_async_mode_) { GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_AFTER_PROC_START)); - ret = output_use_zero_copy ? SyncDataAndDump() : CopyOutputData(input_data.index, output_data); + ret = CopyOutputData(input_data.index, output_data, RT_MEMCPY_DEVICE_TO_DEVICE); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return INTERNAL_ERROR, "Copy Output data to user failed."); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_AFTER_PROC_END)); } @@ -3254,11 +3379,60 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa return SUCCESS; } +// Add active entry stream for special env. +Status DavinciModel::AddHeadStream() { + if (active_stream_list_.empty()) { + GELOGE(INTERNAL_ERROR, "Active stream is empty, stream list size: %zu, stream indication size: %zu.", + stream_list_.size(), active_stream_indication_.size()); + return INTERNAL_ERROR; + } + + if (active_stream_list_.size() == 1) { + GELOGI("Just one active stream, take as head stream."); + rt_head_stream_ = active_stream_list_[0]; + is_pure_head_stream_ = false; + } else { + // Create stream which rt_model_handel running on, this is S0, TS stream. + GELOGI("Multiple active stream: %zu, create head stream.", active_stream_list_.size()); + GE_CHK_RT_RET(rtStreamCreateWithFlags(&rt_head_stream_, priority_, RT_STREAM_PERSISTENT)); + GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, rt_head_stream_, RT_INVALID_FLAG)); // Not active. + is_pure_head_stream_ = true; + + for (auto s : active_stream_list_) { + std::shared_ptr active_entry = MakeShared(rt_head_stream_); + if (active_entry == nullptr) { + GELOGE(FAILED, "Make CpuTaskActiveEntry task failed."); + return FAILED; + } + + if (active_entry->Init(s) != SUCCESS) { + return FAILED; + } + + cpu_task_list_.emplace_back(active_entry); + } + } + + // Create entry stream active head stream. AICPU stream. + GE_CHK_RT_RET(rtStreamCreateWithFlags(&rt_entry_stream_, priority_, RT_STREAM_AICPU)); + GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, rt_entry_stream_, RT_HEAD_STREAM)); + return SUCCESS; +} + +Status DavinciModel::InitEntryTask() { + if (deploy_type_ == AICPU_DEPLOY_CROSS_THREAD) { + GE_CHK_STATUS_RET(AddHeadStream(), "Add head stream failed."); + return CpuActiveStream(); + } else { + return LoadWithQueue(); + } +} + uint8_t *DavinciModel::MallocFeatureMapMem(size_t data_size) { uint8_t *mem_base = nullptr; const string purpose("feature map,used for op input and output."); if (std::getenv(kEnvGeuseStaticMemory) != nullptr) { - data_size = static_cast(VarManager::Instance(0)->GetGraphMemoryMaxSize()); + data_size = static_cast(VarManager::Instance(session_id_)->GetGraphMemoryMaxSize()); string memory_key = std::to_string(0) + "_f"; mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, memory_key, data_size, GetDeviceId()); } else { @@ -3343,12 +3517,14 @@ Status DavinciModel::TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id) return SUCCESS; } -void DavinciModel::SetDataDumperArgs() { +void DavinciModel::SetDataDumperArgs(const ComputeGraphPtr &compute_graph) { GELOGI("set data dumper args, name: %s, id: %u.", name_.c_str(), model_id_); data_dumper_.SetModelName(name_); data_dumper_.SetModelId(model_id_); data_dumper_.SetMemory(runtime_param_); data_dumper_.SetOmName(om_name_); + data_dumper_.SetComputeGraph(compute_graph); + data_dumper_.SetRefInfo(saved_task_addrs_); int32_t device_id = 0; rtError_t rt_ret = rtGetDevice(&device_id); @@ -3404,18 +3580,9 @@ void DavinciModel::ReuseHcclFollowStream(int64_t remain_cap, int64_t &index) { } } -Status DavinciModel::CopyVarData(ComputeGraphPtr &compute_graph) { - return TransVarDataUtils::CopyVarData(compute_graph, session_id_, device_id_); -} - -Status DavinciModel::GetComputeGraphInfo(std::vector &compute_graph_desc_info) { +Status DavinciModel::GetComputeGraphInfo(const ComputeGraphPtr &graph, vector &graph_desc_info) { GELOGI("GetComputeGraphInfo start."); - if (compute_graph_ == nullptr) { - GELOGE(FAILED, "compute_graph_ is nullptr"); - return FAILED; - } - - for (auto &node : compute_graph_->GetAllNodes()) { + for (auto &node : graph->GetAllNodes()) { ComputeGraphDescInfo compute_graph_info; auto op_desc = node->GetOpDesc(); if (op_desc == nullptr) { @@ -3426,6 +3593,11 @@ Status DavinciModel::GetComputeGraphInfo(std::vector &comp auto op_mode = static_cast(domi::ImplyType::INVALID); if (AttrUtils::GetInt(op_desc, ATTR_NAME_IMPLY_TYPE, op_mode) && op_mode == static_cast(domi::ImplyType::TVM)) { + if (!om_name_.empty()) { + compute_graph_info.model_name = om_name_; + } else { + compute_graph_info.model_name = name_; + } compute_graph_info.op_name = op_desc->GetName(); compute_graph_info.op_type = op_desc->GetType(); @@ -3443,12 +3615,18 @@ Status DavinciModel::GetComputeGraphInfo(std::vector &comp compute_graph_info.output_data_type.emplace_back(output_desc.GetDataType()); } - compute_graph_desc_info.emplace_back(compute_graph_info); + graph_desc_info.emplace_back(compute_graph_info); } } GELOGI("GetComputeGraphInfo end."); return SUCCESS; } +void DavinciModel::SetTotalFixedAddrsSize(string tensor_name, int64_t fix_addr_size) { + if (tensor_name_to_fixed_addr_size_.find(tensor_name) == tensor_name_to_fixed_addr_size_.end()) { + tensor_name_to_fixed_addr_size_[tensor_name] = total_fixed_addr_size_; + total_fixed_addr_size_ += fix_addr_size; + } +} Status DavinciModel::GetOrigInputInfo(uint32_t index, OriginInputInfo &orig_input_info) { GE_CHK_BOOL_RET_STATUS(index < data_op_list_.size(), PARAM_INVALID, "Index %u is invalid.", index); @@ -3537,4 +3715,23 @@ Status DavinciModel::GetAllAippInputOutputDims(uint32_t index, std::vectorHasAttr(ATTR_DYNAMIC_SHAPE_FIXED_ADDR) && op_desc->HasAttr(ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX)) { + string tensor_name; + (void)AttrUtils::GetStr(op_desc, ATTR_DYNAMIC_SHAPE_FIXED_ADDR, tensor_name); + int64_t index = -1; + (void)AttrUtils::GetInt(op_desc, ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX, index); + if (index >= 0) { + tensor_name_to_peer_output_index_[tensor_name] = index; + } + } +} } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/davinci_model.h b/src/ge/graph/load/new_model_manager/davinci_model.h index 3254a23b..0f0b1e5c 100644 --- a/src/ge/graph/load/new_model_manager/davinci_model.h +++ b/src/ge/graph/load/new_model_manager/davinci_model.h @@ -29,6 +29,7 @@ #include "common/helper/om_file_helper.h" #include "common/opskernel/ge_task_info.h" #include "common/types.h" +#include "common/properties_manager.h" #include "framework/common/util.h" #include "graph/debug/ge_attr_define.h" #include "graph/load/new_model_manager/data_dumper.h" @@ -47,6 +48,10 @@ #include "task_info/task_info.h" namespace ge { +// op debug need 2048 bits buffer +const size_t kOpDebugMemorySize = 2048UL; +const size_t kDebugP2pSize = 8UL; + typedef enum tagModelProcStage { MODEL_LOAD_START = 1, MODEL_LOAD_END, @@ -171,13 +176,6 @@ class DavinciModel { // get session id uint64_t SessionId() const { return runtime_param_.session_id; } - vector GetOpDesc() { - vector opDescVector; - GE_IF_BOOL_EXEC(AttrUtils::GetListOpDesc(GetGeModel(), MODEL_ATTR_FUSION_MODEL_DEF, opDescVector), - GELOGI("get opDesc of opDescVector")); - return opDescVector; - } - // get model priority int32_t Priority() const { return priority_; } @@ -248,15 +246,9 @@ class DavinciModel { /// Format GetFormat(); - rtModel_t GetRtModelHandle() { - rtModel_t res = rt_model_handle_; - return res; - } + rtModel_t GetRtModelHandle() const { return rt_model_handle_; } - rtStream_t GetRtModelStream() { - rtModel_t res = rt_model_stream_; - return res; - } + rtStream_t GetRtModelStream() const { return rt_model_stream_; } uint64_t GetRtBaseAddr() const { return runtime_param_.logic_mem_base; } @@ -295,7 +287,7 @@ class DavinciModel { /// @param [out] batch_info /// @return execute result /// - Status GetDynamicBatchInfo(std::vector> &batch_info); + Status GetDynamicBatchInfo(std::vector> &batch_info) const; void GetCurShape(std::vector &batch_info); @@ -344,10 +336,9 @@ class DavinciModel { /// /// @ingroup ge /// @brief dump all op input and output information - /// @param [in] op_list model_id - /// @return Status + /// @return void /// - Status DumpOpInputOutput(); + void DumpOpInputOutput(); /// /// @ingroup ge @@ -403,7 +394,9 @@ class DavinciModel { /// uint32_t GetDeviceId() const { return device_id_; } - GeModelPtr GetGeModel() { return ge_model_; } + bool NeedDestroyAicpuKernel() const { return need_destroy_aicpu_kernel_; } + + Status UpdateSessionId(uint64_t session_id); const RuntimeParam &GetRuntimeParam() { return runtime_param_; } @@ -463,6 +456,19 @@ class DavinciModel { void *cur_args = static_cast(args_) + offset; return cur_args; } + void SetTotalFixedAddrsSize(string tensor_name, int64_t fix_addr_size); + int64_t GetFixedAddrsSize(string tensor_name); + void *GetCurrentFixedAddr(int64_t offset) const { + void *cur_addr = static_cast(fixed_addrs_) + offset; + return cur_addr; + } + + uint32_t GetFixedAddrOutputIndex(string tensor_name) { + if (tensor_name_to_peer_output_index_.find(tensor_name) != tensor_name_to_peer_output_index_.end()) { + return tensor_name_to_peer_output_index_[tensor_name]; + } + return UINT32_MAX; + } void SetKnownNode(bool known_node) { known_node_ = known_node; } bool IsKnownNode() { return known_node_; } Status MallocKnownArgs(); @@ -473,9 +479,13 @@ class DavinciModel { Status GetOrigInputInfo(uint32_t index, OriginInputInfo &orig_input_info); Status GetAllAippInputOutputDims(uint32_t index, std::vector &input_dims, std::vector &output_dims); + void SetModelDescVersion(bool is_new_model_desc) { is_new_model_desc_ = is_new_model_desc; } // om file name void SetOmName(string om_name) { om_name_ = om_name; } + void SetDumpProperties(const DumpProperties &dump_properties) { data_dumper_.SetDumpProperties(dump_properties); } + const DumpProperties &GetDumpProperties() const { return data_dumper_.GetDumpProperties(); } + private: // memory address of weights uint8_t *weights_mem_base_; @@ -492,8 +502,6 @@ class DavinciModel { struct timeInfo time_info_; int32_t dataInputTid; - void InitZeroCopyUtil(bool is_dynamic_batch, bool &input_zero_copy, bool &output_zero_copy); - /// /// @ingroup ge /// @brief Save Batch label Info. @@ -531,6 +539,13 @@ class DavinciModel { /// /// @ingroup ge + /// @brief Set copy only for No task feed NetOutput address. + /// @return None. + /// + void SetCopyOnlyOutput(); + + /// + /// @ingroup ge /// @brief Copy Input/Output to model for direct use. /// @param [in] const InputData &input_data: user input data info. /// @param [in/out] OutputData &output_data: user output data info. @@ -554,16 +569,14 @@ class DavinciModel { Status CopyInputData(const InputData &input_data, bool device_data = false); - Status CopyOutputData(uint32_t data_id, OutputData &output_data); - - Status CopyOutputDataToUser(OpDescPtr &op_desc, std::vector &blobs, uint32_t &data_index); + Status CopyOutputData(uint32_t data_id, OutputData &output_data, rtMemcpyKind_t kind); Status SyncVarData(); - Status SyncDataAndDump(); - Status InitModelMem(void *dev_ptr, size_t memsize, void *weight_ptr, size_t weightsize); + void CreateInputDimsInfo(const OpDescPtr &op_desc, Format format, InputOutputDescInfo &input); + Status GetInputDescInfo(vector &input_desc, std::vector &formats); Status InitTaskInfo(domi::ModelTaskDef &modelTaskInfo); @@ -586,7 +599,12 @@ class DavinciModel { bool IsAicpuKernelConnectSpecifiedLayer(); - Status MarkSpecifiedAicpuKernel(); + /// + /// @ingroup ge + /// @brief Reduce memory usage after task sink. + /// @return: void + /// + void Shrink(); /// /// @ingroup ge @@ -722,10 +740,9 @@ class DavinciModel { /// /// @ingroup ge /// @brief definiteness queue schedule, active original model stream. - /// @param [in] streams: streams will active by S0. /// @return: 0 for success / others for fail /// - Status CpuActiveStream(const std::vector &stream_list); + Status CpuActiveStream(); /// /// @ingroup ge @@ -743,6 +760,9 @@ class DavinciModel { /// Status CpuModelRepeat(); + Status InitEntryTask(); + Status AddHeadStream(); + /// /// @ingroup ge /// @brief set ts device. @@ -750,6 +770,10 @@ class DavinciModel { /// Status SetTSDevice(); + Status OpDebugRegister(); + + void OpDebugUnRegister(); + void CheckHasHcomOp(); Status DoTaskSink(); @@ -757,17 +781,17 @@ class DavinciModel { void CreateOutput(uint32_t index, OpDescPtr &op_desc, InputOutputDescInfo &output, uint32_t &format_result); Status TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id); - Status CopyVarData(ComputeGraphPtr &graph); // get desc info of graph for profiling - Status GetComputeGraphInfo(vector &compute_graph_desc_info); + Status GetComputeGraphInfo(const ComputeGraphPtr &graph, vector &graph_desc_info); - void SetDataDumperArgs(); + void SetDataDumperArgs(const ComputeGraphPtr &compute_graph); Status GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data_index, OutputData *output_data, std::vector &outputs); void ParseAIPPInfo(std::string in_out_info, InputOutputDims &dims_info); + void GetFixedAddrAttr(const OpDescPtr &op_desc); bool is_model_has_inited_; uint32_t model_id_; @@ -780,6 +804,9 @@ class DavinciModel { uint32_t version_; GeModelPtr ge_model_; + bool need_destroy_aicpu_kernel_{false}; + vector out_node_name_; + map op_list_; // data op_desc @@ -840,6 +867,11 @@ class DavinciModel { bool is_async_mode_; // For NN execute, Async mode use rtMemcpyAsync on rt_model_stream_. + bool is_pure_head_stream_{false}; + rtStream_t rt_head_stream_{nullptr}; + rtStream_t rt_entry_stream_{nullptr}; + rtAicpuDeployType_t deploy_type_{AICPU_DEPLOY_RESERVED}; + // ACL queue schedule, save queue ids for Init. std::vector cpu_task_list_; std::vector input_queue_ids_; // input queue ids created by caller. @@ -861,8 +893,6 @@ class DavinciModel { std::vector active_stream_list_; std::set active_stream_indication_; - std::shared_ptr model_task_def_; - std::set aicpu_streams_; std::set hcom_streams_; RuntimeParam runtime_param_; @@ -874,22 +904,40 @@ class DavinciModel { // for profiling task and graph info std::map op_name_map_; std::vector task_desc_info_; - ComputeGraphPtr compute_graph_; int64_t maxDumpOpNum_; // for data dump DataDumper data_dumper_; uint64_t iterator_count_; bool is_l1_fusion_enable_; + std::map saved_task_addrs_; bool known_node_ = false; uint32_t total_args_size_ = 0; void *args_ = nullptr; void *args_host_ = nullptr; + void *fixed_addrs_ = nullptr; + int64_t total_fixed_addr_size_ = 0; std::map knonw_input_data_info_; std::map knonw_output_data_info_; + vector> batch_info_; + vector batch_size_; + // key: input tensor name, generally rts op; + // value: the fixed addr of input anchor, same as the peer output anchor addr of the peer op + std::map tensor_name_to_fixed_addr_size_; + + // key: input tensor name, generally rts op; value: the peer output anchor of the peer op + std::map tensor_name_to_peer_output_index_; + // if model is first execute + bool is_first_execute_; + // for op debug + std::mutex debug_reg_mutex_; + bool is_op_debug_reg_ = false; + void *op_debug_addr_ = nullptr; + void *p2p_debug_addr_ = nullptr; + bool is_new_model_desc_{false}; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_ diff --git a/src/ge/graph/load/new_model_manager/model_manager.cc b/src/ge/graph/load/new_model_manager/model_manager.cc index 384e203b..04c836dd 100644 --- a/src/ge/graph/load/new_model_manager/model_manager.cc +++ b/src/ge/graph/load/new_model_manager/model_manager.cc @@ -22,8 +22,9 @@ #include "common/profiling/profiling_manager.h" #include "common/properties_manager.h" #include "framework/common/debug/ge_log.h" -#include "graph/debug/ge_attr_define.h" #include "framework/common/util.h" +#include "graph/common/ge_call_wrapper.h" +#include "graph/debug/ge_attr_define.h" #include "graph/load/new_model_manager/davinci_model.h" #include "graph/load/new_model_manager/davinci_model_parser.h" #include "model/ge_root_model.h" @@ -33,9 +34,10 @@ thread_local uint32_t device_count = 0; namespace { const int kCmdParSize = 2; const int kDumpCmdPairSize = 2; -const char *const kNeedDestroySpecifiedAicpuKernel = "need_destroy_specified_aicpu_kernel"; } // namespace +DumpProperties ModelManager::dump_properties_; + std::shared_ptr ModelManager::GetInstance() { static const std::shared_ptr instance_ptr = shared_ptr(new (std::nothrow) ModelManager(), ModelManager::FinalizeForPtr); @@ -272,6 +274,10 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptrSetId(model_id); davinci_model->SetDeviceId(GetContext().DeviceId()); + const DumpProperties &dump_properties = PropertiesManager::Instance().GetDumpProperties(GetContext().SessionId()); + davinci_model->SetDumpProperties(dump_properties); + dump_properties_ = dump_properties; + auto root_graph = ge_root_model->GetRootGraph(); GE_CHECK_NOTNULL(root_graph); string root_model_name = root_graph->GetName(); @@ -296,9 +302,6 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptrSetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 + timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond davinci_model->SetProfileTime(MODEL_LOAD_END); - if (davinci_model->SinkModelProfile() != SUCCESS) { - GELOGW("Sink model profile failed."); - } } } while (0); @@ -611,10 +614,10 @@ Status ModelManager::HandleDumpCommand(const Command &command) { GELOGE(PARAM_INVALID, "parser dump model failed"); return FAILED; } - GELOGI("dump status = %s.", dump_model.c_str()); + GELOGI("dump model = %s.", dump_model.c_str()); if (dump_status == "off" || dump_status == "OFF") { - PropertiesManager::Instance().DeleteDumpPropertyValue(dump_model); + dump_properties_.DeletePropertyValue(dump_model); return SUCCESS; } @@ -631,9 +634,10 @@ Status ModelManager::HandleDumpCommand(const Command &command) { return FAILED; } if (!dump_path.empty() && dump_path[dump_path.size() - 1] != '/') { - dump_path = dump_path + "/" + CurrentTimeInStr() + "/"; + dump_path = dump_path + "/"; } - GELOGI("dump status = %s.", dump_path.c_str()); + dump_path = dump_path + CurrentTimeInStr() + "/"; + GELOGI("dump path = %s.", dump_path.c_str()); ret = ParserPara(command, DUMP_MODE, dump_mode); if (ret != SUCCESS) { @@ -642,20 +646,10 @@ Status ModelManager::HandleDumpCommand(const Command &command) { } GELOGI("dump mode = %s", dump_mode.c_str()); - auto iter_dump_mode = std::find(command.cmd_params.begin(), command.cmd_params.end(), DUMP_MODE); - if (iter_dump_mode != command.cmd_params.end()) { - ++iter_dump_mode; - if (iter_dump_mode == command.cmd_params.end()) { - GELOGE(PARAM_INVALID, "Invalid access."); - return PARAM_INVALID; - } - dump_mode = *iter_dump_mode; - GELOGI("dump mode = %s", dump_mode.c_str()); - } + dump_properties_.AddPropertyValue(dump_model, dump_layers); + dump_properties_.SetDumpPath(dump_path); + dump_properties_.SetDumpMode(dump_mode); - PropertiesManager::Instance().AddDumpPropertyValue(dump_model, dump_layers); - PropertiesManager::Instance().SetDumpOutputPath(dump_path); - PropertiesManager::Instance().SetDumpMode(dump_mode); return SUCCESS; } @@ -685,11 +679,14 @@ Status ModelManager::GetInputOutputDescInfo(const uint32_t model_id, vector &input_desc, vector &output_desc, - std::vector &inputFormats, std::vector &outputFormats) { + std::vector &inputFormats, std::vector &outputFormats, + bool new_model_desc) { std::shared_ptr davinci_model = GetModel(model_id); GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "GetInputOutputDescInfo Failed, Invalid Model ID %u !", model_id); + davinci_model->SetModelDescVersion(new_model_desc); + return davinci_model->GetInputOutputDescInfo(input_desc, output_desc, inputFormats, outputFormats); } @@ -768,17 +765,6 @@ Status ModelManager::GenSessionId(uint64_t &session_id) { return SUCCESS; } -Status ModelManager::UpdateSessionId(std::shared_ptr &davinci_model, uint64_t session_id) { - GeModelPtr ge_model_current = davinci_model->GetGeModel(); - GE_CHECK_NOTNULL(ge_model_current); - if (!ge::AttrUtils::SetInt(ge_model_current, ge::MODEL_ATTR_SESSION_ID, static_cast(session_id))) { - GELOGW("Set attr[%s] failed in updating session_id.", MODEL_ATTR_SESSION_ID.c_str()); - } - - GELOGD("Update session id: %lu.", session_id); - return SUCCESS; -} - Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model, shared_ptr listener, void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) { GE_CHK_BOOL_RET_STATUS(model.key.empty() || access(model.key.c_str(), F_OK) == 0, PARAM_INVALID, @@ -821,6 +807,7 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model } davinci_model->SetDeviceId(device_id); davinci_model->SetOmName(model.om_name); + davinci_model->SetDumpProperties(dump_properties_); /// In multi-threaded inference, using the same session_id among multiple threads may cause some threads to fail. /// These session_ids come from the same model, so the values of session_id are the same. @@ -828,7 +815,7 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model uint64_t new_session_id; ret = GenSessionId(new_session_id); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, break, "Generate session_id for infer failed."); - ret = UpdateSessionId(davinci_model, new_session_id); + ret = davinci_model->UpdateSessionId(new_session_id); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, break, "Update session_id for infer failed."); ret = davinci_model->Init(dev_ptr, mem_size, weight_ptr, weight_size); @@ -843,9 +830,6 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 + timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond davinci_model->SetProfileTime(MODEL_LOAD_END); - if (davinci_model->SinkModelProfile() != SUCCESS) { - GELOGW("Sink model profile failed."); - } } GE_IF_BOOL_EXEC(ret == SUCCESS, device_count++); @@ -895,7 +879,7 @@ Status ModelManager::LoadModelWithQ(uint32_t &model_id, const ModelData &model_d uint64_t new_session_id; ret = GenSessionId(new_session_id); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Generate session_id for infer failed."); - ret = UpdateSessionId(davinci_model, new_session_id); + ret = davinci_model->UpdateSessionId(new_session_id); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Update session_id for infer failed."); GenModelId(&model_id); @@ -906,6 +890,8 @@ Status ModelManager::LoadModelWithQ(uint32_t &model_id, const ModelData &model_d return ret; } + davinci_model->SetDumpProperties(dump_properties_); + ret = davinci_model->Init(); if (ret != SUCCESS) { GELOGE(ret, "init model failed."); @@ -932,12 +918,8 @@ Status ModelManager::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asy std::shared_ptr davinci_model = GetModel(model_id); GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "Invalid Model ID %u to start! ", model_id); - GeModelPtr ge_model_current = davinci_model->GetGeModel(); - bool need_destroy_aicpu_kernel = false; - bool result = ge::AttrUtils::GetBool(ge_model_current, kNeedDestroySpecifiedAicpuKernel, need_destroy_aicpu_kernel); - if (result && need_destroy_aicpu_kernel) { - GELOGI("Get attr %s successfully, start to destroy specified aicpu kernel.", kNeedDestroySpecifiedAicpuKernel); - + if (davinci_model->NeedDestroyAicpuKernel()) { + GELOGI("Start to destroy specified aicpu kernel."); // Zero copy is enabled by default, no need to judge. uint64_t session_id_davinci = davinci_model->GetSessionId(); uint32_t model_id_davinci = davinci_model->GetModelId(); @@ -1047,4 +1029,19 @@ Status ModelManager::GetAllAippInputOutputDims(uint32_t model_id, uint32_t index return davinci_model->GetAllAippInputOutputDims(index, input_dims, output_dims); } +bool ModelManager::IsDynamicShape(uint32_t model_id) { + auto model = GetHybridModel(model_id); + return model != nullptr; +} + +ge::Status ModelManager::SyncExecuteModel(uint32_t model_id, const vector &inputs, + vector &outputs) { + auto model = GetHybridModel(model_id); + if (model == nullptr) { + GELOGE(FAILED, "Hybrid model not found. model id = %u.", model_id); + return FAILED; + } + + return model->Execute(inputs, outputs); +} } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/model_manager.h b/src/ge/graph/load/new_model_manager/model_manager.h index 9a94e5c9..2ba23d7c 100644 --- a/src/ge/graph/load/new_model_manager/model_manager.h +++ b/src/ge/graph/load/new_model_manager/model_manager.h @@ -31,6 +31,7 @@ #include "common/ge_types.h" #include "common/helper/model_helper.h" #include "common/helper/om_file_helper.h" +#include "common/properties_manager.h" #include "common/types.h" #include "ge/ge_api_types.h" #include "graph/ge_context.h" @@ -141,6 +142,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { ge::Status ExecuteModel(uint32_t model_id, rtStream_t stream, bool async_mode, const InputData &input_data, OutputData &output_data); + ge::Status SyncExecuteModel(uint32_t model_id, const std::vector &inputs, std::vector &outputs); + /// /// @ingroup domi_ome /// @brief model stop @@ -178,7 +181,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { ge::Status GetInputOutputDescInfo(const uint32_t model_id, std::vector &input_desc, std::vector &output_desc, std::vector &inputFormats, - std::vector &outputFormats); + std::vector &outputFormats, bool new_model_desc = false); /// /// @ingroup ge /// @brief Get dynamic batch_info @@ -249,6 +252,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { ge::Status GetAllAippInputOutputDims(uint32_t model_id, uint32_t index, std::vector &input_dims, std::vector &output_dims); + bool IsDynamicShape(uint32_t model_id); + private: /// /// @ingroup domi_ome @@ -276,7 +281,6 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { ge::Status DeleteModel(uint32_t id); void GenModelId(uint32_t *id); - ge::Status UpdateSessionId(std::shared_ptr &davinci_model, uint64_t session_id); std::map> model_map_; std::map> hybrid_model_map_; @@ -287,6 +291,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { std::mutex session_id_create_mutex_; uint64_t session_id_bias_; std::set sess_ids_; + + static DumpProperties dump_properties_; }; } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/model_utils.cc b/src/ge/graph/load/new_model_manager/model_utils.cc index a807f2a3..bd684b9d 100644 --- a/src/ge/graph/load/new_model_manager/model_utils.cc +++ b/src/ge/graph/load/new_model_manager/model_utils.cc @@ -31,7 +31,7 @@ namespace ge { /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get input size. /// @return vector /// @@ -43,22 +43,26 @@ vector ModelUtils::GetInputSize(ConstOpDescPtr op_desc) { const vector v_is_input_const = op_desc->GetIsInputConst(); for (size_t i = 0; i < inputs_size; ++i) { + const GeTensorDescPtr tensor_desc = op_desc->MutableInputDesc(i); + if (tensor_desc == nullptr) { + GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); + continue; + } + + int64_t tensor_size = 0; if ((i < v_is_input_const.size()) && v_is_input_const[i] && (op_type != NETOUTPUT)) { // TBE: add weights size to input - GeTensorDesc tensor_desc = op_desc->GetInputDesc(i); - int64_t tensor_size = 0; - GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); + GE_CHK_STATUS(TensorUtils::GetSize(*tensor_desc, tensor_size)); if (tensor_size) { v_input_size.push_back(tensor_size); } continue; } - int64_t tensor_size = 0; GE_IF_BOOL_EXEC( - TensorUtils::GetSize(op_desc->GetInputDesc(i), tensor_size) != GRAPH_SUCCESS, + TensorUtils::GetSize(*tensor_desc, tensor_size) != GRAPH_SUCCESS, GELOGI("Get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); - continue;); + continue); v_input_size.push_back(tensor_size); } @@ -67,7 +71,7 @@ vector ModelUtils::GetInputSize(ConstOpDescPtr op_desc) { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get output size. /// @return vector /// @@ -82,11 +86,17 @@ vector ModelUtils::GetOutputSize(ConstOpDescPtr op_desc) { return v_output_size;); for (size_t i = 0; i < outputs_size; ++i) { + const GeTensorDescPtr tensor_desc = op_desc->MutableOutputDesc(i); + if (tensor_desc == nullptr) { + GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); + continue; + } + int64_t tensor_size = 0; GE_IF_BOOL_EXEC( - TensorUtils::GetSize(op_desc->GetOutputDesc(i), tensor_size) != GRAPH_SUCCESS, + TensorUtils::GetSize(*tensor_desc, tensor_size) != GRAPH_SUCCESS, GELOGI("Get size from TensorDesc failed, op : %s, output index : %zu", op_desc->GetName().c_str(), i); - continue;); + continue); v_output_size.push_back(tensor_size); } @@ -95,7 +105,7 @@ vector ModelUtils::GetOutputSize(ConstOpDescPtr op_desc) { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get workspace size. /// @return vector /// @@ -118,7 +128,7 @@ vector ModelUtils::GetWorkspaceSize(ConstOpDescPtr op_desc) { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get weight size. /// @return vector /// @@ -142,8 +152,14 @@ vector ModelUtils::GetWeightSize(ConstOpDescPtr op_desc) { const vector v_is_input_const = op_desc->GetIsInputConst(); for (size_t i = 0; i < inputs_size; ++i) { if ((i < v_is_input_const.size()) && v_is_input_const[i]) { + const GeTensorDescPtr tensor_desc = op_desc->MutableInputDesc(i); + if (tensor_desc == nullptr) { + GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); + continue; + } + int64_t tensor_size = 0; - (void)TensorUtils::GetSize(op_desc->GetInputDesc(i), tensor_size); + (void)TensorUtils::GetSize(*tensor_desc, tensor_size); v_weight_size.push_back(tensor_size); } } @@ -152,7 +168,7 @@ vector ModelUtils::GetWeightSize(ConstOpDescPtr op_desc) { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get weights. /// @return vector /// @@ -176,9 +192,14 @@ vector ModelUtils::GetWeights(ConstOpDescPtr op_desc) { const vector v_is_input_const = op_desc->GetIsInputConst(); for (size_t i = 0; i < inputs_size; ++i) { if ((i < v_is_input_const.size()) && v_is_input_const[i]) { + const GeTensorDescPtr tensor_desc = op_desc->MutableInputDesc(i); + if (tensor_desc == nullptr) { + GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); + continue; + } + ConstGeTensorPtr weight = nullptr; - GeTensorDesc tensor_desc = op_desc->GetInputDesc(i); - if (AttrUtils::GetTensor(tensor_desc, ATTR_NAME_WEIGHTS, weight)) { + if (AttrUtils::GetTensor(*tensor_desc, ATTR_NAME_WEIGHTS, weight)) { v_weights.push_back(weight); } } @@ -188,7 +209,7 @@ vector ModelUtils::GetWeights(ConstOpDescPtr op_desc) { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get AiCpuOp Input descriptor. /// @return vector<::tagCcAICPUTensor> /// @@ -205,20 +226,25 @@ vector<::tagCcAICPUTensor> ModelUtils::GetInputDescs(ConstOpDescPtr op_desc) { continue; } + const GeTensorDescPtr tensor_desc = op_desc->MutableInputDesc(i); + if (tensor_desc == nullptr) { + GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); + continue; + } + uint32_t dim_cnt = 0; - const auto &descriptor = op_desc->GetInputDesc(i); - GE_CHK_BOOL_EXEC_WARN(TensorUtils::GetRealDimCnt(descriptor, dim_cnt) == GRAPH_SUCCESS, continue, + GE_CHK_BOOL_EXEC_WARN(TensorUtils::GetRealDimCnt(*tensor_desc, dim_cnt) == GRAPH_SUCCESS, continue, "Get dim_cnt failed"); opTensor_t tmp; - uint32_t tmp_fmt = descriptor.GetFormat(); + uint32_t tmp_fmt = tensor_desc->GetFormat(); tmp.format = tagOpTensorFormat(tmp_fmt); tmp.dim_cnt = static_cast(dim_cnt); - uint32_t tmp_type = descriptor.GetDataType(); + uint32_t tmp_type = tensor_desc->GetDataType(); tmp.data_type = tagOpDataType(tmp_type); for (int32_t j = 0; j < 4; j++) { // 4 dims - tmp.dim[j] = (j < tmp.dim_cnt ? descriptor.GetShape().GetDim(j) : 1); + tmp.dim[j] = (j < tmp.dim_cnt ? tensor_desc->GetShape().GetDim(j) : 1); } v_input_descs.push_back(tmp); @@ -228,7 +254,7 @@ vector<::tagCcAICPUTensor> ModelUtils::GetInputDescs(ConstOpDescPtr op_desc) { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get AiCpuOp Output descriptor. /// @return vector<::tagCcAICPUTensor> /// @@ -240,20 +266,25 @@ vector<::tagCcAICPUTensor> ModelUtils::GetOutputDescs(ConstOpDescPtr op_desc) { // init op output opTensor_t struct const size_t output_num = op_desc->GetOutputsSize(); for (size_t i = 0; i < output_num; ++i) { + const GeTensorDescPtr tensor_desc = op_desc->MutableOutputDesc(i); + if (tensor_desc == nullptr) { + GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); + continue; + } + uint32_t dim_cnt = 0; - const auto &descriptor = op_desc->GetOutputDesc(i); - GE_CHK_BOOL_EXEC_WARN(TensorUtils::GetRealDimCnt(descriptor, dim_cnt) == GRAPH_SUCCESS, continue, + GE_CHK_BOOL_EXEC_WARN(TensorUtils::GetRealDimCnt(*tensor_desc, dim_cnt) == GRAPH_SUCCESS, continue, "Get dim_cnt failed"); opTensor_t tmp; - uint32_t tmp_fmt = descriptor.GetFormat(); + uint32_t tmp_fmt = tensor_desc->GetFormat(); tmp.format = tagOpTensorFormat(tmp_fmt); tmp.dim_cnt = static_cast(dim_cnt); - uint32_t tmp_type = descriptor.GetDataType(); + uint32_t tmp_type = tensor_desc->GetDataType(); tmp.data_type = tagOpDataType(tmp_type); for (int32_t j = 0; j < 4; j++) { // 4 dims - tmp.dim[j] = (j < tmp.dim_cnt ? descriptor.GetShape().GetDim(j) : 1); + tmp.dim[j] = (j < tmp.dim_cnt ? tensor_desc->GetShape().GetDim(j) : 1); } v_output_descs.push_back(tmp); @@ -263,44 +294,14 @@ vector<::tagCcAICPUTensor> ModelUtils::GetOutputDescs(ConstOpDescPtr op_desc) { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get input data address. /// @return vector /// -vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc, - bool need_convert) { +vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc) { vector v_input_data_addr; // init as:buf_base + op_def_->input(i)); GE_CHECK_NOTNULL_EXEC(op_desc, return v_input_data_addr); uint64_t session_id = model_param.session_id; - uint8_t *mem_base = model_param.mem_base; - uint8_t *var_base = model_param.var_base; - uint8_t *weight_base = model_param.weight_base; - const uint64_t logic_mem_base = 0; - uint64_t logic_weight_base = 0; - uint64_t logic_var_base = model_param.logic_var_base; - uint64_t mem_size = model_param.mem_size; - uint64_t weight_size = model_param.weight_size; - uint64_t var_size = model_param.var_size; - - if (need_convert) { - Status status = ConvertVirtualAddressToPhysical(mem_base, mem_size, mem_base); - if (status != SUCCESS) { - GELOGE(RT_FAILED, "Convert virtual address to physical for mem_base failed."); - return v_input_data_addr; - } - - status = ConvertVirtualAddressToPhysical(weight_base, weight_size, weight_base); - if (status != SUCCESS) { - GELOGE(RT_FAILED, "Convert virtual address to physical for weight_base failed."); - return v_input_data_addr; - } - - status = ConvertVirtualAddressToPhysical(var_base, var_size, var_base); - if (status != SUCCESS) { - GELOGE(RT_FAILED, "Convert virtual address to physical for var_base failed."); - return v_input_data_addr; - } - } const size_t inputs_size = op_desc->GetInputsSize(); const vector v_input_offset = op_desc->GetInputOffset(); @@ -319,13 +320,18 @@ vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co for (size_t i = 0; i < inputs_size; ++i) { if ((i < v_is_input_const.size()) && v_is_input_const[i] && (op_type != NETOUTPUT)) { // TBE: add weights address to input - GeTensorDesc tensor_desc = op_desc->GetInputDesc(i); + const GeTensorDescPtr tensor_desc = op_desc->MutableInputDesc(i); + if (tensor_desc == nullptr) { + GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); + continue; + } + int64_t tensor_size = 0; - GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); + GE_CHK_STATUS(TensorUtils::GetSize(*tensor_desc, tensor_size)); if (tensor_size) { int64_t data_offset = 0; - GE_CHK_STATUS(TensorUtils::GetDataOffset(tensor_desc, data_offset)); - uint8_t *weight_addr = static_cast(weight_base + data_offset - logic_weight_base); + GE_CHK_STATUS(TensorUtils::GetDataOffset(*tensor_desc, data_offset)); + uint8_t *weight_addr = model_param.weight_base + data_offset; v_input_data_addr.push_back(weight_addr); GELOGI("[IMAS]GetInputDataAddrs graph_%u type[C] name[%s] input[%zu] memaddr[%p]", model_param.graph_id, op_desc->GetName().c_str(), i, weight_addr); @@ -340,17 +346,13 @@ vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co int64_t input_offset = v_input_offset[non_const_index]; non_const_index++; - GE_IF_BOOL_EXEC(var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(input_offset), - uint8_t *variable_addr = var_base + input_offset - logic_var_base; + GE_IF_BOOL_EXEC(model_param.var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(input_offset), + uint8_t *variable_addr = model_param.var_base + input_offset - model_param.logic_var_base; v_input_data_addr.push_back(variable_addr); GELOGI("[IMAS]GetInputDataAddrs graph_%u type[V] name[%s] input[%lu] memaddr[%p]", model_param.graph_id, op_desc->GetName().c_str(), i, variable_addr); - continue;); + continue); - bool input_tensor = false; - GE_IF_BOOL_EXEC(TensorUtils::GetInputTensor(op_desc->GetOutputDesc(i), input_tensor) != GRAPH_SUCCESS, - GELOGW("get size from TensorDesc failed, op: %s, input index: %zu", op_desc->GetName().c_str(), i); - continue;); // feature maps uint8_t *mem_addr = nullptr; // fusion @@ -358,7 +360,7 @@ vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co mem_addr = reinterpret_cast(reinterpret_cast(input_offset)); v_input_data_addr.push_back(mem_addr); } else { - mem_addr = static_cast(mem_base + input_offset - logic_mem_base); + mem_addr = model_param.mem_base + input_offset; v_input_data_addr.push_back(mem_addr); } GELOGI("[IMAS]GetInputDataAddrs graph_%u type[F] name[%s] input[%zu] memaddr[%p]", model_param.graph_id, @@ -369,41 +371,20 @@ vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get output data address. /// @return vector /// -vector ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc, - bool need_convert) { +vector ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc) { vector v_output_data_addr; // init as:buf_base + op_def_->output(i) GE_CHECK_NOTNULL_EXEC(op_desc, return v_output_data_addr); uint64_t session_id = model_param.session_id; - uint8_t *mem_base = model_param.mem_base; - uint8_t *var_base = model_param.var_base; - const uint64_t logic_mem_base = 0; - uint64_t logic_var_base = model_param.logic_var_base; - uint64_t mem_size = model_param.mem_size; - uint64_t var_size = model_param.var_size; - - if (need_convert) { - Status status = ConvertVirtualAddressToPhysical(mem_base, mem_size, mem_base); - if (status != SUCCESS) { - GELOGE(RT_FAILED, "Convert virtual address to physical for mem_base failed."); - return v_output_data_addr; - } - - status = ConvertVirtualAddressToPhysical(var_base, var_size, var_base); - if (status != SUCCESS) { - GELOGE(RT_FAILED, "Convert virtual address to physical for var_base failed."); - return v_output_data_addr; - } - } const size_t outputs_size = op_desc->GetOutputsSize(); const vector v_output_offset = op_desc->GetOutputOffset(); GE_IF_BOOL_EXEC(v_output_offset.size() != outputs_size, GELOGW("Output param invalid: output_offset=%zu, outputs=%zu.", v_output_offset.size(), outputs_size); - return v_output_data_addr;); + return v_output_data_addr); vector v_memory_type; bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, v_memory_type); if (has_mem_type_attr && (v_memory_type.size() != outputs_size)) { @@ -413,12 +394,12 @@ vector ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, C return v_output_data_addr; } for (size_t i = 0; i < outputs_size; ++i) { - GE_IF_BOOL_EXEC(var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(v_output_offset[i]), - uint8_t *variable_addr = static_cast(var_base + v_output_offset[i] - logic_var_base); + GE_IF_BOOL_EXEC(model_param.var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(v_output_offset[i]), + uint8_t *variable_addr = model_param.var_base + v_output_offset[i] - model_param.logic_var_base; v_output_data_addr.push_back(variable_addr); GELOGI("[IMAS]GetOutputDataAddrs graph_%u type[V] name[%s] output[%zu] memaddr[%p]", model_param.graph_id, op_desc->GetName().c_str(), i, variable_addr); - continue;); + continue); // feature maps uint8_t *mem_addr = nullptr; // fusion @@ -426,7 +407,7 @@ vector ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, C mem_addr = reinterpret_cast(reinterpret_cast(v_output_offset[i])); v_output_data_addr.push_back(mem_addr); } else { - mem_addr = static_cast(mem_base + v_output_offset[i] - logic_mem_base); + mem_addr = static_cast(model_param.mem_base + v_output_offset[i]); v_output_data_addr.push_back(mem_addr); } GELOGI("[IMAS]GetOutputDataAddrs graph_%u type[F] name[%s] output[%zu] memaddr[%p]", model_param.graph_id, @@ -436,24 +417,13 @@ vector ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, C } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get workspace data address. /// @return vector /// -vector ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc, - bool need_convert) { +vector ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc) { vector v_workspace_data_addr; GE_CHECK_NOTNULL_EXEC(op_desc, return v_workspace_data_addr); - uint8_t *mem_base = model_param.mem_base; - uint64_t mem_size = model_param.mem_size; - - if (need_convert) { - Status status = ConvertVirtualAddressToPhysical(mem_base, mem_size, mem_base); - if (status != SUCCESS) { - GELOGE(RT_FAILED, "Convert virtual address to physical for mem_base failed."); - return v_workspace_data_addr; - } - } const vector v_workspace_offset = op_desc->GetWorkspace(); const vector v_workspace_bytes = op_desc->GetWorkspaceBytes(); @@ -466,13 +436,13 @@ vector ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_desc, TVM_ATTR_NAME_WORKSPACE_TYPE, v_memory_type); for (size_t i = 0; i < v_workspace_bytes.size(); ++i) { if (has_mem_type_attr && v_memory_type[i] == RT_MEMORY_L1) { - v_workspace_data_addr.push_back(reinterpret_cast(v_workspace_offset[i])); + v_workspace_data_addr.push_back(reinterpret_cast(reinterpret_cast(v_workspace_offset[i]))); GELOGI("Fusion: op: %s, GetWorkspaceDataAddrs mem_addr[workspace index %zu]:%p", op_desc->GetName().c_str(), i, reinterpret_cast(reinterpret_cast(v_workspace_offset[i]))); } else { int64_t workspace_offset = v_workspace_offset[i]; int64_t workspace_bytes = v_workspace_bytes[i]; - uint8_t *mem_addr = workspace_bytes == 0 ? nullptr : mem_base + workspace_offset; + uint8_t *mem_addr = workspace_bytes == 0 ? nullptr : model_param.mem_base + workspace_offset; v_workspace_data_addr.push_back(mem_addr); GELOGI("[IMAS]GetWorkspaceDataAddrs graph_%u type[F] name[%s] workspace[%zu] offset[%ld] bytes[%ld] memaddr[%p]", model_param.graph_id, op_desc->GetName().c_str(), i, workspace_offset, workspace_bytes, mem_addr); @@ -482,21 +452,32 @@ vector ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param return v_workspace_data_addr; } -Status ModelUtils::ConvertVirtualAddressToPhysical(uint8_t *virtual_address, uint64_t size, - uint8_t *&physical_address) { - // Indicates whether use physical address. - const char *use_physical_address = std::getenv("GE_USE_PHYSICAL_ADDRESS"); - if (use_physical_address == nullptr || virtual_address == 0 || size == 0) { - return SUCCESS; - } - - rtError_t ret = rtKernelConfigTransArg(virtual_address, size, 0, reinterpret_cast(&physical_address)); - if (ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rtKernelConfigTransArg failed, ret: 0x%X", ret); - return RT_FAILED; +/// +/// @ingroup ge +/// @brief Get runtime memory address. +/// @return Status +/// +Status ModelUtils::GetRtAddress(const RuntimeParam ¶m, uintptr_t logic_addr, uint8_t *&mem_addr) { + uint8_t *runtime_base_addr = nullptr; + if ((param.logic_mem_base <= logic_addr) && (logic_addr < param.logic_mem_base + param.mem_size)) { + runtime_base_addr = param.mem_base - param.logic_mem_base; + GELOGI("The logic addr:0x%lx is data address, base:0x%lx, size:%lu", logic_addr, param.logic_mem_base, + param.mem_size); + } else if ((param.logic_weight_base <= logic_addr) && (logic_addr < param.logic_weight_base + param.weight_size)) { + runtime_base_addr = param.weight_base - param.logic_weight_base; + GELOGI("The logic addr:0x%lx is weight address, base:0x%lx, size:%lu", logic_addr, param.logic_weight_base, + param.weight_size); + } else if ((param.logic_var_base <= logic_addr) && (logic_addr < param.logic_var_base + param.var_size)) { + runtime_base_addr = param.var_base - param.logic_var_base; + GELOGI("The logic addr:0x%lx is variable address, base:0x%lx, size:%lu", logic_addr, param.logic_var_base, + param.var_size); + } else if (logic_addr != 0) { + mem_addr = nullptr; + GELOGE(PARAM_INVALID, "The logic addr:0x%lx is abnormal", logic_addr); + return PARAM_INVALID; } - GELOGD("virtual_address=%p, physical_address=%p", virtual_address, physical_address); + mem_addr = runtime_base_addr + logic_addr; return SUCCESS; } } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/model_utils.h b/src/ge/graph/load/new_model_manager/model_utils.h index d6afd5c8..8474a987 100644 --- a/src/ge/graph/load/new_model_manager/model_utils.h +++ b/src/ge/graph/load/new_model_manager/model_utils.h @@ -34,78 +34,79 @@ class ModelUtils { ~ModelUtils() = default; /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get input size. /// @return vector /// static vector GetInputSize(ConstOpDescPtr op_desc); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get output size. /// @return vector /// static vector GetOutputSize(ConstOpDescPtr op_desc); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get workspace size. /// @return vector /// static vector GetWorkspaceSize(ConstOpDescPtr op_desc); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get weight size. /// @return vector /// static vector GetWeightSize(ConstOpDescPtr op_desc); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get weights. /// @return vector /// static vector GetWeights(ConstOpDescPtr op_desc); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get AiCpuOp Input descriptor. /// @return vector<::tagCcAICPUTensor> /// static vector<::tagCcAICPUTensor> GetInputDescs(ConstOpDescPtr op_desc); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get AiCpuOp Output descriptor. /// @return vector<::tagCcAICPUTensor> /// static vector<::tagCcAICPUTensor> GetOutputDescs(ConstOpDescPtr op_desc); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get input data address. /// @return vector /// - static vector GetInputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc, - bool need_convert = true); + static vector GetInputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get output data address. /// @return vector /// - static vector GetOutputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc, - bool need_convert = true); + static vector GetOutputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get workspace data address. /// @return vector /// - static vector GetWorkspaceDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc, - bool need_convert = true); + static vector GetWorkspaceDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc); - static ge::Status ConvertVirtualAddressToPhysical(uint8_t *virtual_address, uint64_t size, - uint8_t *&physical_address); + /// + /// @ingroup ge + /// @brief Get memory runtime base. + /// @return Status + /// + static Status GetRtAddress(const RuntimeParam &model_param, uintptr_t logic_addr, uint8_t *&mem_addr); }; } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc index 077ae827..920b52e6 100644 --- a/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc @@ -45,7 +45,7 @@ Status EndGraphTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin Status EndGraphTaskInfo::Distribute() { GELOGI("EndGraphTaskInfo Distribute Start."); GE_CHECK_NOTNULL(davinci_model_); - auto all_dump_model = PropertiesManager::Instance().GetAllDumpModel(); + auto all_dump_model = davinci_model_->GetDumpProperties().GetAllDumpModel(); if (all_dump_model.find(ge::DUMP_ALL_MODEL) != all_dump_model.end() || all_dump_model.find(davinci_model_->Name()) != all_dump_model.end() || all_dump_model.find(davinci_model_->OmName()) != all_dump_model.end()) { @@ -80,5 +80,4 @@ Status EndGraphTaskInfo::Distribute() { } REGISTER_TASK_INFO(RT_MODEL_TASK_MODEL_END_GRAPH, EndGraphTaskInfo); - } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h b/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h index 49bef082..82e228e6 100644 --- a/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h @@ -22,7 +22,7 @@ namespace ge { class EndGraphTaskInfo : public TaskInfo { public: - EndGraphTaskInfo() : model_(0) {} + EndGraphTaskInfo() {} ~EndGraphTaskInfo() override { model_ = nullptr; } @@ -35,10 +35,10 @@ class EndGraphTaskInfo : public TaskInfo { uint32_t GetStreamId() override { return stream_id_; } private: - rtModel_t model_; - DavinciModel *davinci_model_; - uint32_t task_id_; - uint32_t stream_id_; + rtModel_t model_{nullptr}; + DavinciModel *davinci_model_{nullptr}; + uint32_t task_id_{0}; + uint32_t stream_id_{0}; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_END_GRAPH_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc index 0ee9727a..2a79997f 100644 --- a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc @@ -42,6 +42,7 @@ HcclTaskInfo::~HcclTaskInfo() { davinci_model_ = nullptr; ops_kernel_store_ = nullptr; max_node_of_hccl_stream_ = 0; + args_ = nullptr; } Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { GELOGI("HcclTaskInfo Init Start."); @@ -60,52 +61,59 @@ Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_m GELOGI("HcclTaskInfo Init, op_index is: %u", op_index); // Get HCCL op - OpDescPtr op_desc = davinci_model->GetOpByIndex(op_index); - GE_CHECK_NOTNULL(op_desc); + op_desc_ = davinci_model->GetOpByIndex(op_index); + GE_CHECK_NOTNULL(op_desc_); // Create the kernel hccl infos - CreateKernelHcclInfo(op_desc); + CreateKernelHcclInfo(op_desc_); // Initialize the hccl_type of all kernel hccl info HcomOmeUtil::GetHcclType(task_def, kernel_hccl_infos_); // Only in Horovod scenario should get the inputName and GeShape - ret = HcomOmeUtil::GetHorovodInputs(op_desc, kernel_hccl_infos_); + ret = HcomOmeUtil::GetHorovodInputs(op_desc_, kernel_hccl_infos_); if (ret != SUCCESS) { GELOGE(FAILED, "davinci_model: GetHorovodInputs fail! domi error: %u", ret); return FAILED; } - Status dmrt = HcomOmeUtil::GetHcclDataType(op_desc, kernel_hccl_infos_); + Status dmrt = HcomOmeUtil::GetHcclDataType(op_desc_, kernel_hccl_infos_); if (dmrt != SUCCESS) { GELOGE(FAILED, "davinci_model: GetHcomDataType fail! domi error: %u", dmrt); return FAILED; } - dmrt = HcomOmeUtil::GetHcclCount(op_desc, kernel_hccl_infos_); + dmrt = HcomOmeUtil::GetHcclCount(op_desc_, kernel_hccl_infos_); if (dmrt != SUCCESS) { GELOGE(FAILED, "davinci_model: GetHcomCount fail! domi error: %u", dmrt); return FAILED; } // Only HCOMBROADCAST and HVDCALLBACKBROADCAST need to get the rootId - dmrt = HcomOmeUtil::GetAllRootId(op_desc, kernel_hccl_infos_); + dmrt = HcomOmeUtil::GetAllRootId(op_desc_, kernel_hccl_infos_); if (dmrt != SUCCESS) { GELOGE(FAILED, "davinci_model: Get rootId fail! domi error: %u", dmrt); return FAILED; } - ret = SetAddrs(op_desc, kernel_hccl_infos_); + + // GE's new process: hccl declares the number of streams required, creates a stream by GE, and sends it to hccl + ret = SetFollowStream(op_desc_, davinci_model); if (ret != SUCCESS) { - GELOGE(ret, "Setaddrs Fail."); + GELOGE(ret, "SetStream Fail."); return ret; } - // GE's new process: hccl declares the need for Workspace size, and GE allocates Workspace - ret = SetWorkspace(op_desc, kernel_hccl_infos_); + + if (davinci_model_->IsKnownNode()) { + args_ = davinci_model_->GetCurrentArgsAddr(args_offset_); + GELOGI("Known node %s args addr %p, offset %u.", op_desc_->GetName().c_str(), args_, args_offset_); + } + + ret = SetAddrs(op_desc_, kernel_hccl_infos_); if (ret != SUCCESS) { - GELOGE(ret, "SetWorkspace Fail."); + GELOGE(ret, "Setaddrs Fail."); return ret; } - // GE's new process: hccl declares the number of streams required, creates a stream by GE, and sends it to hccl - ret = SetFollowStream(op_desc, davinci_model); + // GE's new process: hccl declares the need for Workspace size, and GE allocates Workspace + ret = SetWorkspace(op_desc_, kernel_hccl_infos_); if (ret != SUCCESS) { - GELOGE(ret, "SetStream Fail."); + GELOGE(ret, "SetWorkspace Fail."); return ret; } @@ -209,40 +217,83 @@ Status HcclTaskInfo::Distribute() { GELOGI("HcclTaskInfo Distribute Success."); return SUCCESS; } + +Status HcclTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { + GE_CHECK_NOTNULL(davinci_model); + auto hccl_def = task_def.kernel_hccl(); + uint32_t op_index = hccl_def.op_index(); + GELOGI("HcclTaskInfo Init, op_index is: %u", op_index); + // Get HCCL op + auto op_desc = davinci_model->GetOpByIndex(op_index); + GE_CHECK_NOTNULL(op_desc); + GELOGI("Calc opType[%s] args size. Node name is [%s]", op_desc->GetType().c_str(), op_desc->GetName().c_str()); + // Only need the number of addr to allocate args memory + auto input_size = op_desc->GetInputsSize(); + auto output_size = op_desc->GetOutputsSize(); + auto workspace_size = op_desc->GetWorkspaceBytes().size(); + uint32_t args_size = sizeof(void *) * (input_size + output_size + workspace_size); + args_offset_ = davinci_model->GetTotalArgsSize(); + davinci_model->SetTotalArgsSize(args_size); + GELOGI("Calculate hccl task args , args_size %u, args_offset %u", args_size, args_offset_); + return SUCCESS; +} + +Status HcclTaskInfo::UpdateArgs() { + GELOGI("HcclTaskInfo::UpdateArgs in."); + const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); + input_data_addrs_ = ModelUtils::GetInputDataAddrs(rts_param, op_desc_); + output_data_addrs_ = ModelUtils::GetOutputDataAddrs(rts_param, op_desc_); + workspace_data_addrs_ = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc_); + + vector io_addrs; + io_addrs.insert(io_addrs.end(), input_data_addrs_.begin(), input_data_addrs_.end()); + io_addrs.insert(io_addrs.end(), output_data_addrs_.begin(), output_data_addrs_.end()); + io_addrs.insert(io_addrs.end(), workspace_data_addrs_.begin(), workspace_data_addrs_.end()); + + GE_CHK_STATUS_RET(davinci_model_->UpdateKnownZeroCopyAddr(io_addrs, args_offset_), + "update known node %s zero copy addr failed.", op_desc_->GetName().c_str()); + + GELOGI("HcclTaskInfo::UpdateArgs success."); + return SUCCESS; +} + Status HcclTaskInfo::SetAddrs(const std::shared_ptr &op_desc, std::vector &kernel_hccl_infos) { GE_CHECK_NOTNULL(op_desc); - if (HcomOmeUtil::CheckKernelHcclInfo(op_desc, kernel_hccl_infos) != SUCCESS) { - GELOGE(PARAM_INVALID, "HcomOmeUtil:: the number of GETaskKernelHcclInfo is invalid."); - return PARAM_INVALID; - } + GE_CHK_STATUS_RET(HcomOmeUtil::CheckKernelHcclInfo(op_desc, kernel_hccl_infos), + "HcomOmeUtil:: the number of GETaskKernelHcclInfo is invalid."); GELOGI("Set hccl task input output address, node[%s}, type[%s] kernel_hccl_infos.size[%zu].", op_desc->GetName().c_str(), op_desc->GetType().c_str(), kernel_hccl_infos.size()); if (op_desc->GetType() == HVDWAIT) { return SUCCESS; } - domi::Status dmrt; + hcclRedOp_t op_type = HCCL_REP_OP_SUM; GE_CHECK_NOTNULL(davinci_model_); GELOGI("Calc opType[%s] input address before. Node name[%s]", op_desc->GetType().c_str(), op_desc->GetName().c_str()); - auto input_data_addr_list = ModelUtils::GetInputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); - - auto output_data_addr_list = ModelUtils::GetOutputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); + if (!davinci_model_->IsKnownNode()) { + input_data_addrs_ = ModelUtils::GetInputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); + output_data_addrs_ = ModelUtils::GetOutputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); + } + void *input_data_addr = nullptr; + void *output_data_addr = nullptr; // initialize every kernel_hccl_info inputDataAddr for (size_t i = 0; i < kernel_hccl_infos.size(); i++) { std::string hccl_type = kernel_hccl_infos[i].hccl_type; - void *input_data_addr = input_data_addr_list.empty() ? nullptr : input_data_addr_list[i]; + if (davinci_model_->IsKnownNode()) { + input_data_addr = reinterpret_cast(reinterpret_cast(args_) + i); + output_data_addr = reinterpret_cast(reinterpret_cast(args_) + op_desc->GetInputsSize() + i); + GELOGI("Hccl task info known input addr %p, output addr %p.", input_data_addr, output_data_addr); + } else { + input_data_addr = input_data_addrs_.empty() ? nullptr : input_data_addrs_[i]; + output_data_addr = output_data_addrs_.empty() ? nullptr : output_data_addrs_[i]; + } kernel_hccl_infos[i].inputDataAddr = input_data_addr; - - void *output_data_addr = output_data_addr_list.empty() ? nullptr : output_data_addr_list[i]; if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE || hccl_type == HVDCALLBACKALLGATHER) { kernel_hccl_infos[i].outputDataAddr = output_data_addr; } else if (hccl_type == HCOMALLREDUCE || hccl_type == HCOMREDUCESCATTER || hccl_type == HVDCALLBACKALLREDUCE) { - dmrt = HcomOmeUtil::GetHcclOperationType(op_desc, op_type); - if (dmrt != SUCCESS) { - GELOGE(FAILED, "davinci_model: GetHcomOperationType fail! domi error: %u", dmrt); - return FAILED; - } + GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), + "davinci_model: GetHcomOperationType fail!"); kernel_hccl_infos[i].outputDataAddr = output_data_addr; kernel_hccl_infos[i].opType = op_type; } @@ -310,6 +361,7 @@ void HcclTaskInfo::CreateKernelHcclInfo(const ge::ConstOpDescPtr &op_desc) { Status HcclTaskInfo::SetWorkspace(const std::shared_ptr &op_desc, std::vector &kernel_hccl_infos) { GE_CHECK_NOTNULL(op_desc); + GE_CHECK_NOTNULL(davinci_model_); GELOGI("SetWorkspace Node[%s] opType[%s] set workspace.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); uint64_t workspace_mem_size = 0; void *workspace_addr = nullptr; @@ -319,11 +371,12 @@ Status HcclTaskInfo::SetWorkspace(const std::shared_ptr &op_desc, GELOGI("hccl need workSpaceMemSize=%lu", workspace_mem_size_tmp); if (workspace_mem_size_tmp != 0) { workspace_mem_size = workspace_mem_size_tmp; - vector workspace_data_addrs = - ModelUtils::GetWorkspaceDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); - if (!workspace_data_addrs.empty()) { - GELOGI("Get workSpaceAddr"); - workspace_addr = workspace_data_addrs[0]; + if (davinci_model_->IsKnownNode()) { + workspace_addr = reinterpret_cast(reinterpret_cast(args_) + op_desc->GetInputsSize() + + op_desc->GetOutputsSize()); + } else { + workspace_data_addrs_ = ModelUtils::GetWorkspaceDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); + workspace_addr = workspace_data_addrs_.empty() ? nullptr : workspace_data_addrs_[0]; } } } diff --git a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h index bb0a88de..cc3109f4 100644 --- a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h @@ -34,7 +34,10 @@ class HcclTaskInfo : public TaskInfo { hccl_stream_list_(), ops_kernel_store_(nullptr), private_def_(nullptr), - private_def_len_(0) {} + private_def_len_(0), + op_desc_(nullptr), + args_(nullptr), + args_offset_(0) {} ~HcclTaskInfo() override; @@ -44,6 +47,10 @@ class HcclTaskInfo : public TaskInfo { uint32_t GetTaskID() override { return id_; } + Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; + + Status UpdateArgs() override; + private: ge::Status SetAddrs(const std::string &hccl_type, const std::shared_ptr &op); @@ -72,6 +79,12 @@ class HcclTaskInfo : public TaskInfo { static std::mutex hccl_follow_stream_mutex_; static uint32_t max_node_of_hccl_stream_; vector kernel_hccl_infos_; + vector input_data_addrs_; + vector output_data_addrs_; + vector workspace_data_addrs_; + OpDescPtr op_desc_; + void *args_; + uint32_t args_offset_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_HCCL_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc index 79971529..a241e129 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc @@ -79,6 +79,9 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin return FAILED;) } + GELOGI("Node[%s] type[%s] kernel_ext_info size=%zu, ext_info_addr_=%p", op_desc_->GetName().c_str(), + op_desc_->GetType().c_str(), ext_info.size(), ext_info_addr_); + // 2.1 get loop cond variable for tensor array write uint64_t step_id_addr = 0; OpDescPtr step_id_node = davinci_model_->GetVariableOp(NODE_NAME_GLOBAL_STEP); @@ -97,6 +100,11 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuKernel(session_id, davinci_model->Id(), kernel_id) != SUCCESS, GELOGE(FAILED, "CreateAicpuKernel error."); return FAILED;) + // 2.3 Create session + GE_CHECK_NOTNULL(ModelManager::GetInstance()); + GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuSession(session_id) != SUCCESS, + GELOGE(FAILED, "CreateAicpuSession error. session id: %lu", session_id); + return FAILED;) kernel_buf_size_ = sizeof(STR_FWK_OP_KERNEL); if (davinci_model_->IsKnownNode()) { @@ -153,8 +161,8 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy to input_output_addr_ error: 0x%X", rt_ret); return FAILED;) - if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), - op_desc->GetName())) { + if (davinci_model_->GetDumpProperties().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), + op_desc->GetName())) { dump_flag_ = RT_KERNEL_DUMPFLAG; dump_args_ = input_output_addr_; } @@ -167,12 +175,7 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoLen = ext_info.size(); fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoAddr = reinterpret_cast(ext_info_addr_); - // 4. Create session - GE_CHECK_NOTNULL(ModelManager::GetInstance()); - GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuSession(session_id) != SUCCESS, - GELOGE(FAILED, "CreateAicpuSession error. session id: %lu", session_id); - return FAILED;) - // 5. Return result + // 4. Return result rtError_t rt_ret = rtMalloc(&kernel_buf_, sizeof(STR_FWK_OP_KERNEL), RT_MEMORY_HBM); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc error: 0x%X", rt_ret); return FAILED;) @@ -180,12 +183,7 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin sizeof(STR_FWK_OP_KERNEL), RT_MEMCPY_HOST_TO_DEVICE); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy error, ret: Ox%X", rt_ret); return FAILED;) - vector virtual_io_addrs; // use virtual address for zero copy key. - const vector virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); - const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); - virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); - virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); - davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, io_addrs.data(), input_output_addr_, addrs_size, 0); + davinci_model_->SetZeroCopyAddr(op_desc, io_addrs, io_addrs.data(), input_output_addr_, addrs_size, 0); GELOGI("KernelExTaskInfo Init Success. session id: %lu", session_id); return SUCCESS; @@ -207,19 +205,55 @@ Status KernelExTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciMod uint32_t mem_size = sizeof(uint64_t) * mem_length; davinci_model->SetTotalArgsSize(mem_size); GELOGI("kernel task name %s, args_size %u, args_offset %u", op_desc->GetName().c_str(), mem_size, args_offset_); + + // alloc fixed addr + string peer_input_name; + if (AttrUtils::GetStr(op_desc, ATTR_DYNAMIC_SHAPE_FIXED_ADDR, peer_input_name) && !peer_input_name.empty()) { + uint32_t output_index = davinci_model->GetFixedAddrOutputIndex(peer_input_name); + if (output_index > outputs_size) { + GELOGE(FAILED, "The output size[%zu] and output index[%u] are inconsistent.", outputs_size, output_index); + return FAILED; + } + fixed_addr_offset_ = davinci_model->GetFixedAddrsSize(peer_input_name); + auto tensor_desc = op_desc->GetOutputDesc(output_index); + int64_t tensor_size = 0; + GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); + davinci_model->SetTotalFixedAddrsSize(peer_input_name, tensor_size); + GELOGI("Calculate stream switch task args , tensor size is %ld, fixed addr offset %ld", tensor_size, + fixed_addr_offset_); + } return SUCCESS; } Status KernelExTaskInfo::UpdateArgs() { GELOGI("KernelExTaskInfo::UpdateArgs in."); const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); - vector io_addrs; vector input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc_); vector output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc_); - - io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); - io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); - + vector io_addrs; + if (!op_desc_->HasAttr(ATTR_DYNAMIC_SHAPE_FIXED_ADDR)) { + io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); + io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); + } else { + string peer_input_name; + if (AttrUtils::GetStr(op_desc_, ATTR_DYNAMIC_SHAPE_FIXED_ADDR, peer_input_name)) { + uint32_t output_index = davinci_model_->GetFixedAddrOutputIndex(peer_input_name); + if (output_index > output_data_addrs.size()) { + GELOGE(FAILED, "The output data addr size[%zu] and output index[%u] are inconsistent.", + output_data_addrs.size(), output_index); + return FAILED; + } + io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); + for (size_t i = 0; i < output_data_addrs.size(); ++i) { + if (i == output_index) { + void *fixed_addr = davinci_model_->GetCurrentFixedAddr(fixed_addr_offset_); + io_addrs.emplace_back(fixed_addr); + continue; + } + io_addrs.emplace_back(output_data_addrs[i]); + } + } + } GE_CHK_STATUS_RET(davinci_model_->UpdateKnownZeroCopyAddr(io_addrs, args_offset_), "update known node %s zero copy addr failed.", op_desc_->GetName().c_str()); @@ -231,7 +265,7 @@ Status KernelExTaskInfo::CopyTaskInfo(const domi::KernelExDef &kernel_def, const const OpDescPtr &op_desc) { // Userspace copy need virtual address. const vector workspace_data_sizes = ModelUtils::GetWorkspaceSize(op_desc); - const vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc, false); + const vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc); if (workspace_data_addrs.empty() || workspace_data_sizes.empty()) { GELOGE(FAILED, "Node:%s invalid workspace, addrs is %zu, size is %zu.", op_desc->GetName().c_str(), workspace_data_addrs.size(), workspace_data_sizes.size()); diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h index ff8f3119..b26a95ac 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h @@ -54,6 +54,7 @@ class KernelExTaskInfo : public TaskInfo { auto ret = reinterpret_cast(dump_args_); return ret; } + bool CallSaveDumpInfo() override { return true; }; private: Status CopyTaskInfo(const domi::KernelExDef &kernel_def, const RuntimeParam &rts_param, const OpDescPtr &op_desc); @@ -69,6 +70,7 @@ class KernelExTaskInfo : public TaskInfo { void *dump_args_; OpDescPtr op_desc_ = nullptr; uint32_t args_offset_ = 0; + int64_t fixed_addr_offset_ = 0; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_KERNEL_EX_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc index 7ef65555..12fe0206 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc @@ -47,16 +47,16 @@ const uint32_t kAddrLen = sizeof(void *); namespace ge { KernelTaskInfo::SuperKernelTaskInfo KernelTaskInfo::skt_info_ = { - 0, 0, 0, 0, nullptr, nullptr, {}, {}, RT_KERNEL_DEFAULT, kInvalidGroupKey, 0, nullptr}; + 0, 0, 0, 0, nullptr, nullptr, {}, {}, {}, {}, {}, RT_KERNEL_DEFAULT, kInvalidGroupKey, 0, nullptr}; Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { if (davinci_model == nullptr) { - GELOGE(PARAM_INVALID, "davinci_model is null!"); + GELOGE(PARAM_INVALID, "davinci model is null!"); return PARAM_INVALID; } davinci_model_ = davinci_model; is_l1_fusion_enable_ = davinci_model_->GetL1FusionEnableOption(); - GELOGD("KernelTaskInfo Init Start, ge.enableL1Fusion in davinci model is %d.", is_l1_fusion_enable_); + GELOGD("KernelTaskInfo init start, ge.enableL1Fusion in davinci model is %d.", is_l1_fusion_enable_); Status ret = SetStream(task_def.stream_id(), davinci_model_->GetStreamList()); if (ret != SUCCESS) { @@ -73,7 +73,7 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci // get opdesc op_desc_ = davinci_model_->GetOpByIndex(context.op_index()); if (op_desc_ == nullptr) { - GELOGE(INTERNAL_ERROR, "Get op_desc failed, index is out of range!"); + GELOGE(INTERNAL_ERROR, "Get op desc failed, index is out of range!"); return INTERNAL_ERROR; } (void)AttrUtils::GetBool(*op_desc_, ATTR_N_BATCH_SPILT, is_n_batch_spilt_); @@ -138,14 +138,21 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci ret = InitCceTask(kernel_def); } - GELOGD("KernelTaskInfo Init finish, result=%u.", ret); + GELOGD("KernelTaskInfo init finish, result=%u.", ret); return ret; } Status KernelTaskInfo::SaveSKTDumpInfo() { GE_CHECK_NOTNULL(davinci_model_); - davinci_model_->SaveDumpTask(skt_info_.last_task_id, skt_info_.last_stream_id, skt_info_.last_op, - skt_info_.last_dump_args); + if (skt_dump_flag_ == RT_KERNEL_DEFAULT) { + GELOGD("no need save skt dump info"); + return SUCCESS; + } + // all op in super kernel share one taskid and streamid + for (size_t i = 0; i < skt_info_.op_desc_list.size(); i++) { + davinci_model_->SaveDumpTask(skt_info_.last_task_id, skt_info_.last_stream_id, skt_info_.op_desc_list[i], + skt_info_.dump_args_list[i]); + } return SUCCESS; } @@ -187,6 +194,9 @@ Status KernelTaskInfo::SKTFinalize() { GELOGI("SuperKernel Distribute [skt_id:%u]", skt_id_); skt_info_.kernel_list.clear(); skt_info_.arg_list.clear(); + skt_info_.dump_flag_list.clear(); + skt_info_.op_desc_list.clear(); + skt_info_.dump_args_list.clear(); skt_info_.last_stream = nullptr; skt_info_.last_block_dim = 0; skt_info_.last_sm_desc = sm_desc_; @@ -197,6 +207,15 @@ Status KernelTaskInfo::SKTFinalize() { return SUCCESS; } +uint32_t KernelTaskInfo::GetDumpFlag() { + for (auto flag : skt_info_.dump_flag_list) { + if (flag == RT_KERNEL_DUMPFLAG) { + return RT_KERNEL_DUMPFLAG; + } + } + return RT_KERNEL_DEFAULT; +} + Status KernelTaskInfo::SuperKernelLaunch() { if (skt_info_.kernel_list.empty()) { GELOGI("SuperKernelLaunch: Skt_kernel_list has no task, just return"); @@ -206,7 +225,7 @@ Status KernelTaskInfo::SuperKernelLaunch() { auto &skt_kernel_list = skt_info_.kernel_list; auto &skt_arg_list = skt_info_.arg_list; GELOGI("SuperKernelLaunch: Skt_kernel_list size[%d] skt_arg_list[%d]", skt_kernel_list.size(), skt_arg_list.size()); - if (skt_kernel_list.size() == kSKTSingleSize) { + if (skt_kernel_list.size() == kSKTSingleSize && skt_arg_list.size() == kSKTSingleSize) { rt_ret = rtKernelLaunchWithFlag(skt_info_.kernel_list[0], static_cast(skt_info_.last_block_dim), skt_info_.arg_list[0], skt_info_.last_args_size, static_cast(skt_info_.last_sm_desc), skt_info_.last_stream, @@ -215,6 +234,7 @@ Status KernelTaskInfo::SuperKernelLaunch() { GELOGE(RT_FAILED, "SuperKernelLaunch: Call rt api failed, ret: 0x%X", rt_ret); return RT_FAILED; } + call_save_dump_ = true; GE_CHK_STATUS_RET(SKTFinalize(), "Skt finalize failed"); return SUCCESS; } @@ -226,18 +246,22 @@ Status KernelTaskInfo::SuperKernelLaunch() { return RT_FAILED; } // Call the fuse API - skt::SuperKernel *superKernel = nullptr; + std::unique_ptr superKernel = nullptr; if (factory->FuseKernels(skt_kernel_list, skt_arg_list, skt_info_.last_block_dim, superKernel) != SUCCESS) { GELOGE(RT_FAILED, "SuperKernelLaunch: fuse call failed"); return RT_FAILED; } // Launch a super kernel - if (superKernel->Launch(skt_info_.last_stream, RT_KERNEL_DUMPFLAG) != SUCCESS) { + skt_dump_flag_ = GetDumpFlag(); + if (superKernel->Launch(skt_info_.last_stream, skt_dump_flag_) != SUCCESS) { GELOGE(RT_FAILED, "SuperKernelLaunch: launch failed"); return RT_FAILED; } GELOGI("SuperKernelLaunch: success[skt_kernel_list size[%zu] skt_arg_list[%zu]]", skt_kernel_list.size(), skt_arg_list.size()); + // record skt addr for release + superkernel_dev_nav_table_ = superKernel->GetNavTablePtr(); + superkernel_device_args_addr_ = superKernel->GetDeviceArgsPtr(); GE_CHK_STATUS_RET(SKTFinalize(), "Skt finalize failed"); return SUCCESS; } @@ -250,6 +274,9 @@ Status KernelTaskInfo::SaveSuperKernelInfo() { skt_info_.last_args_size = args_size_; skt_info_.last_sm_desc = sm_desc_; skt_info_.last_dump_flag = dump_flag_; + skt_info_.dump_flag_list.push_back(dump_flag_); + skt_info_.op_desc_list.push_back(op_desc_); + skt_info_.dump_args_list.push_back(reinterpret_cast(dump_args_)); skt_info_.last_group_key = group_key_; skt_info_.last_dump_args = reinterpret_cast(dump_args_); skt_info_.last_op = op_desc_; @@ -328,6 +355,7 @@ Status KernelTaskInfo::SuperKernelDistribute() { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return FAILED; } + call_save_dump_ = true; UpdateTaskId(); GELOGI("Current Common Task Distribute [taskid:%u]", task_id_); } else { @@ -356,6 +384,7 @@ Status KernelTaskInfo::Distribute() { rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast(so_name_.c_str()), reinterpret_cast(kernel_name_.c_str()), 1, args_, args_size_, nullptr, stream_, dump_flag_); + call_save_dump_ = true; } else { /* default: not skt launch */ GELOGI( @@ -369,6 +398,7 @@ Status KernelTaskInfo::Distribute() { // call rtKernelLaunch for current task rt_ret = rtKernelLaunchWithFlag(stub_func_, block_dim_, args_, args_size_, static_cast(sm_desc_), stream_, dump_flag_); + call_save_dump_ = true; } } if (rt_ret != RT_ERROR_NONE) { @@ -392,9 +422,31 @@ Status KernelTaskInfo::UpdateArgs() { vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc_); vector io_addrs; - io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); - io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); - io_addrs.insert(io_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); + if (!op_desc_->HasAttr(ATTR_DYNAMIC_SHAPE_FIXED_ADDR)) { + io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); + io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); + io_addrs.insert(io_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); + } else { + string peer_input_name; + if (AttrUtils::GetStr(op_desc_, ATTR_DYNAMIC_SHAPE_FIXED_ADDR, peer_input_name)) { + uint32_t output_index = davinci_model_->GetFixedAddrOutputIndex(peer_input_name); + if (output_index > output_data_addrs.size()) { + GELOGE(FAILED, "The output data addr size[%zu] and output index[%u] are inconsistent.", + output_data_addrs.size(), output_index); + return FAILED; + } + io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); + for (size_t i = 0; i < output_data_addrs.size(); ++i) { + if (i == output_index) { + void *fixed_addr = davinci_model_->GetCurrentFixedAddr(fixed_addr_offset_); + io_addrs.emplace_back(fixed_addr); + continue; + } + io_addrs.emplace_back(output_data_addrs[i]); + } + io_addrs.insert(io_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); + } + } GE_CHK_STATUS_RET(davinci_model_->UpdateKnownZeroCopyAddr(io_addrs, args_offset_), "update known node %s zero copy addr failed.", op_desc_->GetName().c_str()); @@ -408,6 +460,8 @@ Status KernelTaskInfo::Release() { return SUCCESS; } FreeRtMem(&args_); + FreeRtMem(&superkernel_device_args_addr_); + FreeRtMem(&superkernel_dev_nav_table_); FreeRtMem(&flowtable_); FreeRtMem(&custom_info_.input_descs); FreeRtMem(&custom_info_.input_addrs); @@ -472,6 +526,29 @@ Status KernelTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel args_offset_ = davinci_model->GetTotalArgsSize(); davinci_model->SetTotalArgsSize(args_size); GELOGI("kernel task name , args_size %u, args_offset %u", args_size, args_offset_); + + // get opcontext stored in model + const domi::KernelContext &context = kernel_def.context(); + // get opdesc + op_desc_ = davinci_model->GetOpByIndex(context.op_index()); + GE_CHECK_NOTNULL(op_desc_); + // alloc fixed addr + string peer_input_name; + if (AttrUtils::GetStr(op_desc_, ATTR_DYNAMIC_SHAPE_FIXED_ADDR, peer_input_name) && !peer_input_name.empty()) { + uint32_t output_index = davinci_model->GetFixedAddrOutputIndex(peer_input_name); + if (output_index > op_desc_->GetOutputsSize()) { + GELOGE(FAILED, "The output size[%zu] and output index[%u] are inconsistent.", op_desc_->GetOutputsSize(), + output_index); + return FAILED; + } + fixed_addr_offset_ = davinci_model->GetFixedAddrsSize(peer_input_name); + auto tensor_desc = op_desc_->GetOutputDesc(output_index); + int64_t tensor_size = 0; + GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); + davinci_model->SetTotalFixedAddrsSize(peer_input_name, tensor_size); + GELOGI("Calculate stream switch task args , tensor size is %ld, fixed addr offset %ld", tensor_size, + fixed_addr_offset_); + } return SUCCESS; } @@ -549,8 +626,8 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne return FAILED; } - if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), - op_desc->GetName())) { + if (davinci_model_->GetDumpProperties().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), + op_desc->GetName())) { dump_flag_ = RT_KERNEL_DUMPFLAG; dump_args_ = static_cast(args_) + offset; } @@ -561,10 +638,8 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne } vector virtual_io_addrs; // use virtual address for zero copy key. - const vector virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); - const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); - virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); - virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); + virtual_io_addrs.insert(virtual_io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); + virtual_io_addrs.insert(virtual_io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, args_info.data(), args_, args_size_, offset); GELOGD("Do InitTVMTask end"); @@ -602,7 +677,6 @@ Status KernelTaskInfo::InitAICPUCustomTask(uint32_t op_index, const domi::Kernel const std::vector output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); Status ret = StoreInputOutputTensor(input_data_addrs, output_data_addrs, ModelUtils::GetInputDescs(op_desc), ModelUtils::GetOutputDescs(op_desc)); - if (ret != SUCCESS) { GELOGE(ret, "StoreInputOutputTensor Failed"); return ret; @@ -667,11 +741,9 @@ Status KernelTaskInfo::InitAICPUCustomTask(uint32_t op_index, const domi::Kernel return RT_FAILED; } - const vector virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); - const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); - davinci_model_->SetZeroCopyAddr(op_desc, virtual_in_addrs, input_data_addrs.data(), custom_info_.input_addrs, - virtual_in_addrs.size() * kAddrLen, 0); - davinci_model_->SetZeroCopyAddr(op_desc, virtual_out_addrs, output_data_addrs.data(), custom_info_.output_addrs, + davinci_model_->SetZeroCopyAddr(op_desc, input_data_addrs, input_data_addrs.data(), custom_info_.input_addrs, + input_data_addrs.size() * kAddrLen, 0); + davinci_model_->SetZeroCopyAddr(op_desc, output_data_addrs, output_data_addrs.data(), custom_info_.output_addrs, output_data_addrs.size() * kAddrLen, 0); return SUCCESS; } @@ -801,6 +873,9 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k GELOGE(init_ret, "Init aicpu task ext info failed, ext_info size=%zu", ext_info.size()); return init_ret; } + GELOGI("Node[%s] type[%s] kernel_ext_info size=%zu, aicpu_ext_info_addr_=%p", op_desc_->GetName().c_str(), + op_desc_->GetType().c_str(), ext_info.size(), aicpu_ext_info_addr_); + aicpu_param_head->extInfoAddr = reinterpret_cast(aicpu_ext_info_addr_); aicpu_param_head->extInfoLength = reinterpret_cast(ext_info.size()); @@ -819,19 +894,13 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k return RT_FAILED; } - if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), - op_desc->GetName())) { + if (davinci_model_->GetDumpProperties().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), + op_desc->GetName())) { dump_flag_ = RT_KERNEL_DUMPFLAG; dump_args_ = static_cast(args_) + sizeof(aicpu::AicpuParamHead); } - vector virtual_io_addrs; // use virtual address for zero copy key. - const vector virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); - const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); - virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); - virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); - davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, args_addr.get(), args_, args_size_, - sizeof(aicpu::AicpuParamHead)); + davinci_model_->SetZeroCopyAddr(op_desc, io_addrs, args_addr.get(), args_, args_size_, sizeof(aicpu::AicpuParamHead)); return SUCCESS; } diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h index 41ed5728..04cd6312 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h @@ -61,6 +61,8 @@ class KernelTaskInfo : public TaskInfo { sm_desc_ = nullptr; flowtable_ = nullptr; args_ = nullptr; + superkernel_device_args_addr_ = nullptr; + superkernel_dev_nav_table_ = nullptr; } Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; @@ -88,6 +90,8 @@ class KernelTaskInfo : public TaskInfo { uint32_t GetSktTaskID() override { return skt_id_; } + bool CallSaveDumpInfo() override { return call_save_dump_; }; + cce::ccOpContext ctx_; FusionOpInfo fusion_op_info_; @@ -130,6 +134,7 @@ class KernelTaskInfo : public TaskInfo { void UpdateSKTTaskId(); Status SKTFinalize(); Status SuperKernelLaunch(); + uint32_t GetDumpFlag(); Status SaveSuperKernelInfo(); bool IsMarkedLastNode(); bool IsMarkedFirstNode(); @@ -153,6 +158,8 @@ class KernelTaskInfo : public TaskInfo { OpDescPtr op_desc_; DavinciModel *davinci_model_; uint32_t args_offset_ = 0; + int64_t fixed_addr_offset_ = 0; + bool call_save_dump_ = false; // aicpu ext_info device mem void *aicpu_ext_info_addr_ = nullptr; @@ -164,6 +171,9 @@ class KernelTaskInfo : public TaskInfo { bool is_n_batch_spilt_; int64_t group_key_; bool has_group_key_; + uint32_t skt_dump_flag_ = RT_KERNEL_DEFAULT; + void *superkernel_device_args_addr_ = nullptr; + void *superkernel_dev_nav_table_ = nullptr; struct AICPUCustomInfo { void *input_descs = nullptr; @@ -183,6 +193,9 @@ class KernelTaskInfo : public TaskInfo { void *last_sm_desc; std::vector kernel_list; std::vector arg_list; + std::vector dump_flag_list; + std::vector op_desc_list; + std::vector dump_args_list; uint32_t last_dump_flag; int64_t last_group_key; uintptr_t last_dump_args; diff --git a/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc index 818307eb..162cf00d 100644 --- a/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc @@ -16,8 +16,8 @@ #include "graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h" -#include "graph/load/new_model_manager/davinci_model.h" #include "graph/debug/ge_attr_define.h" +#include "graph/load/new_model_manager/davinci_model.h" namespace ge { constexpr uint8_t kLabelSwitchIndexNum = 1; @@ -59,7 +59,13 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo op_desc->GetName().c_str(), input_data_addr.size(), kLabelSwitchIndexNum); return INTERNAL_ERROR; } - index_value_ = input_data_addr[0]; + + if (davinci_model->IsKnownNode()) { + index_value_ = davinci_model->GetCurrentFixedAddr(fixed_addr_offset_); + } else { + index_value_ = input_data_addr[0]; + } + davinci_model->DisableZeroCopy(index_value_); std::vector label_idx_list; @@ -124,5 +130,28 @@ Status LabelSwitchByIndexTaskInfo::Distribute() { return SUCCESS; } +Status LabelSwitchByIndexTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { + GE_CHECK_NOTNULL(davinci_model); + auto label_switch = task_def.label_switch_by_index(); + uint32_t op_index = label_switch.op_index(); + GELOGI("Begin to calculate args, op_index is: %u", op_index); + auto op_desc = davinci_model->GetOpByIndex(op_index); + GE_CHECK_NOTNULL(op_desc); + GELOGI("Calc opType[%s] args size. Node name is [%s]", op_desc->GetType().c_str(), op_desc->GetName().c_str()); + if (op_desc->GetInputsSize() != kLabelSwitchIndexNum) { + GELOGE(FAILED, "Label switch op only have one data input. Now input size is %zu", op_desc->GetInputsSize()); + return FAILED; + } + string input_tensor_name = op_desc->GetInputNameByIndex(0); + fixed_addr_offset_ = davinci_model->GetFixedAddrsSize(input_tensor_name); + auto tensor_desc = op_desc->GetInputDesc(0); + int64_t tensor_size = 0; + GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); + davinci_model->SetTotalFixedAddrsSize(input_tensor_name, tensor_size); + GELOGI("Calculate stream switchn task args , tensor_size %ld, fixed_addr_offset %ld", tensor_size, + fixed_addr_offset_); + return SUCCESS; +} + REGISTER_TASK_INFO(RT_MODEL_TASK_STREAM_LABEL_SWITCH_BY_INDEX, LabelSwitchByIndexTaskInfo); } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h b/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h index 1a644736..4cb39c95 100644 --- a/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h @@ -22,7 +22,8 @@ namespace ge { class LabelSwitchByIndexTaskInfo : public TaskInfo { public: - LabelSwitchByIndexTaskInfo() : index_value_(nullptr), branch_max_(0), args_(nullptr), args_size_(0) {} + LabelSwitchByIndexTaskInfo() + : index_value_(nullptr), branch_max_(0), args_(nullptr), args_size_(0), fixed_addr_offset_(0) {} ~LabelSwitchByIndexTaskInfo() override; @@ -30,13 +31,15 @@ class LabelSwitchByIndexTaskInfo : public TaskInfo { Status Distribute() override; + Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; + private: void *index_value_; // switch index input. uint32_t branch_max_; // max branch count. void *args_; // label info memory. uint32_t args_size_; // label info length. - std::vector label_list_; + int64_t fixed_addr_offset_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ \ No newline at end of file diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc index e9d99189..af32b44f 100644 --- a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc @@ -21,9 +21,9 @@ namespace ge { Status MemcpyAddrAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { - GELOGI("MemcpyAddrAsyncTaskInfo Init Start."); + GELOGI("MemcpyAddrAsyncTaskInfo Init Start"); if (davinci_model == nullptr) { - GELOGE(PARAM_INVALID, "davinci_model is null!"); + GELOGE(PARAM_INVALID, "davinci_model is null"); return PARAM_INVALID; } @@ -32,45 +32,27 @@ Status MemcpyAddrAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel return ret; } - auto memcpy_async_def = task_def.memcpy_async(); - uint32_t op_index = memcpy_async_def.op_index(); - OpDescPtr op_desc = davinci_model->GetOpByIndex(op_index); + const auto &memcpy_async = task_def.memcpy_async(); + OpDescPtr op_desc = davinci_model->GetOpByIndex(memcpy_async.op_index()); if (op_desc == nullptr) { - GELOGE(INTERNAL_ERROR, "Init MemcpyAddrAsyncTaskInfo error, index is out of range!"); + GELOGE(INTERNAL_ERROR, "Task op index:%u out of range", memcpy_async.op_index()); return INTERNAL_ERROR; } - uint64_t logic_dst = memcpy_async_def.dst(); - uint64_t logic_src = memcpy_async_def.src(); - - dst_max_ = memcpy_async_def.dst_max(); - - uint64_t update_base_addr = 0; - ret = GetUpdateBaseAddr(davinci_model, logic_src, update_base_addr); + ret = ModelUtils::GetRtAddress(davinci_model->GetRuntimeParam(), memcpy_async.src(), src_); if (ret != SUCCESS) { return ret; } - src_ = reinterpret_cast(update_base_addr + logic_src); - if (src_ == nullptr) { - GELOGE(PARAM_INVALID, "src_ is null!"); - return PARAM_INVALID; - } - uint64_t mem_base = reinterpret_cast(davinci_model->MemBase()); - uint64_t logic_mem_base = davinci_model->GetRtBaseAddr(); - dst_ = reinterpret_cast(reinterpret_cast(mem_base + (logic_dst - logic_mem_base))); - if (dst_ == nullptr) { - GELOGE(PARAM_INVALID, "dst_ is null!"); - return PARAM_INVALID; + ret = ModelUtils::GetRtAddress(davinci_model->GetRuntimeParam(), memcpy_async.dst(), dst_); + if (ret != SUCCESS) { + return ret; } vector io_addrs; io_addrs.emplace_back(src_); io_addrs.emplace_back(dst_); - count_ = memcpy_async_def.count(); - kind_ = memcpy_async_def.kind(); - // malloc args memory size_t args_size = sizeof(void *) * io_addrs.size(); rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); @@ -88,20 +70,18 @@ Status MemcpyAddrAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel return RT_FAILED; } - // Just dest addr need zero copy. - davinci_model->SetZeroCopyAddr(op_desc, {dst_}, io_addrs.data(), args_, args_size, sizeof(void *)); - - GELOGI("InitMemcpyAddrAsyncTaskInfo, logic_src:%p, logic_dst:%p, src:%p, dst:%p, src_args:%p, dst_args:%p", - reinterpret_cast(reinterpret_cast(logic_src)), - reinterpret_cast(reinterpret_cast(logic_dst)), src_, dst_, args_, - reinterpret_cast(reinterpret_cast(args_) + args_size)); + count_ = memcpy_async.count(); + kind_ = memcpy_async.kind(); + dst_max_ = memcpy_async.dst_max(); + GELOGI("InitMemcpyAddrAsyncTaskInfo, logic[0x%lx, 0x%lx], src:%p, dst:%p, max:%lu, count:%lu, args:%p, size:%zu", + memcpy_async.src(), memcpy_async.dst(), src_, dst_, dst_max_, count_, args_, args_size); + davinci_model->SetZeroCopyAddr(op_desc, io_addrs, io_addrs.data(), args_, args_size, 0); return SUCCESS; } Status MemcpyAddrAsyncTaskInfo::Distribute() { - GELOGI("MemcpyAddrAsyncTaskInfo Distribute Start."); - GELOGI("Distribute MemcpyAddrAsync, dst_max:%lu, count:%lu, kind:%u.", dst_max_, count_, kind_); + GELOGI("MemcpyAddrAsyncTaskInfo Distribute Start, dst_max:%lu, count:%lu, kind:%u", dst_max_, count_, kind_); rtError_t rt_ret = rtMemcpyAsync(reinterpret_cast(reinterpret_cast(args_) + sizeof(void *)), dst_max_, args_, count_, static_cast(kind_), stream_); @@ -113,39 +93,5 @@ Status MemcpyAddrAsyncTaskInfo::Distribute() { return SUCCESS; } -Status MemcpyAddrAsyncTaskInfo::GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, - uint64_t &base_addr) { - GE_CHECK_NOTNULL(davinci_model); - uint64_t data_base_addr = - static_cast(reinterpret_cast(davinci_model->MemBase())) - davinci_model->GetRtBaseAddr(); - uint64_t weight_base_addr = static_cast(reinterpret_cast(davinci_model->WeightsMemBase())) - - davinci_model->GetRtWeightAddr(); - uint64_t var_base_addr = - static_cast(reinterpret_cast(davinci_model->VarMemBase())) - davinci_model->GetRtVarAddr(); - - uint64_t data_base_addr_start = davinci_model->GetRtBaseAddr(); - uint64_t data_base_addr_end = davinci_model->GetRtBaseAddr() + davinci_model->TotalMemSize(); - uint64_t wight_base_addr_start = davinci_model->GetRtWeightAddr(); - uint64_t wight_base_addr_end = davinci_model->GetRtWeightAddr() + davinci_model->TotalWeightsMemSize(); - uint64_t varible_base_addr_start = davinci_model->GetRtVarAddr(); - uint64_t varible_base_addr_end = davinci_model->GetRtVarAddr() + davinci_model->TotalVarMemSize(); - - if ((data_base_addr_start <= update_addr) && (update_addr <= data_base_addr_end)) { - base_addr = data_base_addr; - GELOGI("The update_addr is data address."); - } else if ((wight_base_addr_start <= update_addr) && (update_addr <= wight_base_addr_end)) { - base_addr = weight_base_addr; - GELOGI("The update_addr is weight address."); - } else if ((varible_base_addr_start <= update_addr) && (update_addr <= varible_base_addr_end)) { - base_addr = var_base_addr; - GELOGI("The update_addr is variable address."); - } else if (update_addr != 0) { - base_addr = 0; - GELOGE(PARAM_INVALID, "The update_addr is abnormal."); - return PARAM_INVALID; - } - return SUCCESS; -} - REGISTER_TASK_INFO(RT_MODEL_TASK_MEMCPY_ADDR_ASYNC, MemcpyAddrAsyncTaskInfo); } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h index 9252e43a..f8bf8a90 100644 --- a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h @@ -16,6 +16,7 @@ #ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ADDR_ASYNC_TASK_INFO_H_ #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ADDR_ASYNC_TASK_INFO_H_ + #include "graph/load/new_model_manager/task_info/task_info.h" namespace ge { @@ -32,9 +33,8 @@ class MemcpyAddrAsyncTaskInfo : public TaskInfo { if (ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret); } + args_ = nullptr; } - - args_ = nullptr; } Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; @@ -42,11 +42,9 @@ class MemcpyAddrAsyncTaskInfo : public TaskInfo { Status Distribute() override; private: - Status GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, uint64_t &base_addr); - - void *dst_; + uint8_t *dst_; uint64_t dst_max_; - void *src_; + uint8_t *src_; void *args_; uint64_t count_; uint32_t kind_; diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc index 82eabe69..c2b56436 100644 --- a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc @@ -21,9 +21,9 @@ namespace ge { Status MemcpyAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { - GELOGI("MemcpyAsyncTaskInfo Init Start."); + GELOGI("MemcpyAsyncTaskInfo Init Start"); if (davinci_model == nullptr) { - GELOGE(PARAM_INVALID, "davinci_model is null!"); + GELOGE(PARAM_INVALID, "davinci_model is null"); return PARAM_INVALID; } @@ -32,35 +32,38 @@ Status MemcpyAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da return ret; } - auto memcpy_async_def = task_def.memcpy_async(); - uint64_t logic_dst = memcpy_async_def.dst(); - uint64_t logic_src = memcpy_async_def.src(); - - dst_max_ = memcpy_async_def.dst_max(); - - uint64_t update_base_addr = 0; - ret = GetUpdateBaseAddr(davinci_model, logic_src, update_base_addr); + memcpy_async = task_def.memcpy_async(); + count_ = memcpy_async.count(); + kind_ = memcpy_async.kind(); + dst_max_ = memcpy_async.dst_max(); + if (davinci_model->IsKnownNode()) { + src_ = reinterpret_cast(davinci_model_->GetCurrentArgsAddr(args_offset_)); + dst_ = reinterpret_cast(reinterpret_cast(src_) + sizeof(void *)); + // for zero copy + kind_ = RT_MEMCPY_ADDR_DEVICE_TO_DEVICE; + GELOGI("MemcpyAsyncTaskInfo src_ %p, dst_ %p, args_offset %u.", src_, dst_, args_offset_); + return SUCCESS; + } + ret = ModelUtils::GetRtAddress(davinci_model->GetRuntimeParam(), memcpy_async.src(), src_); if (ret != SUCCESS) { return ret; } - src_ = reinterpret_cast(update_base_addr + logic_src); - davinci_model->DisableZeroCopy(src_); - uint64_t mem_base = reinterpret_cast(davinci_model->MemBase()); - uint64_t logic_mem_base = davinci_model->GetRtBaseAddr(); - dst_ = reinterpret_cast(mem_base + (logic_dst - logic_mem_base)); + ret = ModelUtils::GetRtAddress(davinci_model->GetRuntimeParam(), memcpy_async.dst(), dst_); + if (ret != SUCCESS) { + return ret; + } - count_ = memcpy_async_def.count(); - kind_ = memcpy_async_def.kind(); - GELOGI("MemcpyAsyncTaskInfo Init Success, logic_src:%p, logic_dst:%p, src:%p, dst:%p", - reinterpret_cast(reinterpret_cast(logic_src)), - reinterpret_cast(reinterpret_cast(logic_dst)), src_, dst_); + GELOGI("MemcpyAsyncTaskInfo Init Success, logic[0x%lx, 0x%lx], src:%p, dst:%p, max:%lu, count:%lu", + memcpy_async.src(), memcpy_async.dst(), src_, dst_, dst_max_, count_); + davinci_model->DisableZeroCopy(src_); + davinci_model->DisableZeroCopy(dst_); return SUCCESS; } Status MemcpyAsyncTaskInfo::Distribute() { - GELOGI("MemcpyAsyncTaskInfo Distribute Start. dst_max:%lu, count:%lu, kind:%u.", dst_max_, count_, kind_); + GELOGI("MemcpyAsyncTaskInfo Distribute Start. dst_max:%lu, count:%lu, kind:%u", dst_max_, count_, kind_); rtError_t rt_ret = rtMemcpyAsync(dst_, dst_max_, src_, count_, static_cast(kind_), stream_); if (rt_ret != RT_ERROR_NONE) { @@ -68,40 +71,41 @@ Status MemcpyAsyncTaskInfo::Distribute() { return RT_FAILED; } - GELOGI("MemcpyAsyncTaskInfo Distribute Success."); + GELOGI("MemcpyAsyncTaskInfo Distribute Success"); return SUCCESS; } -Status MemcpyAsyncTaskInfo::GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, uint64_t &base_addr) { - GE_CHECK_NOTNULL(davinci_model); - uint64_t data_base_addr = - reinterpret_cast(reinterpret_cast(davinci_model->MemBase())) - davinci_model->GetRtBaseAddr(); - uint64_t weight_base_addr = reinterpret_cast(reinterpret_cast(davinci_model->WeightsMemBase())) - - davinci_model->GetRtWeightAddr(); - uint64_t var_base_addr = reinterpret_cast(reinterpret_cast(davinci_model->VarMemBase())) - - davinci_model->GetRtVarAddr(); - - uint64_t data_base_addr_start = davinci_model->GetRtBaseAddr(); - uint64_t data_base_addr_end = davinci_model->GetRtBaseAddr() + davinci_model->TotalMemSize(); - uint64_t wight_base_addr_start = davinci_model->GetRtWeightAddr(); - uint64_t wight_base_addr_end = davinci_model->GetRtWeightAddr() + davinci_model->TotalWeightsMemSize(); - uint64_t varible_base_addr_start = davinci_model->GetRtVarAddr(); - uint64_t varible_base_addr_end = davinci_model->GetRtVarAddr() + davinci_model->TotalVarMemSize(); - - if ((data_base_addr_start <= update_addr) && (update_addr <= data_base_addr_end)) { - base_addr = data_base_addr; - GELOGI("The update_addr is data address."); - } else if ((wight_base_addr_start <= update_addr) && (update_addr <= wight_base_addr_end)) { - base_addr = weight_base_addr; - GELOGI("The update_addr is weight address."); - } else if ((varible_base_addr_start <= update_addr) && (update_addr <= varible_base_addr_end)) { - base_addr = var_base_addr; - GELOGI("The update_addr is variable address."); - } else if (update_addr != 0) { - base_addr = 0; - GELOGE(PARAM_INVALID, "The update_addr is abnormal."); - return PARAM_INVALID; +Status MemcpyAsyncTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { + // the num of src and dst size is 2 + uint32_t args_size = sizeof(void *) * 2; + args_offset_ = davinci_model->GetTotalArgsSize(); + davinci_model->SetTotalArgsSize(args_size); + davinci_model_ = davinci_model; + GELOGI("MemcpyAsyncTaskInfo kernel args_size %u, args_offset %u", args_size, args_offset_); + return SUCCESS; +} + +Status MemcpyAsyncTaskInfo::UpdateArgs() { + GELOGI("MemcpyAsyncTaskInfo::UpdateArgs in."); + GE_CHECK_NOTNULL(davinci_model_); + Status ret = ModelUtils::GetRtAddress(davinci_model_->GetRuntimeParam(), memcpy_async.src(), src_); + if (ret != SUCCESS) { + return ret; + } + + ret = ModelUtils::GetRtAddress(davinci_model_->GetRuntimeParam(), memcpy_async.dst(), dst_); + if (ret != SUCCESS) { + return ret; } + + vector io_addrs; + io_addrs.emplace_back(reinterpret_cast(src_)); + io_addrs.emplace_back(reinterpret_cast(dst_)); + + GE_CHK_STATUS_RET(davinci_model_->UpdateKnownZeroCopyAddr(io_addrs, args_offset_), + "update memcpyasync in known node zero copy addr failed."); + + GELOGI("MemcpyAsyncTaskInfo::UpdateArgs success."); return SUCCESS; } diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h index 02872f34..c3daa862 100644 --- a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h @@ -16,6 +16,7 @@ #ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ASYNC_TASK_INFO_H_ #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ASYNC_TASK_INFO_H_ + #include "graph/load/new_model_manager/task_info/task_info.h" namespace ge { @@ -32,14 +33,19 @@ class MemcpyAsyncTaskInfo : public TaskInfo { Status Distribute() override; - private: - Status GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, uint64_t &base_addr); + Status UpdateArgs() override; - void *dst_; + Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; + + private: + uint8_t *dst_; uint64_t dst_max_; - void *src_; + uint8_t *src_; uint64_t count_; uint32_t kind_; + DavinciModel *davinci_model_ = nullptr; + uint32_t args_offset_ = 0; + domi::MemcpyAsyncDef memcpy_async; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ASYNC_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc index a1d2f143..0ebaf573 100644 --- a/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc @@ -42,16 +42,11 @@ Status StreamSwitchTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *d auto stream_switch_def = task_def.stream_switch(); uint32_t op_index = stream_switch_def.op_index(); - // get StreamSwitch op OpDescPtr op_desc = davinci_model->GetOpByIndex(op_index); GE_CHECK_NOTNULL(op_desc); auto input_data_addr = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); - if (!input_data_addr.empty() && input_data_addr.size() >= STREAM_SWITCH_INPUT_NUM) { - input_ptr_ = input_data_addr[0]; - value_ptr_ = input_data_addr[1]; - } - + SetInputAndValuePtr(davinci_model, input_data_addr); uint32_t cond = 0; if (!AttrUtils::GetInt(op_desc, ATTR_NAME_STREAM_SWITCH_COND, cond)) { GELOGE(INTERNAL_ERROR, "StreamSwitchOp get attr STREAM_SWITCH_COND fail."); @@ -115,6 +110,42 @@ Status StreamSwitchTaskInfo::Distribute() { GELOGI("StreamSwitchTaskInfo Distribute Success. cond:%d, stream:%p, datatype:%d.", cond_, true_stream_, data_type_); return SUCCESS; } +Status StreamSwitchTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { + GE_CHECK_NOTNULL(davinci_model); + auto stream_switch_def = task_def.stream_switch(); + uint32_t op_index = stream_switch_def.op_index(); + GELOGI("Begin to calculate args, op_index is: %u", op_index); + auto op_desc = davinci_model->GetOpByIndex(op_index); + GE_CHECK_NOTNULL(op_desc); + GELOGI("Calc opType[%s] args size. Node name is [%s]", op_desc->GetType().c_str(), op_desc->GetName().c_str()); + if (op_desc->GetInputsSize() != STREAM_SWITCH_INPUT_NUM) { + GELOGE(FAILED, "Stream switch op only have one data input. Now input size is %zu", op_desc->GetInputsSize()); + return FAILED; + } + for (uint32_t i = 0; i < STREAM_SWITCH_INPUT_NUM; ++i) { + string input_tensor_name = op_desc->GetInputNameByIndex(i); + int64_t fixed_addr_offset = davinci_model->GetFixedAddrsSize(input_tensor_name); + fixed_addr_offset_.emplace_back(fixed_addr_offset); + auto tensor_desc = op_desc->GetInputDesc(i); + int64_t tensor_size = 0; + GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); + davinci_model->SetTotalFixedAddrsSize(input_tensor_name, tensor_size); + GELOGI("Calculate stream switch task args , tensor size is %ld, fixed addr[%u] offset %ld", tensor_size, i, + fixed_addr_offset); + } + return SUCCESS; +} +void StreamSwitchTaskInfo::SetInputAndValuePtr(DavinciModel *davinci_model, const vector &input_data_addrs) { + if (davinci_model->IsKnownNode() && fixed_addr_offset_.size() == STREAM_SWITCH_INPUT_NUM) { + input_ptr_ = davinci_model->GetCurrentFixedAddr(fixed_addr_offset_[0]); + value_ptr_ = davinci_model->GetCurrentFixedAddr(fixed_addr_offset_[1]); + } else { + if (!input_data_addrs.empty() && input_data_addrs.size() >= STREAM_SWITCH_INPUT_NUM) { + input_ptr_ = input_data_addrs[0]; + value_ptr_ = input_data_addrs[1]; + } + } +} REGISTER_TASK_INFO(RT_MODEL_TASK_STREAM_SWITCH, StreamSwitchTaskInfo); } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h b/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h index 07509c7c..e6e8339a 100644 --- a/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h @@ -39,13 +39,18 @@ class StreamSwitchTaskInfo : public TaskInfo { Status Distribute() override; + Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; + private: + void SetInputAndValuePtr(DavinciModel *davinci_model, const vector &input_data_addrs); void *input_ptr_; rtCondition_t cond_; void *value_ptr_; rtStream_t true_stream_; uint32_t true_stream_id_; rtSwitchDataType_t data_type_; + static const uint32_t kInputNum = 2; + vector fixed_addr_offset_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_STREAM_SWITCH_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc index 29b107bd..01371af7 100644 --- a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc @@ -24,18 +24,15 @@ namespace { const uint32_t kDynamicBtachParamNum = 1; const uint32_t kDynamicResolutionParamNum = 2; +const uint8_t kStreamSwitchnInputNum = 1; } // namespace namespace ge { Status StreamSwitchNTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { GELOGI("StreamSwitchNTaskInfo Init Start."); - if (davinci_model == nullptr) { - GELOGE(PARAM_INVALID, "davinci_model is null!"); - return PARAM_INVALID; - } + GE_CHECK_NOTNULL(davinci_model); - Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); - if (ret != SUCCESS) { + if (SetStream(task_def.stream_id(), davinci_model->GetStreamList()) != SUCCESS) { return FAILED; } @@ -75,14 +72,16 @@ Status StreamSwitchNTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel * GELOGE(FAILED, "Get true stream ptr of switchN op failed."); return FAILED; } - - // set input_ptr_ - auto input_data_addr = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); - if (input_data_addr.empty()) { - GELOGE(FAILED, "Input data addr is nullptr."); - return FAILED; + if (davinci_model->IsKnownNode()) { + input_ptr_ = davinci_model->GetCurrentFixedAddr(args_offset_); + } else { + auto input_data_addr = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); + if (input_data_addr.empty()) { + GELOGE(FAILED, "Input data addr is nullptr."); + return FAILED; + } + input_ptr_ = input_data_addr[0]; } - input_ptr_ = input_data_addr[0]; davinci_model->DisableZeroCopy(input_ptr_); GELOGI("StreamSwitchNTaskInfo Init Success, inputSize:%u, elementSize:%d, trueStreamID:%ld.", input_size_, element_size_, op_desc->GetStreamId()); @@ -140,5 +139,26 @@ Status StreamSwitchNTaskInfo::GetTrueStreamPtr(const OpDescPtr &op_desc, Davinci return SUCCESS; } +Status StreamSwitchNTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { + GE_CHECK_NOTNULL(davinci_model); + auto stream_switchn_def = task_def.stream_switch_n(); + uint32_t op_index = stream_switchn_def.op_index(); + GELOGI("Begin to calculate args, op_index is: %u", op_index); + auto op_desc = davinci_model->GetOpByIndex(op_index); + GE_CHECK_NOTNULL(op_desc); + GELOGI("Calc opType[%s] args size. Node name is [%s]", op_desc->GetType().c_str(), op_desc->GetName().c_str()); + if (op_desc->GetInputsSize() != kStreamSwitchnInputNum) { + GELOGE(FAILED, "Stream switchn op only have one data input. Now input size is %zu", op_desc->GetInputsSize()); + return FAILED; + } + string input_tensor_name = op_desc->GetInputNameByIndex(0); + args_offset_ = davinci_model->GetFixedAddrsSize(input_tensor_name); + auto tensor_desc = op_desc->GetInputDesc(0); + int64_t tensor_size = 0; + GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); + davinci_model->SetTotalFixedAddrsSize(input_tensor_name, tensor_size); + GELOGI("Calculate stream switchn task args , tensor_size %ld, args_offset %ld", tensor_size, args_offset_); + return SUCCESS; +} REGISTER_TASK_INFO(RT_MODEL_TASK_STREAM_SWITCH_N, StreamSwitchNTaskInfo); } // namespace ge \ No newline at end of file diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h index d1002da7..1a96243a 100644 --- a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h @@ -29,7 +29,8 @@ class StreamSwitchNTaskInfo : public TaskInfo { value_ptr_(nullptr), true_stream_ptr_(nullptr), element_size_(0), - data_type_(RT_SWITCH_INT64) {} + data_type_(RT_SWITCH_INT64), + args_offset_(0) {} ~StreamSwitchNTaskInfo() override {} @@ -37,6 +38,8 @@ class StreamSwitchNTaskInfo : public TaskInfo { Status Distribute() override; + Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; + private: Status GetTrueStreamPtr(const OpDescPtr &op_desc, DavinciModel *davinci_model); void *input_ptr_; @@ -47,6 +50,7 @@ class StreamSwitchNTaskInfo : public TaskInfo { rtSwitchDataType_t data_type_; vector true_stream_list_; vector value_list_; + int64_t args_offset_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_STREAM_SWITCHN_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h index 1c31acd1..b7e76af0 100644 --- a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h +++ b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h @@ -34,22 +34,13 @@ class SuperKernel { public: SuperKernel(const void *stub, void *ptr, uint64_t sz, uint32_t dim) : func_stub_(stub), dev_nav_table_(ptr), nav_table_size_(sz), block_dim_(dim) {} - ~SuperKernel() { - // free memory when all releasing - if (device_args_addr_ != nullptr) { - GE_CHK_RT(rtFree(device_args_addr_)); - GELOGI("SKT: super_kernel args addr free."); - } - if (dev_nav_table_ != nullptr) { - GE_CHK_RT(rtFree(dev_nav_table_)); - GELOGI("SKT: super_kernel args addr free."); - } - } + ~SuperKernel() = default; Status Launch(rtStream_t stream, uint32_t dump_flag); const void *GetFuncStub() const { return func_stub_; } - const void *GetNavTablePtr() const { return dev_nav_table_; } uint64_t GetNavTableSize() const { return nav_table_size_; } uint32_t GetBlockDim() const { return block_dim_; } + void *GetNavTablePtr() const { return dev_nav_table_; } + void *GetDeviceArgsPtr() const { return device_args_addr_; } }; } // namespace skt } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc index d2ad474a..397c7d98 100644 --- a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc +++ b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc @@ -42,21 +42,10 @@ Status SuperKernelFactory::Init() { rt_ret = rtGetAddrByFun(this->func_stub_, &this->func_ptr_); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); return FAILED;) - if (this->use_physical_address_ != nullptr) { - void *skt_func = nullptr; - rt_ret = rtKernelConfigTransArg(this->func_ptr_, sizeof(uint64_t), 0, &skt_func); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); - return FAILED;) - GELOGD( - "SKT: fuseKernels super_kernel_template subFunc %p, device func " - "address %p, device physic PC %p", - this->func_stub_, this->func_ptr_, skt_func); - } else { - GELOGD( - "SKT: fuseKernels super_kernel_template subFunc %p, device func " - "address %p", - this->func_stub_, this->func_ptr_); - } + GELOGD( + "SKT: fuseKernels super_kernel_template subFunc %p, device func " + "address %p", + this->func_stub_, this->func_ptr_); } is_init_ = true; @@ -71,7 +60,8 @@ Status SuperKernelFactory::Uninitialize() { } Status SuperKernelFactory::FuseKernels(const std::vector &stub_func_list, - const std::vector &args_addr_list, uint32_t block_dim, SuperKernel *&h) { + const std::vector &args_addr_list, uint32_t block_dim, + std::unique_ptr &h) { // Iterate through the ops to be fused // Each subkernel to be fused contains 2 fields: fn address offset, args // address. @@ -101,70 +91,28 @@ Status SuperKernelFactory::FuseKernels(const std::vector &stub_func_list rtError_t rt_ret; void *hbm_nav_table_addr = nullptr; - if (this->use_physical_address_ != nullptr) { - for (unsigned i = 0; i < stub_func_list.size(); i++) { - void *sub_device_func = nullptr; - rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); - return FAILED;) - void *sub_device_func_pys = nullptr; - void *args_addr_pys = nullptr; - rt_ret = rtKernelConfigTransArg(sub_device_func, sizeof(uint64_t), 0, &sub_device_func_pys); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); - return FAILED;) - rt_ret = rtKernelConfigTransArg(args_addr_list[i], sizeof(uint64_t), 0, &args_addr_pys); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); - return FAILED;) - GELOGD( - "SKT: fuseKernels subFunc %p, device func address %p, device " - "physic func address %p", - stub_func_list[i], sub_device_func, sub_device_func_pys); - // store two uint64_t address - // address divided by 4 because of 32bits encoding, call offset will *4 when calculating - nav_table[i * 2] = reinterpret_cast(reinterpret_cast(sub_device_func_pys)) / 4; - GELOGD("SKT: CALL offset %lu", nav_table[i * 2]); - nav_table[i * 2 + 1] = reinterpret_cast(reinterpret_cast(args_addr_pys)); - - GELOGD("SKT: fuseKernels args base address %lu", nav_table[i * 2 + 1]); - } - - void *hbm_nav_table_addr_pys = nullptr; - rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failed. error: 0x%X", rt_ret); return FAILED;) - rt_ret = - rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failed. error: 0x%X", rt_ret); - GE_CHK_RT(rtFree(hbm_nav_table_addr)); return FAILED;) - rt_ret = rtKernelConfigTransArg(hbm_nav_table_addr, sizeof(uint64_t), 0, &hbm_nav_table_addr_pys); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); - GE_CHK_RT(rtFree(hbm_nav_table_addr)); return FAILED;) - - GELOGD("SKT: hbm_nav_table_addr %p, hbm_nav_table_addr_pys %p", hbm_nav_table_addr, hbm_nav_table_addr_pys); - // Create the necessary metadata for the super kernel - h = new SuperKernel(this->func_stub_, hbm_nav_table_addr_pys, nav_table_size, block_dim); - } else { - for (unsigned i = 0; i < stub_func_list.size(); i++) { - void *sub_device_func = nullptr; - rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); - return FAILED;) - GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], sub_device_func); - // store two uint64_t address - // address divided by 4 because of 32bits encoding, call offset will *4 when calculating - nav_table[i * 2] = reinterpret_cast(reinterpret_cast(sub_device_func)) / 4; - GELOGD("SKT: CALL offet %lu", nav_table[i * 2]); - nav_table[i * 2 + 1] = reinterpret_cast(reinterpret_cast(args_addr_list[i])); - GELOGD("SKT: fuseKernels args base address %lu", nav_table[i * 2 + 1]); - } - rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failed. error: 0x%X", rt_ret); return FAILED;) - rt_ret = - rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failed. error: 0x%X", rt_ret); - GE_CHK_RT(rtFree(hbm_nav_table_addr)); return FAILED;) - // Create the necessary metadata for the super kernel - h = new SuperKernel(this->func_stub_, hbm_nav_table_addr, nav_table_size, block_dim); + for (unsigned i = 0; i < stub_func_list.size(); i++) { + void *sub_device_func = nullptr; + rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); + return FAILED;) + GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], sub_device_func); + // store two uint64_t address + // address divided by 4 because of 32bits encoding, call offset will *4 when calculating + nav_table[i * 2] = reinterpret_cast(reinterpret_cast(sub_device_func)) / 4; + GELOGD("SKT: CALL offet %lu", nav_table[i * 2]); + nav_table[i * 2 + 1] = reinterpret_cast(reinterpret_cast(args_addr_list[i])); + GELOGD("SKT: fuseKernels args base address %lu", nav_table[i * 2 + 1]); } + rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failed. error: 0x%X", rt_ret); return FAILED;) + rt_ret = + rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failed. error: 0x%X", rt_ret); + GE_CHK_RT(rtFree(hbm_nav_table_addr)); return FAILED;) + // Create the necessary metadata for the super kernel + h = + std::unique_ptr(new SuperKernel(this->func_stub_, hbm_nav_table_addr, nav_table_size, block_dim)); return SUCCESS; } } // namespace skt diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h index d8b7ff26..7db44eec 100644 --- a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h +++ b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h @@ -29,7 +29,6 @@ class SuperKernelFactory { void *func_ptr_ = nullptr; void *handle_ = nullptr; std::string sk_stub_name_ = "_Z21super_kernel_templatePmm"; - const char *use_physical_address_ = getenv("GE_USE_PHYSICAL_ADDRESS"); bool is_init_ = false; SuperKernelFactory(){}; ~SuperKernelFactory() { @@ -48,7 +47,7 @@ class SuperKernelFactory { Status Init(); Status Uninitialize(); Status FuseKernels(const std::vector &stub_func_list, const std::vector &args_addr_list, - uint32_t block_dim, SuperKernel *&h); + uint32_t block_dim, std::unique_ptr &h); }; } // namespace skt } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/task_info.h b/src/ge/graph/load/new_model_manager/task_info/task_info.h index 5d2c89eb..f69511e6 100644 --- a/src/ge/graph/load/new_model_manager/task_info/task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/task_info.h @@ -72,6 +72,8 @@ class TaskInfo { virtual uint32_t GetTaskID() { return 0xFFFFFFFF; } + virtual bool CallSaveDumpInfo() { return false; } + virtual uint32_t GetStreamId() { return 0xFFFFFFFF; } virtual uintptr_t GetDumpArgs() { return 0; } diff --git a/src/ge/graph/load/new_model_manager/task_info/task_info_factory.h b/src/ge/graph/load/new_model_manager/task_info/task_info_factory.h index b6954016..5b220960 100644 --- a/src/ge/graph/load/new_model_manager/task_info/task_info_factory.h +++ b/src/ge/graph/load/new_model_manager/task_info/task_info_factory.h @@ -86,5 +86,5 @@ class TaskInfoFactory { return ptr; \ } \ TaskInfoFactory::Registerar g_##type##_Task_Info_Creator(type, Creator_##type##_Task_Info); -}; // namespace ge +} // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_TASK_INFO_FACTORY_H_ diff --git a/src/ge/graph/load/new_model_manager/zero_copy_task.cc b/src/ge/graph/load/new_model_manager/zero_copy_task.cc index 42734a87..be75322d 100644 --- a/src/ge/graph/load/new_model_manager/zero_copy_task.cc +++ b/src/ge/graph/load/new_model_manager/zero_copy_task.cc @@ -129,12 +129,6 @@ Status ZeroCopyTask::UpdateTaskParam(uintptr_t addr, const DataBuffer &data, } auto dst_addr = static_cast(data.data); - auto dst_size = static_cast(data.length); - if (ModelUtils::ConvertVirtualAddressToPhysical(dst_addr, dst_size, dst_addr) != SUCCESS) { - GELOGE(FAILED, "[ZCPY] Convert virtual address to physical for dst_addr failed."); - return FAILED; - } - GELOGI("[ZCPY] %s update task, args: %p, size: %zu, offset: %zu, addr: 0x%lx, length: %u", name_.c_str(), args_addr_, args_size_, offset, addr, data.length); *(uintptr_t *)(args_info + offset) = reinterpret_cast(dst_addr); diff --git a/src/ge/graph/load/output/output.cc b/src/ge/graph/load/output/output.cc deleted file mode 100644 index d922ce7c..00000000 --- a/src/ge/graph/load/output/output.cc +++ /dev/null @@ -1,175 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/load/output/output.h" - -#include - -#include "common/properties_manager.h" -#include "graph/load/new_model_manager/davinci_model.h" -#include "graph/manager/graph_var_manager.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" - -namespace ge { -Output::Output(const OpDescPtr &op_desc, DavinciModel *model) - : base_(nullptr), - var_base_(nullptr), - logic_base_(0), - logic_var_base_(0), - model_(model), - op_desc_(op_desc), - input_num_(0) {} - -Output::~Output() { - var_base_ = nullptr; - base_ = nullptr; - model_ = nullptr; -} - -/// -/// @ingroup domi -/// @brief Initialize input/output params -/// @return Status -/// -Status Output::Init() { - if (op_desc_ == nullptr || model_ == nullptr) { - GELOGE(INTERNAL_ERROR, "The op_desc_ or model_ is nullptr."); - return INTERNAL_ERROR; - } - - base_ = model_->MemBase(); - var_base_ = model_->VarMemBase(); - logic_base_ = model_->GetRtBaseAddr(); - logic_var_base_ = model_->GetRtVarAddr(); - - input_num_ = op_desc_->GetInputsSize(); - v_input_size_.clear(); - v_input_data_addr_.clear(); - - auto input_vector = op_desc_->GetInputOffset(); - if (input_num_ != input_vector.size()) { - GELOGE(INTERNAL_ERROR, "input desc size: %zu != input offset size: %zu.", input_num_, input_vector.size()); - return INTERNAL_ERROR; - } - - for (size_t i = 0; i < input_num_; i++) { - int64_t tensor_size = 0; - auto input_desc = op_desc_->GetInputDescPtr(i); - GE_CHECK_NOTNULL(input_desc); - Status ret = TensorUtils::GetSize(*input_desc, tensor_size); - if (ret != GRAPH_SUCCESS) { - GELOGE(ret, "Get size from TensorDesc failed, op : %s, input index : %zu", op_desc_->GetName().c_str(), i); - return ret; - } - v_input_size_.push_back(tensor_size); - - if (VarManager::Instance(model_->SessionId())->IsVarAddr(input_vector[i])) { - v_input_data_addr_.push_back(static_cast(var_base_ + input_vector[i] - logic_var_base_)); - } else { - v_input_data_addr_.push_back(static_cast(base_ + input_vector[i])); - } - } - - GELOGI("Init output:%lu, %lu, %lu", input_num_, v_input_size_.size(), v_input_data_addr_.size()); - - return SUCCESS; -} - -/// -/// @ingroup domi -/// @brief Copy Op Output to user space. -/// @brief when model running, Add one DataOp as input node, Add one Output Op as output node. -/// @return Status -/// -Status Output::CopyResult(OutputData &rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share) { - uint32_t data_count = 0; - if (input_num_ > rslt.blobs.size() - data_begin) { - GELOGE(FAILED, "Tensor num %zu, data_buf num: %zu.", input_num_, rslt.blobs.size() - data_begin); - return FAILED; - } else if (input_num_ < rslt.blobs.size() - data_begin) { - GELOGW("Tensor num %zu, data_buf num: %zu.", input_num_, rslt.blobs.size() - data_begin); - } - - for (size_t i = 0; i < input_num_; i++) { - DataBuffer data_buf = rslt.blobs[data_begin + data_count]; - Status ret = SetDataBuf(data_buf, data_count, i, support_mem_share); - if (ret != SUCCESS) { - GELOGE(ret, "Copy data to host error. index: %zu", i); - return ret; - } - data_index = data_begin + data_count; - } - - return SUCCESS; -} - -Status Output::SetDataBuf(DataBuffer &data_buf, uint32_t &data_count, size_t i, bool support_mem_share) { - if (data_buf.length == 0) { - ++data_count; - GELOGD("Length of data_buffer is zero, No need to copy. output op : %s, output tensor index : %zu!", - op_desc_->GetName().c_str(), i); - return SUCCESS; - } - - auto tensor_desc = op_desc_->GetInputDescPtr(static_cast(i)); - if (tensor_desc == nullptr) { - GELOGE(FAILED, "tensor_desc is null"); - return FAILED; - } - - if (data_buf.isDataSupportMemShare && support_mem_share) { - GELOGI("No need to copy input data, user's output data buffer can be shared."); - } else { - // Copy result to Databuf - int64_t size = v_input_size_[i]; - GELOGI("Tensor data size before: %ld", size); - - graphStatus graph_status = TensorUtils::GetTensorSizeInBytes(*tensor_desc, size); - if (graph_status != ge::GRAPH_SUCCESS) { - GELOGE(graph_status, "GetTensorSizeInBytes failed!"); - return FAILED; - } - - if (data_buf.length < size) { - GELOGE(FAILED, "Tensor data size: %ld data_buf length: %ld", size, data_buf.length); - return FAILED; - } else if (data_buf.length > size) { - GELOGW("Tensor data size: %ld data_buf length: %ld", size, data_buf.length); - } - - rtError_t rt_ret = rtMemcpy(data_buf.data, size, v_input_data_addr_[i], size, RT_MEMCPY_DEVICE_TO_HOST); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "rtmemcpy error"); - return FAILED; - } - GELOGI("Tensor data size: %ld data_buf length: %ld", size, data_buf.length); - } - - ++data_count; - GELOGD("Successfully copy the output tensor memory to buffer, output op : %s, output tensor index : %zu!", - op_desc_->GetName().c_str(), i); - - return SUCCESS; -} - -void Output::GetOutputData(vector &v_data_addr, vector &v_data_size) { - for (size_t i = 0; i < input_num_; ++i) { - v_data_addr.push_back(v_input_data_addr_[i]); - v_data_size.push_back(v_input_size_[i]); - } -} -} // namespace ge diff --git a/src/ge/graph/load/output/output.h b/src/ge/graph/load/output/output.h deleted file mode 100644 index d93b8de9..00000000 --- a/src/ge/graph/load/output/output.h +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GRAPH_LOAD_OUTPUT_OUTPUT_H_ -#define GE_GRAPH_LOAD_OUTPUT_OUTPUT_H_ - -#include -#include - -#include "common/debug/log.h" -#include "common/op/attr_value_util.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" -#include "common/util.h" -#include "common/ge_types.h" -#include "graph/load/new_model_manager/davinci_model.h" -#include "graph/op_desc.h" -#include "graph/debug/ge_attr_define.h" - -namespace ge { -using std::string; -using std::vector; - -// The base class for all op -class Output { - public: - Output(const OpDescPtr &op_desc, DavinciModel *model); - virtual ~Output(); - - /// - /// @ingroup domi - /// @brief Initialize input/output params - /// @return Status - /// - virtual Status Init(); - - /// - /// @ingroup domi - /// @brief Copy Op Output to user space. - /// @brief when model running, Add one DataOp as input node, Add one Output Op as output node. - /// @return Status - /// - virtual Status CopyResult(OutputData &rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share); - - /// - /// @ingroup domi - /// @brief Trans Output data to fp16 - /// @return Status - /// - Status SetDataBuf(DataBuffer &data_buf, uint32_t &data_count, size_t i, bool support_mem_share); - - /// - /// @ingroup domi - /// @brief Get Output data and size. - /// @return void - /// - void GetOutputData(vector &v_data_addr, vector &v_data_size); - - // Copy assignment operator and copy constructor are deleted - Output &operator=(const Output &output) = delete; - Output(const Output &output) = delete; - - protected: - // Model's base address - uint8_t *base_; - uint8_t *var_base_; - uint64_t logic_base_; - uint64_t logic_var_base_; - // The DavinciModel which ops belong to - DavinciModel *model_; - - ConstOpDescPtr op_desc_; - - // Input descriptions - size_t input_num_; - vector v_input_data_addr_; // init as:buf_base + op_def_->input(i)); - vector v_input_size_; -}; -} // namespace ge - -#endif // GE_GRAPH_LOAD_OUTPUT_OUTPUT_H_ diff --git a/src/ge/graph/manager/graph_caching_allocator.cc b/src/ge/graph/manager/graph_caching_allocator.cc index 5df6769b..cbeafa3f 100644 --- a/src/ge/graph/manager/graph_caching_allocator.cc +++ b/src/ge/graph/manager/graph_caching_allocator.cc @@ -34,9 +34,6 @@ const size_t bin_ranges[kNumBins] = {kRoundBlockSize * kKByteSize, 26 * kGByteSize}; static bool BlockComparator(const Block *left, const Block *right) { - if (left->device_id != right->device_id) { - return left->device_id < right->device_id; - } if (left->size != right->size) { return left->size < right->size; } @@ -267,20 +264,20 @@ Status CachingAllocator::TryExtendCache(size_t size, uint32_t device_id) { return ge::FAILED; } } - if (AddToBlockBin(memory_addr, memory_size) != ge::SUCCESS) { + if (AddToBlockBin(memory_addr, memory_size, device_id) != ge::SUCCESS) { (void)memory_allocator_->FreeMemory(memory_addr); return ge::FAILED; } return ge::SUCCESS; } -Status CachingAllocator::AddToBlockBin(uint8_t *ptr, size_t size) { +Status CachingAllocator::AddToBlockBin(uint8_t *ptr, size_t size, uint32_t device_id) { BlockBin *bin = GetBlockBin(size); if (bin == nullptr) { GELOGE(ge::FAILED, "Get block bin failed size = %zu", size); return ge::FAILED; } - Block *block = new (std::nothrow) Block(0, size, bin, nullptr); + Block *block = new (std::nothrow) Block(device_id, size, bin, nullptr); if (block == nullptr) { GELOGE(ge::FAILED, "Alloc block failed size = %zu", size); return ge::FAILED; @@ -339,5 +336,4 @@ void CachingAllocator::FreeBlockBins() { } } } - } // namespace ge diff --git a/src/ge/graph/manager/graph_caching_allocator.h b/src/ge/graph/manager/graph_caching_allocator.h index 75864ce7..94a5066a 100644 --- a/src/ge/graph/manager/graph_caching_allocator.h +++ b/src/ge/graph/manager/graph_caching_allocator.h @@ -32,7 +32,6 @@ #include "runtime/mem.h" namespace ge { - constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes constexpr double kSplitThreshold = 0.75; // split when malloc size <= small block size * kSpliThreshold constexpr size_t kKByteSize = 1024; @@ -69,6 +68,10 @@ class CachingAllocator { public: explicit CachingAllocator(rtMemType_t memory_type); + CachingAllocator(const CachingAllocator &) = delete; + + CachingAllocator &operator=(const CachingAllocator &) = delete; + virtual ~CachingAllocator() = default; /// @@ -137,9 +140,10 @@ class CachingAllocator { /// @brief add memory to right bin based on size /// @param [in] memory ptr /// @param [in] memory size + /// @param [in] device_id device id /// @return Status result of function /// - Status AddToBlockBin(uint8_t *ptr, size_t size); + Status AddToBlockBin(uint8_t *ptr, size_t size, uint32_t device_id); /// /// @ingroup ge_graph @@ -206,7 +210,5 @@ class CachingAllocator { // block bins by different block size BlockBin *free_block_bins_[kNumBins]; }; - -}; // namespace ge - +} // namespace ge #endif // GE_GRAPH_MANAGER_GRAPH_CACHING_ALLOCATOR_H_ diff --git a/src/ge/graph/manager/graph_manager.cc b/src/ge/graph/manager/graph_manager.cc index dd4855b6..bfd09c72 100644 --- a/src/ge/graph/manager/graph_manager.cc +++ b/src/ge/graph/manager/graph_manager.cc @@ -57,7 +57,6 @@ #include "graph/passes/flow_ctrl_pass.h" #include "graph/passes/hccl_group_pass.h" #include "graph/passes/hccl_memcpy_pass.h" -#include "graph/passes/identify_reference_pass.h" #include "graph/passes/identity_pass.h" #include "graph/passes/iterator_op_pass.h" #include "graph/passes/link_gen_mask_nodes_pass.h" @@ -74,7 +73,9 @@ #include "graph/passes/switch_data_edges_bypass.h" #include "graph/passes/switch_dead_branch_elimination.h" #include "graph/passes/switch_logic_remove_pass.h" -#include "graph/passes/switch_op_pass.h" +#include "graph/passes/merge_to_stream_merge_pass.h" +#include "graph/passes/switch_to_stream_switch_pass.h" +#include "graph/passes/attach_stream_label_pass.h" #include "graph/passes/transop_breadth_fusion_pass.h" #include "graph/passes/transop_depth_fusion_pass.h" #include "graph/passes/transop_nearby_allreduce_fusion_pass.h" @@ -83,6 +84,7 @@ #include "graph/passes/transpose_transdata_pass.h" #include "graph/passes/variable_op_pass.h" #include "graph/passes/variable_prepare_op_pass.h" +#include "graph/passes/ref_identity_delete_op_pass.h" #include "graph/passes/variable_ref_delete_op_pass.h" #include "graph/passes/variable_ref_useless_control_out_delete_pass.h" #include "graph/utils/tensor_adapter.h" @@ -347,12 +349,13 @@ Status GraphManager::SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_gr return SUCCESS; } -#define GM_RUN_AND_DUMP(name, func, ...) \ +#define GM_RUN_AND_DUMP_PERF(name, func, ...) \ do { \ - GE_RUN(GraphManager, func, __VA_ARGS__); \ + GE_RUN_PERF(GraphManager, func, __VA_ARGS__); \ GE_DUMP(compute_graph, "PreRunAfter" name); \ GELOGI("Run %s on graph %s(%u) success.", name, compute_graph->GetName().c_str(), graph_node->GetGraphId()); \ } while (0) + Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector &inputs, GeRootModelPtr &ge_root_model, uint64_t session_id) { GE_CHECK_NOTNULL(graph_node); @@ -365,30 +368,30 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vectorGetName().c_str()); GE_DUMP(compute_graph, "PreRunBegin"); - GM_RUN_AND_DUMP("OptimizeGraphPrepare", graph_optimize_.OptimizeOriginalGraphForQuantize, compute_graph); - GM_RUN_AND_DUMP("HandleSummaryOp", graph_optimize_.HandleSummaryOp, compute_graph); - GM_RUN_AND_DUMP("Prepare", graph_preparer_.PrepareDynShape, graph_node->GetGraph(), inputs, compute_graph, - session_id); - GM_RUN_AND_DUMP("OptimizeOriginalGraph", graph_optimize_.OptimizeOriginalGraph, compute_graph); + GM_RUN_AND_DUMP_PERF("OptimizeGraphPrepare", graph_optimize_.OptimizeOriginalGraphForQuantize, compute_graph); + GM_RUN_AND_DUMP_PERF("HandleSummaryOp", graph_optimize_.HandleSummaryOp, compute_graph); + GM_RUN_AND_DUMP_PERF("Prepare", graph_preparer_.PrepareDynShape, graph_node->GetGraph(), inputs, compute_graph, + session_id); + GM_RUN_AND_DUMP_PERF("OptimizeOriginalGraph", graph_optimize_.OptimizeOriginalGraph, compute_graph); - GM_RUN_AND_DUMP("PrepareRunningFormatRefiner", graph_preparer_.PrepareRunningFormatRefiner); - GM_RUN_AND_DUMP("RefineRunningFormat", graph_optimize_.OptimizeOriginalGraphJudgeInsert, compute_graph); + GM_RUN_AND_DUMP_PERF("PrepareRunningFormatRefiner", graph_preparer_.PrepareRunningFormatRefiner); + GM_RUN_AND_DUMP_PERF("RefineRunningFormat", graph_optimize_.OptimizeOriginalGraphJudgeInsert, compute_graph); GE_RUN(GraphManager, graph_preparer_.RecordAIPPInfo, compute_graph); if (IsTailingOptimization()) { - GM_RUN_AND_DUMP("OptimizeSwitchOp", graph_preparer_.SwitchOpOptimize, compute_graph); + GM_RUN_AND_DUMP_PERF("OptimizeSwitchOp", graph_preparer_.SwitchOpOptimize, compute_graph); } - GM_RUN_AND_DUMP("Optimize1", OptimizeStage1, compute_graph); - GM_RUN_AND_DUMP("InferShape2", compute_graph->InferShapeInNeed); + GM_RUN_AND_DUMP_PERF("Optimize1", OptimizeStage1, compute_graph); + GM_RUN_AND_DUMP_PERF("InferShape2", compute_graph->InferShapeInNeed); const char *unknown_shape_skip = std::getenv("EXPERIMENTAL_DYNAMIC_PARTITION"); if (unknown_shape_skip != nullptr) { PassManager graph_pass; GE_CHK_STATUS_RET(graph_pass.AddPass("PreRun::CtrlEdgeTransferPass", new (std::nothrow) CtrlEdgeTransferPass)) GE_CHK_STATUS_RET(graph_pass.Run(compute_graph)); } - - GM_RUN_AND_DUMP("OptimizeSubgraph", OptimizeSubgraph, graph_node, compute_graph, session_id); - GM_RUN_AND_DUMP("Optimize2", OptimizeStage2, compute_graph); - GM_RUN_AND_DUMP("Build", Build, graph_node, compute_graph, ge_root_model, session_id); + GE_CHK_STATUS_RET(graph_optimize_.IdentifyReference(compute_graph), "Identify reference failed."); + GM_RUN_AND_DUMP_PERF("OptimizeSubgraph", OptimizeSubgraph, graph_node, compute_graph, session_id); + GM_RUN_AND_DUMP_PERF("Optimize2", OptimizeStage2, compute_graph); + GM_RUN_AND_DUMP_PERF("Build", Build, graph_node, compute_graph, ge_root_model, session_id); // when set incre build, save om model and var manager GeModelPtr ge_model = nullptr; @@ -397,7 +400,7 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vectorSetRunFlag(false); @@ -634,7 +637,7 @@ Status GraphManager::RunGraph(const GraphId &graph_id, const std::vectorgraph_run_async_listener_); Status ret = GraphLoader::LoadModelOnline(model_id_info.model_id, ge_root_model, graph_node->graph_run_async_listener_); - GE_TIMESTAMP_END(LoadGraph, "GraphManager::LoadGraphAsync"); + GE_TIMESTAMP_EVENT_END(LoadGraph, "GraphManager::LoadGraphAsync"); if (ret != SUCCESS) { GELOGE(ret, "[LoadGraphAsync] LoadGraphAsync Failed"); graph_node->SetRunFlag(false); @@ -2309,21 +2331,21 @@ Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGra GELOGE(FAILED, "failed get dynamic shape partitioned flag on partitioned graph."); return FAILED; } - GE_TIMESTAMP_END(GraphPartitionDynamicShape, "OptimizeSubgraph::GraphPartitionDynamicShape"); + GE_TIMESTAMP_EVENT_END(GraphPartitionDynamicShape, "OptimizeSubgraph::GraphPartitionDynamicShape"); GE_TIMESTAMP_START(GraphPartition); ret = graph_partitioner_.Partition(compute_graph, GraphPartitioner::kPartitioning); if (ret != SUCCESS) { GELOGE(ret, "Graph partition Failed"); return ret; } - GE_TIMESTAMP_END(GraphPartition, "OptimizeSubgraph::Partition1"); + GE_TIMESTAMP_EVENT_END(GraphPartition, "OptimizeSubgraph::Partition1"); GE_TIMESTAMP_START(SetSubgraph); ret = SetSubgraph(session_id, compute_graph); if (ret != SUCCESS) { GELOGE(ret, "Graph set subgraph Failed"); return ret; } - GE_TIMESTAMP_END(SetSubgraph, "OptimizeSubgraph::SetSubGraph"); + GE_TIMESTAMP_EVENT_END(SetSubgraph, "OptimizeSubgraph::SetSubGraph"); ComputeGraphPtr merged_compute_graph = nullptr; std::vector merged_sub_graph_list; @@ -2342,7 +2364,7 @@ Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGra sub_graph->SetSessionID(session_id); sub_graph->SetGraphID(graph_node->GetGraphId()); } - GE_TIMESTAMP_END(MergeSubgraph, "OptimizeSubgraph::MergeSubGraph"); + GE_TIMESTAMP_EVENT_END(MergeSubgraph, "OptimizeSubgraph::MergeSubGraph"); GE_DUMP(merged_compute_graph, "mergedComputeGraph"); compute_graph = merged_compute_graph; if (!AttrUtils::SetBool(*compute_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, dynamic_shape_partitioned)) { @@ -2368,8 +2390,7 @@ Status GraphManager::Build(const GraphNodePtr &graph_node, ComputeGraphPtr &comp } bool is_always_dump = false; - PropertiesManager &properties_manager = PropertiesManager::Instance(); - if (!properties_manager.GetDumpOutputPath().empty()) { + if (!PropertiesManager::Instance().GetDumpProperties(session_id).GetDumpPath().empty()) { is_always_dump = true; } diff --git a/src/ge/graph/manager/graph_manager.h b/src/ge/graph/manager/graph_manager.h index 8ab28316..fd9542e8 100644 --- a/src/ge/graph/manager/graph_manager.h +++ b/src/ge/graph/manager/graph_manager.h @@ -327,6 +327,6 @@ class GraphManager { std::mutex run_mutex_; }; -}; // namespace ge +} // namespace ge #endif // GE_GRAPH_MANAGER_GRAPH_MANAGER_H_ diff --git a/src/ge/graph/manager/graph_mem_allocator.h b/src/ge/graph/manager/graph_mem_allocator.h index 7bf82897..e4eeded3 100644 --- a/src/ge/graph/manager/graph_mem_allocator.h +++ b/src/ge/graph/manager/graph_mem_allocator.h @@ -190,6 +190,6 @@ class MemManager { std::map caching_allocator_map_; std::recursive_mutex allocator_mutex_; }; -}; // namespace ge +} // namespace ge #endif // GE_GRAPH_MANAGER_GRAPH_MEM_ALLOCATOR_H_ diff --git a/src/ge/graph/manager/graph_var_manager.cc b/src/ge/graph/manager/graph_var_manager.cc index 2982eb89..7ca0224b 100644 --- a/src/ge/graph/manager/graph_var_manager.cc +++ b/src/ge/graph/manager/graph_var_manager.cc @@ -91,7 +91,7 @@ ge::Status VarResource::SaveVarAddr(const std::string &var_name, const ge::GeTen std::string var_key = VarKey(var_name, tensor_desc); GELOGD("VarResource::SaveVarAddr, var_key = %s", var_key.c_str()); if (var_addr_mgr_map_.count(var_key) == 0) { - uint64_t logic_address = VarManager::Instance(0)->GetVarMemLogicBase() + + uint64_t logic_address = VarManager::Instance(session_id_)->GetVarMemLogicBase() + reinterpret_cast(reinterpret_cast(address)); GELOGI("SaveVarAddr node_name %s, tensor_desc format %s, type %s.", var_name.c_str(), TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(), @@ -274,7 +274,7 @@ MemResource::MemResource() : total_size_(0), var_mem_size_(0) {} Status MemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &mem_offset) { size = (size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize; uint64_t real_size = size; - total_size_ = VarManager::Instance(0)->GetVarMemMaxSize(); + total_size_ = VarManager::Instance(session_id)->GetVarMemMaxSize(); if (total_size_ < var_mem_size_) { GELOGE(PARAM_INVALID, "total_size_: %lu is smaller than var_mem_size_: %lu", total_size_, var_mem_size_); return PARAM_INVALID; @@ -684,7 +684,8 @@ uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_ty if (mem_base == nullptr) { return nullptr; } - uint8_t *mem_addr = logic_addr + reinterpret_cast(mem_base) - VarManager::Instance(0)->GetVarMemLogicBase(); + uint8_t *mem_addr = + logic_addr + reinterpret_cast(mem_base) - VarManager::Instance(session_id_)->GetVarMemLogicBase(); return mem_addr; } diff --git a/src/ge/graph/manager/graph_var_manager.h b/src/ge/graph/manager/graph_var_manager.h index be839eee..2142d906 100644 --- a/src/ge/graph/manager/graph_var_manager.h +++ b/src/ge/graph/manager/graph_var_manager.h @@ -309,5 +309,5 @@ class VarManagerPool { std::mutex var_manager_mutex_; map var_manager_map_; }; -}; // namespace ge +} // namespace ge #endif // GE_GRAPH_MANAGER_GRAPH_VAR_MANAGER_H_ diff --git a/src/ge/graph/manager/model_manager/event_manager.h b/src/ge/graph/manager/model_manager/event_manager.h index 1d57dd52..a20afead 100644 --- a/src/ge/graph/manager/model_manager/event_manager.h +++ b/src/ge/graph/manager/model_manager/event_manager.h @@ -17,7 +17,6 @@ #ifndef GE_GRAPH_MANAGER_MODEL_MANAGER_EVENT_MANAGER_H_ #define GE_GRAPH_MANAGER_MODEL_MANAGER_EVENT_MANAGER_H_ - #include #include "common/fmk_error_codes.h" @@ -94,5 +93,5 @@ class EventManager { bool inited_; uint32_t current_idx_; }; // EventManager -}; // namespace ge +} // namespace ge #endif // GE_GRAPH_MANAGER_MODEL_MANAGER_EVENT_MANAGER_H_ diff --git a/src/ge/graph/manager/trans_var_data_utils.cc b/src/ge/graph/manager/trans_var_data_utils.cc index e8444c53..3f346c91 100644 --- a/src/ge/graph/manager/trans_var_data_utils.cc +++ b/src/ge/graph/manager/trans_var_data_utils.cc @@ -397,10 +397,11 @@ Status TransVarDataUtils::SyncTensorToHost(const string &var_name, const ge::GeT uint8_t *src_addr = nullptr; GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, src_tensor_desc, &src_addr)); - uint8_t *mem_addr = src_addr - - static_cast(reinterpret_cast(VarManager::Instance(0)->GetVarMemLogicBase())) + - static_cast( - reinterpret_cast(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); + uint8_t *mem_addr = + src_addr - + static_cast(reinterpret_cast(VarManager::Instance(session_id)->GetVarMemLogicBase())) + + static_cast( + reinterpret_cast(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); GE_CHK_RT_RET(rtMallocHost(reinterpret_cast(host_addr), src_tensor_size)); GE_CHK_RT_RET(rtMemcpy(*host_addr, src_tensor_size, mem_addr, src_tensor_size, RT_MEMCPY_DEVICE_TO_HOST)); @@ -413,10 +414,11 @@ Status TransVarDataUtils::SyncTensorToDevice(const string &var_name, const uint8 const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) { uint8_t *dst_addr = nullptr; GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, dst_tensor_desc, &dst_addr)); - uint8_t *mem_addr = dst_addr - - static_cast(reinterpret_cast(VarManager::Instance(0)->GetVarMemLogicBase())) + - static_cast( - reinterpret_cast(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); + uint8_t *mem_addr = + dst_addr - + static_cast(reinterpret_cast(VarManager::Instance(session_id)->GetVarMemLogicBase())) + + static_cast( + reinterpret_cast(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); GE_CHK_RT_RET(rtMemcpy(mem_addr, addr_size, host_addr, addr_size, RT_MEMCPY_HOST_TO_DEVICE)); GELOGI("SyncTensorToDevice var_name %s, addr_size %u", var_name.c_str(), addr_size); diff --git a/src/ge/graph/manager/util/hcom_util.cc b/src/ge/graph/manager/util/hcom_util.cc index 4f6fe591..5f31c982 100644 --- a/src/ge/graph/manager/util/hcom_util.cc +++ b/src/ge/graph/manager/util/hcom_util.cc @@ -24,7 +24,6 @@ #include "graph/utils/type_utils.h" namespace ge { - Status HcomOmeUtil::GetHcclDataType(const ge::ConstOpDescPtr &op_desc, std::vector &kernel_hccl_infos) { GE_CHECK_NOTNULL(op_desc); @@ -101,6 +100,12 @@ Status HcomOmeUtil::GetHcomCount(const ge::ConstOpDescPtr &op_desc, hcclDataType GE_CHECK_NOTNULL(op_desc->GetInputDescPtr(i)); GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetInputDescPtr(i), input_size), "get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); + // dynamic shape hccl op get size from output tensor desc + if (op_desc->HasAttr(ATTR_NAME_IS_UNKNOWN_SHAPE)) { + GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(i)); + GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetOutputDescPtr(i), input_size), + "get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); + } GE_IF_BOOL_EXEC( op_desc->GetType() == HCOMREDUCESCATTER, int32_t rank_size = 0; @@ -114,6 +119,8 @@ Status HcomOmeUtil::GetHcomCount(const ge::ConstOpDescPtr &op_desc, hcclDataType total_size = total_size + block_size; continue;); int64_t shape_size = op_desc->GetInputDescPtr(i)->GetShape().GetShapeSize(); + GELOGD("hcom util node %s inputsize %ld, shapesize %ld, datasize %d.", op_desc->GetName().c_str(), input_size, + shape_size, size); GE_CHK_STATUS_RET(ge::CheckInt64Int32MulOverflow(shape_size, size), "Product of shape size and size beyond INT64_MAX"); GE_IF_BOOL_EXEC(is_allgather, block_size = shape_size * size;); diff --git a/src/ge/graph/manager/util/hcom_util.h b/src/ge/graph/manager/util/hcom_util.h index 40aac3e5..e31e3ef0 100644 --- a/src/ge/graph/manager/util/hcom_util.h +++ b/src/ge/graph/manager/util/hcom_util.h @@ -144,8 +144,6 @@ class HcomOmeUtil { /// static Status GetHorovodInputs(const ge::ConstOpDescPtr &op_desc, std::vector &kernel_hccl_infos); - - private: /// /// @ingroup domi_ome /// @brief GetHcomCount @@ -154,6 +152,8 @@ class HcomOmeUtil { /// static Status GetHcomCount(const ge::ConstOpDescPtr &op_desc, hcclDataType_t data_type, bool is_allgather, int &count); + + private: /// /// @ingroup domi_ome /// @brief GetHorovodCount diff --git a/src/ge/graph/manager/util/rt_context_util.cc b/src/ge/graph/manager/util/rt_context_util.cc index 05120f6a..e6344539 100644 --- a/src/ge/graph/manager/util/rt_context_util.cc +++ b/src/ge/graph/manager/util/rt_context_util.cc @@ -19,13 +19,30 @@ #include "framework/common/debug/ge_log.h" namespace ge { -void RtContextUtil::AddrtContext(rtContext_t context) { rtContexts_.emplace_back(context); } +void RtContextUtil::AddRtContext(uint64_t session_id, rtContext_t context) { + std::lock_guard lock(ctx_mutex_); + rt_contexts_[session_id].emplace_back(context); +} + +void RtContextUtil::DestroyRtContexts(uint64_t session_id) { + std::lock_guard lock(ctx_mutex_); + auto &contexts = rt_contexts_[session_id]; + DestroyRtContexts(session_id, contexts); +} + +void RtContextUtil::DestroyAllRtContexts() { + std::lock_guard lock(ctx_mutex_); + for (auto &ctx_pair : rt_contexts_) { + DestroyRtContexts(ctx_pair.first, ctx_pair.second); + } + rt_contexts_.clear(); +} -void RtContextUtil::DestroyrtContexts() { - GELOGI("The size of runtime context handle is %zu.", rtContexts_.size()); - for (auto &rtContext : rtContexts_) { +void RtContextUtil::DestroyRtContexts(uint64_t session_id, std::vector &contexts) { + GELOGI("Runtime context handle number of session %lu is %zu.", session_id, contexts.size()); + for (auto &rtContext : contexts) { (void)rtCtxDestroy(rtContext); } - rtContexts_.clear(); + contexts.clear(); } } // namespace ge diff --git a/src/ge/graph/manager/util/rt_context_util.h b/src/ge/graph/manager/util/rt_context_util.h index 93db9882..58cc0803 100644 --- a/src/ge/graph/manager/util/rt_context_util.h +++ b/src/ge/graph/manager/util/rt_context_util.h @@ -18,6 +18,8 @@ #define GE_GRAPH_MANAGER_UTIL_RT_CONTEXT_UTIL_H_ #include +#include +#include #include "runtime/context.h" @@ -29,13 +31,14 @@ class RtContextUtil { return instance; } - void AddrtContext(rtContext_t context); + void AddRtContext(uint64_t session_id, rtContext_t context); const rtContext_t GetNormalModeContext() const { return before_prerun_ctx_; } void SetNormalModeContext(rtContext_t context) { before_prerun_ctx_ = context; } - void DestroyrtContexts(); + void DestroyRtContexts(uint64_t session_id); + void DestroyAllRtContexts(); RtContextUtil &operator=(const RtContextUtil &) = delete; RtContextUtil(const RtContextUtil &RtContextUtil) = delete; @@ -44,8 +47,12 @@ class RtContextUtil { RtContextUtil() = default; ~RtContextUtil() {} - std::vector rtContexts_; + void DestroyRtContexts(uint64_t session_id, std::vector &contexts); + + std::map> rt_contexts_; rtContext_t before_prerun_ctx_ = nullptr; + + std::mutex ctx_mutex_; }; } // namespace ge diff --git a/src/ge/graph/optimize/graph_optimize.cc b/src/ge/graph/optimize/graph_optimize.cc index b42c2e01..09acae33 100644 --- a/src/ge/graph/optimize/graph_optimize.cc +++ b/src/ge/graph/optimize/graph_optimize.cc @@ -299,4 +299,36 @@ void GraphOptimize::TranFrameOp(ComputeGraphPtr &compute_graph) { } } } + +Status GraphOptimize::IdentifyReference(ComputeGraphPtr &compute_graph) { + for (auto &node : compute_graph->GetAllNodes()) { + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto input_name_index = op_desc->GetAllInputName(); + bool is_ref = false; + for (const auto &name_index : input_name_index) { + const int out_index = op_desc->GetOutputIndexByName(name_index.first); + if (out_index != -1) { + auto input_desc = op_desc->GetInputDesc(name_index.second); + input_desc.SetRefPortByIndex({name_index.second}); + op_desc->UpdateInputDesc(name_index.second, input_desc); + GELOGI("SetRefPort: set op[%s] input desc[%u-%s] ref.", op_desc->GetName().c_str(), name_index.second, + name_index.first.c_str()); + auto output_desc = op_desc->GetOutputDesc(static_cast(out_index)); + output_desc.SetRefPortByIndex({name_index.second}); + op_desc->UpdateOutputDesc(static_cast(out_index), output_desc); + GELOGI("SetRefPort: set op[%s] output desc[%u-%s] ref.", op_desc->GetName().c_str(), out_index, + name_index.first.c_str()); + is_ref = true; + } + } + if (is_ref) { + AttrUtils::SetBool(op_desc, ATTR_NAME_REFERENCE, is_ref); + GELOGI("param [node] %s is reference node, set attribute %s to be true.", node->GetName().c_str(), + ATTR_NAME_REFERENCE.c_str()); + } + } + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/optimize/graph_optimize.h b/src/ge/graph/optimize/graph_optimize.h index 72709932..9741814d 100644 --- a/src/ge/graph/optimize/graph_optimize.h +++ b/src/ge/graph/optimize/graph_optimize.h @@ -67,6 +67,9 @@ class GraphOptimize { // handle summary node before preRun graph Status HandleSummaryOp(ComputeGraphPtr &compute_graph); + // Identify reference node before optimize subgraph + Status IdentifyReference(ComputeGraphPtr &compute_graph); + void TranFrameOp(ComputeGraphPtr &compute_graph); private: @@ -85,5 +88,5 @@ class GraphOptimize { std::map> summary_output_indexes_ = {}; std::string func_bin_path_; }; -}; // namespace ge +} // namespace ge #endif // GE_GRAPH_OPTIMIZE_GRAPH_OPTIMIZE_H_ diff --git a/src/ge/graph/optimize/summary_optimize.cc b/src/ge/graph/optimize/summary_optimize.cc index 8b38d602..a8325da3 100644 --- a/src/ge/graph/optimize/summary_optimize.cc +++ b/src/ge/graph/optimize/summary_optimize.cc @@ -80,7 +80,8 @@ Status GraphOptimize::HandleSummaryOp(ComputeGraphPtr &compute_graph) { del_nodes.emplace_back(node_ptr); } } - summary_output_indexes_.insert({compute_graph->GetGraphID(), summary_output_indexes}); + GE_IF_BOOL_EXEC(!summary_output_indexes.empty(), + summary_output_indexes_.insert({compute_graph->GetGraphID(), summary_output_indexes})); // add output nodes for summary std::vector> out_nodes_info; diff --git a/src/ge/graph/partition/dynamic_shape_partition.cc b/src/ge/graph/partition/dynamic_shape_partition.cc index 6a396eef..324129c4 100644 --- a/src/ge/graph/partition/dynamic_shape_partition.cc +++ b/src/ge/graph/partition/dynamic_shape_partition.cc @@ -62,15 +62,16 @@ Status DynamicShapePartitioner::Partition() { } GELOGD("Start dynamic shape partition graph %s.", root_graph_->GetName().c_str()); - REQUIRE_SUCCESS(MarkUnknownShapeNodes(), "Failed mark unknown shape nodes."); + REQUIRE_SUCCESS(MarkUnknownShapeNodes(), "Failed mark unknown shape nodes, root grah name:%s.", + root_graph_->GetName().c_str()); if (unknown_shape_nodes_.empty()) { GELOGD("Skip dynamic shape partition of graph %s as all nodes are known shape.", root_graph_->GetName().c_str()); REQUIRE(AttrUtils::SetBool(*root_graph_, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, false), - "Failed set dynamic shape partitioned flag on root graph."); + "Failed set dynamic shape partitioned flag on root graph %s.", root_graph_->GetName().c_str()); return SUCCESS; } REQUIRE(AttrUtils::SetBool(*root_graph_, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, true), - "Failed set dynamic shape partitioned flag on root graph."); + "Failed set dynamic shape partitioned flag on root graph %s.", root_graph_->GetName().c_str()); DumpGraph("_Before_DSP"); auto status = PartitionImpl(); @@ -107,21 +108,21 @@ void DynamicShapePartitioner::PruneUniqueClusters() { } Status DynamicShapePartitioner::BuildPartitionFrame() { - for (auto cluster : unique_clusters_) { + for (const auto &cluster : unique_clusters_) { REQUIRE_SUCCESS(cluster->BuildFrame(), "Failed build frame of cluster[%lu].", cluster->Id()); } return SUCCESS; } Status DynamicShapePartitioner::CombinePartitionFrame() { - for (auto cluster : unique_clusters_) { + for (const auto &cluster : unique_clusters_) { REQUIRE_SUCCESS(cluster->CombinePartitionFrame(), "Failed combine frame of cluster[%lu].", cluster->Id()); } return SUCCESS; } Status DynamicShapePartitioner::BuildPartitionSubgraph() { - for (auto cluster : unique_clusters_) { + for (const auto &cluster : unique_clusters_) { REQUIRE_SUCCESS(cluster->BuildPartitionSubgraph(), "Failed build subgraph of cluster[%lu].", cluster->Id()); } return SUCCESS; @@ -134,10 +135,10 @@ std::string DynamicShapePartitioner::DebugString() const { size_t netoutput = 0; std::stringstream ss; ss << "All unknown shape nodes:" << std::endl; - for (auto node : unknown_shape_nodes_) { + for (const auto &node : unknown_shape_nodes_) { ss << " [" << node->GetName() << "](" << node->GetType() << ")" << std::endl; } - for (auto cluster : unique_clusters_) { + for (const auto &cluster : unique_clusters_) { if (cluster->IsUnknownShape()) { unknown++; } else if (cluster->IsKnownShape()) { @@ -150,7 +151,7 @@ std::string DynamicShapePartitioner::DebugString() const { } ss << "All clusters:" << unique_clusters_.size() << ", data:" << data << ", known:" << known << ", unknown:" << unknown << ", netoutput:" << netoutput << std::endl; - for (auto cluster : unique_clusters_) { + for (const auto &cluster : unique_clusters_) { ss << " " << cluster->DebugString() << std::endl; } return ss.str(); @@ -158,13 +159,13 @@ std::string DynamicShapePartitioner::DebugString() const { void DynamicShapePartitioner::DumpGraph(const std::string &suffix) { GraphUtils::DumpGEGraphToOnnx(*root_graph_, root_graph_->GetName() + suffix); - for (auto sub_graph : root_graph_->GetAllSubgraphs()) { + for (const auto &sub_graph : root_graph_->GetAllSubgraphs()) { GraphUtils::DumpGEGraphToOnnx(*sub_graph, sub_graph->GetName() + suffix); } } void DynamicShapePartitioner::ClearResource() { - for (auto cluster : unique_clusters_) { + for (const auto &cluster : unique_clusters_) { cluster->Clear(); } node_2_cluster_.clear(); @@ -175,8 +176,7 @@ void DynamicShapePartitioner::ClearResource() { } Status DynamicShapePartitioner::MarkUnknownShapeNodes() { - auto graph = root_graph_; - for (auto &node : graph->GetDirectNode()) { + for (auto &node : root_graph_->GetDirectNode()) { REQUIRE_SUCCESS(CollectSpreadUnknownShapeNodes(node), "Failed collect spread unknown shape nodes %s.", node->GetName().c_str()); } @@ -186,7 +186,7 @@ Status DynamicShapePartitioner::MarkUnknownShapeNodes() { Status DynamicShapePartitioner::InitClusters() { auto graph = root_graph_; size_t rank = 0; - for (const auto node : graph->GetDirectNode()) { + for (const auto &node : graph->GetDirectNode()) { Cluster::Type type = Cluster::DATA; if (node->GetType() == DATA) { type = Cluster::DATA; @@ -208,7 +208,7 @@ Status DynamicShapePartitioner::InitClusters() { cluster->AddInput(node_2_cluster_[parent]); } } - for (const auto node : graph->GetDirectNode()) { + for (const auto &node : graph->GetDirectNode()) { GELOGD("Make cluster for node %s : %s.", node->GetName().c_str(), node_2_cluster_[node]->DebugString().c_str()); } return SUCCESS; @@ -220,8 +220,8 @@ Status DynamicShapePartitioner::TopologicalSortClusters() { std::queue ready_clusters; std::unordered_map cluster_pending_count; std::unordered_set seen_clusters; - for (auto iter = node_2_cluster_.begin(); iter != node_2_cluster_.end(); iter++) { - auto cluster = iter->second; + for (auto &iter : node_2_cluster_) { + auto cluster = iter.second; if (seen_clusters.count(cluster) != 0) { continue; } @@ -242,7 +242,7 @@ Status DynamicShapePartitioner::TopologicalSortClusters() { if (cluster->IsKnownShape()) { ordered_cluster_.push_back(cluster); } - for (auto out_cluster : cluster->Outputs()) { + for (const auto &out_cluster : cluster->Outputs()) { if (cluster_pending_count[out_cluster] > 0 && --cluster_pending_count[out_cluster] == 0) { ready_clusters.push(out_cluster); } @@ -273,16 +273,16 @@ static std::string ToString(const std::vector &clusters) { Status DynamicShapePartitioner::MergeClusters() { // Merge unknown shape clusters - for (auto cluster : ordered_cluster_) { - for (auto in_cluster : cluster->Inputs()) { + for (const auto &cluster : ordered_cluster_) { + for (const auto &in_cluster : cluster->Inputs()) { if (!in_cluster->IsUnknownShape()) { continue; } auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); GELOGD("Merge all path cluster from %lu to %lu %s.", in_cluster->Id(), cluster->Id(), ToString(merged_clusters).c_str()); - for (auto merged_cluster : merged_clusters) { - for (auto node : merged_cluster->Nodes()) { + for (const auto &merged_cluster : merged_clusters) { + for (const auto &node : merged_cluster->Nodes()) { node_2_cluster_[node] = cluster; } } @@ -291,7 +291,7 @@ Status DynamicShapePartitioner::MergeClusters() { REQUIRE_SUCCESS(TopologicalSortClusters(), "Failed topological sort clusters after merge unknown shape clusters."); // Merge known shape clusters - for (auto cluster : ordered_cluster_) { + for (const auto &cluster : ordered_cluster_) { if (cluster->IsRefVariable() && cluster->Inputs().size() == 1) { auto in_cluster = *(cluster->Inputs().begin()); in_cluster->Merge(cluster); @@ -299,13 +299,13 @@ Status DynamicShapePartitioner::MergeClusters() { continue; } - for (auto in_cluster : cluster->Inputs()) { + for (const auto &in_cluster : cluster->Inputs()) { if (!in_cluster->IsKnownShape()) { continue; } if (cluster->TryMerge(in_cluster)) { GELOGD("Success merge known shape cluster from %lu to %lu.", in_cluster->Id(), cluster->Id()); - for (auto node : in_cluster->Nodes()) { + for (const auto &node : in_cluster->Nodes()) { node_2_cluster_[node] = cluster; } } @@ -333,7 +333,7 @@ Status DynamicShapePartitioner::CollectSpreadUnknownShapeNodes(NodePtr node) { if (IsUnknownShapeTensor(out_tensor)) { GELOGD("Collect node %s as unknown as output %lu is unknown.", node->GetName().c_str(), anchor_index); is_unknown = true; - auto anchor = node->GetOutDataAnchor(anchor_index); + auto anchor = node->GetOutDataAnchor(static_cast(anchor_index)); for (const auto peer_anchor : anchor->GetPeerInDataAnchors()) { if (peer_anchor != nullptr) { GELOGD("Collect node %s as has unknown input from %s:%lu.", peer_anchor->GetOwnerNode()->GetName().c_str(), @@ -349,7 +349,7 @@ Status DynamicShapePartitioner::CollectSpreadUnknownShapeNodes(NodePtr node) { if (IsUnknownShapeTensor(in_tensor)) { GELOGD("Collect node %s as unknown as input %lu is unknown.", node->GetName().c_str(), anchor_index); is_unknown = true; - auto anchor = node->GetInDataAnchor(anchor_index); + auto anchor = node->GetInDataAnchor(static_cast(anchor_index)); const auto peer_anchor = anchor->GetPeerOutAnchor(); if (peer_anchor != nullptr) { GELOGD("Collect node %s as has unknown output to %s:%lu.", peer_anchor->GetOwnerNode()->GetName().c_str(), @@ -453,15 +453,15 @@ std::string Cluster::DebugString() const { } ss << "[" << id_ << "](size:" << nodes_.size() << ")"; ss << "(" << min_ << "," << max_ << ")("; - for (auto cluster : in_clusters_) { + for (const auto &cluster : in_clusters_) { ss << cluster->id_ << ","; } ss << ")->("; - for (auto cluster : out_clusters_) { + for (const auto &cluster : out_clusters_) { ss << cluster->id_ << ","; } ss << ")|"; - for (auto node : nodes_) { + for (const auto &node : nodes_) { ss << (node->GetName() + "|"); } return ss.str(); @@ -507,12 +507,12 @@ void Cluster::Merge(ClusterPtr other) { in_clusters_.erase(other); out_clusters_.erase(other); auto in_clusters = other->in_clusters_; - for (auto cluster : in_clusters) { + for (const auto &cluster : in_clusters) { cluster->RemoveOutput(other); cluster->AddOutput(shared_from_this()); } auto out_clusters = other->out_clusters_; - for (auto cluster : out_clusters) { + for (const auto &cluster : out_clusters) { cluster->RemoveInput(other); cluster->AddInput(shared_from_this()); } @@ -529,7 +529,7 @@ bool Cluster::TryMerge(ClusterPtr other) { while (!forward_reached.empty()) { auto current_cluster = forward_reached.front(); forward_reached.pop(); - for (auto cluster : current_cluster->out_clusters_) { + for (const auto &cluster : current_cluster->out_clusters_) { if (cluster->max_ == max_ && current_cluster != other) { return false; } else if (cluster->min_ < max_) { @@ -557,7 +557,7 @@ std::vector Cluster::MergeAllPathFrom(ClusterPtr other) { while (!forward_reached_queue.empty()) { auto current_cluster = forward_reached_queue.front(); forward_reached_queue.pop(); - for (auto cluster : current_cluster->out_clusters_) { + for (const auto &cluster : current_cluster->out_clusters_) { if (cluster->min_ < max_ && cluster->max_ != max_ && forward_reached_clusters.count(cluster) == 0) { forward_reached_clusters.insert(cluster); forward_reached_queue.push(cluster); @@ -567,7 +567,7 @@ std::vector Cluster::MergeAllPathFrom(ClusterPtr other) { while (!backward_reached_queue.empty()) { auto current_cluster = backward_reached_queue.front(); backward_reached_queue.pop(); - for (auto cluster : current_cluster->in_clusters_) { + for (const auto &cluster : current_cluster->in_clusters_) { if (cluster->max_ > other->min_ && cluster->max_ != other->max_ && backward_reached_clusters.count(cluster) == 0) { backward_reached_clusters.insert(cluster); @@ -578,7 +578,7 @@ std::vector Cluster::MergeAllPathFrom(ClusterPtr other) { } } } - for (auto cluster : path_clusters) { + for (const auto &cluster : path_clusters) { Merge(cluster); } return path_clusters; @@ -598,11 +598,11 @@ void Cluster::AddFrameOutput(OutDataAnchorPtr anchor) { }; InDataAnchorPtr Cluster::GetFrameInDataAnchor(InDataAnchorPtr anchor) { - return partition_node_->GetInDataAnchor(inputs_index_[anchor]); + return partition_node_->GetInDataAnchor(static_cast(inputs_index_[anchor])); }; OutDataAnchorPtr Cluster::GetFrameOutDataAnchor(OutDataAnchorPtr anchor) { - return partition_node_->GetOutDataAnchor(outputs_index_[anchor]); + return partition_node_->GetOutDataAnchor(static_cast(outputs_index_[anchor])); }; InControlAnchorPtr Cluster::GetFrameInControlAnchor() { return partition_node_->GetInControlAnchor(); }; @@ -616,22 +616,25 @@ Status Cluster::BuildFrame() { auto node = nodes_.front(); auto in_control_anchor = node->GetInControlAnchor(); if (in_control_anchor != nullptr) { - for (auto peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { + for (const auto &peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { auto src_cluster = partitioner_->node_2_cluster_[peer_out_control_anchor->GetOwnerNode()]; if (src_cluster->id_ != id_) { - auto src_cluster = partitioner_->node_2_cluster_[peer_out_control_anchor->GetOwnerNode()]; - GraphUtils::RemoveEdge(peer_out_control_anchor, in_control_anchor); + REQUIRE_GRAPH_SUCCESS( + GraphUtils::RemoveEdge(peer_out_control_anchor, in_control_anchor), + "Failed remove edge from node %s index %d to node %s index %d.", + peer_out_control_anchor->GetOwnerNode()->GetName().c_str(), AnchorUtils::GetIdx(peer_out_control_anchor), + in_control_anchor->GetOwnerNode()->GetName().c_str(), AnchorUtils::GetIdx(in_control_anchor)); control_inputs_.insert(src_cluster); src_cluster->control_outputs_.insert(peer_out_control_anchor); } } } if (IsData()) { - for (auto anchor : node->GetAllOutDataAnchors()) { + for (const auto &anchor : node->GetAllOutDataAnchors()) { AddFrameOutput(anchor); } } else { - for (auto anchor : node->GetAllInDataAnchors()) { + for (const auto &anchor : node->GetAllInDataAnchors()) { AddFrameInput(anchor); } } @@ -660,7 +663,7 @@ Status Cluster::BuildPartitionFrame() { "Failed set shape flag."); REQUIRE_GRAPH_SUCCESS(GraphUtils::RemoveJustNode(graph, node), "Failed remove root graph node."); REQUIRE_GRAPH_SUCCESS(node->SetOwnerComputeGraph(subgraph_), "Failed set owner graph."); - for (auto anchor : node->GetAllInDataAnchors()) { + for (const auto &anchor : node->GetAllInDataAnchors()) { auto peer_out_anchor = anchor->GetPeerOutAnchor(); if (peer_out_anchor == nullptr) { continue; // Skip overhang input. @@ -674,7 +677,7 @@ Status Cluster::BuildPartitionFrame() { } auto in_control_anchor = node->GetInControlAnchor(); if (in_control_anchor != nullptr) { - for (auto peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { + for (const auto &peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { if (peer_out_control_anchor == nullptr) { continue; } @@ -689,9 +692,9 @@ Status Cluster::BuildPartitionFrame() { } } } - for (auto anchor : node->GetAllOutDataAnchors()) { + for (const auto &anchor : node->GetAllOutDataAnchors()) { auto peer_in_anchors = anchor->GetPeerInDataAnchors(); - for (auto peer_in_anchor : peer_in_anchors) { + for (const auto &peer_in_anchor : peer_in_anchors) { auto src_cluster = partitioner_->node_2_cluster_[peer_in_anchor->GetOwnerNode()]; if (src_cluster->id_ != id_) { AddFrameOutput(anchor); @@ -717,7 +720,7 @@ Status Cluster::BuildPartitionFrame() { } Status Cluster::CombinePartitionFrame() { - for (auto anchor : inputs_) { + for (const auto &anchor : inputs_) { auto peer_out_anchor = anchor->GetPeerOutAnchor(); auto src_cluster = partitioner_->node_2_cluster_[peer_out_anchor->GetOwnerNode()]; auto src_anchor = src_cluster->GetFrameOutDataAnchor(peer_out_anchor); @@ -729,7 +732,7 @@ Status Cluster::CombinePartitionFrame() { src_anchor->GetOwnerNode()->GetName().c_str(), src_anchor->GetIdx(), dst_anchor->GetOwnerNode()->GetName().c_str(), dst_anchor->GetIdx()); } - for (auto src_cluster : control_inputs_) { + for (const auto &src_cluster : control_inputs_) { auto src_anchor = src_cluster->GetFrameOutControlAnchor(); auto dst_anchor = GetFrameInControlAnchor(); REQUIRE_GRAPH_SUCCESS(GraphUtils::AddEdge(src_anchor, dst_anchor), "Failed add edge from %s:%d to %s:%d.", @@ -774,8 +777,8 @@ Status Cluster::BuildPartitionSubgraph() { REQUIRE_NOT_NULL(net_output_node, "Failed add netoutput node to subgraph."); REQUIRE_GRAPH_SUCCESS(net_output_node->SetOwnerComputeGraph(subgraph_), "Failed set owner graph of netoutput node."); parent_node_index = 0; - for (auto anchor : outputs_) { - auto output_desc = anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(anchor->GetIdx()); + for (const auto &anchor : outputs_) { + auto output_desc = anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(static_cast(anchor->GetIdx())); REQUIRE(AttrUtils::SetInt(output_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_node_index), "Failed set parent_node_index on subgraph netoutput's input."); REQUIRE_GRAPH_SUCCESS(net_output_op->UpdateInputDesc(parent_node_index, output_desc), @@ -786,7 +789,7 @@ Status Cluster::BuildPartitionSubgraph() { anchor->GetIdx()); parent_node_index++; } - for (auto anchor : control_outputs_) { + for (const auto &anchor : control_outputs_) { REQUIRE_GRAPH_SUCCESS(GraphUtils::AddEdge(anchor, net_output_node->GetInControlAnchor()), "Faile add control edge from %s:%d to netoutput node.", anchor->GetOwnerNode()->GetName().c_str(), anchor->GetIdx()); diff --git a/src/ge/graph/partition/engine_place.cc b/src/ge/graph/partition/engine_place.cc index 74da0326..2d1a7f13 100644 --- a/src/ge/graph/partition/engine_place.cc +++ b/src/ge/graph/partition/engine_place.cc @@ -38,6 +38,7 @@ Status EnginePlacer::Run() { return FAILED; } // Assign engine for each node in the graph + instance_ptr->DNNEngineManagerObj().InitPerformanceStaistic(); for (const auto &node_ptr : compute_graph_->GetDirectNode()) { GE_CHECK_NOTNULL(node_ptr); GE_CHECK_NOTNULL(node_ptr->GetOpDesc()); @@ -60,12 +61,15 @@ Status EnginePlacer::Run() { return FAILED; } } + for (auto &it : instance_ptr->DNNEngineManagerObj().GetCheckSupportCost()) { + GEEVENT("The time cost of %s::CheckSupported is [%lu] micro second.", it.first.c_str(), it.second); + } GELOGI("Engine placer ends."); return SUCCESS; } Status EnginePlacer::AssignEngineAndLog(ge::ConstNodePtr node_ptr, const std::string &engine_name) { - if (node_ptr == nullptr || node_ptr->GetOpDesc() == nullptr) { + if ((node_ptr == nullptr) || (node_ptr->GetOpDesc() == nullptr)) { GELOGE(FAILED, "node_ptr is null."); return FAILED; } diff --git a/src/ge/graph/partition/graph_partition.cc b/src/ge/graph/partition/graph_partition.cc index 50cd7e81..907d672d 100644 --- a/src/ge/graph/partition/graph_partition.cc +++ b/src/ge/graph/partition/graph_partition.cc @@ -25,6 +25,7 @@ #include "framework/common/types.h" #include "graph/debug/ge_attr_define.h" #include "graph/manager/graph_manager_utils.h" +#include "graph/common/ge_call_wrapper.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/type_utils.h" @@ -231,33 +232,33 @@ Status ge::GraphPartitioner::MergeSubGraph(ge::ComputeGraphPtr &output_merged_co ComputeGraphPtr new_sub_graph = MakeShared(original_compute_graph->GetName()); GE_CHECK_NOTNULL(new_sub_graph); output_merged_compute_graph = new_sub_graph; - GE_TIMESTAMP_START(MergeGraphRemoveNode); + GE_TIMESTAMP_START(MergeSubGraphRemoveNode); if (RemoveNodeAndEdgeBetweenEndPld(output_merged_compute_graph, sub_graph_list) != ge::SUCCESS) { GELOGE(GE_GRAPH_PARAM_NULLPTR, "[GraphPartitioner]: merging sub-graphs failed"); return FAILED; } - GE_TIMESTAMP_END(MergeGraphRemoveNode, "GraphPartitioner::MergeGraphRemoveNodeAndEdge"); - GE_TIMESTAMP_START(MergeGraphTopologicalSorting); + GE_TIMESTAMP_END(MergeSubGraphRemoveNode, "GraphPartitioner::MergeGraphRemoveNodeAndEdge"); + GE_TIMESTAMP_START(MergeSubGraphTopologicalSorting); Status ret = output_merged_compute_graph->TopologicalSorting(); if (ret != SUCCESS) { GELOGE(GE_GRAPH_TOPO_SORT_FAILED, "[GraphPartitioner]: output_merged_compute_graph->TopologicalSorting failed"); return FAILED; } - GE_TIMESTAMP_END(MergeGraphTopologicalSorting, "GraphPartitioner::MergeGraphTopologicalSorting"); + GE_TIMESTAMP_END(MergeSubGraphTopologicalSorting, "GraphPartitioner::MergeGraphTopologicalSorting"); // flush all nodes' engine of merged graph - GE_TIMESTAMP_START(MergeGraphEnginePlacerRun); + GE_TIMESTAMP_START(MergeSubGraphEnginePlacerRun); graph_info_.engine_placer_.SetComputeGraph(output_merged_compute_graph); if (graph_info_.engine_placer_.Run() != SUCCESS) { GELOGE(GE_GRAPH_INIT_FAILED, "[GraphPartitioner]: engine_placer run failed"); return FAILED; } - GE_TIMESTAMP_END(MergeGraphEnginePlacerRun, "GraphPartitioner::MergeGraphEnginePlacerRun"); + GE_TIMESTAMP_END(MergeSubGraphEnginePlacerRun, "GraphPartitioner::MergeGraphEnginePlacerRun"); GELOGI("Graph merge ends."); return SUCCESS; } Status ge::GraphPartitioner::UpdatePldOpDesc(const NodePtr &dst_node, int input_index, OpDescPtr &pld_op_desc) { - if (dst_node == nullptr || pld_op_desc == nullptr || dst_node->GetOpDesc() == nullptr) { + if ((dst_node == nullptr) || (pld_op_desc == nullptr) || (dst_node->GetOpDesc() == nullptr)) { GELOGE(FAILED, "parameter ptr is null."); return FAILED; } @@ -275,7 +276,7 @@ Status ge::GraphPartitioner::UpdatePldOpDesc(const NodePtr &dst_node, int input_ } Status ge::GraphPartitioner::UpdateEndOpDesc(const NodePtr &src_node, int output_index, OpDescPtr &end_op_desc) { - if (src_node == nullptr || end_op_desc == nullptr || src_node->GetOpDesc() == nullptr) { + if ((src_node == nullptr) || (end_op_desc == nullptr) || (src_node->GetOpDesc() == nullptr)) { GELOGE(FAILED, "parameter ptr is null."); return FAILED; } @@ -296,9 +297,9 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr const AnchorPtr &peer_in_anchor, const ge::ComputeGraphPtr &pld_graph, const ge::ComputeGraphPtr &end_graph) { - GE_CHECK_NOTNULL(out_anchor); GE_CHECK_NOTNULL(peer_in_anchor); GE_CHECK_NOTNULL(pld_graph); + GE_CHECK_NOTNULL(out_anchor); GE_CHECK_NOTNULL(end_graph); const auto &src_node = out_anchor->GetOwnerNode(); const auto &dst_node = peer_in_anchor->GetOwnerNode(); @@ -313,6 +314,7 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr GELOGW("SetInt peerIndex failed");) GE_IF_BOOL_EXEC(!AttrUtils::SetStr(end_op_desc, "parentOpType", dst_node->GetType()), GELOGW("SetStr parentOpType failed");) + GE_IF_BOOL_EXEC(!end_op_desc->SetExtAttr("parentNode", dst_node), GELOGW("SetEndExtAttr parentNode failed");) // replace input_desc of end with owner node's desc int output_index = ge::AnchorUtils::GetIdx(out_anchor); bool is_need_update_desc = (output_index >= 0) && (graph_info_.mode_ == kPartitioning); @@ -361,6 +363,7 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr GELOGW("SetStr parentId failed");) GE_IF_BOOL_EXEC(!AttrUtils::SetInt(pld_op_desc, "anchorIndex", AnchorUtils::GetIdx(out_anchor)), GELOGW("SetInt anchorIndex failed");) + GE_IF_BOOL_EXEC(!pld_op_desc->SetExtAttr("parentNode", src_node), GELOGW("SetPldExtAttr parentNode failed");) // do not care over flow graph_info_.num_of_pld_end_++; // replace output_desc of pld with input node's output desc @@ -395,14 +398,14 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr return FAILED; } graph_info_.index_2_end_[graph_info_.num_of_pld_end_] = new_end_node; - graph_info_.end_2_pld_[new_end_node] = new_pld_node; graph_info_.pld_2_end_[new_pld_node] = new_end_node; + graph_info_.end_2_pld_[new_end_node] = new_pld_node; return SUCCESS; } Status ge::GraphPartitioner::LinkInput2EndRemoveOrginalLink(ge::NodePtr input_node, ge::ComputeGraphPtr src_graph, ge::ComputeGraphPtr dst_graph) { - if (input_node == nullptr || src_graph == nullptr || dst_graph == nullptr) { + if ((input_node == nullptr) || (src_graph == nullptr) || (dst_graph == nullptr)) { GELOGE(FAILED, "parameter ptr is null."); return FAILED; } @@ -442,7 +445,7 @@ Status ge::GraphPartitioner::LinkInput2EndRemoveOrginalLink(ge::NodePtr input_no Status ge::GraphPartitioner::PutInputNodesInSubGraph(const ge::ComputeGraphPtr &src_graph, const ge::ComputeGraphPtr &dst_graph) { - if (src_graph == nullptr || dst_graph == nullptr) { + if ((src_graph == nullptr) || (dst_graph == nullptr)) { GELOGE(FAILED, "parameter ptr is null."); return FAILED; } @@ -849,34 +852,34 @@ Status ge::GraphPartitioner::PartitionSubGraph(ge::ComputeGraphPtr compute_graph GELOGE(GE_GRAPH_TOPO_SORT_FAILED, "[GraphPartitioner]: subGraphPtr->TopologicalSorting failed"); return FAILED; } - GE_TIMESTAMP_START(GraphPartitionInitialize); + GE_TIMESTAMP_START(PartitionSubGraphInitialize); if (Initialize(compute_graph) != SUCCESS) { GELOGE(GE_GRAPH_INIT_FAILED, "[GraphPartitioner]: initialize failed"); return FAILED; } - GE_TIMESTAMP_END(GraphPartitionInitialize, "GraphPartitioner::PartitionInitialize"); - GE_TIMESTAMP_START(GraphPartitionMarkClusters); + GE_TIMESTAMP_END(PartitionSubGraphInitialize, "GraphPartitioner::PartitionInitialize"); + GE_TIMESTAMP_START(PartitionSubGraphMarkClusters); MarkClusters(); - GE_TIMESTAMP_END(GraphPartitionMarkClusters, "GraphPartitioner::PartitionMarkClusters"); - GE_TIMESTAMP_START(GraphPartitionSplitSubGraphs); + GE_TIMESTAMP_END(PartitionSubGraphMarkClusters, "GraphPartitioner::PartitionMarkClusters"); + GE_TIMESTAMP_START(PartitionSubGraphSplitSubGraphs); if (SplitSubGraphs(compute_graph) != SUCCESS) { GELOGE(FAILED, "[GraphPartitioner]: SplitSubGraphs failed"); return FAILED; } - GE_TIMESTAMP_END(GraphPartitionSplitSubGraphs, "GraphPartitioner::PartitionSplitSubGraphs"); - GE_TIMESTAMP_START(GraphPartitionSortSubGraphs); + GE_TIMESTAMP_END(PartitionSubGraphSplitSubGraphs, "GraphPartitioner::PartitionSplitSubGraphs"); + GE_TIMESTAMP_START(PartitionSubGraphSortSubGraphs); if (SortSubGraphs(compute_graph) != ge::SUCCESS) { GELOGE(GE_GRAPH_TOPO_SORT_FAILED, "Graph Partition SortSubGraphs failed."); return ge::FAILED; } - GE_TIMESTAMP_END(GraphPartitionSortSubGraphs, "GraphPartitioner::PartitionSortSubGraphs"); - GE_TIMESTAMP_START(GraphPartitionAddPartitionsToGraphNode); + GE_TIMESTAMP_END(PartitionSubGraphSortSubGraphs, "GraphPartitioner::PartitionSortSubGraphs"); + GE_TIMESTAMP_START(PartitionSubGraphAddPartitionsToGraphNode); vector output_subgraphs; if (AddPartitionsToGraphNode(output_subgraphs, compute_graph) != ge::SUCCESS) { GELOGE(GE_GRAPH_EMPTY_PARTITION, "Graph Partition AddPartitionsToGraphNode failed."); return ge::FAILED; } - GE_TIMESTAMP_END(GraphPartitionAddPartitionsToGraphNode, "GraphPartitioner::PartitionAddPartitionsToGraphNode"); + GE_TIMESTAMP_END(PartitionSubGraphAddPartitionsToGraphNode, "GraphPartitioner::PartitionAddPartitionsToGraphNode"); GELOGI("Graph Partition ends. Adding partitions to SubGraphInfo, got %zu sub graphs", output_subgraphs.size()); graph_info_.mode_ = kMerging; // do not care over flow @@ -923,7 +926,7 @@ Status ge::GraphPartitioner::AddPlaceHolderEnd(const AnchorPtr &out_anchor, cons Status ge::GraphPartitioner::SortSubGraphs(const ge::ComputeGraphPtr &compute_graph) { uint32_t rank = kRankOne; // rank 0 for data graph ComputeGraphPtr new_input_nodes_sub_graph = MakeShared("inputNodeGraph"); - if (new_input_nodes_sub_graph == nullptr || compute_graph == nullptr) { + if ((new_input_nodes_sub_graph == nullptr) || (compute_graph == nullptr)) { GELOGE(FAILED, "[GraphPartitioner]: new_input_nodes_sub_graph or compute_graph is null."); return FAILED; } @@ -965,7 +968,7 @@ Status ge::GraphPartitioner::SortSubGraphs(const ge::ComputeGraphPtr &compute_gr } AnchorPtr ge::GraphPartitioner::GetEndInAnchor(const AnchorPtr &src_anchor, const NodePtr &end_node) { - if (src_anchor == nullptr || end_node == nullptr) { + if ((src_anchor == nullptr) || (end_node == nullptr)) { GELOGE(FAILED, "parameter ptr is null."); return nullptr; } @@ -979,7 +982,7 @@ AnchorPtr ge::GraphPartitioner::GetEndInAnchor(const AnchorPtr &src_anchor, cons } AnchorPtr ge::GraphPartitioner::GetPldOutAnchor(const NodePtr &pld_node, const AnchorPtr &dst_anchor) { - if (pld_node == nullptr || dst_anchor == nullptr) { + if ((pld_node == nullptr) || (dst_anchor == nullptr)) { GELOGE(FAILED, "parameter ptr is null."); return nullptr; } @@ -992,16 +995,16 @@ AnchorPtr ge::GraphPartitioner::GetPldOutAnchor(const NodePtr &pld_node, const A return pld_out_anchor; } -void ge::GraphPartitioner::AddEndPldInformationToSubGraphInfo(ge::SubGraphInfoPtr &sub_graph_info) { - if (sub_graph_info == nullptr) { +void ge::GraphPartitioner::AddEndPldInformationToSubGraphInfo(ge::SubGraphInfoPtr &subgraph_info) { + if (subgraph_info == nullptr) { GELOGE(FAILED, "parameter ptr is null."); return; } - auto sub_graph = sub_graph_info->GetSubGraph(); - GE_CHECK_NOTNULL_JUST_RETURN(sub_graph); + auto subgraph = subgraph_info->GetSubGraph(); + GE_CHECK_NOTNULL_JUST_RETURN(subgraph); NodetoNodeMap end_map; NodetoNodeMap pld_map; - for (const auto &node : sub_graph->GetDirectNode()) { + for (const auto &node : subgraph->GetDirectNode()) { if (node->GetType() == kEndType) { end_map[node] = graph_info_.end_2_pld_.at(node); } @@ -1009,8 +1012,8 @@ void ge::GraphPartitioner::AddEndPldInformationToSubGraphInfo(ge::SubGraphInfoPt pld_map[node] = graph_info_.pld_2_end_.at(node); } } - sub_graph_info->SetEnd2PldMap(end_map); - sub_graph_info->SetPld2EndMap(pld_map); + subgraph_info->SetEnd2PldMap(end_map); + subgraph_info->SetPld2EndMap(pld_map); } const Graph2SubGraphInfoList &ge::GraphPartitioner::GetSubGraphMap() { return graph_2_subgraph_list_; } diff --git a/src/ge/graph/passes/atomic_addr_clean_pass.cc b/src/ge/graph/passes/atomic_addr_clean_pass.cc index 7d9b8dec..ae69fd93 100644 --- a/src/ge/graph/passes/atomic_addr_clean_pass.cc +++ b/src/ge/graph/passes/atomic_addr_clean_pass.cc @@ -22,16 +22,12 @@ #include #include -#include "framework/common/debug/ge_log.h" #include "common/ge_inner_error_codes.h" #include "common/ge/ge_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/node_utils.h" #include "init/gelib.h" -namespace { -bool is_loop_graph = false; -} namespace ge { namespace { bool GraphShouldBeSkip(const ge::ComputeGraphPtr &graph) { @@ -44,7 +40,6 @@ bool GraphShouldBeSkip(const ge::ComputeGraphPtr &graph) { } // namespace Status AtomicAddrCleanPass::Run(ComputeGraphPtr graph) { - GE_TIMESTAMP_START(AtomicAddrCleanPass); if (graph == nullptr) { GELOGE(PARAM_INVALID, "param [graph] must not be null."); return PARAM_INVALID; @@ -71,10 +66,10 @@ Status AtomicAddrCleanPass::Run(ComputeGraphPtr graph) { } atomic_node_vec.push_back(node); } - if (!is_loop_graph && node->GetType() == LOOPCOND) { + if (!is_loop_graph_ && node->GetType() == LOOPCOND) { // there is loop in this graph GELOGD("There is no loop node. It will insert clean node follow atomic node."); - is_loop_graph = true; + is_loop_graph_ = true; } } if (atomic_node_vec.empty()) { @@ -83,7 +78,7 @@ Status AtomicAddrCleanPass::Run(ComputeGraphPtr graph) { } // 2.Insert clean node and link to atomic node Status ret; - if (is_loop_graph) { + if (is_loop_graph_) { ret = HandleLoopGraph(graph, atomic_node_vec); if (ret != SUCCESS) { return ret; @@ -95,7 +90,6 @@ Status AtomicAddrCleanPass::Run(ComputeGraphPtr graph) { } } GELOGD("AtomicAddrCleanPass end."); - GE_TIMESTAMP_END(AtomicAddrCleanPass, "GraphManager::AtomicAddrCleanPass"); return SUCCESS; } @@ -172,12 +166,14 @@ NodePtr AtomicAddrCleanPass::InsertAtomicAddrCleanNode(ComputeGraphPtr &graph) { if (!session_graph_id.empty()) { (void)AttrUtils::SetStr(op_desc, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id); } + string node_name = op_desc->GetName(); // Only flush subgraph name - string node_name = (graph->GetParentGraph() != nullptr) - ? (graph->GetName() + "_" + op_desc->GetName() + session_graph_id) - : (op_desc->GetName() + session_graph_id); + if (graph->GetParentGraph() != nullptr) { + node_name = graph->GetName() + "_" + node_name; + } - op_desc->SetName(node_name); + string name = node_name + session_graph_id; + op_desc->SetName(name); GELOGI("Create cleanAddr op:%s.", op_desc->GetName().c_str()); // To avoid same name between graphs, set session graph id to this node NodePtr clean_addr_node = graph->AddNodeFront(op_desc); @@ -203,7 +199,7 @@ Status AtomicAddrCleanPass::LinkToAtomicNode(const NodePtr &atomic_node, NodePtr } GELOGD("Graph add cleanAddrNode op out ctrl edge, dst node: %s.", atomic_node->GetName().c_str()); std::string stream_label; - if (is_loop_graph && AttrUtils::GetStr(atomic_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) { + if (is_loop_graph_ && AttrUtils::GetStr(atomic_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) { if (!AttrUtils::SetStr(atomic_clean_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) { GELOGW("LinkToAtomicNode: SetStr failed"); return INTERNAL_ERROR; @@ -262,7 +258,7 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) { return true; } /// -/// @brief Clear Status, uesd for subgraph pass +/// @brief Clear Status, used for subgraph pass /// @return SUCCESS /// Status AtomicAddrCleanPass::ClearStatus() { diff --git a/src/ge/graph/passes/atomic_addr_clean_pass.h b/src/ge/graph/passes/atomic_addr_clean_pass.h index d2d8f2ce..3640beef 100644 --- a/src/ge/graph/passes/atomic_addr_clean_pass.h +++ b/src/ge/graph/passes/atomic_addr_clean_pass.h @@ -75,6 +75,7 @@ class AtomicAddrCleanPass : public GraphPass { bool IsAtomicOp(const NodePtr &node); vector hcom_node_vec_; + bool is_loop_graph_ = false; }; } // namespace ge diff --git a/src/ge/graph/passes/attach_stream_label_pass.cc b/src/ge/graph/passes/attach_stream_label_pass.cc new file mode 100644 index 00000000..0c342d8c --- /dev/null +++ b/src/ge/graph/passes/attach_stream_label_pass.cc @@ -0,0 +1,319 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/attach_stream_label_pass.h" +#include "ge/ge_api_types.h" +#include "graph/common/omg_util.h" + +namespace ge { +Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { + GELOGD("AttachStreamLabelPass Enter."); + + FindNodes(graph); + for (const auto &node : need_label_nodes_) { + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (!op_desc->HasAttr(ATTR_NAME_STREAM_LABEL)) { + GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str()); + } + } + GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode failed."); + + GELOGD("AttachStreamLabelPass Leave."); + return SUCCESS; +} + +/// +/// @brief Clear Status, used for subgraph pass +/// @return +/// +Status AttachStreamLabelPass::ClearStatus() { + stream_switch_nodes_.clear(); + need_label_nodes_.clear(); + enter_nodes_.clear(); + branch_head_nodes_.clear(); + return SUCCESS; +} + +/// +/// @brief Find StreamSwitch / StreamMerge / Enter node +/// @param [in] graph +/// @return void +/// +void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { + for (const NodePtr &node : graph->GetDirectNode()) { + const std::string &type = node->GetType(); + if (type == STREAMSWITCH) { + stream_switch_nodes_.emplace_back(node); + } else if (type == STREAMMERGE) { + if ((node->GetOpDesc() != nullptr) && !node->GetOpDesc()->HasAttr(ATTR_NAME_NEXT_ITERATION)) { + need_label_nodes_.emplace_back(node); + } + } else if ((type == ENTER) || (type == REFENTER)) { + enter_nodes_.emplace_back(node); + } + } + + for (const auto &node : stream_switch_nodes_) { + for (const auto &out_ctrl_node : node->GetOutControlNodes()) { + MarkHeadNodes(out_ctrl_node, node); + } + need_label_nodes_.emplace_back(node); + } +} + +/// +/// @brief Mark node as head_node of stream_switch +/// @param [in] node +/// @param [in] stream_switch +/// @return void +/// +void AttachStreamLabelPass::MarkHeadNodes(const NodePtr &node, const NodePtr &stream_switch) { + static const std::set bypass_type_set = {IDENTITY, IDENTITYN, CAST, TRANSDATA, + TRANSPOSE, TRANSPOSED, RESHAPE}; + std::stack nodes; + nodes.push(node); + std::set visited; + while (!nodes.empty()) { + NodePtr cur_node = nodes.top(); + nodes.pop(); + if (visited.count(cur_node) > 0) { + continue; + } + GELOGD("branch_head_node %s of stream_switch %s.", cur_node->GetName().c_str(), stream_switch->GetName().c_str()); + branch_head_nodes_[cur_node] = stream_switch; + if (bypass_type_set.count(cur_node->GetType()) > 0) { + for (const auto &out_node : cur_node->GetOutAllNodes()) { + nodes.push(out_node); + } + } + visited.insert(cur_node); + } +} + +/// +/// @brief update cond branch +/// @param [in] node +/// @return Status +/// +Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { + std::string stream_label; + std::unordered_set branch_nodes; + std::unordered_set visited; + std::stack nodes; + nodes.push(node); + + static const std::set end_type_set = {STREAMSWITCH, STREAMMERGE, MERGE}; + bool merge_flag = false; + bool exit_flag = false; + bool net_output_flag = false; + while (!nodes.empty()) { + NodePtr cur_node = nodes.top(); + nodes.pop(); + if (visited.count(cur_node) > 0) { + continue; + } + if (AttachFlag(cur_node, stream_label, merge_flag, exit_flag, net_output_flag) != SUCCESS) { + GELOGE(FAILED, "Attach flag for node %s failed.", cur_node->GetName().c_str()); + return FAILED; + } + + const std::string &type = cur_node->GetType(); + for (const auto &out_node : cur_node->GetOutAllNodes()) { + const std::string &out_type = out_node->GetType(); + bool stop_flag = (end_type_set.count(out_type) > 0) || + ((branch_head_nodes_.count(out_node) > 0) && (branch_head_nodes_[out_node] != node)) || + (((type == ENTER) || (type == REFENTER)) && (out_type != STREAMACTIVE)); + if (!stop_flag) { + nodes.push(out_node); + GELOGD("Insert branch node %s.", out_node->GetName().c_str()); + branch_nodes.insert(out_node); + } + } + visited.insert(cur_node); + } + + if (node->GetType() == STREAMSWITCH) { + GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed."); + } + + bool attach_flag = (merge_flag || exit_flag) && net_output_flag; + if (attach_flag) { + GELOGI("No need to keep on attaching label."); + return SUCCESS; + } + + for (const NodePtr &tmp_node : branch_nodes) { + GELOGD("Attach label %s to node: %s.", stream_label.c_str(), tmp_node->GetName().c_str()); + GE_CHK_STATUS_RET(SetStreamLabel(tmp_node, stream_label), "Set stream label failed."); + } + + return SUCCESS; +} + +/// +/// @brief attach flag +/// @param [in] node +/// @param [out] stream_label +/// @param [out] merge_flag +/// @param [out] exit_flag +/// @param [out] net_output_flag +/// @return Status +/// +Status AttachStreamLabelPass::AttachFlag(const NodePtr &node, std::string &stream_label, bool &merge_flag, + bool &exit_flag, bool &net_output_flag) { + const std::string &type = node->GetType(); + if (type == STREAMSWITCH) { + if (node->GetInDataNodes().empty()) { + GELOGE(INTERNAL_ERROR, "node %s has no input_data_node.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + stream_label = node->GetInDataNodes().at(0)->GetName(); + GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); + bool value = false; + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, + "StreamSwitch get attr TRUE_BRANCH_STREAM failed."); + stream_label += (value ? "_t" : "_f"); + } else if (type == STREAMMERGE) { + stream_label = node->GetName(); + GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); + merge_flag = true; + } else if ((type == EXIT) || (type == REFEXIT)) { + GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); + exit_flag = true; + } else if (type == NETOUTPUT) { + net_output_flag = true; + } + + return SUCCESS; +} + +/// +/// @brief Update stream_label start with enter nodes +/// @return Status +/// +Status AttachStreamLabelPass::UpdateEnterNode() { + std::unordered_map> enter_active_map; + for (const auto &enter_node : enter_nodes_) { + for (const auto &out_ctrl_node : enter_node->GetOutControlNodes()) { + if (out_ctrl_node->GetType() != STREAMACTIVE) { + continue; + } + auto iter = enter_active_map.find(out_ctrl_node); + if (iter == enter_active_map.end()) { + enter_active_map[out_ctrl_node] = {enter_node}; + } else { + iter->second.emplace_back(enter_node); + } + } + } + + for (const auto &pair : enter_active_map) { + if (SetEnterLabel(pair.second, pair.first) != SUCCESS) { + GELOGE(FAILED, "Set stream_label for enter_nodes failed."); + return FAILED; + } + + NodePtr active_node = pair.first; + GE_CHECK_NOTNULL(active_node); + std::vector active_label_list; + if (!AttrUtils::GetListStr(active_node->GetOpDesc(), ATTR_NAME_ACTIVE_LABEL_LIST, active_label_list) || + (active_label_list.size() != 1) || active_label_list[0].empty()) { + GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ACTIVE_LABEL_LIST failed, node: %s.", active_node->GetName().c_str()); + return INTERNAL_ERROR; + } + + std::stack enter_nodes; + for (const auto &enter_node : pair.second) { + enter_nodes.emplace(enter_node); + } + if (UpdateLoopBranch(enter_nodes, active_label_list[0]) != SUCCESS) { + GELOGE(FAILED, "Update stream_label for loop_branch failed."); + return FAILED; + } + } + + return SUCCESS; +} + +/// +/// @brief Set stream_label for enter_nodes +/// @param [in] enter_nodes +/// @param [in] active_node +/// @return Status +/// +Status AttachStreamLabelPass::SetEnterLabel(const std::vector &enter_nodes, const NodePtr &active_node) { + std::string stream_label; + GE_CHECK_NOTNULL(active_node); + (void)AttrUtils::GetStr(active_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label); + + bool same_flag = true; + for (const auto &enter_node : enter_nodes) { + std::string tmp_label; + (void)AttrUtils::GetStr(enter_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, tmp_label); + if (tmp_label.empty() || (stream_label == tmp_label)) { + continue; + } + same_flag = false; + break; + } + + if (stream_label.empty()) { + if (same_flag) { + stream_label = active_node->GetName(); + } else { + GELOGW("stream_label of enter_active is empty while stream_label of some enter_node is not."); + return SUCCESS; + } + } + + for (const auto &enter_node : enter_nodes) { + GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); + } + GE_CHK_STATUS_RET(SetStreamLabel(active_node, stream_label), "Set stream label failed."); + return SUCCESS; +} + +/// +/// @brief Update stream_label for loop_branch +/// @param [in] enter_nodes +/// @param [in] stream_label +/// @return Status +/// +Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack &enter_nodes, + const std::string &stream_label) { + std::stack nodes(enter_nodes); + NodePtr cur_node = nullptr; + while (!nodes.empty()) { + cur_node = nodes.top(); + nodes.pop(); + for (const NodePtr &out_node : cur_node->GetOutAllNodes()) { + OpDescPtr out_desc = out_node->GetOpDesc(); + GE_CHECK_NOTNULL(out_desc); + std::string out_type = out_desc->GetType(); + if (out_desc->HasAttr(ATTR_NAME_STREAM_LABEL) || (out_type == ENTER) || (out_type == REFENTER)) { + continue; + } + GELOGD("Attach label %s to node: %s.", stream_label.c_str(), out_node->GetName().c_str()); + GE_CHK_STATUS_RET(SetStreamLabel(out_node, stream_label), "Set stream label failed."); + nodes.push(out_node); + } + } + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/passes/attach_stream_label_pass.h b/src/ge/graph/passes/attach_stream_label_pass.h new file mode 100644 index 00000000..743ce36e --- /dev/null +++ b/src/ge/graph/passes/attach_stream_label_pass.h @@ -0,0 +1,97 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_ATTACH_STREAM_LABEL_PASS_H_ +#define GE_GRAPH_PASSES_ATTACH_STREAM_LABEL_PASS_H_ + +#include +#include "inc/graph_pass.h" + +namespace ge { +class AttachStreamLabelPass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph); + + /// + /// @brief Clear Status, used for subgraph pass + /// @return + /// + Status ClearStatus() override; + + private: + /// + /// @brief Find StreamSwitch / StreamMerge / Enter node + /// @param [in] graph + /// @return void + /// + void FindNodes(const ComputeGraphPtr &graph); + + /// + /// @brief Mark node as head_node of stream_switch + /// @param [in] node + /// @param [in] stream_switch + /// @return void + /// + void MarkHeadNodes(const NodePtr &node, const NodePtr &stream_switch); + + /// + /// @brief update cond branch + /// @param [in] node + /// @return Status + /// + Status UpdateCondBranch(const NodePtr &node); + + /// + /// @brief attach flag + /// @param [in] node + /// @param [out] stream_label + /// @param [out] merge_flag + /// @param [out] exit_flag + /// @param [out] net_output_flag + /// @return Status + /// + static Status AttachFlag(const NodePtr &node, std::string &stream_label, bool &merge_flag, bool &exit_flag, + bool &net_output_flag); + + /// + /// @brief Update stream_label for loop_branch + /// @param [in] enter_nodes + /// @param [in] stream_label + /// @return Status + /// + static Status UpdateLoopBranch(const std::stack &enter_nodes, const std::string &stream_label); + + /// + /// @brief Update stream_label start with enter nodes + /// @return Status + /// + Status UpdateEnterNode(); + + /// + /// @brief Set stream_label for enter_nodes + /// @param [in] enter_nodes + /// @param [in] active_node + /// @return Status + /// + static Status SetEnterLabel(const std::vector &enter_nodes, const NodePtr &active_node); + + std::vector stream_switch_nodes_; + std::vector need_label_nodes_; + std::vector enter_nodes_; + std::unordered_map branch_head_nodes_; +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_ATTACH_STREAM_LABEL_PASS_H_ diff --git a/src/ge/graph/passes/cast_remove_pass.cc b/src/ge/graph/passes/cast_remove_pass.cc index d18c4b4e..f7ff941c 100644 --- a/src/ge/graph/passes/cast_remove_pass.cc +++ b/src/ge/graph/passes/cast_remove_pass.cc @@ -69,7 +69,6 @@ bool CastRemovePass::HasSameDataType(OpDescPtr &begin_op_desc, OpDescPtr &end_op auto begin_out_desc = begin_op_desc->MutableOutputDesc(0); DataType begin_out_datatype = begin_out_desc->GetDataType(); - if (begin_out_datatype == end_out_datatype && (begin_out_datatype == DT_FLOAT16 || begin_out_datatype == DT_FLOAT)) { type = begin_out_datatype; return true; diff --git a/src/ge/graph/passes/common_subexpression_elimination_pass.cc b/src/ge/graph/passes/common_subexpression_elimination_pass.cc index a52535c1..18f2e857 100644 --- a/src/ge/graph/passes/common_subexpression_elimination_pass.cc +++ b/src/ge/graph/passes/common_subexpression_elimination_pass.cc @@ -83,6 +83,7 @@ Status CommonSubexpressionEliminationPass::Run(ComputeGraphPtr graph) { continue; } auto key = GetCseKey(node); + GELOGD("The node %s cse key %s", node->GetName().c_str(), key.c_str()); auto iter = keys_to_node.find(key); if (iter == keys_to_node.end()) { keys_to_node[key] = node; diff --git a/src/ge/graph/passes/compile_nodes_pass.cc b/src/ge/graph/passes/compile_nodes_pass.cc index def7655e..330569a2 100644 --- a/src/ge/graph/passes/compile_nodes_pass.cc +++ b/src/ge/graph/passes/compile_nodes_pass.cc @@ -23,6 +23,7 @@ #include "common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" +#include "graph/common/ge_call_wrapper.h" #include "graph/op_desc.h" using domi::ImplyType; @@ -78,7 +79,7 @@ graphStatus CompileNodesPass::Run(ComputeGraphPtr graph) { return result; } GELOGI("[CompileNodesPass]: Optimize success."); - GE_TIMESTAMP_END(CompileNodesPass, "GraphManager::CompileNodesPass"); + GE_TIMESTAMP_EVENT_END(CompileNodesPass, "OptimizeStage2::ControlAttrOptimize::CompileNodesPass"); return GRAPH_SUCCESS; } @@ -101,7 +102,6 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std: } } OpsKernelInfoStorePtr kernel_info = instance->OpsKernelManagerObj().GetOpsKernelInfoStore(kernel_lib_name); - if (kernel_info == nullptr) { GELOGE(ge::GE_GRAPH_PARAM_NULLPTR, "Get op %s ops kernel info store failed", node->GetName().c_str()); return ge::GE_GRAPH_PARAM_NULLPTR; diff --git a/src/ge/graph/passes/cond_pass.cc b/src/ge/graph/passes/cond_pass.cc index 651cf98b..2f3f9333 100644 --- a/src/ge/graph/passes/cond_pass.cc +++ b/src/ge/graph/passes/cond_pass.cc @@ -226,7 +226,7 @@ Status CondPass::HandleScalarCond(const ComputeGraphPtr &graph, const OutDataAnc return FAILED; } - if (GraphUtils::InsertNodeBefore(out_anchor, {in_anchor}, cast_node) != GRAPH_SUCCESS) { + if (GraphUtils::InsertNodeAfter(out_anchor, {in_anchor}, cast_node) != GRAPH_SUCCESS) { GELOGE(FAILED, "Insert Cast node %s between %s->%s failed.", cast_node->GetName().c_str(), out_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetOwnerNode()->GetName().c_str()); return FAILED; @@ -271,7 +271,7 @@ Status CondPass::InsertNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr } AddRePassNode(new_node); - if (GraphUtils::InsertNodeBefore(out_anchor, {in_anchor}, new_node) != GRAPH_SUCCESS) { + if (GraphUtils::InsertNodeAfter(out_anchor, {in_anchor}, new_node) != GRAPH_SUCCESS) { GELOGE(FAILED, "Insert %s node %s between %s->%s failed.", type.c_str(), new_node->GetName().c_str(), out_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetOwnerNode()->GetName().c_str()); return FAILED; diff --git a/src/ge/graph/passes/cond_remove_pass.cc b/src/ge/graph/passes/cond_remove_pass.cc index 8bc34fbc..1650be92 100644 --- a/src/ge/graph/passes/cond_remove_pass.cc +++ b/src/ge/graph/passes/cond_remove_pass.cc @@ -225,41 +225,40 @@ bool CondRemovePass::CheckIfCondConstInput(const OutDataAnchorPtr &cond_out_anch Status CondRemovePass::ReplaceIfCaseNodeWithPartitioncall(const NodePtr &node, const ComputeGraphPtr &save_branch) { // Add compute graph to new node - const auto &input_anchors = node->GetAllInAnchors(); - const auto &output_anchors = node->GetAllOutAnchors(); + const auto &input_desc_size = node->GetOpDesc()->GetInputsSize(); + const auto &output_desc_size = node->GetOpDesc()->GetOutputsSize(); // Create subgraph opdesc & node auto partitioncall_opdesc = - CreateSubgraphOpDesc(save_branch->GetName(), input_anchors.size() - kConditionIndexNum, output_anchors.size()); + CreateSubgraphOpDesc(save_branch->GetName(), input_desc_size - kConditionIndexNum, output_desc_size); auto partitioncall_node = node->GetOwnerComputeGraph()->AddNode(partitioncall_opdesc); // Link node's peerout anchors to new node's inanchors - for (const auto &input_anchor : input_anchors) { + for (const auto &input_anchor : node->GetAllInAnchors()) { for (const auto &peerout_anchor : input_anchor->GetPeerAnchors()) { if (GraphUtils::AddEdge(peerout_anchor, partitioncall_node->GetInAnchor( input_anchor->GetIdx() - kConditionIndexNum)) != ge::GRAPH_SUCCESS) { GELOGE(FAILED, "Add edge failed, from node:%s idx:%d to node:%s idx:%d, input num:%d, output num:%d", peerout_anchor->GetOwnerNode()->GetName().c_str(), peerout_anchor->GetIdx(), - partitioncall_node->GetName().c_str(), input_anchor->GetIdx(), input_anchors.size(), - output_anchors.size()); + partitioncall_node->GetName().c_str(), input_anchor->GetIdx(), input_desc_size, output_desc_size); return FAILED; } } } // Remove If / Case anchor and peer in anchor // Link new node's out anchors to node's peer inanchors - for (const auto &output_anchor : output_anchors) { + for (const auto &output_anchor : node->GetAllOutAnchors()) { for (const auto &peerin_anchor : output_anchor->GetPeerAnchors()) { if (GraphUtils::RemoveEdge(node->GetOutAnchor(output_anchor->GetIdx()), peerin_anchor) != ge::GRAPH_SUCCESS) { GELOGE(FAILED, "Remove edge failed, from node:%s idx:%d to node:%s idx:%d, input num:%d, output num:%d", node->GetName().c_str(), output_anchor->GetIdx(), peerin_anchor->GetOwnerNode()->GetName().c_str(), - peerin_anchor->GetIdx(), input_anchors.size(), output_anchors.size()); + peerin_anchor->GetIdx(), input_desc_size, output_desc_size); return FAILED; } if (GraphUtils::AddEdge(partitioncall_node->GetOutAnchor(output_anchor->GetIdx()), peerin_anchor) != ge::GRAPH_SUCCESS) { GELOGE(FAILED, "Add edge failed, from node:%s idx:%d to node:%s idx:%d, input num:%d, output num:%d", partitioncall_node->GetName().c_str(), output_anchor->GetIdx(), - peerin_anchor->GetOwnerNode()->GetName().c_str(), peerin_anchor->GetIdx(), input_anchors.size(), - output_anchors.size()); + peerin_anchor->GetOwnerNode()->GetName().c_str(), peerin_anchor->GetIdx(), input_desc_size, + output_desc_size); return FAILED; } } diff --git a/src/ge/graph/passes/constant_folding_pass.cc b/src/ge/graph/passes/constant_folding_pass.cc index 3ac7feb6..80bf7867 100644 --- a/src/ge/graph/passes/constant_folding_pass.cc +++ b/src/ge/graph/passes/constant_folding_pass.cc @@ -29,6 +29,18 @@ #include "inc/kernel.h" namespace ge { +const int64_t kStartCallNum = 1; + +const std::unordered_map> + &ConstantFoldingPass::GetGeConstantFoldingPerfStatistic() const { + return statistic_of_ge_constant_folding_; +} + +const std::unordered_map> + &ConstantFoldingPass::GetOpConstantFoldingPerfStatistic() const { + return statistic_of_op_constant_folding_; +} + Status ConstantFoldingPass::Run(ge::NodePtr &node) { GE_CHECK_NOTNULL(node); GELOGD("Begin to run constant folding on node %s", node->GetName().c_str()); @@ -50,6 +62,8 @@ Status ConstantFoldingPass::Run(ge::NodePtr &node) { auto inputs = OpDescUtils::GetInputData(input_nodes); vector outputs; + // Statistic of ge constant folding kernel + uint64_t start_time = GetCurrentTimestap(); auto ret = RunOpKernel(node, inputs, outputs); if (ret != SUCCESS) { auto op_kernel = folding_pass::GetKernelByType(node); @@ -59,7 +73,18 @@ Status ConstantFoldingPass::Run(ge::NodePtr &node) { return SUCCESS; } + // Statistic of op and fe constant folding kernel + start_time = GetCurrentTimestap(); ret = op_kernel->Compute(node_desc, inputs, outputs); + uint64_t cost_time = GetCurrentTimestap() - start_time; + if (statistic_of_ge_constant_folding_.find(node->GetType()) != statistic_of_ge_constant_folding_.end()) { + uint64_t &cnt = statistic_of_ge_constant_folding_[node->GetType()].first; + uint64_t &cur_cost_time = statistic_of_ge_constant_folding_[node->GetType()].second; + cnt++; + cur_cost_time += cost_time; + } else { + statistic_of_ge_constant_folding_[node->GetType()] = std::pair(kStartCallNum, cost_time); + } if (ret != SUCCESS) { if (ret == NOT_CHANGED) { GELOGD("Node %s type %s, compute terminates and exits the constant folding.", node->GetName().c_str(), @@ -70,6 +95,16 @@ Status ConstantFoldingPass::Run(ge::NodePtr &node) { return ret; } GELOGI("Node %s type %s, constant folding compute success.", node->GetName().c_str(), node->GetType().c_str()); + } else { + if (statistic_of_op_constant_folding_.find(node->GetType()) != statistic_of_op_constant_folding_.end()) { + uint64_t &cnt = statistic_of_op_constant_folding_[node->GetType()].first; + uint64_t &cost_time = statistic_of_op_constant_folding_[node->GetType()].second; + cnt++; + cost_time += GetCurrentTimestap() - start_time; + } else { + statistic_of_op_constant_folding_[node->GetType()] = + std::pair(kStartCallNum, GetCurrentTimestap() - start_time); + } } if (outputs.empty()) { diff --git a/src/ge/graph/passes/constant_folding_pass.h b/src/ge/graph/passes/constant_folding_pass.h index 1dcbcdc3..683b66f1 100644 --- a/src/ge/graph/passes/constant_folding_pass.h +++ b/src/ge/graph/passes/constant_folding_pass.h @@ -26,6 +26,12 @@ namespace ge { class ConstantFoldingPass : public FoldingPass { public: Status Run(ge::NodePtr &node) override; + const std::unordered_map> &GetGeConstantFoldingPerfStatistic() const; + const std::unordered_map> &GetOpConstantFoldingPerfStatistic() const; + + private: + std::unordered_map> statistic_of_op_constant_folding_; + std::unordered_map> statistic_of_ge_constant_folding_; }; } // namespace ge diff --git a/src/ge/graph/passes/control_trigger_pass.cc b/src/ge/graph/passes/control_trigger_pass.cc index 77fcbd69..0c00d553 100644 --- a/src/ge/graph/passes/control_trigger_pass.cc +++ b/src/ge/graph/passes/control_trigger_pass.cc @@ -15,16 +15,9 @@ */ #include "graph/passes/control_trigger_pass.h" - #include - #include "common/ge/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/debug/log.h" -#include "framework/common/ge_inner_error_codes.h" -#include "framework/common/types.h" #include "graph/common/omg_util.h" -#include "graph/debug/ge_attr_define.h" #include "graph/utils/type_utils.h" namespace ge { @@ -444,7 +437,7 @@ Status ControlTriggerPass::FindPredInput(const NodePtr &switch_node) { return SUCCESS; } /// -/// @brief Clear Status, uesd for subgraph pass +/// @brief Clear Status, used for subgraph pass /// @return SUCCESS /// Status ControlTriggerPass::ClearStatus() { diff --git a/src/ge/graph/passes/hccl_memcpy_pass.cc b/src/ge/graph/passes/hccl_memcpy_pass.cc index 5325f56e..a9b3484b 100644 --- a/src/ge/graph/passes/hccl_memcpy_pass.cc +++ b/src/ge/graph/passes/hccl_memcpy_pass.cc @@ -28,6 +28,7 @@ namespace { const int32_t kAnchorSize = 1; const int kAnchorNum = 0; +const char *const kInputMutable = "_input_mutable"; } // namespace namespace ge { Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { @@ -35,7 +36,16 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { for (const auto &node : graph->GetDirectNode()) { auto op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(op_desc == nullptr, continue); - if (!NeedInsertMemcpyOp(op_desc)) { + + bool node_input_mutable = false; + if (!AttrUtils::HasAttr(op_desc, kInputMutable)) { + continue; + } + + GE_IF_BOOL_EXEC(!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable), + GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str()); + return FAILED); + if (!node_input_mutable) { continue; } @@ -53,7 +63,7 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { NodePtr src_node = src_out_anchor->GetOwnerNode(); std::string src_type = src_node->GetType(); bool check_src_type = (src_type == CONSTANTOP) || (src_type == DATA) || (src_type == CONSTANT); - if (check_src_type && node->GetType() == HCOMALLREDUCE) { + if (check_src_type) { Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor); if (ret != SUCCESS) { GELOGE(INTERNAL_ERROR, "Failed to modify the connection."); @@ -136,16 +146,6 @@ std::string HcclMemcpyPass::CheckDuplicateName(const std::string &node_name) { } /// -/// @brief Check hcom op -/// @param [in] ge::ConstOpDescPtr op_desc -/// @return bool -/// -bool HcclMemcpyPass::NeedInsertMemcpyOp(const ge::ConstOpDescPtr &op_desc) const { - return (op_desc->GetType() == HCOMALLGATHER || op_desc->GetType() == HCOMALLREDUCE || - op_desc->GetType() == HCOMREDUCESCATTER); -} - -/// /// @brief Modify edge connection /// @param [in] ComputeGraphPtr graph /// @param [in] OutDataAnchorPtr src_out_anchor @@ -182,7 +182,7 @@ Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const return SUCCESS; } /// -/// @brief Clear Status, uesd for subgraph pass +/// @brief Clear Status, used for subgraph pass /// @return SUCCESS /// Status HcclMemcpyPass::ClearStatus() { diff --git a/src/ge/graph/passes/hccl_memcpy_pass.h b/src/ge/graph/passes/hccl_memcpy_pass.h index 9de96fbf..13863bd6 100644 --- a/src/ge/graph/passes/hccl_memcpy_pass.h +++ b/src/ge/graph/passes/hccl_memcpy_pass.h @@ -34,8 +34,6 @@ class HcclMemcpyPass : public GraphPass { std::string CheckDuplicateName(const std::string &node_name); - bool NeedInsertMemcpyOp(const ge::ConstOpDescPtr &op_desc) const; - Status ModifyEdgeConnection(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor, const InDataAnchorPtr &hccl_in_anchor); diff --git a/src/ge/graph/passes/identify_reference_pass.cc b/src/ge/graph/passes/identify_reference_pass.cc deleted file mode 100644 index b4131287..00000000 --- a/src/ge/graph/passes/identify_reference_pass.cc +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/identify_reference_pass.h" - -#include -#include "framework/common/debug/ge_log.h" -#include "graph/debug/ge_attr_define.h" - -namespace ge { -Status IdentifyReferencePass::Run(NodePtr &node) { - if (node == nullptr) { - GELOGE(PARAM_INVALID, "param [node] must not be null."); - return PARAM_INVALID; - } - auto op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - GELOGE(PARAM_INVALID, "OpDesc of param [node] must not be null."); - return PARAM_INVALID; - } - - auto input_names = op_desc->GetAllInputNames(); - auto outputs = op_desc->GetAllOutputName(); - for (auto &output : outputs) { - for (auto &input_name : input_names) { - if (input_name == output.first) { - bool is_ref = true; - if (AttrUtils::SetBool(op_desc, ATTR_NAME_REFERENCE, is_ref)) { - GELOGI("param [node] %s is reference node, set attribute %s to be true.", - node->GetName().c_str(), ATTR_NAME_REFERENCE.c_str()); - return SUCCESS; - } - } - } - } - - return SUCCESS; -} -} // namespace ge diff --git a/src/ge/graph/passes/infershape_pass.cc b/src/ge/graph/passes/infershape_pass.cc index 18767cea..8b44d31b 100644 --- a/src/ge/graph/passes/infershape_pass.cc +++ b/src/ge/graph/passes/infershape_pass.cc @@ -15,7 +15,7 @@ */ #include "graph/passes/infershape_pass.h" - +#include "common/util/error_manager/error_manager.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/shape_refiner.h" @@ -24,6 +24,8 @@ namespace ge { Status InferShapePass::Run(NodePtr &node) { auto ret = ShapeRefiner::InferShapeAndType(node, !OptionExists(kOptimizeAfterSubGraph)); if (ret != GRAPH_SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E35003", {"opname", "err_msg"}, + {node->GetName(), "check your model!"}); GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "infershape failed. node: %s", node->GetName().c_str()); return GE_GRAPH_INFERSHAPE_FAILED; } diff --git a/src/ge/graph/passes/iterator_op_pass.cc b/src/ge/graph/passes/iterator_op_pass.cc index e1d452b1..1d11004d 100644 --- a/src/ge/graph/passes/iterator_op_pass.cc +++ b/src/ge/graph/passes/iterator_op_pass.cc @@ -73,14 +73,14 @@ Status IteratorOpPass::Run(ge::ComputeGraphPtr graph) { GE_IF_BOOL_EXEC(status != SUCCESS, GELOGW("Fail to Get var_desc of NODE_NAME_FLOWCTRL_LOOP_PER_ITER failed."); continue); Status ret; - ret = SetRtContext(rtContext_t(), RT_CTX_NORMAL_MODE); + ret = SetRtContext(graph->GetSessionID(), rtContext_t(), RT_CTX_NORMAL_MODE); // EOS will not be considered if ret is not SUCCESS. - GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGW("Set rt context RT_CTX_GEN_MODE failed."); continue); + GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGW("Set rt context RT_CTX_NORMAL_MODE failed."); continue); status = GetVariableValue(graph->GetSessionID(), ge_tensor_desc, NODE_NAME_FLOWCTRL_LOOP_PER_ITER, &loop_per_iter); - ret = SetRtContext(rtContext_t(), RT_CTX_GEN_MODE); + ret = SetRtContext(graph->GetSessionID(), rtContext_t(), RT_CTX_GEN_MODE); // The following process will be affected if ret is not SUCCESS. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Set rt context RT_CTX_GEN_MODE failed."); return ret); @@ -108,7 +108,7 @@ Status IteratorOpPass::GetVariableValue(uint64_t session_id, const ge::GeTensorD // base_addr uint8_t *var_mem_base = VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM); GE_CHECK_NOTNULL(var_mem_base); - // offset + // offset + logic_base uint8_t *dev_ptr = nullptr; GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, tensor_desc, &dev_ptr), "Get variable %s address failed.", var_name.c_str()); @@ -279,11 +279,11 @@ ge::OpDescPtr IteratorOpPass::CreateMemcpyAsyncOp(const ge::NodePtr &pre_node) { return op_desc; } -Status IteratorOpPass::SetRtContext(rtContext_t rt_context, rtCtxMode_t mode) { +Status IteratorOpPass::SetRtContext(uint64_t session_id, rtContext_t rt_context, rtCtxMode_t mode) { GELOGI("set rt_context %d, device id:%u.", static_cast(mode), ge::GetContext().DeviceId()); GE_CHK_RT_RET(rtCtxCreate(&rt_context, mode, ge::GetContext().DeviceId())); GE_CHK_RT_RET(rtCtxSetCurrent(rt_context)); - RtContextUtil::GetInstance().AddrtContext(rt_context); + RtContextUtil::GetInstance().AddRtContext(session_id, rt_context); return SUCCESS; } } // namespace ge diff --git a/src/ge/graph/passes/iterator_op_pass.h b/src/ge/graph/passes/iterator_op_pass.h index e403020c..78b951e6 100644 --- a/src/ge/graph/passes/iterator_op_pass.h +++ b/src/ge/graph/passes/iterator_op_pass.h @@ -64,7 +64,7 @@ class IteratorOpPass : public GraphPass { /// ge::OpDescPtr CreateMemcpyAsyncOp(const ge::NodePtr &pre_node); - Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode); + Status SetRtContext(uint64_t session_id, rtContext_t rt_context, rtCtxMode_t mode); }; } // namespace ge #endif // GE_GRAPH_PASSES_ITERATOR_OP_PASS_H_ diff --git a/src/ge/graph/passes/link_gen_mask_nodes_pass.cc b/src/ge/graph/passes/link_gen_mask_nodes_pass.cc index ff150a54..63ca68a2 100644 --- a/src/ge/graph/passes/link_gen_mask_nodes_pass.cc +++ b/src/ge/graph/passes/link_gen_mask_nodes_pass.cc @@ -97,9 +97,16 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vectorGetOpDesc() == nullptr) || (node->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { + continue; + } + auto in_data_nodes = node->GetInDataNodes(); if (in_data_nodes.size() > kGenMaskInputIndex) { NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); + if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { + continue; + } if (AreAllInputsConst(gen_mask) && nodes_set.count(gen_mask) == 0) { gen_mask_nodes.emplace_back(gen_mask); nodes_set.emplace(gen_mask); diff --git a/src/ge/graph/passes/mark_same_addr_pass.cc b/src/ge/graph/passes/mark_same_addr_pass.cc new file mode 100644 index 00000000..06d63393 --- /dev/null +++ b/src/ge/graph/passes/mark_same_addr_pass.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/mark_same_addr_pass.h" + +namespace ge { +bool MarkSameAddrPass::IsNextNodeExpected(const ge::NodePtr &cur_node, const vector &next_nodes, + int &out_anchor_idx) { + for (auto out_anchor : cur_node->GetAllOutDataAnchors()) { + if (out_anchor == nullptr) { + continue; + } + for (auto in_anchor : out_anchor->GetPeerInDataAnchors()) { + if (in_anchor == nullptr) { + continue; + } + auto dst_node = in_anchor->GetOwnerNode(); + if (dst_node == nullptr) { + continue; + } + if (std::count(next_nodes.begin(), next_nodes.end(), dst_node->GetType()) > 0) { + out_anchor_idx = out_anchor->GetIdx(); + GELOGD("Current node is %s, next node is %s.", cur_node->GetName().c_str(), dst_node->GetName().c_str()); + return true; + } + } + } + return false; +} + +Status MarkSameAddrPass::Run(ComputeGraphPtr graph) { + GELOGD("MarkSameAddrPass begin."); + GE_CHECK_NOTNULL(graph); + auto parent_node = graph->GetParentNode(); + if (parent_node == nullptr) { + return SUCCESS; + } + auto parent_op_desc = parent_node->GetOpDesc(); + GE_CHECK_NOTNULL(parent_op_desc); + if (!parent_op_desc->HasAttr(ATTR_NAME_IS_UNKNOWN_SHAPE)) { + GELOGD("Graph[%s] do not have unknown shape attr. Parent node is %s", graph->GetName().c_str(), + parent_op_desc->GetName().c_str()); + return SUCCESS; + } + + bool is_unknown_shape = false; + (void)AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape); + if (is_unknown_shape) { + GELOGD("Graph[%s] is unknown shape, do not need to set fixed addr attr. Parent node is %s", + graph->GetName().c_str(), parent_op_desc->GetName().c_str()); + return SUCCESS; + } + + int out_anchor_idx = 0; + for (const ge::NodePtr &node : graph->GetDirectNode()) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + vector next_nodes = {STREAMSWITCH, STREAMSWITCHN, LABELSWITCHBYINDEX}; + if (IsNextNodeExpected(node, next_nodes, out_anchor_idx)) { + string tensor_name = op_desc->GetOutputNameByIndex(out_anchor_idx); + (void)ge::AttrUtils::SetStr(node->GetOpDesc(), ATTR_DYNAMIC_SHAPE_FIXED_ADDR, tensor_name); + (void)ge::AttrUtils::SetInt(node->GetOpDesc(), ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX, out_anchor_idx); + } + } + GELOGD("MarkSameAddrPass end."); + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/passes/mark_same_addr_pass.h b/src/ge/graph/passes/mark_same_addr_pass.h new file mode 100644 index 00000000..ebfcf6b2 --- /dev/null +++ b/src/ge/graph/passes/mark_same_addr_pass.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/graph.h" +#include "inc/graph_pass.h" + +#ifndef GE_GRAPH_PASSES_MARK_SAME_ADDR_PASS_H_ +#define GE_GRAPH_PASSES_MARK_SAME_ADDR_PASS_H_ + +namespace ge { +class MarkSameAddrPass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph); + + private: + bool IsNextNodeExpected(const ge::NodePtr &cur_node, const vector &next_nodes, int &out_anchor_idx); +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_MARK_SAME_ADDR_PASS_H_ diff --git a/src/ge/graph/passes/merge_to_stream_merge_pass.cc b/src/ge/graph/passes/merge_to_stream_merge_pass.cc new file mode 100644 index 00000000..b785ddfa --- /dev/null +++ b/src/ge/graph/passes/merge_to_stream_merge_pass.cc @@ -0,0 +1,234 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/merge_to_stream_merge_pass.h" +#include "common/ge/ge_util.h" +#include "ge/ge_api_types.h" +#include "graph/common/omg_util.h" + +namespace ge { +Status MergeToStreamMergePass::Run(ComputeGraphPtr graph) { + GELOGD("MergeToStreamMergePass Enter"); + + bypass_nodes_.clear(); + for (const auto &node : graph->GetDirectNode()) { + if ((node->GetType() != MERGE) && (node->GetType() != REFMERGE)) { + continue; + } + + OpDescPtr merge_op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(merge_op_desc); + if (merge_op_desc->HasAttr(ATTR_INSERT_BY_MBATCH)) { + GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, node, true), "Merge add memcpy node failed."); + GE_CHK_STATUS_RET(SetStreamLabel(node, node->GetName()), "Set stream label failed"); + } else { + GE_CHK_STATUS_RET(ReplaceMergeNode(graph, node), "Add StreamMerge node failed."); + } + } + + for (const auto &node : bypass_nodes_) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveNodeWithoutRelink(graph, node) == GRAPH_SUCCESS, return FAILED, + "Remove merge node failed."); + } + + GELOGD("MergeToStreamMergePass Leave"); + return SUCCESS; +} + +/// +/// @brief Replace Merge Op +/// @param [in] graph +/// @param [in] merge_node +/// @return Status +/// +Status MergeToStreamMergePass::ReplaceMergeNode(const ComputeGraphPtr &graph, const NodePtr &merge_node) { + OpDescPtr merge_op_desc = merge_node->GetOpDesc(); + GE_CHECK_NOTNULL(merge_op_desc); + + const std::string &node_name = merge_node->GetName(); + GELOGI("Create StreamMerge Op, name=%s.", node_name.c_str()); + OpDescPtr op_desc = MakeShared(node_name, STREAMMERGE); + if (op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc failed, StreamMerge:%s.", node_name.c_str()); + return FAILED; + } + + for (const InDataAnchorPtr &in_anchor : merge_node->GetAllInDataAnchors()) { + GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(merge_op_desc->GetInputDesc(in_anchor->GetIdx())) == GRAPH_SUCCESS, + return FAILED, "Create StreamMerge op: add input desc failed."); + } + + for (const OutDataAnchorPtr &out_anchor : merge_node->GetAllOutDataAnchors()) { + GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(merge_op_desc->GetOutputDesc(out_anchor->GetIdx())) == GRAPH_SUCCESS, + return FAILED, "Create StreamMerge op: add output desc failed."); + } + + NodePtr stream_merge = graph->AddNode(op_desc); + GE_CHK_BOOL_EXEC(stream_merge != nullptr, return FAILED, "Insert StreamMerge node failed."); + GE_CHK_STATUS_RET(MoveEdges(merge_node, stream_merge), "Move edges failed."); + bypass_nodes_.insert(merge_node); + + if (merge_op_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) { + std::string next_iteration_name; + GE_IF_BOOL_EXEC(!AttrUtils::GetStr(merge_op_desc, ATTR_NAME_NEXT_ITERATION, next_iteration_name), + GELOGE(INTERNAL_ERROR, "Get ATTR_NAME_NEXT_ITERATION failed"); + return INTERNAL_ERROR); + GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "Set next iteration failed"); + } + + return AddMemcpyAsyncNodes(graph, stream_merge, false); +} + +/// +/// @brief Add MemcpyAsync Op as StreamMerge in_node +/// @param [in] graph +/// @param [in] node +/// @param [in] multi_batch_flag +/// @return Status +/// +Status MergeToStreamMergePass::AddMemcpyAsyncNodes(const ComputeGraphPtr &graph, const NodePtr &node, + bool multi_batch_flag) { + GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Param of pre node is null."); + for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); + NodePtr in_node = peer_out_anchor->GetOwnerNode(); + const std::string &type = in_node->GetType(); + // For WhileLoop no need memcpy & active for merge. + GE_IF_BOOL_EXEC((type == ENTER) || (type == REFENTER) || (type == NEXTITERATION) || (type == REFNEXTITERATION), + continue); + + const std::string &memcpy_name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()); + NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, memcpy_name, peer_out_anchor, multi_batch_flag); + GE_CHK_BOOL_EXEC(memcpy_node != nullptr, return FAILED, "Create MemcpyAsync node failed."); + GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "MemcpyAsync node remove edge failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, memcpy_node->GetInDataAnchor(0)), + "MemcpyAsync node add edge failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(memcpy_node->GetOutDataAnchor(0), in_data_anchor), + "MemcpyAsync node add edge failed."); + + NodePtr active_node = CreateActiveNode(graph, memcpy_node); + GE_CHK_BOOL_EXEC(active_node != nullptr, return FAILED, "Create StreamActive node failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(active_node->GetOutControlAnchor(), node->GetInControlAnchor()), + "StreamActive add ctrl edge failed."); + if (SetActiveLabelList(active_node, {node->GetName()}) != SUCCESS) { + GELOGE(FAILED, "SetActiveLabelList for node %s failed.", active_node->GetName().c_str()); + return FAILED; + } + } + + return SUCCESS; +} + +/// +/// @brief Add MemcpyAsync Node +/// @param [in] graph +/// @param [in] name +/// @param [in] out_data_anchor +/// @param [in] multi_batch_flag +/// @return ge::NodePtr +/// +NodePtr MergeToStreamMergePass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, + const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag) { + GE_CHK_BOOL_EXEC(out_data_anchor != nullptr, return nullptr, "Param of input node is null."); + OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "OpDesc of pre node is invalid."); + + const std::string &memcpy_type = multi_batch_flag ? MEMCPYADDRASYNC : MEMCPYASYNC; + const std::string &node_name = name + "_" + memcpy_type; + GELOGI("Create MemcpyAsync op:%s.", node_name.c_str()); + OpDescPtr op_desc = MakeShared(node_name, memcpy_type); + if (op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc failed, MemcpyAsync:%s.", node_name.c_str()); + return nullptr; + } + + GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) == GRAPH_SUCCESS, + return nullptr, "Create MemcpyAsync op: add input desc failed."); + GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) == GRAPH_SUCCESS, + return nullptr, "Create MemcpyAsync op: add output desc failed."); + + return graph->AddNode(op_desc); +} + +/// +/// @brief Create Active Op +/// @param [in] graph +/// @param [in] node +/// @return ge::NodePtr +/// +NodePtr MergeToStreamMergePass::CreateActiveNode(const ComputeGraphPtr &graph, const NodePtr &node) { + const std::string &node_name = node->GetName() + "_" + STREAMACTIVE; + GELOGI("Create StreamActive op:%s.", node_name.c_str()); + OpDescPtr op_desc = MakeShared(node_name, STREAMACTIVE); + if (op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc failed, StreamActive:%s.", node_name.c_str()); + return nullptr; + } + + NodePtr active_node = graph->AddNode(op_desc); + GE_CHK_BOOL_EXEC(active_node != nullptr, return nullptr, "Create StreamActive node failed."); + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(node->GetOutControlAnchor(), active_node->GetInControlAnchor()) != SUCCESS, + GELOGE(INTERNAL_ERROR, "add edge failed"); + return nullptr); + GE_IF_BOOL_EXEC(SetSwitchBranchNodeLabel(active_node, node_name) != SUCCESS, + GELOGE(INTERNAL_ERROR, "set switch branch node label failed"); + return nullptr); + + return active_node; +} + +/// +/// @brief move edges from old_node to new_node +/// @param [in] old_node +/// @param [in] new_node +/// @return Status +/// +Status MergeToStreamMergePass::MoveEdges(const NodePtr &old_node, const NodePtr &new_node) { + for (const InDataAnchorPtr &in_data_anchor : old_node->GetAllInDataAnchors()) { + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); + + GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "Merge remove in data edge failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, new_node->GetInDataAnchor(in_data_anchor->GetIdx())), + "StreamMerge add in data edge failed."); + } + + for (const OutDataAnchorPtr &out_data_anchor : old_node->GetAllOutDataAnchors()) { + for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor), "Merge remove out data edge failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutDataAnchor(out_data_anchor->GetIdx()), peer_in_anchor), + "StreamMerge add out data edge failed."); + } + } + + for (const NodePtr &in_ctrl_node : old_node->GetInControlNodes()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), old_node->GetInControlAnchor()), + "Merge remove in ctrl edge failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), new_node->GetInControlAnchor()), + "StreamMerge add in ctrl edge failed."); + } + + for (const NodePtr &out_ctrl_node : old_node->GetOutControlNodes()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(old_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()), + "Merge remove out ctrl edge failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()), + "StreamMerge add out ctrl edge failed."); + } + + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/passes/merge_to_stream_merge_pass.h b/src/ge/graph/passes/merge_to_stream_merge_pass.h new file mode 100644 index 00000000..9f713989 --- /dev/null +++ b/src/ge/graph/passes/merge_to_stream_merge_pass.h @@ -0,0 +1,75 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_MERGE_TO_STREAM_MERGE_PASS_H_ +#define GE_GRAPH_PASSES_MERGE_TO_STREAM_MERGE_PASS_H_ + +#include "inc/graph_pass.h" + +namespace ge { +class MergeToStreamMergePass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph); + + private: + /// + /// @brief Replace Merge Op + /// @param [in] graph + /// @param [in] merge_node + /// @return Status + /// + Status ReplaceMergeNode(const ComputeGraphPtr &graph, const NodePtr &merge_node); + + /// + /// @brief Add MemcpyAsync Op as StreamMerge in_node + /// @param [in] graph + /// @param [in] node + /// @param [in] multi_batch_flag + /// @return Status + /// + Status AddMemcpyAsyncNodes(const ComputeGraphPtr &graph, const NodePtr &node, bool multi_batch_flag); + + /// + /// @brief Add MemcpyAsync Node + /// @param [in] graph + /// @param [in] name + /// @param [in] out_data_anchor + /// @param [in] multi_batch_flag + /// @return ge::NodePtr + /// + NodePtr CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, + const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag); + + /// + /// @brief Create Active Op + /// @param [in] graph + /// @param [in] node + /// @return ge::NodePtr + /// + NodePtr CreateActiveNode(const ComputeGraphPtr &graph, const NodePtr &node); + + /// + /// @brief move edges from old_node to new_node + /// @param [in] old_node + /// @param [in] new_node + /// @return Status + /// + Status MoveEdges(const NodePtr &old_node, const NodePtr &new_node); + + std::set bypass_nodes_; +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_MERGE_TO_STREAM_MERGE_PASS_H_ diff --git a/src/ge/graph/passes/multi_batch_pass.cc b/src/ge/graph/passes/multi_batch_pass.cc index bb0050be..26190168 100644 --- a/src/ge/graph/passes/multi_batch_pass.cc +++ b/src/ge/graph/passes/multi_batch_pass.cc @@ -32,7 +32,7 @@ namespace ge { Status MultiBatchPass::Run(ComputeGraphPtr graph) { GELOGD("MultiBatchPass Enter"); - GE_CHECK_NOTNULL(graph); + if (graph->GetParentGraph() != nullptr) { GELOGI("Subgraph %s skip the MultiBatchPass.", graph->GetName().c_str()); return SUCCESS; @@ -44,26 +44,26 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) { return SUCCESS; } if (ret != SUCCESS) { - GELOGE(FAILED, "FindPredValue fail."); + GELOGE(FAILED, "FindPredValue failed."); return FAILED; } std::vector> batch_shape; if (!CheckSwitchN(batch_shape)) { - GELOGE(FAILED, "CheckSwitchN fail."); + GELOGE(FAILED, "CheckSwitchN failed."); return FAILED; } FindSwitchOutNodes(batch_shape.size()); if (ReplaceSwitchN(graph, pred_value, batch_shape) != SUCCESS) { - GELOGE(FAILED, "Replace SwitchN nodes fail."); + GELOGE(FAILED, "Replace SwitchN nodes failed."); return FAILED; } - for (NodePtr &node : bypass_nodes_) { - if (graph->RemoveNode(node) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Remove SwitchN nodes %s fail.", node->GetName().c_str()); + for (const NodePtr &node : bypass_nodes_) { + if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Remove SwitchN nodes %s failed.", node->GetName().c_str()); return FAILED; } } @@ -79,19 +79,19 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) { /// @return Status /// Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchorPtr &pred_value) { - for (NodePtr &node : graph->GetDirectNode()) { + for (const NodePtr &node : graph->GetDirectNode()) { if (node->GetType() != SWITCHN) { continue; } InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); if (in_data_anchor == nullptr) { - GELOGE(FAILED, "FindPredInput fail, in_data_anchor is null, node:%s.", node->GetName().c_str()); + GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str()); return FAILED; } OutDataAnchorPtr pred_input = in_data_anchor->GetPeerOutAnchor(); if (pred_input == nullptr) { - GELOGE(FAILED, "FindPredInput fail, pred_input is null, node:%s.", node->GetName().c_str()); + GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str()); return FAILED; } @@ -110,7 +110,7 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor } if (pred_value == nullptr) { - GELOGE(FAILED, "FindPredInput fail, pred_value is null."); + GELOGE(FAILED, "FindPredInput failed, pred_value is null."); return FAILED; } @@ -126,7 +126,7 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor bool MultiBatchPass::CheckSwitchN(std::vector> &batch_shape) { // Check if output_num of different SwitchN is same uint32_t batch_num = 0; - for (NodePtr &node : switch_n_nodes_) { + for (const NodePtr &node : switch_n_nodes_) { uint32_t tmp_num = node->GetAllOutDataAnchorsSize(); if (batch_num == 0) { batch_num = tmp_num; @@ -140,21 +140,21 @@ bool MultiBatchPass::CheckSwitchN(std::vector> &batch_shape std::vector> idx_batch_shape; for (uint32_t i = 0; i < batch_num; i++) { idx_batch_shape.clear(); - for (NodePtr &node : switch_n_nodes_) { + for (const NodePtr &node : switch_n_nodes_) { std::vector output_dims; OpDescPtr op_desc = node->GetOpDesc(); if (op_desc == nullptr) { - GELOGE(FAILED, "CheckDims fail, get op_desc fail, node: %s.", node->GetName().c_str()); + GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str()); return false; } if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) { - GELOGE(FAILED, "CheckDims fail, get attr ATTR_NAME_SWITCHN_PRED_VALUE fail, batch_index=%u.", i); + GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i); return false; } idx_batch_shape.emplace_back(output_dims); } if (!CheckDims(idx_batch_shape)) { - GELOGE(FAILED, "CheckDims fail, batch_index=%u.", i); + GELOGE(FAILED, "CheckDims failed, batch_index=%u.", i); return false; } @@ -187,11 +187,11 @@ void MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) { std::vector output_nodes; for (uint32_t i = 0; i < batch_num; i++) { output_nodes.clear(); - for (NodePtr &node : switch_n_nodes_) { + for (const NodePtr &node : switch_n_nodes_) { // idx is promised to be valid OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(i); GE_CHECK_NOTNULL_JUST_RETURN(out_data_anchor); - for (InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { output_nodes.emplace_back(peer_in_anchor->GetOwnerNode()); } } @@ -208,33 +208,33 @@ void MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) { /// @param [in] batch_shape /// @return Status /// -Status MultiBatchPass::ReplaceSwitchN(ComputeGraphPtr &graph, OutDataAnchorPtr &pred_value, +Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value, const std::vector> &batch_shape) { NodePtr pred_value_node = pred_value->GetOwnerNode(); // Create SwitchCase node - const std::string switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN; + const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN; NodePtr switch_case = CreateSwitchCaseNode(graph, switch_case_name, pred_value, batch_shape); if (switch_case == nullptr) { - GELOGE(FAILED, "CreateSwitchCaseNode %s fail.", switch_case_name.c_str()); + GELOGE(FAILED, "CreateSwitchCaseNode %s failed.", switch_case_name.c_str()); return FAILED; } - for (NodePtr &switch_n_node : switch_n_nodes_) { + for (const NodePtr &switch_n_node : switch_n_nodes_) { if (BypassSwitchN(switch_n_node, switch_case) != SUCCESS) { - GELOGE(FAILED, "Bypass SwitchN %s fail.", switch_case_name.c_str()); + GELOGE(FAILED, "Bypass SwitchN %s failed.", switch_case_name.c_str()); return FAILED; } } // Add switchCase input edge if (GraphUtils::AddEdge(pred_value, switch_case->GetInDataAnchor(0)) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Add SwitchCase in_data_edge fail, %s->%s.", pred_value_node->GetName().c_str(), + GELOGE(FAILED, "Add SwitchCase in_data_edge failed, %s->%s.", pred_value_node->GetName().c_str(), switch_case->GetName().c_str()); return FAILED; } if (AttachLabel(switch_case) != SUCCESS) { - GELOGE(FAILED, "AttachLabel fail."); + GELOGE(FAILED, "AttachLabel failed."); return FAILED; } @@ -248,7 +248,7 @@ Status MultiBatchPass::ReplaceSwitchN(ComputeGraphPtr &graph, OutDataAnchorPtr & /// bool MultiBatchPass::CheckDims(const std::vector> &output_shape) const { if (output_shape.empty()) { - GELOGE(FAILED, "CheckDims fail: output_shape is empty."); + GELOGE(FAILED, "CheckDims failed: output_shape is empty."); return false; } @@ -257,7 +257,7 @@ bool MultiBatchPass::CheckDims(const std::vector> &output_s for (size_t i = 1; i < num; i++) { size_t tmp_dim_num = output_shape[i].size(); if (dim_num != tmp_dim_num) { - GELOGE(FAILED, "CheckDims fail: dim_num not equal, output_0:%zu, output_%zu:%zu.", dim_num, i, tmp_dim_num); + GELOGE(FAILED, "CheckDims failed: dim_num not equal, output_0:%zu, output_%zu:%zu.", dim_num, i, tmp_dim_num); return false; } } @@ -271,7 +271,7 @@ bool MultiBatchPass::CheckDims(const std::vector> &output_s for (size_t j = 1; j < num; j++) { int64_t tmp_dim_value = output_shape[j][i]; if (dim_value != tmp_dim_value) { - GELOGE(FAILED, "CheckDims fail: dim_value not equal, dim_index=%zu, dim_value_0:%ld, dim_value_%zu:%ld.", i, + GELOGE(FAILED, "CheckDims failed: dim_value not equal, dim_index=%zu, dim_value_0:%ld, dim_value_%zu:%ld.", i, dim_value, j, tmp_dim_value); return false; } @@ -289,41 +289,41 @@ bool MultiBatchPass::CheckDims(const std::vector> &output_s /// @param [in] batch_shape /// @return ge::NodePtr /// -NodePtr MultiBatchPass::CreateSwitchCaseNode(ComputeGraphPtr &graph, const std::string &name, +NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name, const OutDataAnchorPtr &pred_value, const std::vector> &batch_shape) { OpDescPtr op_desc = MakeShared(name, STREAMSWITCHN); if (op_desc == nullptr) { - GELOGE(FAILED, "Create op_desc fail, StreamSwitchN:%s.", name.c_str()); + GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str()); return nullptr; } GELOGI("Create StreamSwitchN op:%s.", name.c_str()); OpDescPtr pred_desc = pred_value->GetOwnerNode()->GetOpDesc(); if (pred_desc == nullptr) { - GELOGE(FAILED, "Get pred_desc fail, StreamSwitchN:%s.", name.c_str()); + GELOGE(FAILED, "Get pred_desc failed, StreamSwitchN:%s.", name.c_str()); return nullptr; } if (op_desc->AddInputDesc(pred_desc->GetOutputDesc(pred_value->GetIdx())) != GRAPH_SUCCESS) { - GELOGE(FAILED, "AddInputDesc fail, StreamSwitchN:%s.", name.c_str()); + GELOGE(FAILED, "AddInputDesc failed, StreamSwitchN:%s.", name.c_str()); return nullptr; } NodePtr switch_case_node = graph->AddNode(op_desc); if (switch_case_node == nullptr) { - GELOGE(FAILED, "Create node fail, StreamSwitchN:%s.", name.c_str()); + GELOGE(FAILED, "Create node failed, StreamSwitchN:%s.", name.c_str()); return nullptr; } uint32_t batch_num = static_cast(batch_shape.size()); if (!AttrUtils::SetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) { - GELOGE(FAILED, "set attr ATTR_NAME_BATCH_NUM fail, StreamSwitchN:%s.", name.c_str()); + GELOGE(FAILED, "set attr ATTR_NAME_BATCH_NUM failed, StreamSwitchN:%s.", name.c_str()); return nullptr; } for (uint32_t i = 0; i < batch_num; i++) { - const std::string attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i); + const std::string &attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i); if (!AttrUtils::SetListInt(op_desc, attr_name, batch_shape[i])) { - GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE fail, StreamSwitchN:%s.", name.c_str()); + GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str()); return nullptr; } } @@ -337,43 +337,43 @@ NodePtr MultiBatchPass::CreateSwitchCaseNode(ComputeGraphPtr &graph, const std:: /// @param [in] switch_case /// @return Status /// -Status MultiBatchPass::BypassSwitchN(NodePtr &switch_n_node, NodePtr &switch_case) { +Status MultiBatchPass::BypassSwitchN(const NodePtr &switch_n_node, const NodePtr &switch_case) { InDataAnchorPtr in_data_anchor = switch_n_node->GetInDataAnchor(SWITCH_DATA_INPUT); if (in_data_anchor == nullptr) { - GELOGE(FAILED, "Check in_data_anchor fail, SwitchN:%s.", switch_n_node->GetName().c_str()); + GELOGE(FAILED, "Check in_data_anchor failed, SwitchN:%s.", switch_n_node->GetName().c_str()); return FAILED; } OutDataAnchorPtr peer_data_anchor = in_data_anchor->GetPeerOutAnchor(); if (peer_data_anchor == nullptr) { - GELOGE(FAILED, "Check peer_data_anchor fail, SwitchN:%s.", switch_n_node->GetName().c_str()); + GELOGE(FAILED, "Check peer_data_anchor failed, SwitchN:%s.", switch_n_node->GetName().c_str()); return FAILED; } NodePtr data_input = peer_data_anchor->GetOwnerNode(); // Remove SwitchN data input if (GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Remove SwitchN in_data_edge fail, %s->%s.", data_input->GetName().c_str(), + GELOGE(FAILED, "Remove SwitchN in_data_edge failed, %s->%s.", data_input->GetName().c_str(), switch_n_node->GetName().c_str()); return FAILED; } if (GraphUtils::AddEdge(data_input->GetOutControlAnchor(), switch_case->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Add StreamSwitchN in_control_edge fail, %s->%s.", data_input->GetName().c_str(), + GELOGE(FAILED, "Add StreamSwitchN in_control_edge failed, %s->%s.", data_input->GetName().c_str(), switch_case->GetName().c_str()); return FAILED; } // Add SwitchCase control output - for (OutDataAnchorPtr &out_data_anchor : switch_n_node->GetAllOutDataAnchors()) { - for (InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + for (const OutDataAnchorPtr &out_data_anchor : switch_n_node->GetAllOutDataAnchors()) { + for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { NodePtr data_output = peer_in_anchor->GetOwnerNode(); if ((GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor) != GRAPH_SUCCESS) || (GraphUtils::AddEdge(peer_data_anchor, peer_in_anchor) != GRAPH_SUCCESS)) { - GELOGE(FAILED, "Bypass SwitchN data_edge fail, %s->%s->%s.", data_input->GetName().c_str(), + GELOGE(FAILED, "Bypass SwitchN data_edge failed, %s->%s->%s.", data_input->GetName().c_str(), switch_n_node->GetName().c_str(), data_output->GetName().c_str()); return FAILED; } if (GraphUtils::AddEdge(switch_case->GetOutControlAnchor(), data_output->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Add SwitchCase out_control_edge fail, %s->%s.", switch_case->GetName().c_str(), + GELOGE(FAILED, "Add SwitchCase out_control_edge failed, %s->%s.", switch_case->GetName().c_str(), data_output->GetName().c_str()); return FAILED; } @@ -390,17 +390,17 @@ Status MultiBatchPass::BypassSwitchN(NodePtr &switch_n_node, NodePtr &switch_cas /// @param [in] switch_case_node /// @return Status /// -Status MultiBatchPass::AttachLabel(NodePtr &switch_case_node) { +Status MultiBatchPass::AttachLabel(const NodePtr &switch_case_node) { std::vector stream_label_list; for (uint32_t i = 0; i < static_cast(batch_head_nodes_.size()); i++) { if (AttachBatchLabel(i) != SUCCESS) { - GELOGE(FAILED, "AttachBatchLabel fail, batch_idx=%u", i); + GELOGE(FAILED, "AttachBatchLabel failed, batch_idx=%u", i); return FAILED; } - const std::string stream_label = "stream_label_batch_" + std::to_string(i); + const std::string &stream_label = "stream_label_batch_" + std::to_string(i); if (AttachStreamLabel(i, stream_label) != SUCCESS) { - GELOGE(FAILED, "AttachStreamLabel fail, stream_label=%s", stream_label.c_str()); + GELOGE(FAILED, "AttachStreamLabel failed, stream_label=%s", stream_label.c_str()); return FAILED; } stream_label_list.emplace_back(stream_label); @@ -416,11 +416,11 @@ Status MultiBatchPass::AttachLabel(NodePtr &switch_case_node) { /// Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) { std::stack nodes; - for (auto &node : batch_head_nodes_[batch_idx]) { + for (const auto &node : batch_head_nodes_[batch_idx]) { nodes.push(node); } - const std::string batch_label = "Batch_" + std::to_string(batch_idx); + const std::string &batch_label = "Batch_" + std::to_string(batch_idx); std::unordered_set handled_nodes; while (!nodes.empty()) { NodePtr cur_node = nodes.top(); @@ -434,7 +434,7 @@ Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) { if (cur_desc->HasAttr(ATTR_NAME_BATCH_LABEL)) { std::string tmp_label; if (!AttrUtils::GetStr(cur_desc, ATTR_NAME_BATCH_LABEL, tmp_label)) { - GELOGE(FAILED, "get attr ATTR_NAME_BATCH_LABEL fail, node: %s.", cur_desc->GetName().c_str()); + GELOGE(FAILED, "get attr ATTR_NAME_BATCH_LABEL failed, node: %s.", cur_desc->GetName().c_str()); return FAILED; } if (tmp_label != batch_label) { @@ -445,14 +445,14 @@ Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) { } GELOGD("Attach batch_label %s to node %s.", batch_label.c_str(), cur_desc->GetName().c_str()); if (!AttrUtils::SetStr(cur_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { - GELOGE(FAILED, "set attr ATTR_NAME_BATCH_LABEL fail, node:%s.", cur_desc->GetName().c_str()); + GELOGE(FAILED, "set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", cur_desc->GetName().c_str()); return FAILED; } - for (auto &out_node : cur_node->GetOutAllNodes()) { + for (const auto &out_node : cur_node->GetOutAllNodes()) { OpDescPtr op_desc = out_node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - const std::string type = op_desc->GetType(); + const std::string &type = op_desc->GetType(); if ((type == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { continue; } @@ -476,7 +476,7 @@ Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) { /// Status MultiBatchPass::AttachStreamLabel(uint32_t batch_idx, const std::string &stream_label) { std::stack nodes; - for (auto &node : batch_head_nodes_[batch_idx]) { + for (const auto &node : batch_head_nodes_[batch_idx]) { nodes.push(node); } @@ -493,11 +493,11 @@ Status MultiBatchPass::AttachStreamLabel(uint32_t batch_idx, const std::string & GELOGD("Attach stream_label %s to node %s.", stream_label.c_str(), cur_desc->GetName().c_str()); if (SetStreamLabel(cur_node, stream_label) != SUCCESS) { - GELOGE(FAILED, "SetStreamLabel fail, node:%s.", cur_node->GetName().c_str()); + GELOGE(FAILED, "Set stream_label failed, node:%s.", cur_node->GetName().c_str()); return FAILED; } - for (auto &out_node : cur_node->GetOutAllNodes()) { + for (const auto &out_node : cur_node->GetOutAllNodes()) { nodes.push(out_node); } diff --git a/src/ge/graph/passes/multi_batch_pass.h b/src/ge/graph/passes/multi_batch_pass.h index 6e3f5e46..2e83262c 100644 --- a/src/ge/graph/passes/multi_batch_pass.h +++ b/src/ge/graph/passes/multi_batch_pass.h @@ -31,14 +31,15 @@ class MultiBatchPass : public GraphPass { Status FindPredValue(const ComputeGraphPtr &graph, OutDataAnchorPtr &pred_value); bool CheckSwitchN(std::vector> &batch_shape); void FindSwitchOutNodes(uint32_t batch_num); - Status ReplaceSwitchN(ComputeGraphPtr &graph, OutDataAnchorPtr &pred_value, + Status ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value, const std::vector> &batch_shape); bool CheckDims(const std::vector> &output_shape) const; - NodePtr CreateSwitchCaseNode(ComputeGraphPtr &graph, const std::string &name, const OutDataAnchorPtr &pred_value, + NodePtr CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name, + const OutDataAnchorPtr &pred_value, const std::vector> &batch_shape); - Status BypassSwitchN(NodePtr &switch_n_node, NodePtr &switch_case_node); - Status AttachLabel(NodePtr &switch_case_node); + Status BypassSwitchN(const NodePtr &switch_n_node, const NodePtr &switch_case_node); + Status AttachLabel(const NodePtr &switch_case_node); Status AttachBatchLabel(uint32_t batch_idx); Status AttachStreamLabel(uint32_t batch_idx, const std::string &stream_label); diff --git a/src/ge/graph/passes/next_iteration_pass.cc b/src/ge/graph/passes/next_iteration_pass.cc index 138ad86b..c664ac53 100644 --- a/src/ge/graph/passes/next_iteration_pass.cc +++ b/src/ge/graph/passes/next_iteration_pass.cc @@ -16,19 +16,8 @@ #include "graph/passes/next_iteration_pass.h" -#include -#include -#include -#include -#include - #include "common/ge/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/debug/log.h" -#include "framework/common/ge_inner_error_codes.h" -#include "framework/common/types.h" #include "graph/common/omg_util.h" -#include "graph/debug/ge_attr_define.h" namespace ge { Status NextIterationPass::Run(ComputeGraphPtr graph) { @@ -41,24 +30,24 @@ Status NextIterationPass::Run(ComputeGraphPtr graph) { if ((type != ENTER) && (type != REFENTER)) { continue; } - if (HandleEnterNode(node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "HandleEnterNode for node %s fail.", node->GetName().c_str()); + if (GroupEnterNode(node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Group enter_node %s failed.", node->GetName().c_str()); return INTERNAL_ERROR; } } if (FindWhileGroups() != SUCCESS) { - GELOGE(INTERNAL_ERROR, "FindWhileGroups fail"); + GELOGE(INTERNAL_ERROR, "Find while groups failed."); return INTERNAL_ERROR; } if (!VerifyWhileGroup()) { - GELOGE(INTERNAL_ERROR, "VerifyWhileGroup fail"); + GELOGE(INTERNAL_ERROR, "Verify while groups failed."); return INTERNAL_ERROR; } if (HandleWhileGroup(graph) != SUCCESS) { - GELOGE(FAILED, "HandleWhileGroup fail"); + GELOGE(FAILED, "Handle while groups failed."); return FAILED; } @@ -67,16 +56,16 @@ Status NextIterationPass::Run(ComputeGraphPtr graph) { } /// -/// @brief Handle Enter node +/// @brief Group Enter node /// @param [in] enter_node /// @return Status /// -Status NextIterationPass::HandleEnterNode(const NodePtr &enter_node) { +Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) { OpDescPtr enter_desc = enter_node->GetOpDesc(); GE_CHECK_NOTNULL(enter_desc); std::string frame_name; if (!ge::AttrUtils::GetStr(enter_desc, ENTER_ATTR_FRAME_NAME, frame_name) || frame_name.empty()) { - GELOGE(FAILED, "Get attr ENTER_ATTR_FRAME_NAME fail, node: %s", enter_desc->GetName().c_str()); + GELOGE(FAILED, "Get attr ENTER_ATTR_FRAME_NAME failed, node: %s", enter_desc->GetName().c_str()); return FAILED; } @@ -84,7 +73,7 @@ Status NextIterationPass::HandleEnterNode(const NodePtr &enter_node) { if (iter == loop_group_map_.end()) { LoopCondGroupPtr loop_group = MakeShared(); if (loop_group == nullptr) { - GELOGE(FAILED, "MakeShared for LoopCondGroup fail."); + GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); return FAILED; } loop_group->enter_nodes.emplace_back(enter_node); @@ -101,30 +90,30 @@ Status NextIterationPass::HandleEnterNode(const NodePtr &enter_node) { /// @return Status /// Status NextIterationPass::FindWhileGroups() { - for (auto &loop_group_iter : loop_group_map_) { - const std::string frame_name = loop_group_iter.first; - for (auto &enter_node : loop_group_iter.second->enter_nodes) { - for (auto &out_node : enter_node->GetOutAllNodes()) { - const std::string type = out_node->GetType(); + for (const auto &loop_group_iter : loop_group_map_) { + const std::string &frame_name = loop_group_iter.first; + for (const auto &enter_node : loop_group_iter.second->enter_nodes) { + for (const auto &out_node : enter_node->GetOutAllNodes()) { + const std::string &type = out_node->GetType(); if ((type != MERGE) && (type != REFMERGE)) { continue; } NodePtr next_node = nullptr; if (FindTargetNode(out_node, NEXTITERATION, true, next_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get NextIteration node fail, frame_name: %s.", frame_name.c_str()); + GELOGE(INTERNAL_ERROR, "Get NextIteration node failed, frame_name: %s.", frame_name.c_str()); return INTERNAL_ERROR; } NodePtr switch_node = nullptr; if (FindTargetNode(out_node, SWITCH, false, switch_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get Switch node fail, frame_name: %s.", frame_name.c_str()); + GELOGE(INTERNAL_ERROR, "Get Switch node failed, frame_name: %s.", frame_name.c_str()); return INTERNAL_ERROR; } NodePtr loop_cond = nullptr; if (FindTargetNode(switch_node, LOOPCOND, true, loop_cond) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get LoopCond node fail, frame_name: %s.", frame_name.c_str()); + GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str()); return INTERNAL_ERROR; } @@ -148,21 +137,21 @@ Status NextIterationPass::FindWhileGroups() { /// bool NextIterationPass::VerifyWhileGroup() { // map - for (auto &loop_group_iter : loop_group_map_) { - const std::string frame_name = loop_group_iter.first; + for (const auto &loop_group_iter : loop_group_map_) { + const std::string &frame_name = loop_group_iter.first; if (frame_name.empty()) { - GELOGE(INTERNAL_ERROR, "VerifyWhileGroup fail, frame_name is empty."); + GELOGE(INTERNAL_ERROR, "Verify while group failed, frame_name is empty."); return false; } if (loop_group_iter.second->loop_cond == nullptr) { - GELOGE(INTERNAL_ERROR, "VerifyWhileGroup fail, LoopCond is null, frame_name: %s.", frame_name.c_str()); + GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str()); return false; } - for (auto &pair_iter : loop_group_iter.second->merge_next_pairs) { + for (const auto &pair_iter : loop_group_iter.second->merge_next_pairs) { if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) { - GELOGE(INTERNAL_ERROR, "VerifyWhileGroup fail, merge_node/next_node is null, frame_name: %s.", + GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.", frame_name.c_str()); return false; } @@ -178,51 +167,51 @@ bool NextIterationPass::VerifyWhileGroup() { /// @return Status /// Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { - for (auto &loop_cond_iter : loop_group_map_) { - std::string cond_name = loop_cond_iter.second->loop_cond->GetName(); - GELOGI("HandleWhileGroup, LoopCond node: %s.", cond_name.c_str()); + for (const auto &loop_cond_iter : loop_group_map_) { + const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName(); + GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); - // Create Active node, Enter->Active->Merge, NextItaration->Active->Merge + // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE); NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE); if ((enter_active == nullptr) || (next_active == nullptr)) { - GELOGE(INTERNAL_ERROR, "CreateActiveNode fail, cond_name: %s.", cond_name.c_str()); + GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str()); return INTERNAL_ERROR; } - for (auto &enter_node : loop_cond_iter.second->enter_nodes) { + for (const auto &enter_node : loop_cond_iter.second->enter_nodes) { // Enter --> Active if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add control edge fail"); + GELOGE(INTERNAL_ERROR, "Add control edge failed."); return INTERNAL_ERROR; } } - for (auto &pair : loop_cond_iter.second->merge_next_pairs) { + for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { NodePtr merge_node = pair.first; NodePtr next_node = pair.second; // Active --> Merge if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add control edge fail"); + GELOGE(INTERNAL_ERROR, "Add control edge failed."); return INTERNAL_ERROR; } // NextIteration --> Active if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add control edge fail"); + GELOGE(INTERNAL_ERROR, "Add control edge failed."); return INTERNAL_ERROR; } // break link between NextIteration and Merge if (BreakNextIteration(next_node, merge_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "BreakNextIteration failed"); + GELOGE(INTERNAL_ERROR, "Break NextIteration failed"); return INTERNAL_ERROR; } } if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || (SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) { - GELOGE(INTERNAL_ERROR, "SetActiveLabelList failed"); + GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed."); return INTERNAL_ERROR; } } @@ -245,12 +234,12 @@ NodePtr NextIterationPass::CreateActiveNode(ComputeGraphPtr &graph, const std::s GELOGI("Create StreamActive op:%s.", op_desc->GetName().c_str()); NodePtr active_node = graph->AddNode(op_desc); if (active_node == nullptr) { - GELOGE(INTERNAL_ERROR, "Create node[%s] fail.", name.c_str()); + GELOGE(INTERNAL_ERROR, "Create node[%s] failed.", name.c_str()); return nullptr; } if (SetSwitchBranchNodeLabel(active_node, name) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "SetSwitchBranchNodeLabel for node: %s failed.", active_node->GetName().c_str()); + GELOGE(INTERNAL_ERROR, "Set attr SWITCH_BRANCH_NODE_LABEL for node: %s failed.", active_node->GetName().c_str()); return nullptr; } @@ -268,18 +257,18 @@ Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr & GELOGE(PARAM_INVALID, "merge node or next node is null."); return PARAM_INVALID; } - for (auto &in_anchor : merge_node->GetAllInDataAnchors()) { + for (const auto &in_anchor : merge_node->GetAllInDataAnchors()) { OutDataAnchorPtr out_anchor = in_anchor->GetPeerOutAnchor(); if ((out_anchor == nullptr) || (out_anchor->GetOwnerNode() != next_node)) { continue; } if (GraphUtils::RemoveEdge(out_anchor, in_anchor) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Remove data edge fail, %s->%s.", next_node->GetName().c_str(), + GELOGE(INTERNAL_ERROR, "Remove data edge failed, %s->%s.", next_node->GetName().c_str(), merge_node->GetName().c_str()); return INTERNAL_ERROR; } if (SetNextIteration(merge_node, next_node->GetName()) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "SetNextIteration for node %s fail.", merge_node->GetName().c_str()); + GELOGE(INTERNAL_ERROR, "Set attr NEXT_ITERATION for node %s failed.", merge_node->GetName().c_str()); return INTERNAL_ERROR; } } @@ -302,16 +291,16 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string } std::vector nodes; if (is_input) { - for (auto &tmp_node : node->GetInDataNodes()) { + for (const auto &tmp_node : node->GetInDataNodes()) { nodes.emplace_back(tmp_node); } } else { - for (auto &tmp_node : node->GetOutDataNodes()) { + for (const auto &tmp_node : node->GetOutDataNodes()) { nodes.emplace_back(tmp_node); } } - for (auto &tmp_node : nodes) { + for (const auto &tmp_node : nodes) { const std::string type = tmp_node->GetType(); if ((target_type == LOOPCOND) && (type == target_type)) { target_node = tmp_node; @@ -323,13 +312,14 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string } if (target_node == nullptr) { - GELOGE(INTERNAL_ERROR, "Find node %s fail", target_type.c_str()); + GELOGE(INTERNAL_ERROR, "Find node %s failed.", target_type.c_str()); return INTERNAL_ERROR; } return SUCCESS; } + /// -/// @brief Clear Status, uesd for subgraph pass +/// @brief Clear Status, used for subgraph pass /// @return SUCCESS /// Status NextIterationPass::ClearStatus() { diff --git a/src/ge/graph/passes/next_iteration_pass.h b/src/ge/graph/passes/next_iteration_pass.h index 4bbced4f..4cdf4b51 100644 --- a/src/ge/graph/passes/next_iteration_pass.h +++ b/src/ge/graph/passes/next_iteration_pass.h @@ -17,12 +17,6 @@ #ifndef GE_GRAPH_PASSES_NEXT_ITERATION_PASS_H_ #define GE_GRAPH_PASSES_NEXT_ITERATION_PASS_H_ -#include -#include -#include -#include -#include - #include "inc/graph_pass.h" struct LoopCondGroup { @@ -37,15 +31,64 @@ namespace ge { class NextIterationPass : public GraphPass { public: Status Run(ComputeGraphPtr graph); + + /// + /// @brief Clear Status, used for subgraph pass + /// @return SUCCESS + /// Status ClearStatus() override; private: - Status HandleEnterNode(const NodePtr &enter_node); + /// + /// @brief Group Enter node + /// @param [in] enter_node + /// @return Status + /// + Status GroupEnterNode(const NodePtr &enter_node); + + /// + /// @brief Find while groups + /// @return Status + /// Status FindWhileGroups(); + + /// + /// @brief Verify if valid + /// @return bool + /// bool VerifyWhileGroup(); + + /// + /// @brief Handle while group + /// @param [in] graph + /// @return Status + /// Status HandleWhileGroup(ComputeGraphPtr &graph); + + /// + /// @brief Create Active Node + /// @param [in] graph + /// @param [in] name + /// @return ge::NodePtr + /// NodePtr CreateActiveNode(ComputeGraphPtr &graph, const std::string &name); + + /// + /// @brief Break NextIteration Link & add name to merge attr + /// @param [in] next_node + /// @param [in] merge_node + /// @return Status + /// Status BreakNextIteration(const NodePtr &next_node, NodePtr &merge_node); + + /// + /// @brief find target node + /// @param [in] node + /// @param [in] target_type + /// @param [in] is_input + /// @param [out] target_node + /// @return Status + /// Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, NodePtr &target_node); // map diff --git a/src/ge/graph/passes/pass_manager.cc b/src/ge/graph/passes/pass_manager.cc index eec33eef..5be54f0a 100644 --- a/src/ge/graph/passes/pass_manager.cc +++ b/src/ge/graph/passes/pass_manager.cc @@ -19,6 +19,7 @@ #include "common/types.h" #include "common/util.h" #include "graph/utils/node_utils.h" +#include "graph/common/ge_call_wrapper.h" #include "omg/omg_inner_types.h" namespace ge { diff --git a/src/ge/graph/passes/permute_pass.cc b/src/ge/graph/passes/permute_pass.cc index f5fd9dc5..3c0dfd4e 100644 --- a/src/ge/graph/passes/permute_pass.cc +++ b/src/ge/graph/passes/permute_pass.cc @@ -33,7 +33,6 @@ using domi::TENSORFLOW; namespace ge { Status PermutePass::Run(ComputeGraphPtr graph) { - GE_TIMESTAMP_START(PermutePass); GE_CHECK_NOTNULL(graph); std::vector isolate_nodes; for (NodePtr &node : graph->GetDirectNode()) { @@ -116,8 +115,6 @@ Status PermutePass::Run(ComputeGraphPtr graph) { GE_RETURN_WITH_LOG_IF_ERROR(graph->RemoveNode(node), "[%s]:remove permute node failed", node->GetOpDesc()->GetName().c_str()); }); - - GE_TIMESTAMP_END(PermutePass, "GraphManager::PermutePass"); return SUCCESS; } } // namespace ge diff --git a/src/ge/graph/passes/print_op_pass.h b/src/ge/graph/passes/print_op_pass.h index 64bf6573..15b0badc 100644 --- a/src/ge/graph/passes/print_op_pass.h +++ b/src/ge/graph/passes/print_op_pass.h @@ -31,6 +31,6 @@ class PrintOpPass : public BaseNodePass { public: Status Run(ge::NodePtr &node) override; }; -}; // namespace ge +} // namespace ge #endif // GE_GRAPH_PASSES_PRINT_OP_PASS_H_ diff --git a/src/ge/graph/passes/ref_identity_delete_op_pass.cc b/src/ge/graph/passes/ref_identity_delete_op_pass.cc new file mode 100644 index 00000000..5bc0fad6 --- /dev/null +++ b/src/ge/graph/passes/ref_identity_delete_op_pass.cc @@ -0,0 +1,225 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ref_identity_delete_op_pass.h" +#include +#include +#include "graph/common/transop_util.h" + +namespace ge { +Status RefIdentityDeleteOpPass::Run(ComputeGraphPtr graph) { + GE_CHECK_NOTNULL(graph); + for (auto &node : graph->GetAllNodes()) { + if (node->GetType() != REFIDENTITY) { + continue; + } + int input_index = 0; + NodePtr ref_node = GetRefNode(node, input_index); + CHECK_FALSE_EXEC(GetRefNode(node, input_index) != nullptr, + GELOGE(FAILED, "Ref node of RefIdentity[%s] not found", node->GetName().c_str()); + return FAILED); + CHECK_FALSE_EXEC(DealNoOutputRef(ref_node, node, input_index, graph) == SUCCESS, + GELOGE(FAILED, "Ref identity [%s] delete failed", node->GetName().c_str()); + return FAILED); + } + return SUCCESS; +} + +NodePtr RefIdentityDeleteOpPass::GetRefNode(const NodePtr &node, int &input_index) { + OutDataAnchorPtr out_anchor = node->GetOutDataAnchor(0); + CHECK_FALSE_EXEC(out_anchor != nullptr, return nullptr); + for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { + CHECK_FALSE_EXEC(peer_in_anchor != nullptr, continue); + auto peer_node = peer_in_anchor->GetOwnerNode(); + CHECK_FALSE_EXEC(peer_node != nullptr, continue); + const auto &peer_op_desc = peer_node->GetOpDesc(); + CHECK_FALSE_EXEC(peer_op_desc != nullptr, return nullptr); + const auto &peer_input_desc = peer_op_desc->GetInputDescPtr(static_cast(peer_in_anchor->GetIdx())); + if (!peer_input_desc->GetRefPortIndex().empty()) { + input_index = peer_in_anchor->GetIdx(); + return peer_node; + } + } + return nullptr; +} + +Status RefIdentityDeleteOpPass::DealNoOutputRef(const NodePtr &node, const NodePtr &ref_identity, int input_index, + const ComputeGraphPtr &graph) { + NodePtr first_node = nullptr; + NodePtr variable_ref = GetVariableRef(node, ref_identity, first_node); + if (variable_ref == nullptr) { + GELOGE(FAILED, "[RefIdentityDeleteOpPass]Can not find variable ref for %s:%d", node->GetName().c_str(), + input_index); + return FAILED; + } + if (first_node->GetName() != variable_ref->GetName()) { + // Remove the control edge between ref node and variable ref + // Add a control edge between ref node and trans node + // +-----------+ +-----------+ + // +---------+RefIdentity| +-----------+RefIdentity| + // | +-----+-----+ | +-----+-----+ + // | | | | + // | v | v + // +-----v-----+ +----+----+ +-----v-----+ +----+----+ + // | TransNode | | RefNode | ==> | TransNode +<--C--+ RefNode | + // +-----+-----+ +----+----+ +-----+-----+ +---------+ + // | | | + // v C v + // +-----+-----+ | +-----+-----+ + // |VariableRef+<--------+ |VariableRef| + // +-----------+ +-----------+ + auto ret = ge::GraphUtils::AddEdge(node->GetOutControlAnchor(), first_node->GetInControlAnchor()); + if (ret != SUCCESS) { + GELOGE(FAILED, "Add control edge between ref node and trans node failed"); + return FAILED; + } + ret = ge::GraphUtils::RemoveEdge(node->GetOutControlAnchor(), variable_ref->GetInControlAnchor()); + if (ret != SUCCESS) { + GELOGE(FAILED, "Remove control edge between ref node and its peer node failed"); + return FAILED; + } + } else { + // +-----------+ +-----------+ + // +-----------+RefIdentity| +-----------+RefIdentity| + // | +-----+-----+ | +-----+-----+ + // | | | | + // | v | v + // +-----v-----+ +----+----+ +-----v-----+ +----+----+ + // |VariableRef+<--C--+ RefNode | ==> |VariableRef+<--C--+ RefNode | + // +-----+-----+ +----+----+ +-----------+ +----+----+ + // | | | + // | v v + // | +---+----+ +---+----+ + // +-----C------>+ | | | + // +--------+ +--------+ + auto ret = RemoveUselessControlEdge(node, variable_ref); + if (ret != SUCCESS) { + GELOGE(FAILED, "Remove useless control edge failed."); + return FAILED; + } + } + // remove ref identity + if (GraphUtils::IsolateNode(ref_identity, {0}) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Isolate removed node: %s, type: %s failed", ref_identity->GetName().c_str(), + variable_ref->GetType().c_str()); + return FAILED; + } + if (GraphUtils::RemoveNodeWithoutRelink(graph, ref_identity) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Remove node: %s, type: %s without relink failed", ref_identity->GetName().c_str(), + ref_identity->GetType().c_str()); + return FAILED; + } + return SUCCESS; +} + +ge::NodePtr RefIdentityDeleteOpPass::GetVariableRef(const NodePtr &ref, const NodePtr &ref_identity, + NodePtr &first_node) { + const auto &ref_identity_out_anchor = ref_identity->GetOutDataAnchor(0); + if (ref_identity_out_anchor == nullptr) { + return nullptr; + } + for (auto &peer_in_anchor : ref_identity_out_anchor->GetPeerInDataAnchors()) { + const auto &peer_node = peer_in_anchor->GetOwnerNode(); + if (peer_node == nullptr || peer_node->GetName() == ref->GetName()) { + continue; + } + // DFS to find variable ref node. + std::stack nodes_to_check; + nodes_to_check.push(peer_node); + GELOGI("[RefIdentityDeleteOpPass]Start to search variable ref node from %s.", peer_node->GetName().c_str()); + NodePtr cur_node = nullptr; + while (!nodes_to_check.empty()) { + cur_node = nodes_to_check.top(); + nodes_to_check.pop(); + const auto &type = cur_node->GetType(); + if (type == VARIABLE && CheckControlEdge(ref, cur_node)) { + // Target variable ref node found. + GELOGI("[RefIdentityDeleteOpPass]variable ref node[%s] found.", cur_node->GetName().c_str()); + first_node = peer_node; + return cur_node; + } + + int data_index = TransOpUtil::GetTransOpDataIndex(type); + if (data_index < 0) { + GELOGI("[RefIdentityDeleteOpPass]Find node[%s] that is not trans op[%s], stop to search its output.", + cur_node->GetName().c_str(), type.c_str()); + continue; + } + const auto &cur_out_anchor = cur_node->GetOutDataAnchor(0); + if (cur_out_anchor == nullptr) { + GELOGI("[RefIdentityDeleteOpPass]Get out anchor of [%s] failed, stop to search its output.", + cur_node->GetName().c_str()); + continue; + } + for (const auto &cur_peer_in_anchor : cur_out_anchor->GetPeerInDataAnchors()) { + const auto &cur_peer_node = cur_peer_in_anchor->GetOwnerNode(); + if (cur_peer_node == nullptr) { + continue; + } + nodes_to_check.push(cur_peer_node); + } + } + GELOGI("[RefIdentityDeleteOpPass]Can not find variable ref node from %s.", peer_node->GetName().c_str()); + } + GELOGI("[RefIdentityDeleteOpPass]Can not find variable ref node, return nullptr."); + return nullptr; +} + +bool RefIdentityDeleteOpPass::CheckControlEdge(const NodePtr &ref, const NodePtr &variable_ref) { + const auto &control_out_anchor = ref->GetOutControlAnchor(); + if (control_out_anchor == nullptr) { + return false; + } + const string &variable_ref_name = variable_ref->GetName(); + for (const auto &peer_in_control_anchor : control_out_anchor->GetPeerInControlAnchors()) { + const auto &node = peer_in_control_anchor->GetOwnerNode(); + if (node != nullptr && node->GetName() == variable_ref_name) { + return true; + } + } + return false; +} + +Status RefIdentityDeleteOpPass::RemoveUselessControlEdge(const NodePtr &ref, const NodePtr &variable_ref) { + map out_nodes_map; + for (const auto &out_anchor : ref->GetAllOutDataAnchors()) { + for (const auto &peer_in_anchor : out_anchor->GetPeerAnchors()) { + const auto &peer_node = peer_in_anchor->GetOwnerNode(); + if (peer_node == nullptr) { + continue; + } + out_nodes_map[peer_node->GetName()] = peer_node; + } + } + const auto &out_control_anchor = variable_ref->GetOutControlAnchor(); + GE_CHECK_NOTNULL(out_control_anchor); + for (const auto &peer_in_control_anchor : out_control_anchor->GetPeerInControlAnchors()) { + const auto &peer_node = peer_in_control_anchor->GetOwnerNode(); + if (peer_node == nullptr) { + continue; + } + if (out_nodes_map.find(peer_node->GetName()) != out_nodes_map.end()) { + auto ret = ge::GraphUtils::RemoveEdge(out_control_anchor, peer_in_control_anchor); + if (ret != SUCCESS) { + GELOGE(FAILED, "Remove control edge between variable ref node[%s] and ref node's peer node[%s] failed", + variable_ref->GetName().c_str(), peer_node->GetName().c_str()); + return FAILED; + } + } + } + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/passes/ref_identity_delete_op_pass.h b/src/ge/graph/passes/ref_identity_delete_op_pass.h new file mode 100644 index 00000000..3e42def4 --- /dev/null +++ b/src/ge/graph/passes/ref_identity_delete_op_pass.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_REF_IDENTITY_DELETE_OP_PASS_H_ +#define GE_GRAPH_PASSES_REF_IDENTITY_DELETE_OP_PASS_H_ + +#include +#include +#include "framework/common/ge_inner_error_codes.h" +#include "inc/graph_pass.h" + +namespace ge { +class RefIdentityDeleteOpPass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph); + + private: + Status DealNoOutputRef(const NodePtr &node, const NodePtr &ref_identity, int input_index, + const ComputeGraphPtr &graph); + NodePtr GetVariableRef(const NodePtr &ref, const NodePtr &ref_identity, NodePtr &first_node); + bool CheckControlEdge(const NodePtr &ref, const NodePtr &variable_ref); + Status RemoveUselessControlEdge(const NodePtr &ref, const NodePtr &variable_ref); + NodePtr GetRefNode(const NodePtr &node, int &input_index); +}; +} // namespace ge + +#endif // GE_GRAPH_PASSES_REF_IDENTITY_DELETE_OP_PASS_H_ diff --git a/src/ge/graph/passes/reshape_recovery_pass.cc b/src/ge/graph/passes/reshape_recovery_pass.cc index 787c8d83..07b08de9 100644 --- a/src/ge/graph/passes/reshape_recovery_pass.cc +++ b/src/ge/graph/passes/reshape_recovery_pass.cc @@ -30,6 +30,10 @@ NodePtr CreateReshape(const ConstGeTensorDescPtr &src, const ConstGeTensorDescPt if (ret != GRAPH_SUCCESS) { return nullptr; } + ret = reshape->AddInputDesc("shape", GeTensorDesc(GeShape(), Format(), DT_INT32)); + if (ret != GRAPH_SUCCESS) { + return nullptr; + } ret = reshape->AddOutputDesc("y", *dst); if (ret != GRAPH_SUCCESS) { return nullptr; @@ -49,7 +53,10 @@ Status InsertReshapeIfNeed(const NodePtr &node) { GE_CHECK_NOTNULL(dst_node); GE_CHECK_NOTNULL(dst_node->GetOpDesc()); auto dst_tensor = dst_node->GetOpDesc()->GetInputDescPtr(dst_anchor->GetIdx()); - if (src_tensor->GetShape().GetDims() != dst_tensor->GetShape().GetDims()) { + bool is_need_insert_reshape = src_tensor->GetShape().GetDims() != UNKNOWN_RANK && + dst_tensor->GetShape().GetDims() != UNKNOWN_RANK && + src_tensor->GetShape().GetDims() != dst_tensor->GetShape().GetDims(); + if (is_need_insert_reshape) { auto reshape = CreateReshape(src_tensor, dst_tensor, node->GetOwnerComputeGraph()); GE_CHECK_NOTNULL(reshape); auto ret = GraphUtils::InsertNodeBetweenDataAnchors(src_anchor, dst_anchor, reshape); diff --git a/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc b/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc index 3b4e4c19..d51f52e1 100644 --- a/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc +++ b/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc @@ -22,7 +22,6 @@ #include #include "common/ge_inner_error_codes.h" #include "common/types.h" -#include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" @@ -117,20 +116,44 @@ void SameTransdataBreadthFusionPass::InsertSameTransdataNodeIndex(int anchors_in same_transdata_nodes.push_back(anchors_index); } +std::set SameTransdataBreadthFusionPass::GetInControlIdentityNodes(const NodePtr &node, + int subgraph_index) { + std::set in_node_names; + for (const auto &in_node : node->GetInControlNodes()) { + if (in_node->GetType() == IDENTITY) { + in_node_names.insert(in_node->GetName()); + } + } + for (const auto &subgraph_node : before_transdata_nodes_[subgraph_index]) { + for (const auto &in_node : subgraph_node->GetInControlNodes()) { + if (in_node->GetType() == IDENTITY) { + in_node_names.insert(in_node->GetName()); + } + } + } + GELOGD("control in nodes for %s(%d): %zu", node->GetName().c_str(), subgraph_index, in_node_names.size()); + return in_node_names; +} + void SameTransdataBreadthFusionPass::GetSameTransdataNode(vector &same_transdata_nodes) { auto iter = all_transdata_nodes_.begin(); same_transdata_nodes.push_back(iter->first); + auto node_for_compare_in_anchor = iter->second; GE_CHECK_NOTNULL_JUST_RETURN(node_for_compare_in_anchor); auto node_for_compare = node_for_compare_in_anchor->GetOwnerNode(); + + // Get op-desc, input/output desc, in-control-edges-from-identity, as the compare-key auto op_desc_for_compare = node_for_compare->GetOpDesc(); GE_CHECK_NOTNULL_JUST_RETURN(op_desc_for_compare); string op_compare_stream_label; (void)AttrUtils::GetStr(op_desc_for_compare, ATTR_NAME_STREAM_LABEL, op_compare_stream_label); + auto op_compare_in_ctrl_nodes = GetInControlIdentityNodes(node_for_compare, iter->first); auto input_desc_for_compare = op_desc_for_compare->GetInputDescPtr(node_for_compare_in_anchor->GetIdx()); GE_CHECK_NOTNULL_JUST_RETURN(input_desc_for_compare); auto output_desc_for_compare = op_desc_for_compare->GetOutputDescPtr(0); GE_CHECK_NOTNULL_JUST_RETURN(output_desc_for_compare); + iter = all_transdata_nodes_.erase(iter); while (iter != all_transdata_nodes_.end()) { auto in_anchor = iter->second; @@ -149,12 +172,14 @@ void SameTransdataBreadthFusionPass::GetSameTransdataNode(vector &same_tran auto output_desc_tmp = op_desc_tmp->GetOutputDescPtr(0); string op_tmp_stream_label; (void)AttrUtils::GetStr(op_desc_tmp, ATTR_NAME_STREAM_LABEL, op_tmp_stream_label); + auto op_tmp_in_ctrl_nodes = GetInControlIdentityNodes(node_tmp, iter->first); GE_CHECK_NOTNULL_JUST_RETURN(input_desc_tmp); GE_CHECK_NOTNULL_JUST_RETURN(output_desc_tmp); if ((op_compare_stream_label == op_tmp_stream_label) && (input_desc_tmp->GetFormat() == input_desc_for_compare->GetFormat()) && - (output_desc_tmp->GetFormat() == output_desc_for_compare->GetFormat())) { + (output_desc_tmp->GetFormat() == output_desc_for_compare->GetFormat()) && + (op_compare_in_ctrl_nodes == op_tmp_in_ctrl_nodes)) { GELOGD("same transdata node:%s, src node:%s", node_tmp->GetName().c_str(), node_for_compare->GetName().c_str()); InsertSameTransdataNodeIndex(iter->first, same_transdata_nodes); iter = all_transdata_nodes_.erase(iter); @@ -339,14 +364,13 @@ graphStatus SameTransdataBreadthFusionPass::ReLinkTransdataControlOutput2PreNode } graphStatus SameTransdataBreadthFusionPass::Run(ComputeGraphPtr graph) { - GE_TIMESTAMP_START(SameTransdataBreadthFusionPass); GELOGI("[SameTransdataBreadthFusionPass]: optimize begin."); if (graph == nullptr) { return GRAPH_SUCCESS; } for (auto &node : graph->GetDirectNode()) { - if (IsTransOp(node) || node->GetOutDataNodes().size() <= 1) { + if (IsTransOp(node) || node->GetOutDataNodesSize() <= 1) { continue; } @@ -374,7 +398,6 @@ graphStatus SameTransdataBreadthFusionPass::Run(ComputeGraphPtr graph) { } GELOGI("[SameTransdataBreadthFusionPass]: Optimize success."); - GE_TIMESTAMP_END(SameTransdataBreadthFusionPass, "GraphManager::SameTransdataBreadthFusionPass"); return GRAPH_SUCCESS; } diff --git a/src/ge/graph/passes/same_transdata_breadth_fusion_pass.h b/src/ge/graph/passes/same_transdata_breadth_fusion_pass.h index f4b44a59..a6a3bb26 100644 --- a/src/ge/graph/passes/same_transdata_breadth_fusion_pass.h +++ b/src/ge/graph/passes/same_transdata_breadth_fusion_pass.h @@ -42,7 +42,7 @@ class SameTransdataBreadthFusionPass : public GraphPass { void GetSubGraphNodesInfo(); void EraseInvalidAnchorsPair(); - + std::set GetInControlIdentityNodes(const NodePtr &node, int subgraph_index); OpDescPtr GetCastOp(const GeTensorDesc &in_desc, const GeTensorDesc &out_desc); graphStatus AddCastNode(const ComputeGraphPtr &graph, int anchors_index, OutDataAnchorPtr &pre_out_anchor, diff --git a/src/ge/graph/passes/subgraph_pass.cc b/src/ge/graph/passes/subgraph_pass.cc index d759aa12..80ce995a 100644 --- a/src/ge/graph/passes/subgraph_pass.cc +++ b/src/ge/graph/passes/subgraph_pass.cc @@ -15,7 +15,6 @@ */ #include "graph/passes/subgraph_pass.h" -#include #include "graph/utils/node_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" @@ -67,13 +66,13 @@ Status SubgraphPass::Run(ComputeGraphPtr graph) { /** * @ingroup ge - * @brief Check Subgraph NetOutput node + * @brief Check Subgraph Input node * @param [in] graph: ComputeGraph. - * @param [in] node: NetOutput node in Subgraph. + * @param [in] node: Data node in Subgraph. * @return: 0 for SUCCESS / others for FAILED */ Status SubgraphPass::SubgraphInputNode(const ComputeGraphPtr &graph, const NodePtr &node) { - GELOGD("Hadle input_node %s for graph %s.", node->GetName().c_str(), graph->GetName().c_str()); + GELOGD("Handle input_node %s for graph %s.", node->GetName().c_str(), graph->GetName().c_str()); // Data has and only has one output bool input_continues_required_flag = false; OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(0); @@ -86,7 +85,7 @@ Status SubgraphPass::SubgraphInputNode(const ComputeGraphPtr &graph, const NodeP // Data->InputContinuesRequiredOp in subgraph need memcpy. if (input_continues_required_flag) { GELOGD("Data %s output_node required continues input.", node->GetName().c_str()); - std::string name = node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); + std::string name = node->GetName() + "_output_0_Memcpy"; if (InsertMemcpyNode(graph, out_data_anchor, in_anchors, name) != SUCCESS) { GELOGE(FAILED, "Insert memcpy after %s failed.", node->GetName().c_str()); return FAILED; @@ -123,7 +122,7 @@ Status SubgraphPass::SubgraphInputNode(const ComputeGraphPtr &graph, const NodeP GE_CHECK_NOTNULL(peer_out_anchor); GELOGD("Constant input %s links to While %s.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), parent_node->GetName().c_str()); - std::string name = in_node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); + std::string name = parent_node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; if (InsertMemcpyNode(parent_graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { GELOGE(FAILED, "Insert memcpy between %s and %s failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), parent_node->GetName().c_str()); @@ -136,7 +135,7 @@ Status SubgraphPass::SubgraphInputNode(const ComputeGraphPtr &graph, const NodeP /** * @ingroup ge - * @brief Check Subgraph NetOutput node + * @brief Check Subgraph Output node * @param [in] graph: ComputeGraph. * @param [in] node: NetOutput node in Subgraph. * @return: 0 for SUCCESS / others for FAILED @@ -153,14 +152,14 @@ Status SubgraphPass::SubgraphOutputNode(const ComputeGraphPtr &graph, const Node // 1. Const->NetOutput in subgraph // 2. AtomicOp->NetOutput in subgraph // 3. OutputContinuesRequiredOp->NetOutput in subgraph - // 4. Data->NetOutput in subgraph but not while body + // 4. Data->NetOutput in subgraph but parent_node is not while std::string op_type; bool insert_flag = NodeUtils::GetConstOpType(in_node, op_type) || IsAtomicRequired(in_node, peer_out_anchor->GetIdx()) || IsOutputContinuesRequired(in_node) || - ((in_node->GetType() == DATA) && !IsWhileBodyOutput(in_data_anchor)); + ((in_node->GetType() == DATA) && (kWhileOpTypes.count(graph->GetParentNode()->GetType()) == 0)); if (insert_flag) { - GELOGI("Insert MemcpyAsync node between %s and %s.", node->GetName().c_str(), in_node->GetName().c_str()); - std::string name = in_node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); + GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str()); + std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { GELOGE(FAILED, "Insert memcpy between %s and %s failed.", in_node->GetName().c_str(), node->GetName().c_str()); return FAILED; @@ -186,8 +185,8 @@ Status SubgraphPass::WhileInputNodes(const ComputeGraphPtr &graph, const NodePtr GE_CHECK_NOTNULL(in_node); // Input->While and Input link to other nodes need insert memcpy if (peer_out_anchor->GetPeerInDataAnchors().size() > 1) { - GELOGI("Input %s of While %s links to other nodes.", in_node->GetName().c_str(), node->GetName().c_str()); - std::string name = in_node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); + GELOGD("Input %s of While %s links to other nodes.", in_node->GetName().c_str(), node->GetName().c_str()); + std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { GELOGE(FAILED, "Insert memcpy between %s and %s failed.", in_node->GetName().c_str(), node->GetName().c_str()); return FAILED; @@ -206,231 +205,121 @@ Status SubgraphPass::WhileInputNodes(const ComputeGraphPtr &graph, const NodePtr * @return: 0 for SUCCESS / others for FAILED */ Status SubgraphPass::WhileBodySubgraph(const ComputeGraphPtr &graph, const NodePtr &node) { - ComputeGraphPtr while_body = GetWhileBodySubgraph(graph, node); + // index of body_subgraph is 1 + ComputeGraphPtr while_body = NodeUtils::GetSubgraph(*node, 1); if (while_body == nullptr) { GELOGE(FAILED, "while_body of %s is NULL.", node->GetName().c_str()); return FAILED; } - NodePtr output_node = while_body->FindFirstNodeMatchType(NETOUTPUT); - if (output_node == nullptr) { - GELOGE(FAILED, "net_output_node not exist in graph %s.", while_body->GetName().c_str()); - return FAILED; - } - OpDescPtr output_desc = output_node->GetOpDesc(); - GE_CHECK_NOTNULL(output_desc); - std::unordered_map> node_to_attr_index; - for (const InDataAnchorPtr &in_data_anchor : output_node->GetAllInDataAnchors()) { - uint32_t index = 0; - if (!AttrUtils::GetInt(output_desc->GetInputDesc(in_data_anchor->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index)) { - GELOGE(FAILED, "Get attr PARENT_NODE_INDEX failed, node %s:%u.", output_node->GetName().c_str(), - in_data_anchor->GetIdx()); - return FAILED; + std::vector data_nodes; + std::set bypass_index; + NodePtr output_node = nullptr; + for (const auto &n : while_body->GetDirectNode()) { + const std::string &type = n->GetType(); + if (type == DATA) { + if (CheckInsertInputMemcpy(n, bypass_index)) { + data_nodes.emplace_back(n); + } + } else if (type == NETOUTPUT) { + if (output_node == nullptr) { + output_node = n; + } else { + GELOGE(FAILED, "while_body %s exists multi NetOutput nodes.", while_body->GetName().c_str()); + return FAILED; + } } - MarkOutputIndex(in_data_anchor->GetPeerOutAnchor(), index, node_to_attr_index); } - - std::set data_nodes; - std::set netoutput_input_indexes; - GetExchangeInOut(node_to_attr_index, data_nodes, netoutput_input_indexes); - return InsertMemcpyInWhileBody(while_body, data_nodes, output_node, netoutput_input_indexes); -} - -/** - * @ingroup ge - * @brief Get body subgraph of While op - * @param [in] graph: ComputeGraph. - * @param [in] node: While node. - * @return: body subgraph - */ -ComputeGraphPtr SubgraphPass::GetWhileBodySubgraph(const ComputeGraphPtr &graph, const NodePtr &node) { - OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - GELOGE(FAILED, "op_desc is NULL."); - return nullptr; - } - - const std::vector &subgraph_instance_names = op_desc->GetSubgraphInstanceNames(); - std::string body_instance_name; - for (const std::string &instance_name : subgraph_instance_names) { - std::string subgraph_name; - if (op_desc->GetSubgraphNameByInstanceName(instance_name, subgraph_name) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Get subgraph_name by instance_name %s failed, node:%s.", instance_name.c_str(), - node->GetName().c_str()); - return nullptr; - } - if (subgraph_name == ATTR_NAME_WHILE_BODY) { - body_instance_name = instance_name; - break; - } + if (output_node == nullptr) { + GELOGE(FAILED, "while_body %s has no output.", while_body->GetName().c_str()); + return FAILED; } - ComputeGraphPtr root_graph = GraphUtils::FindRootGraph(graph); - if (root_graph == nullptr) { - GELOGE(FAILED, "root_graph is NULL."); - return nullptr; + if ((InsertInputMemcpy(while_body, data_nodes) != SUCCESS) || + (InsertOutputMemcpy(while_body, output_node, bypass_index) != SUCCESS)) { + GELOGE(FAILED, "Insert memcpy node in while_body %s failed.", while_body->GetName().c_str()); + return FAILED; } - return root_graph->GetSubgraph(body_instance_name); + return SUCCESS; } /** * @ingroup ge - * @brief Mark output parent_node_index - * @param [in] peer_out_anchor: peer_out_anchor of NetOutput - * @param [in] index: parent_node_index of NetOutput - * @param [out] node_to_attr_index: key for node in subgraph, value for parent_node_index - * @return: void + * @brief Insert input memcpy node in while_body + * @param [in] graph: while_body + * @param [in] data_nodes: data_nodes + * @return: 0 for SUCCESS / others for FAILED */ -void SubgraphPass::MarkOutputIndex(const OutDataAnchorPtr &peer_out_anchor, uint32_t index, - std::unordered_map> &node_to_attr_index) { - if (peer_out_anchor == nullptr) { - return; - } - std::set visited_nodes; - std::stack nodes; - nodes.emplace(peer_out_anchor->GetOwnerNode()); - while (!nodes.empty()) { - NodePtr cur_node = nodes.top(); - nodes.pop(); - if (visited_nodes.count(cur_node) > 0) { - continue; - } - node_to_attr_index[cur_node].emplace_back(index); - for (const NodePtr &in_node : cur_node->GetInDataNodes()) { - nodes.emplace(in_node); - } - visited_nodes.emplace(cur_node); +Status SubgraphPass::InsertInputMemcpy(const ComputeGraphPtr &graph, const std::vector &data_nodes) { + if (data_nodes.empty()) { + GELOGD("No need to insert input memcpy node in while_body %s.", graph->GetName().c_str()); + return SUCCESS; } -} - -/** - * @ingroup ge - * @brief Get data_nodes / input_indexes of netoutput if need insert memcpy - * @param [in] node_to_attr_index: key for node in subgraph, value for parent_node_index - * @param [out] data_nodes: data_nodes need insert memcpy - * @param [out] netoutput_input_indexes: input_indexes of netoutput need insert memcpy - * @return: void - */ -void SubgraphPass::GetExchangeInOut(const std::unordered_map> &node_to_attr_index, - std::set &data_nodes, std::set &netoutput_input_indexes) { - for (const auto &item : node_to_attr_index) { - NodePtr node = item.first; - uint32_t input_index = 0; - if ((node->GetType() != DATA) || !AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, input_index)) { - continue; - } - if (item.second.empty() || ((item.second.size() == 1) && (item.second[0] == input_index))) { - continue; - } - data_nodes.emplace(node); + std::string in_name = graph->GetName() + "_input_Memcpy"; + OpDescBuilder in_builder(in_name, MEMCPYASYNC); + for (size_t i = 0; i < data_nodes.size(); i++) { // Data node has and only has one output - OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(0); - if (out_data_anchor == nullptr) { - continue; - } - for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { - NodePtr out_node = peer_in_anchor->GetOwnerNode(); - if ((out_node->GetType() != NETOUTPUT) || (out_node->GetOpDesc() == nullptr)) { - continue; - } - uint32_t output_index = 0; - GeTensorDesc input_tensor = out_node->GetOpDesc()->GetInputDesc(peer_in_anchor->GetIdx()); - if (!AttrUtils::GetInt(input_tensor, ATTR_NAME_PARENT_NODE_INDEX, output_index)) { - continue; - } - if (input_index != output_index) { - netoutput_input_indexes.emplace(peer_in_anchor->GetIdx()); - } - } + in_builder.AddInput("x" + std::to_string(i), data_nodes[i]->GetOpDesc()->GetOutputDesc(0)) + .AddOutput("y" + std::to_string(i), data_nodes[i]->GetOpDesc()->GetOutputDesc(0)); } -} - -/** - * @ingroup ge - * @brief Insert memcpy node in while_body - * @param [in] graph: while_body - * @param [in] data_nodes: data_nodes need insert memcpy - * @param [in] output_node: NetOutput in while_body - * @param [in] netoutput_input_indexes: input_indexes of netoutput need insert memcpy - * @return: 0 for SUCCESS / others for FAILED - */ -Status SubgraphPass::InsertMemcpyInWhileBody(const ComputeGraphPtr &graph, const std::set &data_nodes, - const NodePtr &output_node, - const std::set &netoutput_input_indexes) { - for (const NodePtr &data_node : data_nodes) { + GELOGD("Insert memcpy after data_nodes of while_body %s.", graph->GetName().c_str()); + NodePtr in_memcpy = graph->AddNode(in_builder.Build()); + GE_CHECK_NOTNULL(in_memcpy); + for (size_t i = 0; i < data_nodes.size(); i++) { // Data node has and only has one output - OutDataAnchorPtr out_data_anchor = data_node->GetOutDataAnchor(0); + OutDataAnchorPtr out_data_anchor = data_nodes[i]->GetOutDataAnchor(0); std::vector in_anchors; for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { in_anchors.emplace_back(peer_in_anchor); } - std::string name = data_node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); - GELOGD("Insert memcpy after while_body %s input_node %s.", graph->GetName().c_str(), data_node->GetName().c_str()); - if (InsertMemcpyNode(graph, out_data_anchor, in_anchors, name) != SUCCESS) { - GELOGE(FAILED, "Insert MemcpyAsync node %s after %s failed.", name.c_str(), data_node->GetName().c_str()); - return FAILED; - } - } - - for (uint32_t index : netoutput_input_indexes) { - InDataAnchorPtr in_data_anchor = output_node->GetInDataAnchor(index); - GE_CHECK_NOTNULL(in_data_anchor); - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_anchor); - std::string name = - peer_out_anchor->GetOwnerNode()->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); - GELOGD("Insert memcpy after while_body %s output %u.", graph->GetName().c_str(), index); - if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { - GELOGE(FAILED, "Insert MemcpyAsync node %s after %s failed.", name.c_str(), - peer_out_anchor->GetOwnerNode()->GetName().c_str()); + if (InsertNodeBetween(out_data_anchor, in_anchors, in_memcpy, i, i) != SUCCESS) { + GELOGE(FAILED, "Insert MemcpyAsync %s in while_body %s failed.", in_name.c_str(), graph->GetName().c_str()); return FAILED; } } - std::set memcpy_nodes; - std::set loop_body_nodes; - for (const NodePtr &data_node : data_nodes) { - // data_node has only one output node - NodePtr memcpy_node = data_node->GetOutDataNodes().at(0); - GE_CHECK_NOTNULL(memcpy_node); - memcpy_nodes.emplace(memcpy_node); - for (const NodePtr &out_node : memcpy_node->GetOutDataNodes()) { - loop_body_nodes.insert(out_node); - } - } - return InsertNoOp(graph, memcpy_nodes, loop_body_nodes); + return SUCCESS; } /** * @ingroup ge - * @brief Insert NoOp node between memcpy_nodes and loop_body_nodes + * @brief Insert output memcpy node in while_body * @param [in] graph: while_body - * @param [in] memcpy_nodes - * @param [in] loop_body_nodes + * @param [in] output_node: NetOutput + * @param [in] bypass_index * @return: 0 for SUCCESS / others for FAILED */ -Status SubgraphPass::InsertNoOp(const ComputeGraphPtr &graph, const std::set &memcpy_nodes, - const std::set &loop_body_nodes) { - if (memcpy_nodes.empty() || loop_body_nodes.empty()) { +Status SubgraphPass::InsertOutputMemcpy(const ComputeGraphPtr &graph, const NodePtr &output_node, + const std::set &bypass_index) { + if (output_node->GetAllInDataAnchorsSize() == bypass_index.size()) { + GELOGD("No need to insert output memcpy node in while_body %s, output_size=%zu, bypass_num=%zu.", + graph->GetName().c_str(), output_node->GetAllInDataAnchorsSize(), bypass_index.size()); return SUCCESS; } - OpDescBuilder noop_desc_builder("NoOp_for_Control", NOOP); - OpDescPtr noop_desc = noop_desc_builder.Build(); - NodePtr noop_node = graph->AddNode(noop_desc); - GE_CHECK_NOTNULL(noop_node); - for (const NodePtr &memcpy_node : memcpy_nodes) { - if (GraphUtils::AddEdge(memcpy_node->GetOutControlAnchor(), noop_node->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Add ctrl edge %s->%s failed.", memcpy_node->GetName().c_str(), noop_node->GetName().c_str()); - return FAILED; + std::string out_name = graph->GetName() + "_output_Memcpy"; + OpDescBuilder out_builder(out_name, MEMCPYASYNC); + for (size_t i = 0; i < output_node->GetAllInDataAnchorsSize(); i++) { + if (bypass_index.count(i) == 0) { + out_builder.AddInput("x" + std::to_string(i), output_node->GetOpDesc()->GetInputDesc(i)) + .AddOutput("y" + std::to_string(i), output_node->GetOpDesc()->GetInputDesc(i)); } } - for (const NodePtr &loop_body_node : loop_body_nodes) { - if (GraphUtils::AddEdge(noop_node->GetOutControlAnchor(), loop_body_node->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Add ctrl edge %s->%s failed.", noop_node->GetName().c_str(), loop_body_node->GetName().c_str()); - return FAILED; + GELOGD("Insert memcpy before NetOutput of while_body %s.", graph->GetName().c_str()); + NodePtr out_memcpy = graph->AddNode(out_builder.Build()); + GE_CHECK_NOTNULL(out_memcpy); + size_t cnt = 0; + for (size_t i = 0; i < output_node->GetAllInDataAnchorsSize(); i++) { + if (bypass_index.count(i) == 0) { + InDataAnchorPtr in_data_anchor = output_node->GetInDataAnchor(i); + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + if (InsertNodeBetween(peer_out_anchor, {in_data_anchor}, out_memcpy, cnt, cnt) != SUCCESS) { + GELOGE(FAILED, "Insert MemcpyAsync %s in while_body %s failed.", out_name.c_str(), graph->GetName().c_str()); + return FAILED; + } + cnt++; } } @@ -439,28 +328,39 @@ Status SubgraphPass::InsertNoOp(const ComputeGraphPtr &graph, const std::setnetoutput in while body - * @param [in] in_data_anchor - * @return: true for data->netoutput in while body / for false for others + * @brief Check is data->netoutput without change in while body + * @param [in] node: data node + * @param [out] bypass_index + * @return: false for data->netoutput without change in while body / for true for others */ -bool SubgraphPass::IsWhileBodyOutput(const InDataAnchorPtr &in_data_anchor) { - // Check is subgraph - NodePtr parent_node = in_data_anchor->GetOwnerNode()->GetOwnerComputeGraph()->GetParentNode(); - if (parent_node == nullptr) { - return false; +bool SubgraphPass::CheckInsertInputMemcpy(const NodePtr &node, std::set &bypass_index) { + uint32_t input_index = 0; + if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, input_index)) { + return true; } - // Check if parent_node is While - if (kWhileOpTypes.count(parent_node->GetType()) == 0) { - return false; + // Data node has and only has one output + OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(0); + if ((out_data_anchor == nullptr) || (out_data_anchor->GetPeerInDataAnchors().size() != 1)) { + return true; + } + InDataAnchorPtr peer_in_anchor = out_data_anchor->GetPeerInDataAnchors().at(0); + if (peer_in_anchor->GetOwnerNode()->GetType() != NETOUTPUT) { + return true; } - // While cond / body - OpDescPtr op_desc = in_data_anchor->GetOwnerNode()->GetOpDesc(); - if (op_desc == nullptr) { - return false; + OpDescPtr op_desc = peer_in_anchor->GetOwnerNode()->GetOpDesc(); + uint32_t output_index = 0; + if ((op_desc == nullptr) || + !AttrUtils::GetInt(op_desc->GetInputDesc(peer_in_anchor->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, output_index)) { + return true; } - return AttrUtils::HasAttr(op_desc->GetInputDesc(in_data_anchor->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX); + + if (input_index != output_index) { + return true; + } + bypass_index.insert(peer_in_anchor->GetIdx()); + return false; } /** @@ -542,7 +442,7 @@ Status SubgraphPass::InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDat OpDescPtr op_desc = op_desc_builder.AddInput("x", in_node->GetOpDesc()->GetOutputDesc(0)) .AddOutput("y", in_node->GetOpDesc()->GetOutputDesc(0)) .Build(); - if (GraphUtils::InsertNodeBefore(out_anchor, in_anchors, graph->AddNode(op_desc)) != GRAPH_SUCCESS) { + if (GraphUtils::InsertNodeAfter(out_anchor, in_anchors, graph->AddNode(op_desc)) != GRAPH_SUCCESS) { GELOGE(FAILED, "Insert MemcpyAsync node %s after %s failed.", name.c_str(), in_node->GetName().c_str()); return FAILED; } @@ -550,4 +450,33 @@ Status SubgraphPass::InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDat return SUCCESS; } +/// +/// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst +/// @param [in] src +/// @param [in] dsts +/// @param [in] insert_node +/// @param [in] input_index +/// @param [in] output_index +/// @return Status +/// +Status SubgraphPass::InsertNodeBetween(const OutDataAnchorPtr &src, const std::vector &dsts, + const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { + if (GraphUtils::AddEdge(src, insert_node->GetInDataAnchor(input_index)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add data_edge %s:%d->%s:%u failed.", src->GetOwnerNode()->GetName().c_str(), src->GetIdx(), + insert_node->GetName().c_str(), input_index); + return FAILED; + } + for (const auto &dst : dsts) { + GELOGD("Insert node %s between %s->%s.", insert_node->GetName().c_str(), src->GetOwnerNode()->GetName().c_str(), + dst->GetOwnerNode()->GetName().c_str()); + if ((GraphUtils::RemoveEdge(src, dst) != GRAPH_SUCCESS) || + (GraphUtils::AddEdge(insert_node->GetOutDataAnchor(output_index), dst) != GRAPH_SUCCESS)) { + GELOGE(FAILED, "Replace data_edge %s:%d->%s:%d by %s:%u->%s:%d failed.", src->GetOwnerNode()->GetName().c_str(), + src->GetIdx(), dst->GetOwnerNode()->GetName().c_str(), dst->GetIdx(), insert_node->GetName().c_str(), + output_index, dst->GetOwnerNode()->GetName().c_str(), dst->GetIdx()); + return FAILED; + } + } + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/passes/subgraph_pass.h b/src/ge/graph/passes/subgraph_pass.h index 2308b1bd..7ff2019f 100644 --- a/src/ge/graph/passes/subgraph_pass.h +++ b/src/ge/graph/passes/subgraph_pass.h @@ -17,12 +17,6 @@ #ifndef GE_GRAPH_PASSES_SUBGRAPH_PASS_H_ #define GE_GRAPH_PASSES_SUBGRAPH_PASS_H_ -#include -#include -#include -#include - -#include "graph/types.h" #include "inc/graph_pass.h" namespace ge { @@ -75,65 +69,32 @@ class SubgraphPass : public GraphPass { /** * @ingroup ge - * @brief Get body subgraph of While op - * @param [in] graph: ComputeGraph. - * @param [in] node: While node. - * @return: body subgraph - */ - ComputeGraphPtr GetWhileBodySubgraph(const ComputeGraphPtr &graph, const NodePtr &node); - - /** - * @ingroup ge - * @brief Mark output parent_node_index - * @param [in] peer_out_anchor: peer_out_anchor of NetOutput - * @param [in] index: parent_node_index of NetOutput - * @param [out] node_to_attr_index: key for node in subgraph, value for parent_node_index - * @return: void - */ - void MarkOutputIndex(const OutDataAnchorPtr &peer_out_anchor, uint32_t index, - std::unordered_map> &node_to_attr_index); - - /** - * @ingroup ge - * @brief Get data_nodes / input_indexes of netoutput if need insert memcpy - * @param [in] node_to_attr_index: key for node in subgraph, value for parent_node_index - * @param [out] data_nodes: data_nodes need insert memcpy - * @param [out] netoutput_input_indexes: input_indexes of netoutput need insert memcpy - * @return: void - */ - void GetExchangeInOut(const std::unordered_map> &node_to_attr_index, - std::set &data_nodes, std::set &netoutput_input_indexes); - - /** - * @ingroup ge - * @brief Insert memcpy node in while_body + * @brief Insert input memcpy node in while_body * @param [in] graph: while_body - * @param [in] data_nodes: data_nodes need insert memcpy - * @param [in] output_node: NetOutput in while_body - * @param [in] netoutput_input_indexes: input_indexes of netoutput need insert memcpy + * @param [in] data_nodes: data_nodes * @return: 0 for SUCCESS / others for FAILED */ - Status InsertMemcpyInWhileBody(const ComputeGraphPtr &graph, const std::set &data_nodes, - const NodePtr &output_node, const std::set &netoutput_input_indexes); + Status InsertInputMemcpy(const ComputeGraphPtr &graph, const std::vector &data_nodes); /** * @ingroup ge - * @brief Insert NoOp node between memcpy_nodes and loop_body_nodes + * @brief Insert output memcpy node in while_body * @param [in] graph: while_body - * @param [in] memcpy_nodes - * @param [in] loop_body_nodes + * @param [in] output_node: NetOutput + * @param [in] bypass_index * @return: 0 for SUCCESS / others for FAILED */ - Status InsertNoOp(const ComputeGraphPtr &graph, const std::set &memcpy_nodes, - const std::set &loop_body_nodes); + Status InsertOutputMemcpy(const ComputeGraphPtr &graph, const NodePtr &output_node, + const std::set &bypass_index); /** * @ingroup ge - * @brief Check is Data->NetOutput in while body - * @param [in] in_data_anchor - * @return: true for Data->NetOutput in while body / false for others + * @brief Check is data->netoutput without change in while body + * @param [in] node: data node + * @param [out] bypass_index + * @return: false for data->netoutput without change in while body / for true for others */ - bool IsWhileBodyOutput(const InDataAnchorPtr &in_data_anchor); + bool CheckInsertInputMemcpy(const NodePtr &node, std::set &bypass_index); /** * @ingroup ge @@ -172,8 +133,17 @@ class SubgraphPass : public GraphPass { Status InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, const std::vector &in_anchors, const std::string &name); - // Append index for new memcpy node. - uint32_t memcpy_num_{0}; + /// + /// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst + /// @param [in] src + /// @param [in] dsts + /// @param [in] insert_node + /// @param [in] input_index + /// @param [in] output_index + /// @return Status + /// + Status InsertNodeBetween(const OutDataAnchorPtr &src, const std::vector &dsts, + const NodePtr &insert_node, uint32_t input_index, uint32_t output_index); }; } // namespace ge #endif // GE_GRAPH_PASSES_SUBGRAPH_PASS_H_ diff --git a/src/ge/graph/passes/switch_op_pass.cc b/src/ge/graph/passes/switch_op_pass.cc deleted file mode 100644 index ed3e9b36..00000000 --- a/src/ge/graph/passes/switch_op_pass.cc +++ /dev/null @@ -1,1227 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/switch_op_pass.h" -#include -#include -#include -#include -#include -#include -#include -#include "common/ge/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/debug/log.h" -#include "framework/common/ge_inner_error_codes.h" -#include "framework/common/types.h" -#include "ge/ge_api_types.h" -#include "graph/common/omg_util.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/ge_context.h" -#include "graph/utils/type_utils.h" - -namespace ge { -Status SwitchOpPass::Run(ComputeGraphPtr graph) { - GELOGD("SwitchOpPass Enter"); - GE_CHK_STATUS_RET(CheckCycleDependence(graph), "CheckCycleDependence fail."); - - for (auto &switch_node : switch_nodes_) { - GE_CHK_STATUS_RET(ReplaceSwitchNode(graph, switch_node), "Add StreamSwitch node fail."); - } - - for (auto &merge_node : merge_nodes_) { - OpDescPtr merge_op_desc = merge_node->GetOpDesc(); - GE_CHECK_NOTNULL(merge_op_desc); - if (merge_op_desc->HasAttr(ATTR_INSERT_BY_MBATCH)) { - GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, merge_node, true), "Merge add memcpy node fail."); - GE_CHK_STATUS_RET(SetStreamLabel(merge_node, merge_node->GetName()), "Set stream label failed"); - } else { - GE_CHK_STATUS_RET(ReplaceMergeNode(graph, merge_node), "Add StreamMerge node fail."); - } - } - - GE_CHK_STATUS_RET(CombineSwitchNode(graph), "Combine StreamSwitch nodes fail."); - - for (auto &node : bypass_nodes_) { - GE_CHK_BOOL_EXEC(graph->RemoveNode(node) == GRAPH_SUCCESS, return FAILED, "Remove switch node fail."); - } - - for (auto &node : stream_switch_nodes_) { - for (auto &out_ctrl_node : node->GetOutControlNodes()) { - MarkHeadNodes(out_ctrl_node, node); - } - } - - for (auto &node : need_label_nodes_) { - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - if (!op_desc->HasAttr(ATTR_NAME_STREAM_LABEL)) { - GE_CHK_STATUS_RET(UpdateCondBranch(node), "Set cond branch fail, start node:%s", node->GetName().c_str()); - } - } - - GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode fail."); - - GELOGD("SwitchOpPass Leave"); - return SUCCESS; -} - -/// -/// @brief Replace Switch Op -/// @param [in] graph -/// @param [in] switch_node -/// @return Status -/// -Status SwitchOpPass::ReplaceSwitchNode(ComputeGraphPtr &graph, NodePtr &switch_node) { - std::string type; - GE_CHK_STATUS_RET(GetOriginalType(switch_node, type), "Get node type fail."); - GE_CHK_BOOL_EXEC((type == SWITCH) || (type == REFSWITCH), return FAILED, "Type of input node is not switch."); - - OutDataAnchorPtr peer_data_anchor = nullptr; - OutDataAnchorPtr peer_cond_anchor = nullptr; - GE_CHK_BOOL_EXEC(BypassSwitchNode(switch_node, peer_data_anchor, peer_cond_anchor) == SUCCESS, return FAILED, - "Bypass switch node %s fail.", switch_node->GetName().c_str()); - GE_CHECK_NOTNULL(peer_data_anchor); - GE_CHECK_NOTNULL(peer_cond_anchor); - OpDescPtr cond_desc = peer_cond_anchor->GetOwnerNode()->GetOpDesc(); - GE_CHECK_NOTNULL(cond_desc); - DataType cond_data_type = cond_desc->GetOutputDesc(peer_cond_anchor->GetIdx()).GetDataType(); - GE_CHK_BOOL_EXEC(cond_data_type == DT_BOOL, return FAILED, - "SwitchNode not support datatype %s, datatype of cond_input should be bool", - TypeUtils::DataTypeToSerialString(cond_data_type).c_str()); - - OpDescPtr switch_desc = switch_node->GetOpDesc(); - GE_CHECK_NOTNULL(switch_desc); - bool cyclic_flag = switch_desc->HasAttr(ATTR_NAME_CYCLIC_DEPENDENCE_FLAG); - - std::set out_node_list; - for (OutDataAnchorPtr &out_data_anchor : switch_node->GetAllOutDataAnchors()) { - bool true_branch_flag = (static_cast(out_data_anchor->GetIdx()) == SWITCH_TRUE_OUTPUT); - NodePtr stream_switch = nullptr; - out_node_list.clear(); - for (auto &peer_in_anchor : out_data_anchor->GetPeerAnchors()) { - GE_IF_BOOL_EXEC(stream_switch == nullptr, { - std::string suffix = (true_branch_flag ? "_t" : "_f"); - stream_switch = CreateStreamSwitchNode(graph, switch_node, suffix, peer_cond_anchor); - GE_CHK_BOOL_EXEC(stream_switch != nullptr, return FAILED, "Create stream_switch node fail."); - if (SetSwitchTrueBranchFlag(stream_switch, true_branch_flag) != SUCCESS) { - GELOGE(FAILED, "SetSwitchTrueBranchFlag for node %s fail.", stream_switch->GetName().c_str()); - return FAILED; - } - if (MarkBranchs(peer_cond_anchor, stream_switch, true_branch_flag) != SUCCESS) { - GELOGE(FAILED, "MarkBranchs for stream_switch %s fail.", stream_switch->GetName().c_str()); - return FAILED; - } - - if (!cyclic_flag) { - GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor->GetOwnerNode()->GetOutControlAnchor(), - stream_switch->GetInControlAnchor()), - "StreamSwitch node add ctl edge fail."); - } - }); - - GE_CHK_STATUS(GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor), "Remove Switch data output fail."); - - NodePtr out_node = peer_in_anchor->GetOwnerNode(); - GE_CHK_STATUS_RET(GetOriginalType(out_node, type), "Get node type fail."); - if ((type == MERGE) || (type == REFMERGE)) { - NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, peer_data_anchor, false); - GE_CHK_BOOL_EXEC(memcpy_node != nullptr, return FAILED, "Create memcpy_async node fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor, memcpy_node->GetInDataAnchor(0)), - "MemcpyAsync node add edge fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(memcpy_node->GetOutDataAnchor(0), peer_in_anchor), - "MemcpyAsync node add edge fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(stream_switch->GetOutControlAnchor(), memcpy_node->GetInControlAnchor()), - "MemcpyAsync node add ctl edge fail."); - out_node_list.insert(memcpy_node->GetName()); - } else { - GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor, peer_in_anchor), "StreamSwitch node add edge fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(stream_switch->GetOutControlAnchor(), out_node->GetInControlAnchor()), - "StreamSwitch node add ctl edge fail."); - out_node_list.insert(out_node->GetName()); - } - } - GE_IF_BOOL_EXEC(stream_switch != nullptr, { - CopyControlEdges(switch_node, stream_switch, true); - switch_node_map_[stream_switch] = out_node_list; - if (SetOriginalNodeName(stream_switch, switch_node->GetName()) != SUCCESS) { - GELOGE(FAILED, "SetOriginalNodeName for node %s fail.", stream_switch->GetName().c_str()); - return FAILED; - } - }); - } - - RemoveControlEdges(switch_node); - (void)bypass_nodes_.insert(switch_node); - - return SUCCESS; -} - -/// -/// @brief Replace Merge Op -/// @param [in] graph -/// @param [in] merge_node -/// @return Status -/// -Status SwitchOpPass::ReplaceMergeNode(ComputeGraphPtr &graph, NodePtr &merge_node) { - std::string type; - GE_CHK_STATUS_RET(GetOriginalType(merge_node, type), "Get node type fail."); - GE_CHK_BOOL_EXEC((type == MERGE) || (type == REFMERGE), return FAILED, "Type of input node is not merge."); - - OpDescPtr merge_op_desc = merge_node->GetOpDesc(); - GE_CHECK_NOTNULL(merge_op_desc); - - const std::string node_name = merge_node->GetName(); - GELOGI("Create StreamMerge Op, name=%s.", node_name.c_str()); - OpDescPtr op_desc = MakeShared(node_name, STREAMMERGE); - if (op_desc == nullptr) { - GELOGE(FAILED, "Create op_desc fail, StreamMerge:%s.", node_name.c_str()); - return FAILED; - } - - for (InDataAnchorPtr &in_anchor : merge_node->GetAllInDataAnchors()) { - GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(merge_op_desc->GetInputDesc(in_anchor->GetIdx())) == GRAPH_SUCCESS, - return FAILED, "Create StreamMerge op: add input desc fail."); - } - - for (OutDataAnchorPtr &out_anchor : merge_node->GetAllOutDataAnchors()) { - GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(merge_op_desc->GetOutputDesc(out_anchor->GetIdx())) == GRAPH_SUCCESS, - return FAILED, "Create StreamMerge op: add output desc fail."); - } - - NodePtr stream_merge = graph->AddNode(op_desc); - GE_CHK_BOOL_EXEC(stream_merge != nullptr, return FAILED, "Insert StreamMerge node fail."); - - for (InDataAnchorPtr &in_data_anchor : merge_node->GetAllInDataAnchors()) { - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); - - GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "Remove Merge data input fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, stream_merge->GetInDataAnchor(in_data_anchor->GetIdx())), - "StreamMerge node add edge fail."); - } - - for (OutDataAnchorPtr &out_data_anchor : merge_node->GetAllOutDataAnchors()) { - for (InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor), "Remove Merge data output fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(stream_merge->GetOutDataAnchor(out_data_anchor->GetIdx()), peer_in_anchor), - "StreamMerge node add edge fail."); - } - } - - ReplaceControlEdges(merge_node, stream_merge); - - if (merge_op_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) { - std::string next_iteration_name; - GE_IF_BOOL_EXEC(!AttrUtils::GetStr(merge_op_desc, ATTR_NAME_NEXT_ITERATION, next_iteration_name), - GELOGE(INTERNAL_ERROR, "get ATTR_NAME_NEXT_ITERATION failed"); - return INTERNAL_ERROR); - - GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "set next iteration failed"); - } else { - need_label_nodes_.emplace_back(stream_merge); - } - - (void)bypass_nodes_.insert(merge_node); - - GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, stream_merge, false), "StreamMerge add memcpy node fail."); - - return SUCCESS; -} - -/// -/// @brief Create StreamSwitch Node -/// @param [in] graph -/// @param [in] switch_node -/// @param [in] suffix -/// @param [in] peer_cond_anchor -/// @return ge::NodePtr -/// -NodePtr SwitchOpPass::CreateStreamSwitchNode(ComputeGraphPtr &graph, const NodePtr &switch_node, - const std::string &suffix, OutDataAnchorPtr &peer_cond_anchor) { - GE_CHK_BOOL_EXEC(switch_node != nullptr, return nullptr, "Param of merge node is null."); - OpDescPtr switch_op_desc = switch_node->GetOpDesc(); - GE_CHK_BOOL_EXEC(switch_op_desc != nullptr, return nullptr, "OpDesc of Switch node is invalid."); - GE_IF_BOOL_EXEC(switch_op_desc->GetInputsSize() != SWITCH_INPUT_NUM, { - GELOGE(FAILED, "Switch input param invalid, input_size=%lu, should be %u", switch_op_desc->GetInputsSize(), - SWITCH_INPUT_NUM); - return nullptr; - }); - - const std::string node_name = switch_node->GetName() + "_" + STREAMSWITCH + suffix; - GELOGI("Create StreamSwitch, name=%s.", node_name.c_str()); - OpDescPtr op_desc = MakeShared(node_name, STREAMSWITCH); - if (op_desc == nullptr) { - GELOGE(FAILED, "Create op_desc fail, StreamSwitch:%s.", node_name.c_str()); - return nullptr; - } - // mark hccl group id - std::string hccl_group_id; - if (AttrUtils::GetStr(switch_node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { - (void)AttrUtils::SetStr(op_desc, ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id); - GELOGI("Set attr ATTR_NAME_HCCL_FUSED_GROUP for Stream_Switch%s, value is %s.", node_name.c_str(), - hccl_group_id.c_str()); - } else { - GELOGI("Can not find attr ATTR_NAME_HCCL_FUSED_GROUP for node %s.", switch_node->GetName().c_str()); - } - - if (!AttrUtils::SetInt(op_desc, ATTR_NAME_SWITCH_DATA_TYPE, RT_SWITCH_INT32) || - !AttrUtils::SetInt(op_desc, ATTR_NAME_STREAM_SWITCH_COND, (int64_t)RT_EQUAL)) { - GELOGE(INTERNAL_ERROR, "set int failed"); - return nullptr; - } - - // Already checked, first input is Variable will passed, second is condition will checked. - GeTensorDesc cond_input_desc = switch_op_desc->GetInputDesc(SWITCH_PRED_INPUT); - GeTensorDesc input_desc(GeShape(cond_input_desc.GetShape().GetDims()), cond_input_desc.GetFormat(), DT_INT32); - GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(input_desc) == GRAPH_SUCCESS, return nullptr, - "Create StreamSwitch node: add input desc fail."); - GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(input_desc) == GRAPH_SUCCESS, return nullptr, - "Create StreamSwitch node: add input desc fail."); - - NodePtr stream_switch = graph->AddNode(op_desc); - GE_CHK_BOOL_EXEC(stream_switch != nullptr, return nullptr, "Insert StreamSwitch node fail."); - - GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), - "StreamSwitch node add cond edge fail."); - - return stream_switch; -} - -/// -/// @brief Add MemcpyAsync Node -/// @param [in] graph -/// @param [in] in_node -/// @param [in] multi_batch_flag -/// @return ge::NodePtr -/// -NodePtr SwitchOpPass::CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, - bool multi_batch_flag) { - GE_CHK_BOOL_EXEC(out_data_anchor != nullptr, return nullptr, "Param of input node is null."); - OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); - GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "OpDesc of pre node is invalid."); - - std::string memcpy_type = multi_batch_flag ? MEMCPYADDRASYNC : MEMCPYASYNC; - std::string node_name = pre_op_desc->GetName() + "_" + memcpy_type; - node_name = CheckDuplicateName(node_name); - GELOGI("Create MemcpyAsync op:%s.", node_name.c_str()); - OpDescPtr op_desc = MakeShared(node_name, memcpy_type); - if (op_desc == nullptr) { - GELOGE(FAILED, "Create op_desc fail, MemcpyAsync:%s.", node_name.c_str()); - return nullptr; - } - - GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) == GRAPH_SUCCESS, - return nullptr, "Create MemcpyAsync op: add input desc fail."); - GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) == GRAPH_SUCCESS, - return nullptr, "Create MemcpyAsync op: add output desc fail."); - - NodePtr memcpy_node = graph->AddNode(op_desc); - GE_CHK_BOOL_EXEC(memcpy_node != nullptr, return nullptr, "Insert MemcpyAsync node fail."); - - return memcpy_node; -} - -/// -/// @brief Combine switch nodes link to same cond -/// @param [in] graph -/// @return Status -/// -Status SwitchOpPass::CombineSwitchNode(ComputeGraphPtr &graph) { - for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { - for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { - OutDataAnchorPtr peer_cond_anchor = iter->first; - GE_CHECK_NOTNULL(peer_cond_anchor); - std::list false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; - std::list true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; - std::set same_cond_switch; - same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); - same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); - - NodePtr cond_node = peer_cond_anchor->GetOwnerNode(); - GELOGI("CombineSwitchNode: cond_node=%s", cond_node->GetName().c_str()); - - NodePtr cast_node = CreateCastOp(graph, peer_cond_anchor); - GE_CHK_BOOL_EXEC(cast_node != nullptr, return FAILED, "Create cast_node fail."); - - NodePtr active_node = CreateActiveNode(graph, cond_node); - GE_CHK_BOOL_EXEC(active_node != nullptr, return FAILED, "Create StreamActive node fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutControlAnchor(), active_node->GetInControlAnchor()), - "StreamActive add ctl edge fail."); - if (SetActiveLabelList(active_node, {cast_node->GetName()}) != SUCCESS) { - GELOGE(FAILED, "SetActiveLabelList for node %s fail.", active_node->GetName().c_str()); - return FAILED; - } - - const std::string cond_group = cond_node->GetName(); - for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { - bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); - std::list &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); - GE_IF_BOOL_EXEC(switch_list.empty(), continue); - - // select first stream_switch - NodePtr stream_switch = switch_list.front(); - OpDescPtr switch_desc = stream_switch->GetOpDesc(); - GE_CHECK_NOTNULL(switch_desc); - std::string node_name = cond_group + "/" + STREAMSWITCH + (true_branch_flag ? "_t" : "_f"); - node_name = CheckDuplicateName(node_name); - switch_desc->SetName(node_name); - stream_switch_nodes_.emplace_back(stream_switch); - need_label_nodes_.emplace_back(stream_switch); - - // 0_input: original pred input, 1_input: constant node - GE_CHK_STATUS_RET(AddConstNode(graph, stream_switch), "Add const node fail"); - GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), - "StreamSwitch remove data edge fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)), - "Cast add data edge fail."); - - for (NodePtr &node : switch_list) { - GE_CHECK_NOTNULL(node); - GE_IF_BOOL_EXEC(node != stream_switch, { - GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), - "StreamSwitch remove data edge fail."); - }); - GE_CHK_STATUS(ModifySwitchInCtlEdges(node, cast_node, same_cond_switch), "ModifySwitchInCtlEdges fail"); - GE_CHK_STATUS(ModifySwitchOutCtlEdges(node, stream_switch, active_node), "ModifySwitchOutCtlEdges fail"); - } - - GE_CHK_STATUS(GraphUtils::AddEdge(active_node->GetOutControlAnchor(), stream_switch->GetInControlAnchor()), - "StreamActive add ctl edge fail."); - } - } - } - return SUCCESS; -} - -/// -/// @brief Create Active Op -/// @param [in] graph -/// @param [in] cond_node -/// @return ge::NodePtr -/// -NodePtr SwitchOpPass::CreateActiveNode(ComputeGraphPtr &graph, NodePtr &node) { - GE_CHK_BOOL_EXEC(node != nullptr, return nullptr, "Param of pre cond_node is null."); - std::string node_name = node->GetName() + "_" + STREAMACTIVE; - node_name = CheckDuplicateName(node_name); - GELOGI("Create StreamActive op:%s.", node_name.c_str()); - OpDescPtr op_desc = MakeShared(node_name, STREAMACTIVE); - if (op_desc == nullptr) { - GELOGE(FAILED, "Create op_desc fail, StreamActive:%s.", node_name.c_str()); - return nullptr; - } - - NodePtr active_node = graph->AddNode(op_desc); - GE_CHK_BOOL_EXEC(active_node != nullptr, return nullptr, "Create StreamActive node fail."); - - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(node->GetOutControlAnchor(), active_node->GetInControlAnchor()) != SUCCESS, - GELOGE(INTERNAL_ERROR, "add edge failed"); - return nullptr); - - GE_IF_BOOL_EXEC(SetSwitchBranchNodeLabel(active_node, node_name) != SUCCESS, - GELOGE(INTERNAL_ERROR, "set switch branch node label failed"); - return nullptr); - - return active_node; -} - -/// -/// @brief Add MemcpyAsync Op as StreamMerge in_node -/// @param [in] graph -/// @param [in] node -/// @param [in] multi_batch_flag -/// @return Status -/// -Status SwitchOpPass::AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &node, bool multi_batch_flag) { - GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Param of pre node is null."); - for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); - NodePtr in_node = peer_out_anchor->GetOwnerNode(); - - const std::string type = in_node->GetType(); - // For WhileLoop no need memcpy & active for merge. - GE_IF_BOOL_EXEC((type == ENTER) || (type == REFENTER) || (type == NEXTITERATION) || (type == REFNEXTITERATION), - continue); - - GE_IF_BOOL_EXEC(type != MEMCPYASYNC, { - in_node = CreateMemcpyAsyncNode(graph, peer_out_anchor, multi_batch_flag); - GE_CHK_BOOL_EXEC(in_node != nullptr, return FAILED, "Create MemcpyAsync node fail."); - GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "MemcpyAsync node remove edge fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, in_node->GetInDataAnchor(0)), - "MemcpyAsync node add edge fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(in_node->GetOutDataAnchor(0), in_data_anchor), - "MemcpyAsync node add edge fail."); - }); - - NodePtr active_node = CreateActiveNode(graph, in_node); - GE_CHK_BOOL_EXEC(active_node != nullptr, return FAILED, "Create StreamActive node fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(active_node->GetOutControlAnchor(), node->GetInControlAnchor()), - "StreamActive add ctl edge fail."); - if (SetActiveLabelList(active_node, {node->GetName()}) != SUCCESS) { - GELOGE(FAILED, "SetActiveLabelList for node %s fail.", active_node->GetName().c_str()); - return FAILED; - } - } - - return SUCCESS; -} - -/// -/// @brief Bypass Switch Node -/// @param [in] switch_node -/// @param [out] peer_data_anchor -/// @param [out] peer_cond_anchor -/// @return Status -/// -Status SwitchOpPass::BypassSwitchNode(NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor, - OutDataAnchorPtr &peer_cond_anchor) { - GE_CHK_BOOL_EXEC(switch_node != nullptr, return FAILED, "Switch_node is null."); - for (uint32_t idx = 0; idx < SWITCH_INPUT_NUM; ++idx) { - InDataAnchorPtr in_data_anchor = switch_node->GetInDataAnchor(idx); - GE_CHK_BOOL_EXEC(in_data_anchor != nullptr, return FAILED, "node[%s]Check Switch input anchor fail.", - switch_node->GetName().c_str()); - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHK_BOOL_EXEC(peer_out_anchor != nullptr, return FAILED, "node[%s]Check Pre node output anchor fail.", - switch_node->GetName().c_str()); - // Remove Switch data input. - GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "remove edge failed"); - - if (idx == SWITCH_DATA_INPUT) { - peer_data_anchor = peer_out_anchor; - } else { - if (FindSwitchCondInput(false, peer_out_anchor) != SUCCESS) { - GELOGE(FAILED, "FindSwitchCondInput fail, switch=%s", switch_node->GetName().c_str()); - return FAILED; - } - peer_cond_anchor = peer_out_anchor; - } - } - - return SUCCESS; -} - -/// -/// @brief Find Switch cond input -/// @param [in] pass_switch_flag -/// @param [out] peer_cond_anchor -/// @return Status -/// -Status SwitchOpPass::FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor) { - NodePtr tmp_node = nullptr; - string type; - bool need_pass_type = true; - while (need_pass_type) { - if (tmp_node == nullptr) { - GE_CHECK_NOTNULL(peer_cond_anchor); - tmp_node = peer_cond_anchor->GetOwnerNode(); - } else { - InDataAnchorPtr in_data_anchor = tmp_node->GetInDataAnchor(SWITCH_DATA_INPUT); - GE_CHECK_NOTNULL(in_data_anchor); - peer_cond_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_cond_anchor); - tmp_node = peer_cond_anchor->GetOwnerNode(); - } - - GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type fail"); - need_pass_type = (pass_switch_flag && ((type == SWITCH) || (type == REFSWITCH))); - } - - return SUCCESS; -} - -int64_t SwitchOpPass::GetGroupId(const NodePtr &node) { - string tailing_optimization_option; - bool is_tailing_optimization = false; - auto ret = GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option); - if (ret == GRAPH_SUCCESS) { - // "1" means it's True from frontend option - is_tailing_optimization = (tailing_optimization_option == "1"); - GELOGI("Option ge.exec.isTailingOptimization is %s", tailing_optimization_option.c_str()); - } - if (!is_tailing_optimization) { - return 0; - } - - string hccl_group_id; - if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { - GELOGI("Node is %s, can not find hccl group id", node->GetName().c_str()); - return 0; - } - auto key_index = hccl_group_id.find_last_of('_'); - auto key_num = hccl_group_id.substr(key_index + 1, hccl_group_id.length() - key_index); - GELOGI("Node is %s,Hccl group id is %s, key_num is %s", node->GetName().c_str(), hccl_group_id.c_str(), - key_num.c_str()); - int64_t num = atoi(key_num.c_str()); - if (num == 0) { - return 0; - } - GELOGI("Hccl group id is %s, group id is %ld", hccl_group_id.c_str(), num); - return num; -} - -/// -/// @brief Mark Switch Branch -/// @param [in] peer_cond_anchor -/// @param [in] stream_switch -/// @param [in] true_branch_flag -/// @return Status -/// -Status SwitchOpPass::MarkBranchs(OutDataAnchorPtr &peer_cond_anchor, NodePtr &stream_switch, bool true_branch_flag) { - uint32_t index = true_branch_flag ? SWITCH_TRUE_OUTPUT : SWITCH_FALSE_OUTPUT; - GE_CHECK_NOTNULL(stream_switch); - auto it = cond_node_map_.find(peer_cond_anchor); - if (it != cond_node_map_.end()) { - int64_t switch_group_id = GetGroupId(stream_switch); - auto switch_group_it = it->second.find(switch_group_id); - if (switch_group_it == it->second.end()) { - std::list false_node_list; - std::list true_node_list; - std::list &node_list = true_branch_flag ? true_node_list : false_node_list; - node_list.emplace_back(stream_switch); - std::vector> switch_list; - switch_list.emplace_back(false_node_list); - switch_list.emplace_back(true_node_list); - (void)it->second.emplace(switch_group_id, switch_list); - } else { - GE_IF_BOOL_EXEC(switch_group_it->second.size() != SWITCH_OUTPUT_NUM, { - GELOGE(INTERNAL_ERROR, "cond_node_map_ check size fail, node: %s", stream_switch->GetName().c_str()); - return FAILED; - }); - switch_group_it->second[index].emplace_back(stream_switch); - } - } else { - int64_t switch_group_id = GetGroupId(stream_switch); - map>> switch_group_map; - std::list false_node_list; - std::list true_node_list; - std::list &node_list = true_branch_flag ? true_node_list : false_node_list; - node_list.emplace_back(stream_switch); - std::vector> switch_list; - switch_list.emplace_back(false_node_list); - switch_list.emplace_back(true_node_list); - (void)switch_group_map.emplace(switch_group_id, switch_list); - auto result = cond_node_map_.insert( - std::pair>>>(peer_cond_anchor, switch_group_map)); - GE_IF_BOOL_EXEC(!result.second, { - GELOGE(INTERNAL_ERROR, "cond_node_map_ insert fail, node: %s", stream_switch->GetName().c_str()); - return FAILED; - }); - } - return SUCCESS; -} - -/// -/// @brief Create cast node -/// @param [in] graph -/// @param [in] peer_cond_anchor -/// @return NodePtr -/// -NodePtr SwitchOpPass::CreateCastOp(ComputeGraphPtr &graph, OutDataAnchorPtr &peer_cond_anchor) { - GE_CHK_BOOL_EXEC(peer_cond_anchor != nullptr, return nullptr, "Param of pre cond_node is null."); - OpDescPtr cond_desc = peer_cond_anchor->GetOwnerNode()->GetOpDesc(); - GE_CHK_BOOL_EXEC(cond_desc != nullptr, return nullptr, "Get cond_desc fail."); - - std::string cast_name = cond_desc->GetName() + "_" + CAST; - cast_name = CheckDuplicateName(cast_name); - GELOGI("Create cast_node: %s, input datatype:DT_BOOL, out datatype:DT_INT32", cast_name.c_str()); - OpDescPtr cast_desc = MakeShared(cast_name, CAST); - if (cast_desc == nullptr) { - GELOGE(FAILED, "Create op_desc fail, Cast:%s.", cast_name.c_str()); - return nullptr; - } - if (!(AttrUtils::SetInt(cast_desc, CAST_ATTR_SRCT, (int64_t)DT_BOOL) && - AttrUtils::SetInt(cast_desc, CAST_ATTR_DSTT, (int64_t)DT_INT32) && - AttrUtils::SetInt(cast_desc, CAST_ATTR_DST_TYPE, (int64_t)DT_INT32) && - AttrUtils::SetBool(cast_desc, CAST_ATTR_TRUNCATE, false))) { - GELOGE(FAILED, "Set CAST_ATTR_SRCT or CAST_ATTR_DSTT or CAST_ATTR_DST_TYPE or CAST_ATTR_TRUNCATE fail, node: %s.", - cast_name.c_str()); - return nullptr; - } - GeTensorDesc tensor_desc = cond_desc->GetOutputDesc(peer_cond_anchor->GetIdx()); - tensor_desc.SetDataType(DT_BOOL); - GE_CHK_BOOL_EXEC(cast_desc->AddInputDesc(tensor_desc) == SUCCESS, return nullptr, "Cast_node add input desc fail."); - tensor_desc.SetDataType(DT_INT32); - GE_CHK_BOOL_EXEC(cast_desc->AddOutputDesc(tensor_desc) == SUCCESS, return nullptr, "Cast_node add output desc fail."); - - NodePtr cast_node = graph->AddNode(cast_desc); - GE_CHK_BOOL_EXEC(cast_node != nullptr, return nullptr, "Create cast_node fail."); - - GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, cast_node->GetInDataAnchor(0)), "Cast add data edge fail."); - - return cast_node; -} - -/// -/// @brief Add const node as switch input1 -/// @param [in] graph -/// @param [in] stream_switch -/// @return Status -/// -Status SwitchOpPass::AddConstNode(ComputeGraphPtr &graph, NodePtr &stream_switch) { - GE_CHK_BOOL_EXEC(stream_switch != nullptr, return FAILED, "stream_switch is null."); - OpDescPtr op_desc = stream_switch->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - bool value = false; - GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, - "StreamSwitch get attr TRUE_BRANCH_STREAM fail."); - - const std::string const_node_name = op_desc->GetName() + "_Constant_" + (value ? "t" : "f"); - GELOGI("Create const op: %s", const_node_name.c_str()); - OpDescPtr const_op_desc = MakeShared(const_node_name, CONSTANT); - if (const_op_desc == nullptr) { - GELOGE(FAILED, "Create op_desc fail, Constant:%s.", const_node_name.c_str()); - return FAILED; - } - - auto resize_value = (int32_t)value; - GeTensorDesc data_desc = op_desc->GetInputDesc(1); - GeTensorPtr const_value = - MakeShared(data_desc, reinterpret_cast(&resize_value), sizeof(int32_t)); - if (const_value == nullptr) { - GELOGE(FAILED, "Create tensor fail."); - return FAILED; - } - GE_CHK_BOOL_EXEC(AttrUtils::SetTensor(const_op_desc, ATTR_NAME_WEIGHTS, const_value), return FAILED); - GE_CHK_BOOL_EXEC(const_op_desc->AddOutputDesc(data_desc) == GRAPH_SUCCESS, return FAILED, - "Create Const op: add output desc fail."); - - NodePtr const_node = graph->AddNode(const_op_desc); - GE_CHK_BOOL_EXEC(const_node != nullptr, return FAILED, "Insert Const node fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(1)), - "StreamSwitch node add ctl edge fail."); - - return SUCCESS; -} - -/// -/// @brief update cond branch -/// @param [in] node -/// @return Status -/// -Status SwitchOpPass::UpdateCondBranch(NodePtr &node) { - std::string stream_label; - std::unordered_set branch_nodes; - std::unordered_set handled_set; - std::stack nodes; - nodes.push(node); - - static const std::set end_type_set = {STREAMSWITCH, STREAMMERGE, MERGE}; - bool merge_flag = false; - bool exit_flag = false; - bool net_output_flag = false; - - while (!nodes.empty()) { - NodePtr cur_node = nodes.top(); - nodes.pop(); - if (handled_set.count(cur_node) > 0) { - continue; - } - GE_CHECK_NOTNULL(cur_node); - if (UpdateAttachFlag(cur_node, stream_label, merge_flag, exit_flag, net_output_flag) != SUCCESS) { - GELOGE(FAILED, "UpdateAttachFlag fail, cur_node: %s.", cur_node->GetName().c_str()); - return FAILED; - } - - const std::string type = cur_node->GetType(); - for (auto &out_node : cur_node->GetOutAllNodes()) { - const std::string out_type = out_node->GetType(); - bool stop_flag = (end_type_set.count(out_type) > 0) || - ((branch_head_nodes_.count(out_node) > 0) && (branch_head_nodes_[out_node] != node)) || - (((type == ENTER) || (type == REFENTER)) && (out_type != STREAMACTIVE)); - if (!stop_flag) { - nodes.push(out_node); - GELOGD("branch_nodes insert %s", out_node->GetName().c_str()); - branch_nodes.insert(out_node); - } - } - handled_set.insert(cur_node); - } - - if (node->GetType() == STREAMSWITCH) { - GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed"); - } - - bool attach_flag = (merge_flag || exit_flag) && net_output_flag; - if (attach_flag) { - GELOGI("No need to keep on attaching label."); - return SUCCESS; - } - - for (NodePtr tmp_node : branch_nodes) { - GELOGD("Attach label %s to node: %s", stream_label.c_str(), tmp_node->GetName().c_str()); - GE_CHK_STATUS_RET(SetStreamLabel(tmp_node, stream_label), "set stream label failed"); - } - - return SUCCESS; -} - -/// -/// @brief update attach flag -/// @param [in] node -/// @param [out] stream_label -/// @param [out] merge_flag -/// @param [out] exit_flag -/// @param [out] net_output_flag -/// @return Status -/// -Status SwitchOpPass::UpdateAttachFlag(const NodePtr &node, std::string &stream_label, bool &merge_flag, bool &exit_flag, - bool &net_output_flag) { - const std::string type = node->GetType(); - if (type == STREAMSWITCH) { - if (node->GetInDataNodes().empty()) { - GELOGE(INTERNAL_ERROR, "cur_node %s has no input_data_node", node->GetName().c_str()); - return INTERNAL_ERROR; - } - stream_label = node->GetInDataNodes().at(0)->GetName(); - GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "set stream label failed"); - bool value = false; - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, - "StreamSwitch get attr TRUE_BRANCH_STREAM fail."); - stream_label += (value ? "_t" : "_f"); - } else if (type == STREAMMERGE) { - stream_label = node->GetName(); - GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "set stream label failed"); - merge_flag = true; - } else if ((type == EXIT) || (type == REFEXIT)) { - GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "set stream label failed"); - exit_flag = true; - } else if (type == NETOUTPUT) { - net_output_flag = true; - } - - return SUCCESS; -} - -/// -/// @brief update loop branch -/// @param [in] enter_nodes -/// @param [in] stream_label -/// @return Status -/// -Status SwitchOpPass::UpdateLoopBranch(const std::stack &enter_nodes, const std::string &stream_label) { - std::stack nodes(enter_nodes); - NodePtr cur_node = nullptr; - while (!nodes.empty()) { - cur_node = nodes.top(); - nodes.pop(); - for (NodePtr &out_node : cur_node->GetOutAllNodes()) { - OpDescPtr out_desc = out_node->GetOpDesc(); - GE_CHECK_NOTNULL(out_desc); - if (out_desc->HasAttr(ATTR_NAME_STREAM_LABEL)) { - continue; - } - GELOGD("Attach label %s to node: %s", stream_label.c_str(), out_node->GetName().c_str()); - GE_CHK_STATUS_RET(SetStreamLabel(out_node, stream_label), "set stream label failed"); - nodes.push(out_node); - } - } - - return SUCCESS; -} - -/// -/// @brief update enter nodes -/// @return Status -/// -Status SwitchOpPass::UpdateEnterNode() { - std::unordered_map> enter_active_map; - for (auto &enter_node : enter_nodes_) { - for (auto &out_ctrl_node : enter_node->GetOutControlNodes()) { - if (out_ctrl_node->GetType() != STREAMACTIVE) { - continue; - } - auto iter = enter_active_map.find(out_ctrl_node); - if (iter == enter_active_map.end()) { - enter_active_map[out_ctrl_node] = {enter_node}; - } else { - iter->second.emplace_back(enter_node); - } - } - } - - for (auto &pair : enter_active_map) { - std::string stream_label; - NodePtr active_node = pair.first; - GE_CHECK_NOTNULL(active_node); - OpDescPtr active_desc = active_node->GetOpDesc(); - GE_CHECK_NOTNULL(active_desc); - (void)AttrUtils::GetStr(active_desc, ATTR_NAME_STREAM_LABEL, stream_label); - if (stream_label.empty()) { - stream_label = active_desc->GetName(); - GE_CHK_STATUS_RET(SetStreamLabel(active_node, stream_label), "set stream label failed"); - } - std::stack enter_nodes; - for (auto &enter_node : pair.second) { - GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "set stream label failed"); - enter_nodes.emplace(enter_node); - } - - std::vector active_label_list; - if (!AttrUtils::GetListStr(active_desc, ATTR_NAME_ACTIVE_LABEL_LIST, active_label_list) || - (active_label_list.size() != 1) || active_label_list[0].empty()) { - GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ACTIVE_LABEL_LIST fail, node: %s", active_desc->GetName().c_str()); - return INTERNAL_ERROR; - } - if (UpdateLoopBranch(enter_nodes, active_label_list[0]) != SUCCESS) { - GELOGE(FAILED, "UpdateLoopBranch fail."); - return FAILED; - } - } - - return SUCCESS; -} - -/// -/// @brief Check duplicate node_name -/// @param [in] node_name -/// @return std::string -/// -std::string SwitchOpPass::CheckDuplicateName(const std::string &node_name) { - std::string tmp_name = node_name; - auto iter = node_num_map_.find(tmp_name); - if (iter != node_num_map_.end()) { - tmp_name = tmp_name + "_" + std::to_string(iter->second); - (iter->second)++; - } else { - node_num_map_[tmp_name] = 1; - } - return tmp_name; -} - -/// -/// @brief Check cyclic dependence -/// @param [in] graph -/// @return Status -/// -Status SwitchOpPass::CheckCycleDependence(ComputeGraphPtr &graph) { - std::string type; - std::unordered_map> cond_switch_map; - for (NodePtr &node : graph->GetDirectNode()) { - GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type fail"); - if ((type == SWITCH) || (type == REFSWITCH)) { - InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); - GE_CHK_BOOL_EXEC(in_cond_anchor != nullptr, return INTERNAL_ERROR, "Check Switch in_cond_anchor fail."); - OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor(); - GE_CHK_BOOL_EXEC(peer_out_anchor != nullptr, return INTERNAL_ERROR, "Check Switch peer_out_anchor fail."); - if (FindSwitchCondInput(true, peer_out_anchor) != SUCCESS) { - GELOGE(FAILED, "FindSwitchCondInput fail, switch=%s", node->GetName().c_str()); - return FAILED; - } - - NodePtr cond_node = peer_out_anchor->GetOwnerNode(); - auto iter = cond_switch_map.find(cond_node); - if (iter == cond_switch_map.end()) { - cond_switch_map[cond_node] = {node}; - } else { - iter->second.emplace_back(node); - } - - switch_nodes_.emplace_back(node); - } else if ((type == MERGE) || (type == REFMERGE)) { - merge_nodes_.emplace_back(node); - } else if ((type == ENTER) || (type == REFENTER)) { - enter_nodes_.emplace_back(node); - } - } - - MarkCycleDependence(cond_switch_map); - - return SUCCESS; -} - -/// -/// @brief Mark cyclic dependence -/// @param [in] graph -/// @param [in] cond_switch_map -/// @return void -/// -void SwitchOpPass::MarkCycleDependence(const std::unordered_map> &cond_switch_map) { - std::stack out_nodes; - NodePtr tmp_node = nullptr; - std::unordered_set handled_set; - for (auto &iter : cond_switch_map) { - std::set switch_nodes(iter.second.begin(), iter.second.end()); - for (auto &switch_node : switch_nodes) { - GE_CHECK_NOTNULL_JUST_RETURN(switch_node); - GELOGD("CheckCycleDependence: cond_node=%s, switch=%s", iter.first->GetName().c_str(), - switch_node->GetName().c_str()); - for (const NodePtr &node : switch_node->GetOutAllNodes()) { - out_nodes.push(node); - } - } - handled_set.clear(); - while (!out_nodes.empty()) { - tmp_node = out_nodes.top(); - GE_CHECK_NOTNULL_JUST_RETURN(tmp_node); - out_nodes.pop(); - if (handled_set.count(tmp_node) > 0) { - continue; - } - GELOGD("CheckCycleDependence: tmp_node=%s", tmp_node->GetName().c_str()); - for (NodePtr &out_node : tmp_node->GetOutAllNodes()) { - if (switch_nodes.find(out_node) == switch_nodes.end()) { - out_nodes.push(out_node); - continue; - } - GE_IF_BOOL_EXEC(SetCyclicDependenceFlag(out_node) != SUCCESS, GELOGW("set cyclic dependence failed"); return ); - auto map_iter = switch_cyclic_map_.find(out_node); - if (map_iter == switch_cyclic_map_.end()) { - switch_cyclic_map_[out_node] = {tmp_node->GetName()}; - } else { - map_iter->second.insert(tmp_node->GetName()); - } - } - handled_set.insert(tmp_node); - } - } - - return; -} - -/// -/// @brief Modify in ctl edge for switch_node -/// @param [in] switch_node -/// @param [in] cast_node -/// @param [in] same_cond_switch -/// @return Status -/// -Status SwitchOpPass::ModifySwitchInCtlEdges(NodePtr &switch_node, NodePtr &cast_node, - const std::set &same_cond_switch) { - GE_CHECK_NOTNULL(switch_node); - GE_CHECK_NOTNULL(cast_node); - GELOGI("ModifySwitchInCtlEdges: switch_node=%s, active_node=%s", switch_node->GetName().c_str(), - cast_node->GetName().c_str()); - - std::string orig_switch_name = switch_node->GetName(); - OpDescPtr switch_desc = switch_node->GetOpDesc(); - GE_CHECK_NOTNULL(switch_desc); - if (!AttrUtils::GetStr(switch_desc, ATTR_NAME_ORIG_NODE_NAME, orig_switch_name) || orig_switch_name.empty()) { - GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ORIG_NODE_NAME fail, node: %s", switch_desc->GetName().c_str()); - return INTERNAL_ERROR; - } - - for (NodePtr &in_ctl_node : switch_node->GetInControlNodes()) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), - "Remove ctl edge fail."); - GE_IF_BOOL_EXEC(!in_ctl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), - "Add ctl edge fail."); - }); - - GE_IF_BOOL_EXEC(in_ctl_node->GetType() != STREAMSWITCH, continue); - if (same_cond_switch.count(in_ctl_node) > 0) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), - "Remove ctl edge fail."); - continue; - } - auto find_res1 = switch_node_map_.find(in_ctl_node); - GE_IF_BOOL_EXEC(find_res1 == switch_node_map_.end(), { - GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctl_node->GetName().c_str()); - return INTERNAL_ERROR; - }); - auto find_res2 = find_res1->second.find(orig_switch_name); - auto find_res3 = find_res1->second.find(cast_node->GetName()); - GE_IF_BOOL_EXEC((find_res2 != find_res1->second.end()) && (find_res3 == find_res1->second.end()), { - find_res1->second.erase(find_res2); - find_res1->second.insert(cast_node->GetName()); - continue; - }); - } - - return SUCCESS; -} - -/// -/// @brief Modify out ctl edge for switch_node -/// @param [in] switch_node -/// @param [in] stream_switch -/// @param [in] active_node -/// @return Status -/// -Status SwitchOpPass::ModifySwitchOutCtlEdges(NodePtr &switch_node, NodePtr &stream_switch, NodePtr &active_node) { - GE_CHECK_NOTNULL(switch_node); - GE_CHECK_NOTNULL(stream_switch); - GE_CHECK_NOTNULL(active_node); - GELOGI("ModifySwitchOutCtlEdges: switch_node=%s, stream_switch=%s, active_node=%s", switch_node->GetName().c_str(), - stream_switch->GetName().c_str(), active_node->GetName().c_str()); - auto find_res = switch_node_map_.find(switch_node); - GE_IF_BOOL_EXEC(find_res == switch_node_map_.end(), { - GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", switch_node->GetName().c_str()); - return INTERNAL_ERROR; - }); - GE_IF_BOOL_EXEC(find_res->second.empty(), { - GELOGE(INTERNAL_ERROR, "true_nodes of StreamSwitch node %s is empty.", switch_node->GetName().c_str()); - return INTERNAL_ERROR; - }); - - for (NodePtr &node : switch_node->GetOutControlNodes()) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(switch_node->GetOutControlAnchor(), node->GetInControlAnchor()), - "Remove ctl edge fail."); - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - std::string orig_name = op_desc->GetName(); - GE_IF_BOOL_EXEC(op_desc->HasAttr(ATTR_NAME_ORIG_NODE_NAME), { - if (!AttrUtils::GetStr(op_desc, ATTR_NAME_ORIG_NODE_NAME, orig_name) || orig_name.empty()) { - GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ORIG_NODE_NAME fail, node: %s.", op_desc->GetName().c_str()); - return INTERNAL_ERROR; - } - }); - if (find_res->second.find(orig_name) == find_res->second.end()) { - auto active_out_control_anchor = active_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(active_out_control_anchor); - GE_IF_BOOL_EXEC(!active_out_control_anchor->IsLinkedWith(node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(active_out_control_anchor, node->GetInControlAnchor()), "Add ctl edge fail."); - }); - } else { - auto stream_switch_out_control_anchor = stream_switch->GetOutControlAnchor(); - GE_CHECK_NOTNULL(stream_switch_out_control_anchor); - GE_IF_BOOL_EXEC(!stream_switch_out_control_anchor->IsLinkedWith(node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(stream_switch_out_control_anchor, node->GetInControlAnchor()), - "Add ctl edge fail."); - }); - } - } - - GE_IF_BOOL_EXEC(switch_node != stream_switch, (void)bypass_nodes_.insert(switch_node)); - - return SUCCESS; -} - -/// -/// @brief Copy Control Edges -/// @param [in] old_node -/// @param [in] new_node -/// @param [in] input_check_flag -/// @return void -/// -void SwitchOpPass::CopyControlEdges(NodePtr &old_node, NodePtr &new_node, bool input_check_flag) { - GE_CHECK_NOTNULL_JUST_RETURN(old_node); - GE_CHECK_NOTNULL_JUST_RETURN(new_node); - GE_IF_BOOL_EXEC(old_node == new_node, return ); - auto iter = switch_cyclic_map_.find(old_node); - bool check_flag = input_check_flag && (iter != switch_cyclic_map_.end()); - for (NodePtr &node : old_node->GetInControlNodes()) { - if (check_flag && (iter->second.count(node->GetName()) > 0)) { - for (auto &out_node : old_node->GetOutAllNodes()) { - auto out_control_anchor = node->GetOutControlAnchor(); - GE_CHECK_NOTNULL_JUST_RETURN(out_control_anchor); - GE_IF_BOOL_EXEC(!out_control_anchor->IsLinkedWith(out_node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(out_control_anchor, out_node->GetInControlAnchor()), "Add ctl edge fail."); - }); - } - } else { - auto out_control_anchor = node->GetOutControlAnchor(); - GE_CHECK_NOTNULL_JUST_RETURN(out_control_anchor); - GE_IF_BOOL_EXEC(!out_control_anchor->IsLinkedWith(new_node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(out_control_anchor, new_node->GetInControlAnchor()), "Add in ctl edge fail."); - }); - } - } - - for (NodePtr &node : old_node->GetOutControlNodes()) { - GE_IF_BOOL_EXEC(!new_node->GetOutControlAnchor()->IsLinkedWith(node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), node->GetInControlAnchor()), - "Add out ctl edge fail."); - }); - } -} - -/// -/// @brief Remove Control Edges -/// @param [in] node -/// @return void -/// -void SwitchOpPass::RemoveControlEdges(NodePtr &node) { - GE_CHECK_NOTNULL_JUST_RETURN(node); - for (NodePtr &in_node : node->GetInControlNodes()) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(in_node->GetOutControlAnchor(), node->GetInControlAnchor()), - "Remove in ctl edge fail."); - } - - for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { - for (auto &in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(out_data_anchor, in_ctrl_anchor), "Remove in ctl edge fail."); - } - } - - auto out_control_anchor = node->GetOutControlAnchor(); - GE_CHECK_NOTNULL_JUST_RETURN(out_control_anchor); - for (auto &peer_anchor : out_control_anchor->GetPeerAnchors()) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(out_control_anchor, peer_anchor), "Remove out ctl edge fail."); - } -} - -/// -/// @brief Replace Control Edges -/// @param [in] old_node -/// @param [in] new_node -/// @return void -/// -void SwitchOpPass::ReplaceControlEdges(NodePtr &old_node, NodePtr &new_node) { - GE_IF_BOOL_EXEC(old_node == new_node, return ); - CopyControlEdges(old_node, new_node); - RemoveControlEdges(old_node); -} - -/// -/// @brief Mark node as head_node of stream_switch -/// @param [in] node -/// @param [in] stream_switch -/// @return void -/// -void SwitchOpPass::MarkHeadNodes(const NodePtr &node, const NodePtr &stream_switch) { - std::stack nodes; - nodes.push(node); - std::set visited; - while (!nodes.empty()) { - NodePtr cur_node = nodes.top(); - nodes.pop(); - if (visited.count(cur_node) > 0) { - continue; - } - GELOGD("branch_head_node %s of stream_switch %s", cur_node->GetName().c_str(), stream_switch->GetName().c_str()); - branch_head_nodes_[cur_node] = stream_switch; - if ((cur_node->GetType() == IDENTITY) || (cur_node->GetType() == IDENTITYN)) { - for (auto &out_node : cur_node->GetOutAllNodes()) { - nodes.push(out_node); - } - } - visited.insert(cur_node); - } -} - -/// -/// @brief Clear Status, uesd for subgraph pass -/// @return -/// -Status SwitchOpPass::ClearStatus() { - switch_nodes_.clear(); - merge_nodes_.clear(); - enter_nodes_.clear(); - switch_cyclic_map_.clear(); - bypass_nodes_.clear(); - branch_head_nodes_.clear(); - stream_switch_nodes_.clear(); - need_label_nodes_.clear(); - cond_node_map_.clear(); - switch_node_map_.clear(); - node_num_map_.clear(); - return SUCCESS; -} -} // namespace ge diff --git a/src/ge/graph/passes/switch_to_stream_switch_pass.cc b/src/ge/graph/passes/switch_to_stream_switch_pass.cc new file mode 100644 index 00000000..ef8879dd --- /dev/null +++ b/src/ge/graph/passes/switch_to_stream_switch_pass.cc @@ -0,0 +1,755 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/switch_to_stream_switch_pass.h" +#include +#include "common/ge/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/types.h" +#include "ge/ge_api_types.h" +#include "graph/common/omg_util.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/ge_context.h" +#include "graph/utils/type_utils.h" + +namespace ge { +Status SwitchToStreamSwitchPass::Run(ComputeGraphPtr graph) { + GELOGD("SwitchToStreamSwitchPass Enter"); + + GE_CHK_STATUS_RET(CheckCycleDependence(graph), "Check cyclic dependence failed."); + for (const auto &switch_node : switch_nodes_) { + GE_CHK_STATUS_RET(ReplaceSwitchNode(graph, switch_node), "Replace Switch by StreamSwitch failed."); + } + GE_CHK_STATUS_RET(CombineSwitchNode(graph), "Combine StreamSwitch nodes failed."); + + for (const auto &node : bypass_nodes_) { + GE_CHK_BOOL_EXEC(graph->IsolateNode(node) == GRAPH_SUCCESS, return FAILED, "Isolate node failed."); + GE_CHK_BOOL_EXEC(GraphUtils::RemoveNodeWithoutRelink(graph, node) == GRAPH_SUCCESS, return FAILED, + "Remove switch node failed."); + } + + GELOGD("SwitchToStreamSwitchPass Leave"); + return SUCCESS; +} + +/// +/// @brief Clear Status, used for subgraph pass +/// @return +/// +Status SwitchToStreamSwitchPass::ClearStatus() { + switch_nodes_.clear(); + switch_cyclic_map_.clear(); + bypass_nodes_.clear(); + stream_switch_nodes_.clear(); + cond_node_map_.clear(); + switch_node_map_.clear(); + node_num_map_.clear(); + return SUCCESS; +} + +/// +/// @brief Check cyclic dependence +/// @param [in] graph +/// @return Status +/// +Status SwitchToStreamSwitchPass::CheckCycleDependence(const ComputeGraphPtr &graph) { + std::string type; + std::unordered_map> cond_switch_map; + for (const NodePtr &node : graph->GetDirectNode()) { + GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed."); + if ((type == SWITCH) || (type == REFSWITCH)) { + InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); + GE_CHECK_NOTNULL(in_cond_anchor); + OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + if (FindSwitchCondInput(true, peer_out_anchor) != SUCCESS) { + GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str()); + return FAILED; + } + + NodePtr cond_node = peer_out_anchor->GetOwnerNode(); + auto iter = cond_switch_map.find(cond_node); + if (iter == cond_switch_map.end()) { + cond_switch_map[cond_node] = {node}; + } else { + iter->second.emplace_back(node); + } + switch_nodes_.emplace_back(node); + } + } + + MarkCycleDependence(cond_switch_map); + return SUCCESS; +} + +/// +/// @brief Mark cyclic dependence +/// @param [in] graph +/// @param [in] cond_switch_map +/// @return void +/// +void SwitchToStreamSwitchPass::MarkCycleDependence( + const std::unordered_map> &cond_switch_map) { + std::stack out_nodes; + NodePtr tmp_node = nullptr; + std::unordered_set visited; + for (const auto &iter : cond_switch_map) { + std::set switch_nodes(iter.second.begin(), iter.second.end()); + for (const auto &switch_node : switch_nodes) { + GELOGD("MarkCycleDependence: cond_node=%s, switch=%s.", iter.first->GetName().c_str(), + switch_node->GetName().c_str()); + for (const auto &node : switch_node->GetOutAllNodes()) { + out_nodes.push(node); + } + } + visited.clear(); + while (!out_nodes.empty()) { + tmp_node = out_nodes.top(); + out_nodes.pop(); + if (visited.count(tmp_node) > 0) { + continue; + } + GELOGD("MarkCycleDependence: tmp_node=%s.", tmp_node->GetName().c_str()); + for (const NodePtr &out_node : tmp_node->GetOutAllNodes()) { + if (switch_nodes.find(out_node) == switch_nodes.end()) { + out_nodes.push(out_node); + continue; + } + GE_IF_BOOL_EXEC(SetCyclicDependenceFlag(out_node) != SUCCESS, GELOGW("set cyclic dependence attr failed."); + return ); + auto map_iter = switch_cyclic_map_.find(out_node); + if (map_iter == switch_cyclic_map_.end()) { + switch_cyclic_map_[out_node] = {tmp_node->GetName()}; + } else { + map_iter->second.insert(tmp_node->GetName()); + } + } + visited.insert(tmp_node); + } + } + + return; +} + +/// +/// @brief Replace Switch Op +/// @param [in] graph +/// @param [in] switch_node +/// @return Status +/// +Status SwitchToStreamSwitchPass::ReplaceSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node) { + OutDataAnchorPtr peer_data_anchor = nullptr; + OutDataAnchorPtr peer_cond_anchor = nullptr; + GE_CHK_BOOL_EXEC(BypassSwitchNode(switch_node, peer_data_anchor, peer_cond_anchor) == SUCCESS, return FAILED, + "Bypass switch node %s failed.", switch_node->GetName().c_str()); + GE_CHECK_NOTNULL(peer_data_anchor); + GE_CHECK_NOTNULL(peer_cond_anchor); + OpDescPtr cond_desc = peer_cond_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHECK_NOTNULL(cond_desc); + DataType cond_data_type = cond_desc->GetOutputDesc(peer_cond_anchor->GetIdx()).GetDataType(); + GE_CHK_BOOL_EXEC(cond_data_type == DT_BOOL, return FAILED, + "pred_input of Switch only support DT_BOOL data_type, but %s exactly.", + TypeUtils::DataTypeToSerialString(cond_data_type).c_str()); + + OpDescPtr switch_desc = switch_node->GetOpDesc(); + GE_CHECK_NOTNULL(switch_desc); + bool cyclic_flag = switch_desc->HasAttr(ATTR_NAME_CYCLIC_DEPENDENCE_FLAG); + std::set out_node_list; + for (const auto &out_data_anchor : switch_node->GetAllOutDataAnchors()) { + bool true_branch_flag = (static_cast(out_data_anchor->GetIdx()) == SWITCH_TRUE_OUTPUT); + NodePtr stream_switch = nullptr; + out_node_list.clear(); + for (const auto &peer_in_anchor : out_data_anchor->GetPeerAnchors()) { + GE_IF_BOOL_EXEC(stream_switch == nullptr, { + stream_switch = CreateStreamSwitchNode(graph, switch_node, true_branch_flag ? "_t" : "_f", peer_cond_anchor); + GE_CHK_BOOL_EXEC(stream_switch != nullptr, return FAILED, "Create stream_switch node failed."); + if (SetSwitchTrueBranchFlag(stream_switch, true_branch_flag) != SUCCESS) { + GELOGE(FAILED, "SetSwitchTrueBranchFlag for node %s failed.", stream_switch->GetName().c_str()); + return FAILED; + } + if (MarkBranches(peer_cond_anchor, stream_switch, true_branch_flag) != SUCCESS) { + GELOGE(FAILED, "Mark branches for stream_switch %s failed.", stream_switch->GetName().c_str()); + return FAILED; + } + + if (!cyclic_flag) { + GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor->GetOwnerNode()->GetOutControlAnchor(), + stream_switch->GetInControlAnchor()), + "StreamSwitch node add ctl edge failed."); + } + }); + + GE_CHK_STATUS(GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor), "Remove Switch data output failed."); + + NodePtr out_node = peer_in_anchor->GetOwnerNode(); + GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor, peer_in_anchor), "StreamSwitch node add edge failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(stream_switch->GetOutControlAnchor(), out_node->GetInControlAnchor()), + "StreamSwitch node add ctl edge failed."); + out_node_list.insert(out_node->GetName()); + } + + GE_IF_BOOL_EXEC(stream_switch != nullptr, { + MoveCtrlEdges(switch_node, stream_switch); + switch_node_map_[stream_switch] = out_node_list; + if (SetOriginalNodeName(stream_switch, switch_node->GetName()) != SUCCESS) { + GELOGE(FAILED, "SetOriginalNodeName for node %s failed.", stream_switch->GetName().c_str()); + return FAILED; + } + }); + } + + (void)bypass_nodes_.insert(switch_node); + return SUCCESS; +} + +/// +/// @brief Bypass Switch Node +/// @param [in] switch_node +/// @param [out] peer_data_anchor +/// @param [out] peer_cond_anchor +/// @return Status +/// +Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor, + OutDataAnchorPtr &peer_cond_anchor) { + for (uint32_t idx = 0; idx < SWITCH_INPUT_NUM; ++idx) { + InDataAnchorPtr in_data_anchor = switch_node->GetInDataAnchor(idx); + GE_CHECK_NOTNULL(in_data_anchor); + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + // Remove Switch data input. + if (GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Remove data edge %s->%s failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), + switch_node->GetName().c_str()); + return FAILED; + } + + if (idx == SWITCH_DATA_INPUT) { + peer_data_anchor = peer_out_anchor; + } else { + if (FindSwitchCondInput(false, peer_out_anchor) != SUCCESS) { + GELOGE(FAILED, "Find pred_input for switch_node %s failed.", switch_node->GetName().c_str()); + return FAILED; + } + peer_cond_anchor = peer_out_anchor; + } + } + + return SUCCESS; +} + +/// +/// @brief Find Switch cond input +/// @param [in] pass_switch_flag +/// @param [out] peer_cond_anchor +/// @return Status +/// +Status SwitchToStreamSwitchPass::FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor) { + NodePtr tmp_node = nullptr; + string type; + bool need_pass_type = true; + while (need_pass_type) { + if (tmp_node == nullptr) { + tmp_node = peer_cond_anchor->GetOwnerNode(); + } else { + InDataAnchorPtr in_data_anchor = tmp_node->GetInDataAnchor(SWITCH_DATA_INPUT); + GE_CHECK_NOTNULL(in_data_anchor); + peer_cond_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_cond_anchor); + tmp_node = peer_cond_anchor->GetOwnerNode(); + } + + GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type failed."); + need_pass_type = (pass_switch_flag && ((type == SWITCH) || (type == REFSWITCH))); + } + + return SUCCESS; +} + +/// +/// @brief Create StreamSwitch Node +/// @param [in] graph +/// @param [in] switch_node +/// @param [in] suffix +/// @param [in] peer_cond_anchor +/// @return ge::NodePtr +/// +NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node, + const std::string &suffix, + const OutDataAnchorPtr &peer_cond_anchor) { + OpDescPtr switch_op_desc = switch_node->GetOpDesc(); + GE_CHK_BOOL_EXEC(switch_op_desc != nullptr, return nullptr, "OpDesc of Switch node is invalid."); + GE_IF_BOOL_EXEC(switch_op_desc->GetInputsSize() != SWITCH_INPUT_NUM, { + GELOGE(FAILED, "Switch input param invalid, input_size=%lu, should be %u.", switch_op_desc->GetInputsSize(), + SWITCH_INPUT_NUM); + return nullptr; + }); + + const std::string &node_name = switch_node->GetName() + "_" + STREAMSWITCH + suffix; + GELOGI("Create StreamSwitch, name=%s.", node_name.c_str()); + OpDescPtr op_desc = MakeShared(node_name, STREAMSWITCH); + if (op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc failed, StreamSwitch:%s.", node_name.c_str()); + return nullptr; + } + + // mark hccl group id + std::string hccl_group_id; + if (AttrUtils::GetStr(switch_node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { + (void)AttrUtils::SetStr(op_desc, ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id); + GELOGD("Set attr ATTR_NAME_HCCL_FUSED_GROUP for Stream_Switch %s, value is %s.", node_name.c_str(), + hccl_group_id.c_str()); + } + + if (!AttrUtils::SetInt(op_desc, ATTR_NAME_SWITCH_DATA_TYPE, RT_SWITCH_INT32) || + !AttrUtils::SetInt(op_desc, ATTR_NAME_STREAM_SWITCH_COND, (int64_t)RT_EQUAL)) { + GELOGE(INTERNAL_ERROR, "set int failed"); + return nullptr; + } + + // Already checked, first input is Variable will passed, second is condition will checked. + GeTensorDesc cond_input_desc = switch_op_desc->GetInputDesc(SWITCH_PRED_INPUT); + GeTensorDesc input_desc(GeShape(cond_input_desc.GetShape().GetDims()), cond_input_desc.GetFormat(), DT_INT32); + GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(input_desc) == GRAPH_SUCCESS, return nullptr, + "Create StreamSwitch node: add input desc failed."); + GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(input_desc) == GRAPH_SUCCESS, return nullptr, + "Create StreamSwitch node: add input desc failed."); + + NodePtr stream_switch = graph->AddNode(op_desc); + GE_CHK_BOOL_EXEC(stream_switch != nullptr, return nullptr, "Insert StreamSwitch node failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), + "StreamSwitch node add cond edge failed."); + + return stream_switch; +} + +/// +/// @brief Mark Switch Branch +/// @param [in] peer_cond_anchor +/// @param [in] stream_switch +/// @param [in] true_branch_flag +/// @return Status +/// +Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_anchor, const NodePtr &stream_switch, + bool true_branch_flag) { + uint32_t index = true_branch_flag ? SWITCH_TRUE_OUTPUT : SWITCH_FALSE_OUTPUT; + auto it = cond_node_map_.find(peer_cond_anchor); + if (it != cond_node_map_.end()) { + int64_t switch_group_id = GetGroupId(stream_switch); + auto switch_group_it = it->second.find(switch_group_id); + if (switch_group_it == it->second.end()) { + std::list false_node_list; + std::list true_node_list; + std::list &node_list = true_branch_flag ? true_node_list : false_node_list; + node_list.emplace_back(stream_switch); + std::vector> switch_list; + switch_list.emplace_back(false_node_list); + switch_list.emplace_back(true_node_list); + it->second[switch_group_id] = switch_list; + } else { + GE_IF_BOOL_EXEC(switch_group_it->second.size() != SWITCH_OUTPUT_NUM, { + GELOGE(INTERNAL_ERROR, "Check size failed, node: %s", stream_switch->GetName().c_str()); + return FAILED; + }); + switch_group_it->second[index].emplace_back(stream_switch); + } + } else { + int64_t switch_group_id = GetGroupId(stream_switch); + map>> switch_group_map; + std::list false_node_list; + std::list true_node_list; + std::list &node_list = true_branch_flag ? true_node_list : false_node_list; + node_list.emplace_back(stream_switch); + std::vector> switch_list; + switch_list.emplace_back(false_node_list); + switch_list.emplace_back(true_node_list); + switch_group_map[switch_group_id] = switch_list; + cond_node_map_[peer_cond_anchor] = switch_group_map; + } + return SUCCESS; +} + +/// +/// @brief Get group_id for switch_node +/// @param [in] node +/// @return group_id +/// +int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { + string tailing_optimization_option; + bool is_tailing_optimization = false; + if (GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option) == GRAPH_SUCCESS) { + // "1" means it's True from frontend option + is_tailing_optimization = (tailing_optimization_option == "1"); + GELOGI("Option ge.exec.isTailingOptimization is %s", tailing_optimization_option.c_str()); + } + if (!is_tailing_optimization) { + return 0; + } + + string hccl_group_id; + if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { + GELOGI("Node %s can not find hccl group id.", node->GetName().c_str()); + return 0; + } + auto key_index = hccl_group_id.find_last_of('_'); + auto key_num = hccl_group_id.substr(key_index + 1, hccl_group_id.length() - key_index); + GELOGI("Node:%s, hccl_group_id=%s, key_num=%s", node->GetName().c_str(), hccl_group_id.c_str(), key_num.c_str()); + int64_t num = atoi(key_num.c_str()); + if (num == 0) { + return 0; + } + + GELOGI("Hccl_group_id is %s, group_id is %ld", hccl_group_id.c_str(), num); + return num; +} + +/// +/// @brief Combine switch nodes link to same cond +/// @param [in] graph +/// @return Status +/// +Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) { + for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { + for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { + std::list false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; + std::list true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; + std::set same_cond_switch; + same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); + same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); + + OutDataAnchorPtr peer_cond_anchor = iter->first; + NodePtr cond_node = peer_cond_anchor->GetOwnerNode(); + GELOGI("CombineSwitchNode: cond_node=%s.", cond_node->GetName().c_str()); + + NodePtr cast_node = CreateCastOp(graph, peer_cond_anchor); + GE_CHK_BOOL_EXEC(cast_node != nullptr, return FAILED, "Create cast_node failed."); + + NodePtr active_node = CreateActiveNode(graph, cond_node); + GE_CHK_BOOL_EXEC(active_node != nullptr, return FAILED, "Create StreamActive node failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutControlAnchor(), active_node->GetInControlAnchor()), + "StreamActive add ctl edge failed."); + if (SetActiveLabelList(active_node, {cast_node->GetName()}) != SUCCESS) { + GELOGE(FAILED, "Set active_label_list attr for node %s failed.", active_node->GetName().c_str()); + return FAILED; + } + + const std::string &cond_group = cond_node->GetName(); + for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { + bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); + std::list &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); + GE_IF_BOOL_EXEC(switch_list.empty(), continue); + + // select first stream_switch + NodePtr stream_switch = switch_list.front(); + OpDescPtr switch_desc = stream_switch->GetOpDesc(); + GE_CHECK_NOTNULL(switch_desc); + switch_desc->SetName(CheckDuplicateName(cond_group + "/" + STREAMSWITCH + (true_branch_flag ? "_t" : "_f"))); + stream_switch_nodes_.emplace_back(stream_switch); + + // 0_input: original pred input, 1_input: constant node + GE_CHK_STATUS_RET(AddConstNode(graph, stream_switch), "Add const node failed."); + GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), + "StreamSwitch remove data edge failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)), + "Cast add data edge failed."); + + for (const NodePtr &node : switch_list) { + GE_IF_BOOL_EXEC(node != stream_switch, { + GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), + "StreamSwitch remove data edge failed."); + }); + GE_CHK_STATUS(ModifySwitchInCtlEdges(node, cast_node, same_cond_switch), "ModifySwitchInCtlEdges failed."); + GE_CHK_STATUS(ModifySwitchOutCtlEdges(node, stream_switch, active_node), "ModifySwitchOutCtlEdges failed."); + } + + GE_CHK_STATUS(GraphUtils::AddEdge(active_node->GetOutControlAnchor(), stream_switch->GetInControlAnchor()), + "StreamActive add ctl edge failed."); + } + } + } + return SUCCESS; +} + +/// +/// @brief Create Active Op +/// @param [in] graph +/// @param [in] cond_node +/// @return ge::NodePtr +/// +NodePtr SwitchToStreamSwitchPass::CreateActiveNode(const ComputeGraphPtr &graph, const NodePtr &node) { + const std::string &node_name = CheckDuplicateName(node->GetName() + "_" + STREAMACTIVE); + GELOGI("Create StreamActive op:%s.", node_name.c_str()); + OpDescPtr op_desc = MakeShared(node_name, STREAMACTIVE); + if (op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc failed, StreamActive:%s.", node_name.c_str()); + return nullptr; + } + + NodePtr active_node = graph->AddNode(op_desc); + GE_CHK_BOOL_EXEC(active_node != nullptr, return nullptr, "Create StreamActive node failed."); + + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(node->GetOutControlAnchor(), active_node->GetInControlAnchor()) != SUCCESS, + GELOGE(INTERNAL_ERROR, "add edge failed"); + return nullptr); + + GE_IF_BOOL_EXEC(SetSwitchBranchNodeLabel(active_node, node_name) != SUCCESS, + GELOGE(INTERNAL_ERROR, "set switch branch node label failed"); + return nullptr); + + return active_node; +} + +/// +/// @brief Create cast node +/// @param [in] graph +/// @param [in] peer_cond_anchor +/// @return NodePtr +/// +NodePtr SwitchToStreamSwitchPass::CreateCastOp(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_cond_anchor) { + OpDescPtr cond_desc = peer_cond_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHK_BOOL_EXEC(cond_desc != nullptr, return nullptr, "Get cond_desc failed."); + + const std::string &cast_name = CheckDuplicateName(cond_desc->GetName() + "_" + CAST); + GELOGI("Create cast_node: %s, input datatype:DT_BOOL, out datatype:DT_INT32", cast_name.c_str()); + OpDescPtr cast_desc = MakeShared(cast_name, CAST); + if (cast_desc == nullptr) { + GELOGE(FAILED, "Create op_desc failed, Cast:%s.", cast_name.c_str()); + return nullptr; + } + if (!(AttrUtils::SetInt(cast_desc, CAST_ATTR_SRCT, (int64_t)DT_BOOL) && + AttrUtils::SetInt(cast_desc, CAST_ATTR_DSTT, (int64_t)DT_INT32) && + AttrUtils::SetInt(cast_desc, CAST_ATTR_DST_TYPE, (int64_t)DT_INT32) && + AttrUtils::SetBool(cast_desc, CAST_ATTR_TRUNCATE, false))) { + GELOGE(FAILED, "Set CAST_ATTR_SRCT or CAST_ATTR_DSTT or CAST_ATTR_DST_TYPE or CAST_ATTR_TRUNCATE failed, node: %s.", + cast_name.c_str()); + return nullptr; + } + + GeTensorDesc tensor_desc = cond_desc->GetOutputDesc(peer_cond_anchor->GetIdx()); + tensor_desc.SetDataType(DT_BOOL); + GE_CHK_BOOL_EXEC(cast_desc->AddInputDesc(tensor_desc) == SUCCESS, return nullptr, "Cast_node add input desc failed."); + tensor_desc.SetDataType(DT_INT32); + GE_CHK_BOOL_EXEC(cast_desc->AddOutputDesc(tensor_desc) == SUCCESS, return nullptr, + "Cast_node add output desc failed."); + + NodePtr cast_node = graph->AddNode(cast_desc); + GE_CHK_BOOL_EXEC(cast_node != nullptr, return nullptr, "Create cast_node failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, cast_node->GetInDataAnchor(0)), "Cast add data edge failed."); + + return cast_node; +} + +/// +/// @brief Add const node as switch input1 +/// @param [in] graph +/// @param [in] stream_switch +/// @return Status +/// +Status SwitchToStreamSwitchPass::AddConstNode(const ComputeGraphPtr &graph, const NodePtr &stream_switch) { + OpDescPtr op_desc = stream_switch->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + bool value = false; + GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, + "StreamSwitch get attr TRUE_BRANCH_STREAM failed."); + + const std::string &const_node_name = op_desc->GetName() + "_Constant_" + (value ? "t" : "f"); + GELOGI("Create const op: %s", const_node_name.c_str()); + OpDescPtr const_op_desc = MakeShared(const_node_name, CONSTANT); + if (const_op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc failed, Constant:%s.", const_node_name.c_str()); + return FAILED; + } + + auto resize_value = (int32_t)value; + GeTensorDesc data_desc = op_desc->GetInputDesc(1); + GeTensorPtr const_value = + MakeShared(data_desc, reinterpret_cast(&resize_value), sizeof(int32_t)); + if (const_value == nullptr) { + GELOGE(FAILED, "Create tensor failed."); + return FAILED; + } + GE_CHK_BOOL_EXEC(AttrUtils::SetTensor(const_op_desc, ATTR_NAME_WEIGHTS, const_value), return FAILED); + GE_CHK_BOOL_EXEC(const_op_desc->AddOutputDesc(data_desc) == GRAPH_SUCCESS, return FAILED, + "Create Const op: add output desc failed."); + + NodePtr const_node = graph->AddNode(const_op_desc); + GE_CHK_BOOL_EXEC(const_node != nullptr, return FAILED, "Insert Const node failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(1)), + "StreamSwitch node add ctl edge failed."); + + return SUCCESS; +} + +/// +/// @brief Modify in ctl edge for switch_node +/// @param [in] switch_node +/// @param [in] cast_node +/// @param [in] same_cond_switch +/// @return Status +/// +Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_node, const NodePtr &cast_node, + const std::set &same_cond_switch) { + GELOGI("ModifySwitchInCtlEdges: switch_node=%s, active_node=%s", switch_node->GetName().c_str(), + cast_node->GetName().c_str()); + std::string orig_switch_name = switch_node->GetName(); + OpDescPtr switch_desc = switch_node->GetOpDesc(); + GE_CHECK_NOTNULL(switch_desc); + if (!AttrUtils::GetStr(switch_desc, ATTR_NAME_ORIG_NODE_NAME, orig_switch_name) || orig_switch_name.empty()) { + GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ORIG_NODE_NAME failed, node: %s", switch_desc->GetName().c_str()); + return INTERNAL_ERROR; + } + + for (const NodePtr &in_ctl_node : switch_node->GetInControlNodes()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), + "Remove ctl edge failed."); + GE_IF_BOOL_EXEC(!in_ctl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), + "Add ctl edge failed."); + }); + + GE_IF_BOOL_EXEC(in_ctl_node->GetType() != STREAMSWITCH, continue); + if (same_cond_switch.count(in_ctl_node) > 0) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), + "Remove ctl edge failed."); + continue; + } + + auto find_res1 = switch_node_map_.find(in_ctl_node); + GE_IF_BOOL_EXEC(find_res1 == switch_node_map_.end(), { + GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctl_node->GetName().c_str()); + return INTERNAL_ERROR; + }); + auto find_res2 = find_res1->second.find(orig_switch_name); + auto find_res3 = find_res1->second.find(cast_node->GetName()); + GE_IF_BOOL_EXEC((find_res2 != find_res1->second.end()) && (find_res3 == find_res1->second.end()), { + find_res1->second.erase(find_res2); + find_res1->second.insert(cast_node->GetName()); + continue; + }); + } + + return SUCCESS; +} + +/// +/// @brief Modify out ctl edge for switch_node +/// @param [in] switch_node +/// @param [in] stream_switch +/// @param [in] active_node +/// @return Status +/// +Status SwitchToStreamSwitchPass::ModifySwitchOutCtlEdges(const NodePtr &switch_node, const NodePtr &stream_switch, + const NodePtr &active_node) { + GELOGI("ModifySwitchOutCtlEdges: switch_node=%s, stream_switch=%s, active_node=%s", switch_node->GetName().c_str(), + stream_switch->GetName().c_str(), active_node->GetName().c_str()); + auto find_res = switch_node_map_.find(switch_node); + GE_IF_BOOL_EXEC(find_res == switch_node_map_.end(), { + GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", switch_node->GetName().c_str()); + return INTERNAL_ERROR; + }); + GE_IF_BOOL_EXEC(find_res->second.empty(), { + GELOGE(INTERNAL_ERROR, "true_nodes of StreamSwitch node %s is empty.", switch_node->GetName().c_str()); + return INTERNAL_ERROR; + }); + + for (const NodePtr &node : switch_node->GetOutControlNodes()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(switch_node->GetOutControlAnchor(), node->GetInControlAnchor()), + "Remove ctl edge failed."); + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + std::string orig_name = op_desc->GetName(); + GE_IF_BOOL_EXEC(op_desc->HasAttr(ATTR_NAME_ORIG_NODE_NAME), { + if (!AttrUtils::GetStr(op_desc, ATTR_NAME_ORIG_NODE_NAME, orig_name) || orig_name.empty()) { + GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ORIG_NODE_NAME failed, node: %s.", op_desc->GetName().c_str()); + return INTERNAL_ERROR; + } + }); + if (find_res->second.find(orig_name) == find_res->second.end()) { + auto active_out_ctrl_anchor = active_node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(active_out_ctrl_anchor); + GE_IF_BOOL_EXEC(!active_out_ctrl_anchor->IsLinkedWith(node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(active_out_ctrl_anchor, node->GetInControlAnchor()), "Add ctl edge failed."); + }); + } else { + auto switch_out_ctrl_anchor = stream_switch->GetOutControlAnchor(); + GE_CHECK_NOTNULL(switch_out_ctrl_anchor); + GE_IF_BOOL_EXEC(!switch_out_ctrl_anchor->IsLinkedWith(node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(switch_out_ctrl_anchor, node->GetInControlAnchor()), "Add ctl edge failed."); + }); + } + } + + GE_IF_BOOL_EXEC(switch_node != stream_switch, (void)bypass_nodes_.insert(switch_node)); + return SUCCESS; +} + +/// +/// @brief Check duplicate node_name +/// @param [in] node_name +/// @return std::string +/// +std::string SwitchToStreamSwitchPass::CheckDuplicateName(const std::string &node_name) { + std::string tmp_name = node_name; + auto iter = node_num_map_.find(tmp_name); + if (iter != node_num_map_.end()) { + tmp_name = tmp_name + "_" + std::to_string(iter->second); + (iter->second)++; + } else { + node_num_map_[tmp_name] = 1; + } + return tmp_name; +} + +/// +/// @brief Move Control Edges +/// @param [in] old_node +/// @param [in] new_node +/// @return void +/// +void SwitchToStreamSwitchPass::MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new_node) { + GE_IF_BOOL_EXEC(old_node == new_node, return ); + auto iter = switch_cyclic_map_.find(old_node); + bool check_flag = (iter != switch_cyclic_map_.end()); + for (const NodePtr &in_node : old_node->GetInControlNodes()) { + auto out_ctrl_anchor = in_node->GetOutControlAnchor(); + GE_CHECK_NOTNULL_JUST_RETURN(out_ctrl_anchor); + if (check_flag && (iter->second.count(in_node->GetName()) > 0)) { + for (const auto &out_node : old_node->GetOutAllNodes()) { + GE_IF_BOOL_EXEC(!out_ctrl_anchor->IsLinkedWith(out_node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(out_ctrl_anchor, out_node->GetInControlAnchor()), + "Add in ctrl edge failed."); + }); + } + } else { + GE_IF_BOOL_EXEC(!out_ctrl_anchor->IsLinkedWith(new_node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(out_ctrl_anchor, new_node->GetInControlAnchor()), "Add in ctrl edge failed."); + }); + } + GE_CHK_STATUS(GraphUtils::RemoveEdge(out_ctrl_anchor, old_node->GetInControlAnchor()), + "Remove in ctrl edge failed."); + } + + for (const NodePtr &out_node : old_node->GetOutControlNodes()) { + GE_IF_BOOL_EXEC(!new_node->GetOutControlAnchor()->IsLinkedWith(out_node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), out_node->GetInControlAnchor()), + "Add out ctrl edge failed."); + }); + GE_CHK_STATUS(GraphUtils::RemoveEdge(old_node->GetOutControlAnchor(), out_node->GetInControlAnchor()), + "Remove out ctrl edge failed."); + } +} +} // namespace ge diff --git a/src/ge/graph/passes/switch_op_pass.h b/src/ge/graph/passes/switch_to_stream_switch_pass.h similarity index 61% rename from src/ge/graph/passes/switch_op_pass.h rename to src/ge/graph/passes/switch_to_stream_switch_pass.h index 202b919c..15fe9dce 100644 --- a/src/ge/graph/passes/switch_op_pass.h +++ b/src/ge/graph/passes/switch_to_stream_switch_pass.h @@ -14,15 +14,9 @@ * limitations under the License. */ -#ifndef GE_GRAPH_PASSES_SWITCH_OP_PASS_H_ -#define GE_GRAPH_PASSES_SWITCH_OP_PASS_H_ - -#include -#include -#include -#include -#include -#include +#ifndef GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_ +#define GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_ + #include "inc/graph_pass.h" namespace ge { @@ -91,78 +85,158 @@ namespace ge { +-----------+ +-----------+ +-----------+ +-----| Less |----+ +-----------+ */ -class SwitchOpPass : public GraphPass { +class SwitchToStreamSwitchPass : public GraphPass { public: Status Run(ComputeGraphPtr graph); + + /// + /// @brief Clear Status, used for subgraph pass + /// @return + /// Status ClearStatus() override; private: - Status ReplaceSwitchNode(ComputeGraphPtr &graph, NodePtr &switch_node); - - Status ReplaceMergeNode(ComputeGraphPtr &graph, NodePtr &merge_node); - - NodePtr CreateStreamSwitchNode(ComputeGraphPtr &graph, const NodePtr &switch_node, const std::string &suffix, - OutDataAnchorPtr &peer_cond_anchor); - - NodePtr CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag); - - Status CombineSwitchNode(ComputeGraphPtr &graph); - - NodePtr CreateActiveNode(ComputeGraphPtr &graph, NodePtr &node); - - Status AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &stream_merge_node, bool multi_batch_flag); - - Status BypassSwitchNode(NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor, OutDataAnchorPtr &peer_cond_anchor); + /// + /// @brief Check cyclic dependence + /// @param [in] graph + /// @return Status + /// + Status CheckCycleDependence(const ComputeGraphPtr &graph); + + /// + /// @brief Mark cyclic dependence + /// @param [in] graph + /// @param [in] cond_switch_map + /// @return void + /// + void MarkCycleDependence(const std::unordered_map> &cond_switch_map); + /// + /// @brief Replace Switch Op + /// @param [in] graph + /// @param [in] switch_node + /// @return Status + /// + Status ReplaceSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node); + + /// + /// @brief Bypass Switch Node + /// @param [in] switch_node + /// @param [out] peer_data_anchor + /// @param [out] peer_cond_anchor + /// @return Status + /// + Status BypassSwitchNode(const NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor, + OutDataAnchorPtr &peer_cond_anchor); + + /// + /// @brief Find Switch cond input + /// @param [in] pass_switch_flag + /// @param [out] peer_cond_anchor + /// @return Status + /// Status FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor); - Status MarkBranchs(OutDataAnchorPtr &peer_cond_anchor, NodePtr &stream_switch_node, bool true_branch_flag); - - NodePtr CreateCastOp(ComputeGraphPtr &graph, OutDataAnchorPtr &peer_cond_anchor); - - Status AddConstNode(ComputeGraphPtr &graph, NodePtr &stream_switch_node); - - Status UpdateCondBranch(NodePtr &node); - - Status UpdateAttachFlag(const NodePtr &node, std::string &stream_label, bool &merge_flag, bool &exit_flag, - bool &net_output_flag); - - Status UpdateLoopBranch(const std::stack &enter_nodes, const std::string &stream_label); - - Status UpdateEnterNode(); + /// + /// @brief Create StreamSwitch Node + /// @param [in] graph + /// @param [in] switch_node + /// @param [in] suffix + /// @param [in] peer_cond_anchor + /// @return ge::NodePtr + /// + NodePtr CreateStreamSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node, const std::string &suffix, + const OutDataAnchorPtr &peer_cond_anchor); + + /// + /// @brief Mark Switch Branch + /// @param [in] peer_cond_anchor + /// @param [in] stream_switch + /// @param [in] true_branch_flag + /// @return Status + /// + Status MarkBranches(const OutDataAnchorPtr &peer_cond_anchor, const NodePtr &stream_switch_node, + bool true_branch_flag); + + /// + /// @brief Get group_id for switch_node + /// @param [in] node + /// @return group_id + /// + int64_t GetGroupId(const NodePtr &node); + /// + /// @brief Combine switch nodes link to same cond + /// @param [in] graph + /// @return Status + /// + Status CombineSwitchNode(const ComputeGraphPtr &graph); + + /// + /// @brief Create cast node + /// @param [in] graph + /// @param [in] peer_cond_anchor + /// @return NodePtr + /// + NodePtr CreateCastOp(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_cond_anchor); + + /// + /// @brief Create Active Op + /// @param [in] graph + /// @param [in] cond_node + /// @return ge::NodePtr + /// + NodePtr CreateActiveNode(const ComputeGraphPtr &graph, const NodePtr &node); + + /// + /// @brief Add const node as switch input1 + /// @param [in] graph + /// @param [in] stream_switch + /// @return Status + /// + Status AddConstNode(const ComputeGraphPtr &graph, const NodePtr &stream_switch_node); + + /// + /// @brief Modify in ctl edge for switch_node + /// @param [in] switch_node + /// @param [in] cast_node + /// @param [in] same_cond_switch + /// @return Status + /// + Status ModifySwitchInCtlEdges(const NodePtr &switch_node, const NodePtr &cast_node, + const std::set &same_cond_switch); + + /// + /// @brief Modify out ctl edge for switch_node + /// @param [in] switch_node + /// @param [in] stream_switch + /// @param [in] active_node + /// @return Status + /// + Status ModifySwitchOutCtlEdges(const NodePtr &switch_node, const NodePtr &stream_switch, const NodePtr &active_node); + + /// + /// @brief Check duplicate node_name + /// @param [in] node_name + /// @return std::string + /// std::string CheckDuplicateName(const std::string &node_name); - Status CheckCycleDependence(ComputeGraphPtr &graph); - - void MarkCycleDependence(const std::unordered_map> &cond_switch_map); - - Status ModifySwitchInCtlEdges(NodePtr &switch_node, NodePtr &cast_node, const std::set &same_cond_switch); - - Status ModifySwitchOutCtlEdges(NodePtr &switch_node, NodePtr &stream_switch, NodePtr &active_node); - - void CopyControlEdges(NodePtr &old_node, NodePtr &new_node, bool input_check_flag = false); - - void RemoveControlEdges(NodePtr &node); - - void ReplaceControlEdges(NodePtr &old_node, NodePtr &new_node); - - int64_t GetGroupId(const NodePtr &node); - - void MarkHeadNodes(const NodePtr &node, const NodePtr &stream_switch); + /// + /// @brief Move Control Edges + /// @param [in] old_node + /// @param [in] new_node + /// @return void + /// + void MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new_node); std::vector switch_nodes_; - std::vector merge_nodes_; - std::vector enter_nodes_; std::unordered_map> switch_cyclic_map_; - std::set bypass_nodes_; - std::unordered_map branch_head_nodes_; std::vector stream_switch_nodes_; - std::vector need_label_nodes_; std::unordered_map>>> cond_node_map_; std::unordered_map> switch_node_map_; std::unordered_map node_num_map_; }; } // namespace ge -#endif // GE_GRAPH_PASSES_SWITCH_OP_PASS_H_ +#endif // GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_ diff --git a/src/ge/graph/passes/transop_breadth_fusion_pass.cc b/src/ge/graph/passes/transop_breadth_fusion_pass.cc index 53f9e825..d8df4a22 100644 --- a/src/ge/graph/passes/transop_breadth_fusion_pass.cc +++ b/src/ge/graph/passes/transop_breadth_fusion_pass.cc @@ -19,14 +19,12 @@ #include #include -#include "framework/common/debug/ge_log.h" #include "common/types.h" #include "graph/common/transop_util.h" #include "graph/utils/node_utils.h" namespace ge { Status TransOpBreadthFusionPass::Run(ge::ComputeGraphPtr graph) { - GE_TIMESTAMP_START(TransOpBreadthFusionPass); if (graph == nullptr) { return SUCCESS; } @@ -47,7 +45,6 @@ Status TransOpBreadthFusionPass::Run(ge::ComputeGraphPtr graph) { } } } - GE_TIMESTAMP_END(TransOpBreadthFusionPass, "GraphManager::TransOpBreadthFusionPass"); return SUCCESS; } diff --git a/src/ge/graph/passes/transop_depth_fusion_pass.cc b/src/ge/graph/passes/transop_depth_fusion_pass.cc index c0c854b6..afeca3c4 100644 --- a/src/ge/graph/passes/transop_depth_fusion_pass.cc +++ b/src/ge/graph/passes/transop_depth_fusion_pass.cc @@ -17,7 +17,6 @@ #include "graph/passes/transop_depth_fusion_pass.h" #include -#include "framework/common/debug/ge_log.h" #include "common/ge_inner_error_codes.h" #include "common/types.h" #include "graph/compute_graph.h" @@ -29,7 +28,6 @@ namespace ge { graphStatus TransOpDepthFusionPass::Run(ComputeGraphPtr graph) { - GE_TIMESTAMP_START(TransOpDepthFusionPass); GELOGI("[TransOpDepthFusionPass]: optimize in depth begin..."); if (graph == nullptr) { return GRAPH_SUCCESS; @@ -53,7 +51,6 @@ graphStatus TransOpDepthFusionPass::Run(ComputeGraphPtr graph) { } } GELOGI("[TransOpDepthFusionPass]: Optimize in depth success..."); - GE_TIMESTAMP_END(TransOpDepthFusionPass, "GraphManager::TransOpDepthFusionPass"); return GRAPH_SUCCESS; } diff --git a/src/ge/graph/passes/transop_symmetry_elimination_pass.cc b/src/ge/graph/passes/transop_symmetry_elimination_pass.cc index 38b6684b..2ff7cd82 100644 --- a/src/ge/graph/passes/transop_symmetry_elimination_pass.cc +++ b/src/ge/graph/passes/transop_symmetry_elimination_pass.cc @@ -24,7 +24,6 @@ namespace { const int kTransOpOutIndex = 0; static std::map precision_loss_transfer_map = {{ge::DT_FLOAT, ge::DT_BOOL}}; - } // namespace namespace ge { Status TransOpSymmetryEliminationPass::Run(NodePtr &node) { diff --git a/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc b/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc index ba4cd031..1d97d9a1 100644 --- a/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc +++ b/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc @@ -22,7 +22,6 @@ #include "common/ge/ge_util.h" #include "common/ge_inner_error_codes.h" #include "common/types.h" -#include "framework/common/debug/ge_log.h" #include "graph/compute_graph.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_tensor.h" @@ -733,7 +732,6 @@ void TransOpWithoutReshapeFusionPass::RemoveNousedNodes(const ComputeGraphPtr &g } graphStatus TransOpWithoutReshapeFusionPass::Run(ComputeGraphPtr graph) { - GE_TIMESTAMP_START(TransOpWithoutReshapeFusionPass); GELOGI("[TransOpWithoutReshapeFusionPass]: optimize begin."); if (graph == nullptr) { return GRAPH_SUCCESS; @@ -786,7 +784,6 @@ graphStatus TransOpWithoutReshapeFusionPass::Run(ComputeGraphPtr graph) { } } GELOGI("[TransOpWithoutReshapeFusionPass]: Optimize end."); - GE_TIMESTAMP_END(TransOpWithoutReshapeFusionPass, "GraphManager::TransOpWithoutReshapeFusionPass"); return GRAPH_SUCCESS; } diff --git a/src/ge/graph/passes/variable_op_pass.cc b/src/ge/graph/passes/variable_op_pass.cc index 175a049a..8c34cd36 100644 --- a/src/ge/graph/passes/variable_op_pass.cc +++ b/src/ge/graph/passes/variable_op_pass.cc @@ -20,7 +20,6 @@ #include "common/formats/formats.h" #include "common/formats/utils/formats_trans_utils.h" -#include "framework/common/debug/ge_log.h" #include "graph/ge_context.h" #include "graph/graph.h" #include "graph/manager/graph_var_manager.h" @@ -115,7 +114,6 @@ bool IsTransSupport(const TransNodeInfo &trans_info) { } // namespace Status VariableOpPass::Run(ge::ComputeGraphPtr graph) { - GE_TIMESTAMP_START(VariableOpPass); if (graph == nullptr) { GELOGE(INTERNAL_ERROR, "Failed to run variable op pass, null graph"); return INTERNAL_ERROR; @@ -190,9 +188,15 @@ Status VariableOpPass::Run(ge::ComputeGraphPtr graph) { if (UpdateIOFormatInfo(end_iter->output, node_set) != SUCCESS) { return GE_GRAPH_VARIABLE_OP_PASS_FAILED; } + + // renew var desc if the trans_road is all reshape or reformat + ret = RenewVarDesc(graph->GetSessionID(), node, fusion_road); + if (ret != SUCCESS) { + GELOGE(FAILED, "var manager renew var[%s] descriptor failed!", node->GetName().c_str()); + return FAILED; + } } - GE_TIMESTAMP_END(VariableOpPass, "GraphManager::VariableOpPass"); return SUCCESS; } @@ -604,4 +608,28 @@ Status VariableOpPass::RenewVarDesc(ge::ComputeGraphPtr &graph) { } return SUCCESS; } + +Status VariableOpPass::RenewVarDesc(uint64_t session_id, const NodePtr &node, const VarTransRoad &fusion_road) { + // renew var desc if the trans_road is all reshape or reformat + for (auto &road : fusion_road) { + if (road.node_type != RESHAPE && road.node_type != REFORMAT) { + return SUCCESS; + } + } + + if (!ge::VarManager::Instance(session_id)->IsVarExist(node->GetName())) { + GELOGD("var manager does not exist var node[%s]", node->GetName().c_str()); + return SUCCESS; + } + GELOGD("var manager exist var node[%s]", node->GetName().c_str()); + GE_CHECK_NOTNULL(node->GetOpDesc()); + Status ret = ge::VarManager::Instance(session_id)->RenewCurVarDesc(node->GetName(), node->GetOpDesc()); + if (ret != SUCCESS) { + GELOGE(FAILED, "var manager renew var[%s] descriptor failed!", node->GetName().c_str()); + return FAILED; + } + + return SUCCESS; +} + } // namespace ge diff --git a/src/ge/graph/passes/variable_op_pass.h b/src/ge/graph/passes/variable_op_pass.h index 4e194a0c..e17980e9 100644 --- a/src/ge/graph/passes/variable_op_pass.h +++ b/src/ge/graph/passes/variable_op_pass.h @@ -66,6 +66,7 @@ class VariableOpPass : public GraphPass { Status UpdateIOFormatInfo(const GeTensorDesc &final_output, std::set &nodes); Status RenewVarDesc(ge::ComputeGraphPtr &graph); + Status RenewVarDesc(uint64_t session_id, const NodePtr &node, const VarTransRoad &fusion_road); std::map> var_and_var_ref_map_; diff --git a/src/ge/graph/passes/variable_prepare_op_pass.cc b/src/ge/graph/passes/variable_prepare_op_pass.cc index 4db78a46..d93e1003 100644 --- a/src/ge/graph/passes/variable_prepare_op_pass.cc +++ b/src/ge/graph/passes/variable_prepare_op_pass.cc @@ -30,6 +30,7 @@ namespace ge { std::map> VariablePrepareOpPass::ref_node_without_prototype_map_{ {REFSWITCH, {{0, 0}, {0, 1}}}}; + Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); for (const auto &node : graph->GetDirectNode()) { @@ -62,7 +63,6 @@ Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) { GELOGI("{ %d : %d }", index_iter->first, index_iter->second); } } - return SUCCESS; } @@ -73,10 +73,13 @@ Status VariablePrepareOpPass::DealVariableNode(NodePtr &var_node) { GE_CHECK_NOTNULL(dst_node); InDataAnchorPtr dst_in_data_anchor = dst_node_and_inanchor.second; GE_CHECK_NOTNULL(dst_in_data_anchor); - int out_index = GetWritableNodeOutIndex(dst_node, dst_in_data_anchor->GetIdx()); + auto input_index = dst_in_data_anchor->GetIdx(); + int out_index = GetWritableNodeOutIndex(dst_node, input_index); if (out_index >= 0) { - Status ret = DealWritableNode(dst_node, var_node, out_index); + Status ret = DealWritableNode(dst_node, input_index, var_node); if (ret != SUCCESS) { + GELOGE(FAILED, "Deal writable node[%s] failed, input index: %d, var: %s.", dst_node->GetName().c_str(), + input_index, var_node->GetName().c_str()); return FAILED; } } @@ -84,84 +87,97 @@ Status VariablePrepareOpPass::DealVariableNode(NodePtr &var_node) { return SUCCESS; } -Status VariablePrepareOpPass::DealWritableNode(ge::NodePtr &writable_node, ge::NodePtr &var_node, int out_index) { - GE_CHECK_NOTNULL(writable_node); - GE_CHECK_NOTNULL(var_node); - NodePtr final_writable_node = writable_node; - bool is_have_peer_node = false; - for (auto &dst_node_and_inanchor : writable_node->GetOutDataNodesAndAnchors()) { - NodePtr dst_node = dst_node_and_inanchor.first; - GE_CHECK_NOTNULL(dst_node); - InDataAnchorPtr dst_in_data_anchor = dst_node_and_inanchor.second; - GE_CHECK_NOTNULL(dst_in_data_anchor); - is_have_peer_node = true; - int current_out_index = GetWritableNodeOutIndex(dst_node, dst_in_data_anchor->GetIdx()); - if (current_out_index >= 0) { - final_writable_node = GetFinalWritableNode(dst_node, current_out_index); - out_index = current_out_index; - } - - GE_CHECK_NOTNULL(final_writable_node); - Status ret = AddVariableRef(final_writable_node, var_node, out_index); - if (ret != SUCCESS) { - GELOGE(FAILED, "add variable ref failed"); - return FAILED; +Status VariablePrepareOpPass::DealWritableNode(const ge::NodePtr &writable_node, int input_index, + const ge::NodePtr &var_node) { + // Find the last ref node: + // If the ref input has corresponding output, add variable ref after it. + // If the ref input has no corresponding output, insert RefIdentity and variable ref before it. + // If ref node with control output was found while finding the last ref node, add variable ref after it. + std::stack> nodes_to_check; + nodes_to_check.push({writable_node, input_index}); + while (!nodes_to_check.empty()) { + auto node_index = nodes_to_check.top(); + nodes_to_check.pop(); + auto cur_node = node_index.first; + int cur_input_index = node_index.second; + // Collect ref node after cur node + const auto nodes_size = nodes_to_check.size(); + // Add peer ref output node of current node to stack + CHECK_FALSE_EXEC(GetPeerNodeOfRefInput(cur_node, cur_input_index, nodes_to_check) == SUCCESS, + GELOGE(FAILED, "GetPeerNodeOfRefInput for node[%s] failed.", cur_node->GetName().c_str()); + return FAILED); + auto output_index = GetWritableNodeOutIndex(cur_node, cur_input_index); + CHECK_FALSE_EXEC(output_index >= 0, + GELOGE(FAILED, "Get writable node[%s] ref input[%d]'s corresponding out index failed: %d.", + cur_node->GetName().c_str(), cur_input_index, output_index); + return FAILED); + if (nodes_size == nodes_to_check.size()) { + const auto &op_desc = cur_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + // No need to add variable_ref for frameworkop + if (op_desc->GetType() == FRAMEWORKOP) { + GELOGD("No need to add variable_ref for frameworkop"); + continue; + } + if (static_cast(output_index) < op_desc->GetOutputsSize()) { + // Add variable ref node after ref output for final ref node + CHECK_FALSE_EXEC(AddVariableRef(cur_node, var_node, output_index) == SUCCESS, + GELOGE(FAILED, "Add variable ref failed"); + return FAILED); + } else { + // Insert variable ref node before ref input without corresponding ref output + CHECK_FALSE_EXEC(InsertVariableRef(cur_node, cur_input_index, var_node) == SUCCESS, + GELOGE(FAILED, "Insert variable ref and ref identity failed"); + return FAILED); + } + continue; } - } - if (final_writable_node->GetName() == writable_node->GetName() && !is_have_peer_node) { - Status ret = AddVariableRef(final_writable_node, var_node, out_index); - if (ret != SUCCESS) { - return FAILED; + if (HasControlOut(cur_node)) { + // Add variable ref node after ref output for ref node has control output. + CHECK_FALSE_EXEC(AddVariableRef(cur_node, var_node, output_index) == SUCCESS, + GELOGE(FAILED, "Add variable ref failed"); + return FAILED); } } return SUCCESS; } -NodePtr VariablePrepareOpPass::GetFinalWritableNode(ge::NodePtr &writable_node, int &out_index) { - NodePtr current_node = writable_node; - std::unordered_set seen_node; - while (true) { - if (seen_node.count(current_node.get())) { - GELOGE(FAILED, "There is a ring structure in the graph"); - return nullptr; - } - seen_node.insert(current_node.get()); - OutDataAnchorPtr out_anchor = current_node->GetOutDataAnchor(out_index); - if (out_anchor == nullptr) { - GELOGE(FAILED, "Failed to get data anchor by index %d", out_index); - return nullptr; - } - bool found_writeable_node = false; - auto peer_in_anchors = out_anchor->GetPeerInDataAnchors(); - for (auto &peer_in_anchor : peer_in_anchors) { - if (peer_in_anchor == nullptr) { - GELOGE(FAILED, "peer in data anchor is nullptr, node %s:%s", current_node->GetType().c_str(), - current_node->GetName().c_str()); - continue; - } - - NodePtr peer_node = peer_in_anchor->GetOwnerNode(); - int current_out_index = GetWritableNodeOutIndex(peer_node, peer_in_anchor->GetIdx()); - if (current_out_index >= 0) { - current_node = peer_node; - out_index = current_out_index; - found_writeable_node = true; - break; - } +Status VariablePrepareOpPass::GetPeerNodeOfRefInput(const ge::NodePtr &node, int input_index, + std::stack> &nodes) { + auto output_index = GetWritableNodeOutIndex(node, input_index); + if (output_index == -1) { + GELOGE(PARAM_INVALID, "Node[%s] is not a ref node.", node->GetName().c_str()); + return PARAM_INVALID; + } + const auto &op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (static_cast(output_index) == op_desc->GetOutputsSize()) { + return SUCCESS; + } + if (output_index >= static_cast(node->GetAllOutDataAnchorsSize())) { + GELOGW("Can not get %d th output anchor of %s", output_index, node->GetName().c_str()); + return SUCCESS; + } + const auto &out_anchor = node->GetOutDataAnchor(output_index); + GE_CHECK_NOTNULL(out_anchor); + for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { + auto peer_node = peer_in_anchor->GetOwnerNode(); + if (peer_node == nullptr) { + continue; } - if (!found_writeable_node) { - GELOGD("final writable node is %s", current_node->GetName().c_str()); - return current_node; + const int peer_in_index = peer_in_anchor->GetIdx(); + if (GetWritableNodeOutIndex(peer_node, peer_in_index) != -1) { + nodes.push({peer_node, peer_in_index}); } } + return SUCCESS; } -Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, ge::NodePtr &var_node, int index) { +Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, const ge::NodePtr &var_node, int index) { GE_CHECK_NOTNULL(final_writable_node); GE_CHECK_NOTNULL(var_node); - - if (final_writable_node->GetType() == FRAMEWORKOP) { - GELOGD("No need to add variable_ref for frameworkop"); + if (index >= static_cast(final_writable_node->GetAllOutDataAnchorsSize())) { + GELOGW("Can not get %d th output anchor of %s", index, final_writable_node->GetName().c_str()); return SUCCESS; } // Check for duplicate creation @@ -181,7 +197,8 @@ Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, g // creat variable_ref std::stringstream variable_ref_name; variable_ref_name << "_TO_" << final_writable_node->GetName() << "_REF_" << index; - NodePtr variable_ref_node = CreatVariableRef(var_node->GetName() + variable_ref_name.str(), var_node); + NodePtr variable_ref_node = CreateVariableRef(var_node->GetName() + variable_ref_name.str(), var_node); + GE_CHECK_NOTNULL(variable_ref_node); Status ret_check = CheckStreamLabel(variable_ref_node, final_writable_node); if (ret_check != SUCCESS) { GELOGE(FAILED, "check stream lable failed"); @@ -189,23 +206,12 @@ Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, g } GELOGI("Add variable_ref between [%s] and [%s]", var_node->GetName().c_str(), variable_ref_node->GetName().c_str()); - GE_CHECK_NOTNULL(variable_ref_node); - // add control anchor between variable_ref and final peer node + // add control anchor between variable_ref and final peer node // variable_ref_node need to execute before other nodes - auto final_writable_outAnchors = final_writable_node->GetAllOutAnchors(); - for (auto &final_writable_outAnchor : final_writable_outAnchors) { - GE_CHECK_NOTNULL(final_writable_outAnchor); - for (auto &final_writable_peerAnchor : final_writable_outAnchor->GetPeerAnchors()) { - GE_CHECK_NOTNULL(final_writable_peerAnchor); - NodePtr peer_node = final_writable_peerAnchor->GetOwnerNode(); - graphStatus ret = - ge::GraphUtils::AddEdge(variable_ref_node->GetOutControlAnchor(), peer_node->GetInControlAnchor()); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "add control anchor between variable_ref and final_writable peer node failed"); - return FAILED; - } - } - } + CHECK_FALSE_EXEC(AddControlEdge(final_writable_node, variable_ref_node) == SUCCESS, + GELOGE(FAILED, "Add control edges between variable ref node and output nodes of ref node failed"); + return FAILED); + graphStatus ret = ge::GraphUtils::AddEdge(out_anchor, variable_ref_node->GetInDataAnchor(0)); if (ret != GRAPH_SUCCESS) { GELOGE(FAILED, "add data anchor between variable_ref and final_writable peer node failed"); @@ -214,7 +220,110 @@ Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, g return SUCCESS; } -ge::NodePtr VariablePrepareOpPass::CreatVariableRef(const std::string &variable_ref_name, ge::NodePtr &var_node) { +Status VariablePrepareOpPass::InsertVariableRef(ge::NodePtr &node, int in_index, const ge::NodePtr &var_node) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(var_node); + // Check connection between two nodes + const auto in_anchor = node->GetInDataAnchor(in_index); + GE_CHECK_NOTNULL(in_anchor); + auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + auto peer_in_node = peer_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(peer_in_node); + + // Create ref_identity + std::stringstream ref_identity_name; + ref_identity_name << "RefIdentity_" << peer_in_node->GetName() << "_" << peer_out_anchor->GetIdx() << "_TO_" + << node->GetName() << "_" << in_index; + NodePtr ref_identity_node = CreateRefIdentity(ref_identity_name.str(), node, static_cast(in_index)); + GE_CHECK_NOTNULL(ref_identity_node); + + // Create variable_ref + std::stringstream variable_ref_name; + variable_ref_name << "_TO_" << node->GetName() << "_REF_" << in_index; + NodePtr variable_ref_node = CreateVariableRef(var_node->GetName() + variable_ref_name.str(), var_node); + GE_CHECK_NOTNULL(variable_ref_node); + Status ret_check = CheckStreamLabel(variable_ref_node, node); + if (ret_check != SUCCESS) { + GELOGE(FAILED, "check stream lable failed"); + return FAILED; + } + + GELOGI("Insert variable_ref of [%s] between [%s] and [%s]", var_node->GetName().c_str(), + peer_in_node->GetName().c_str(), node->GetName().c_str()); + // add control anchor between variable_ref and node + // variable_ref_node need to execute before other nodes + CHECK_FALSE_EXEC(AddControlEdge(node, variable_ref_node) == SUCCESS, + GELOGE(FAILED, "Add control edges between variable ref node and output nodes of ref node failed"); + return FAILED); + + // Insert variable ref node between two nodes and remove the original edge. + CHECK_FALSE_EXEC(ge::GraphUtils::RemoveEdge(peer_out_anchor, in_anchor) == SUCCESS, + GELOGE(FAILED, "Remove edge between ref node and its peer node failed"); + return FAILED); + CHECK_FALSE_EXEC(ge::GraphUtils::AddEdge(peer_out_anchor, ref_identity_node->GetInDataAnchor(0)) == SUCCESS, + GELOGE(FAILED, "Add data edge between pre node and ref_identity failed"); + return FAILED); + CHECK_FALSE_EXEC(ge::GraphUtils::AddEdge(ref_identity_node->GetOutDataAnchor(0), in_anchor) == SUCCESS, + GELOGE(FAILED, "Add data edge between ref_identity and ref node failed"); + return FAILED); + + // Add edge from ref identity node to variable ref node. + CHECK_FALSE_EXEC( + ge::GraphUtils::AddEdge(ref_identity_node->GetOutDataAnchor(0), variable_ref_node->GetInDataAnchor(0)) == SUCCESS, + GELOGE(FAILED, "Add data edge between ref_identity and variable_ref failed"); + return FAILED); + CHECK_FALSE_EXEC( + ge::GraphUtils::AddEdge(node->GetOutControlAnchor(), variable_ref_node->GetInControlAnchor()) == SUCCESS, + GELOGE(FAILED, "Add control edge between ref_identity and variable_ref failed"); + return FAILED); + return SUCCESS; +} + +Status VariablePrepareOpPass::AddControlEdge(const ge::NodePtr &node, const ge::NodePtr &variable_ref_node) { + auto out_anchors = node->GetAllOutAnchors(); + for (auto &out_anchor : out_anchors) { + GE_CHECK_NOTNULL(out_anchor); + for (auto &peer_in_anchor : out_anchor->GetPeerAnchors()) { + GE_CHECK_NOTNULL(peer_in_anchor); + NodePtr peer_node = peer_in_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(peer_node); + CHECK_FALSE_EXEC( + ge::GraphUtils::AddEdge(variable_ref_node->GetOutControlAnchor(), peer_node->GetInControlAnchor()) == SUCCESS, + GELOGE(FAILED, "Add control edge between variable_ref and ref node's peer node failed"); + return FAILED); + } + } + return SUCCESS; +} + +ge::NodePtr VariablePrepareOpPass::CreateRefIdentity(const std::string &ref_identity_name, const ge::NodePtr &node, + uint32_t input_index) { + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(FAILED, "opdesc is nullptr"); + return nullptr; + } + + OpDescPtr ref_identity_op_desc = MakeShared(ref_identity_name.c_str(), REFIDENTITY); + if (ref_identity_op_desc == nullptr) { + GELOGE(FAILED, "ref_identity op desc is nullptr"); + return nullptr; + } + + GE_IF_BOOL_EXEC(ref_identity_op_desc->AddOutputDesc(op_desc->GetInputDesc(input_index)) != SUCCESS, + GELOGW("add output desc edge failed"); + return nullptr); + GE_IF_BOOL_EXEC(ref_identity_op_desc->AddInputDesc(op_desc->GetInputDesc(input_index)) != SUCCESS, + GELOGW("add input desc edge failed"); + return nullptr); + NodePtr ref_identity_node = node->GetOwnerComputeGraph()->AddNode(ref_identity_op_desc); + GE_IF_BOOL_EXEC(ref_identity_node == nullptr, GELOGW("ref_identity_node is null"); return nullptr); + return ref_identity_node; +} + +ge::NodePtr VariablePrepareOpPass::CreateVariableRef(const std::string &variable_ref_name, + const ge::NodePtr &var_node) { OpDescPtr var_op_desc = var_node->GetOpDesc(); if (var_op_desc == nullptr) { GELOGE(FAILED, "get var opdesc is nullptr"); @@ -250,7 +359,6 @@ int VariablePrepareOpPass::GetWritableNodeOutIndex(const NodePtr &node, int inpu } GELOGD("get writable node and input index %s:%d", node->GetName().c_str(), input_index); auto node_type = node->GetType(); - if (node_type == FRAMEWORKOP) { std::string original_type; GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, GELOGW("Get node original type fail")); @@ -266,25 +374,17 @@ void VariablePrepareOpPass::GenerateRefTypeAndInputOutputMap(const NodePtr &node GELOGW("op_desc in null, please check node:[%s]", node->GetName().c_str()); return; } - for (const auto &out_ancohor : node->GetAllOutDataAnchors()) { - int output_index = out_ancohor->GetIdx(); - string output_name = op_desc->GetOutputNameByIndex(output_index); - GELOGD("output name:[%s]", output_name.c_str()); - - int input_index = op_desc->GetInputIndexByName(output_name); - if (input_index == -1) { + for (const auto &name_index : op_desc->GetAllInputName()) { + // Record the index of output with the same name as input, thinking of them as a pair of ref input and output. + const int out_index = op_desc->GetOutputIndexByName(name_index.first); + if (out_index != -1) { + ref_input_output_map_[node->GetType()][name_index.second] = out_index; continue; } - auto ref_type_and_input_output_iter = ref_input_output_map_.find(node->GetType()); - if (ref_type_and_input_output_iter != ref_input_output_map_.end()) { - auto &input_output_index_map = ref_type_and_input_output_iter->second; - if (input_output_index_map.find(input_index) == input_output_index_map.end()) { - input_output_index_map.emplace(input_index, output_index); - GELOGD("Add RefInputOutputMap %s:{ %d, %d }", node->GetType().c_str(), input_index, output_index); - } - } else { - ref_input_output_map_.insert({node->GetType(), {{input_index, output_index}}}); - GELOGD("Create RefInputOutputMap { %s:{ %d, %d } }", node->GetType().c_str(), input_index, output_index); + // Record the ref input without corresponding output. + const auto &input_desc = op_desc->GetInputDesc(name_index.second); + if (!input_desc.GetRefPortIndex().empty()) { + ref_input_output_map_[node->GetType()][name_index.second] = static_cast(op_desc->GetOutputsSize()); } } } @@ -317,4 +417,15 @@ Status VariablePrepareOpPass::CheckStreamLabel(const ge::NodePtr &var_ref_node, } return SUCCESS; } + +bool VariablePrepareOpPass::HasControlOut(const ge::NodePtr &node) { + const auto &out_control_anchor = node->GetOutControlAnchor(); + for (const auto &peer_in_control_anchor : out_control_anchor->GetPeerInControlAnchors()) { + if (peer_in_control_anchor == nullptr || peer_in_control_anchor->GetOwnerNode() == nullptr) { + continue; + } + return true; + } + return false; +} } // namespace ge diff --git a/src/ge/graph/passes/variable_prepare_op_pass.h b/src/ge/graph/passes/variable_prepare_op_pass.h index c8b9883e..f024a464 100644 --- a/src/ge/graph/passes/variable_prepare_op_pass.h +++ b/src/ge/graph/passes/variable_prepare_op_pass.h @@ -18,6 +18,7 @@ #define GE_GRAPH_PASSES_VARIABLE_PREPARE_OP_PASS_H_ #include +#include #include #include "framework/common/ge_inner_error_codes.h" @@ -30,15 +31,19 @@ class VariablePrepareOpPass : public GraphPass { private: Status DealVariableNode(ge::NodePtr &node); - Status DealWritableNode(ge::NodePtr &writable_node, ge::NodePtr &var_node, int out_index); - NodePtr GetFinalWritableNode(ge::NodePtr &writable_node, int &out_index); - Status AddVariableRef(ge::NodePtr &node, ge::NodePtr &var_node, int index); - NodePtr CreatVariableRef(const std::string &variable_ref_name, ge::NodePtr &var_node); + Status DealWritableNode(const ge::NodePtr &writable_node, int input_index, const ge::NodePtr &var_node); + Status GetPeerNodeOfRefInput(const ge::NodePtr &node, int input_index, std::stack> &nodes); + Status AddVariableRef(ge::NodePtr &node, const ge::NodePtr &var_node, int index); + Status InsertVariableRef(ge::NodePtr &node, int in_index, const ge::NodePtr &var_node); + Status AddControlEdge(const ge::NodePtr &node, const ge::NodePtr &variable_ref_node); + NodePtr CreateVariableRef(const std::string &variable_ref_name, const ge::NodePtr &var_node); + NodePtr CreateRefIdentity(const std::string &ref_identity_name, const ge::NodePtr &node, uint32_t input_index); int GetWritableNodeOutIndex(const NodePtr &node, int input_index); void GenerateRefTypeAndInputOutputMap(const NodePtr &node); int FindRefOutIndex(const std::string &node_type, int input_index, const std::map> &ref_map); Status CheckStreamLabel(const ge::NodePtr &var_ref_node, const ge::NodePtr &final_writable_node); + bool HasControlOut(const ge::NodePtr &node); std::map> ref_input_output_map_; static std::map> ref_node_without_prototype_map_; diff --git a/src/ge/graph/passes/variable_ref_delete_op_pass.cc b/src/ge/graph/passes/variable_ref_delete_op_pass.cc index cd5b9fe9..32236814 100644 --- a/src/ge/graph/passes/variable_ref_delete_op_pass.cc +++ b/src/ge/graph/passes/variable_ref_delete_op_pass.cc @@ -16,18 +16,10 @@ #include "graph/passes/variable_ref_delete_op_pass.h" #include -#include "framework/common/debug/ge_log.h" namespace ge { Status VariableRefDeleteOpPass::Run(ge::ComputeGraphPtr graph) { - GE_TIMESTAMP_START(VariableRefDeleteOpPass); GE_CHECK_NOTNULL(graph); - - for (auto &node : graph->GetDirectNode()) { - GELOGD("before VariableRefDeleteOpPass, graph has node: %s, and node name: %s", node->GetType().c_str(), - node->GetName().c_str()); - } - for (auto &node : graph->GetDirectNode()) { GE_CHECK_NOTNULL(node->GetOpDesc()); std::string ref_var_src_var_name; @@ -42,13 +34,6 @@ Status VariableRefDeleteOpPass::Run(ge::ComputeGraphPtr graph) { return FAILED; } } - - for (auto &node : graph->GetDirectNode()) { - GELOGD("after VariableRefDeleteOpPass, graph has node: %s, and node name: %s", node->GetType().c_str(), - node->GetName().c_str()); - } - GE_TIMESTAMP_END(VariableRefDeleteOpPass, "GraphManager::VariableRefDeleteOpPass"); - return SUCCESS; } @@ -68,21 +53,21 @@ Status VariableRefDeleteOpPass::DealVariableRef(ge::ComputeGraphPtr &graph, ge:: // get previous node of variable_ref NodePtr peer_node = inAnchor0->GetPeerOutAnchor()->GetOwnerNode(); - // add attr [REF_VAR_SRC_VAR_NAME] to the previous node of the variable_ref - GE_CHECK_NOTNULL(peer_node->GetOpDesc()); - bool is_set_str = ge::AttrUtils::SetStr(peer_node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); - + // add attr [REF_VAR_SRC_VAR_NAME] to the previous op output desc of the variable_ref + auto op_desc = peer_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto out_desc = op_desc->GetOutputDesc(static_cast(index)); + bool is_set_str = ge::AttrUtils::SetStr(out_desc, REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); + (void)op_desc->UpdateOutputDesc(static_cast(index), out_desc); ge::NodePtr ref_var_src_var = GraphUtils::FindNodeFromAllNodes(graph, ref_var_src_var_name); if (ref_var_src_var == nullptr) { - GELOGE(FAILED, "get ref_var_src_var failed"); + GELOGE(FAILED, "Can not find source variable[%s] of variable ref[%s]", ref_var_src_var_name.c_str(), + variable_ref->GetName().c_str()); return FAILED; } - - GE_CHECK_NOTNULL(ref_var_src_var->GetOpDesc()); - bool is_set_index = ge::AttrUtils::SetInt(ref_var_src_var->GetOpDesc(), REF_VAR_PRE_PEER_OUT_INDEX, index); - if (is_set_str && is_set_index) { - GELOGI("[%s]: add attr [REF_VAR_SRC_VAR_NAME: %s ] ", peer_node->GetName().c_str(), ref_var_src_var_name.c_str()); - GELOGI("[%s]: add attr [REF_VAR_PRE_PEER_OUT_INDEX: %d]", ref_var_src_var->GetName().c_str(), index); + if (is_set_str) { + GELOGI("[%s-%d]: add attr [REF_VAR_SRC_VAR_NAME: %s ] ", peer_node->GetName().c_str(), index, + ref_var_src_var_name.c_str()); } // remove variable_ref diff --git a/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc b/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc index bd153184..1321cf20 100644 --- a/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc +++ b/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc @@ -17,7 +17,6 @@ #include "variable_ref_useless_control_out_delete_pass.h" namespace ge { - Status VariableRefUselessControlOutDeletePass::Run(ge::ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); for (const auto &node : graph->GetDirectNode()) { diff --git a/src/ge/graph/preprocess/graph_preprocess.cc b/src/ge/graph/preprocess/graph_preprocess.cc index 9c82a06d..94818698 100644 --- a/src/ge/graph/preprocess/graph_preprocess.cc +++ b/src/ge/graph/preprocess/graph_preprocess.cc @@ -19,9 +19,12 @@ #include #include #include +#include "common/formats/format_transfers/format_transfer_fractal_nz.h" +#include "common/formats/format_transfers/format_transfer_fractal_z.h" #include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" #include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h" #include "common/formats/format_transfers/format_transfer_transpose.h" +#include "common/formats/utils/formats_trans_utils.h" #include "common/helper/model_helper.h" #include "common/math/math_util.h" #include "common/op/ge_op_utils.h" @@ -80,7 +83,9 @@ #include "graph/passes/switch_dead_branch_elimination.h" #include "graph/passes/switch_fusion_pass.h" #include "graph/passes/switch_logic_remove_pass.h" -#include "graph/passes/switch_op_pass.h" +#include "graph/passes/merge_to_stream_merge_pass.h" +#include "graph/passes/switch_to_stream_switch_pass.h" +#include "graph/passes/attach_stream_label_pass.h" #include "graph/passes/switch_split_pass.h" #include "graph/passes/unused_const_pass.h" #include "graph/passes/unused_op_remove_pass.h" @@ -96,7 +101,6 @@ #include "runtime/dev.h" #include "graph/passes/dimension_adjust_pass.h" -#include "graph/passes/identify_reference_pass.h" #include "graph/passes/link_gen_mask_nodes_pass.h" #include "graph/passes/permute_pass.h" #include "graph/passes/reshape_remove_pass.h" @@ -134,14 +138,14 @@ OpDescPtr CreateTensorShape(const GeTensorDesc &data_tensor) { auto dim_cnt = static_cast(dst_ge_shape.GetDimNum()); if (dim_cnt == 0) { // if the dim_cnt is 0, the tensor is a scalar tensor->MutableTensorDesc().SetShape(GeShape()); - int64_t dst_shape = 1; - if (tensor->SetData(reinterpret_cast(&dst_shape), sizeof(int64_t)) != GRAPH_SUCCESS) { + int32_t dst_shape = 1; + if (tensor->SetData(reinterpret_cast(&dst_shape), sizeof(int32_t)) != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "tensor set data failed"); return nullptr; } } else { tensor->MutableTensorDesc().SetShape(GeShape(std::vector({dim_cnt}))); - unique_ptr dst_shape(new (std::nothrow) int64_t[dim_cnt]()); + unique_ptr dst_shape(new (std::nothrow) int32_t[dim_cnt]()); if (dst_shape == nullptr) { GELOGE(INTERNAL_ERROR, "Create unique ptr failed"); return nullptr; @@ -151,7 +155,7 @@ OpDescPtr CreateTensorShape(const GeTensorDesc &data_tensor) { } GE_IF_BOOL_EXEC( - tensor->SetData(reinterpret_cast(dst_shape.get()), dim_cnt * sizeof(int64_t)) != GRAPH_SUCCESS, + tensor->SetData(reinterpret_cast(dst_shape.get()), dim_cnt * sizeof(int32_t)) != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "tensor set data failed"); return nullptr;) } @@ -648,7 +652,39 @@ Status ModifyFormatAndShapeForSingleTensor(const GeTensorDescPtr &input_output) input_output->SetShape(ge::GeShape(dst_shape_dims)); return SUCCESS; } +Status ModifyDataNetOutputFormatAndShape(OpDescPtr &op_desc, uint32_t index, Format storage_format, + vector &dst_shape_dims) { + GE_CHECK_NOTNULL(op_desc); + const GeTensorDescPtr &input = op_desc->MutableInputDesc(index); + GE_CHECK_NOTNULL(input); + ge::Format old_format = input->GetFormat(); + std::vector old_shape = input->GetShape().GetDims(); + + input->SetShape(ge::GeShape(dst_shape_dims)); + input->SetFormat(storage_format); + auto output = op_desc->MutableOutputDesc(index); + GE_CHECK_NOTNULL(output); + output->SetShape(ge::GeShape(dst_shape_dims)); + output->SetFormat(storage_format); + + int64_t size = 0; + graphStatus graph_status = TensorUtils::GetTensorMemorySizeInBytes(*output, size); + if (graph_status != ge::GRAPH_SUCCESS) { + GELOGE(graph_status, "GetTensorSizeInBytes failed!"); + return FAILED; + } + ge::TensorUtils::SetSize(*input, size); + ge::TensorUtils::SetSize(*output, size); + + GELOGI( + "Modify Data NetOutput format and shape success, node:%s, index:%d, old_shape:%s, old_Format:%s, " + "new_shape:%s, new_format:%s, new_size:%u", + op_desc->GetName().c_str(), index, formats::JoinToString(old_shape).c_str(), + ge::TypeUtils::FormatToSerialString(old_format).c_str(), formats::JoinToString(dst_shape_dims).c_str(), + ge::TypeUtils::FormatToSerialString(storage_format).c_str(), size); + return SUCCESS; +} Status ProcessInputNC1HWC0(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr &switchn_node) { GE_CHECK_NOTNULL(node_ptr); auto op_desc = node_ptr->GetOpDesc(); @@ -1054,7 +1090,6 @@ Status ProcessInputFP16DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodeP return SUCCESS; } input->SetDataType(DT_FLOAT16); - input->SetOriginDataType(DT_FLOAT16); int64_t input_shape_size = 0; int64_t output_shape_size = 0; ge::graphStatus input_graph_status = ge::TensorUtils::GetTensorSizeInBytes(*input, input_shape_size); @@ -1067,7 +1102,6 @@ Status ProcessInputFP16DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodeP const GeTensorDescPtr &output = op_desc->MutableOutputDesc(0); GE_CHECK_NOTNULL(output); output->SetDataType(DT_FLOAT16); - output->SetOriginDataType(DT_FLOAT16); ge::TensorUtils::SetSize(*output, output_shape_size); if (is_dynamic_batch) { GELOGI("The node [%s] dtype set fp16", switchn_node->GetName().c_str()); @@ -1076,12 +1110,10 @@ Status ProcessInputFP16DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodeP auto switchn_input = switchn_op_desc->MutableInputDesc(0); GE_CHECK_NOTNULL(switchn_input); switchn_input->SetDataType(DT_FLOAT16); - switchn_input->SetOriginDataType(DT_FLOAT16); for (uint32_t i = 0; i < switchn_node->GetAllOutDataAnchorsSize(); ++i) { const GeTensorDescPtr &switchn_output = switchn_op_desc->MutableOutputDesc(i); GE_CHECK_NOTNULL(switchn_output); switchn_output->SetDataType(DT_FLOAT16); - switchn_output->SetOriginDataType(DT_FLOAT16); } } return SUCCESS; @@ -1100,10 +1132,6 @@ Status ProcessInputNC1HWC0DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, No GELOGE(INTERNAL_ERROR, "The format [%s] is unsupported", TypeUtils::FormatToSerialString(old_format).c_str()); return FAILED; } - if (old_format == FORMAT_NC1HWC0) { - GELOGI("No need to transfer format"); - return SUCCESS; - } if (ModifyInputFormatAndShape(node_ptr) != SUCCESS) { GELOGE(INTERNAL_ERROR, "modify format and shape failed"); return FAILED; @@ -1139,7 +1167,7 @@ Status ProcessDataNodeDynShape(NodePtr &node_ptr) { } for (auto const &next_node : node_ptr->GetOutNodes()) { if (next_node->GetType() == AIPP) { - ErrorManager::GetInstance().ATCReportErrMessage("E10049", {"opname"}, {node_ptr->GetName()}); + ErrorManager::GetInstance().ATCReportErrMessage("E10034", {"opname"}, {node_ptr->GetName()}); GELOGE(INTERNAL_ERROR, "This input op [%s] is linked to aipp, can not be set to fp16, " "please check your atc parameter --insert_op_conf, --input_fp16_nodes.", @@ -1171,6 +1199,42 @@ Status ProcessDataNodeDynShape(NodePtr &node_ptr) { return SUCCESS; } +Status GetStorageFormatAndShape(OpDescPtr &op_desc, const GeTensorDescPtr &tensor_desc_ptr, Format &storage_format, + vector &dst_shape_dims) { + GE_CHECK_NOTNULL(op_desc); + GE_CHECK_NOTNULL(tensor_desc_ptr); + + storage_format = FORMAT_RESERVED; + int64_t format = FORMAT_RESERVED; + dst_shape_dims.clear(); + if (ge::AttrUtils::GetInt(*tensor_desc_ptr, ATTR_NAME_STORAGE_FORMAT, format)) { + storage_format = static_cast(format); + vector storage_shape; + if (ge::AttrUtils::GetListInt(*tensor_desc_ptr, ATTR_NAME_STORAGE_SHAPE, storage_shape)) { + for (auto dim : storage_shape) { + dst_shape_dims.push_back(static_cast(dim)); + } + GELOGI("Update node by storage format, node: [%s], storage_format: [%s], storage_shape:[%s]", + op_desc->GetName().c_str(), TypeUtils::FormatToSerialString(storage_format).c_str(), + formats::JoinToString(storage_shape).c_str()); + } else { + GELOGE(PARAM_INVALID, + "Update node by storage format failed, storage_shape not set. " + "node: [%s], storage_format [%s]", + op_desc->GetName().c_str(), TypeUtils::FormatToSerialString(storage_format).c_str()); + return FAILED; + } + + ge::Format old_format = tensor_desc_ptr->GetFormat(); + auto old_shape = tensor_desc_ptr->GetShape().GetDims(); + if (old_format == storage_format && old_shape == dst_shape_dims) { + GELOGI("Update node by storage format, not changed."); + storage_format = FORMAT_RESERVED; + return SUCCESS; + } + } + return SUCCESS; +} Status ProcessNetoutputNodeFp16Nc1hwc0DynShape(GeTensorDesc &src_desc, GeTensorDescPtr &net_output_input_desc, NodePtr &node) { bool is_dynamic = CheckOpType(node, MERGE); @@ -1180,24 +1244,16 @@ Status ProcessNetoutputNodeFp16Nc1hwc0DynShape(GeTensorDesc &src_desc, GeTensorD ge::Format src_format = src_desc.GetFormat(); net_output_input_desc->SetDataType(DT_FLOAT16); - net_output_input_desc->SetOriginDataType(DT_FLOAT16); if (is_dynamic) { auto merge_output = src_op_desc->MutableOutputDesc(0); GE_CHECK_NOTNULL(merge_output); merge_output->SetDataType(DT_FLOAT16); - merge_output->SetOriginDataType(DT_FLOAT16); for (uint32_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { auto merge_input = src_op_desc->MutableInputDesc(i); GE_CHECK_NOTNULL(merge_input); merge_input->SetDataType(DT_FLOAT16); - merge_input->SetOriginDataType(DT_FLOAT16); } } - - if (src_format == FORMAT_NC1HWC0) { - GELOGI("Format is NC1HWC0, no need to transfer"); - return SUCCESS; - } std::vector dst_shape_dims; std::vector src_shape_dims = src_shape.GetDims(); if (TransferShape2NC1HWC0(src_format, src_shape_dims, DT_FLOAT16, FORMAT_NC1HWC0, dst_shape_dims) != SUCCESS) { @@ -1291,17 +1347,14 @@ Status ProcessNetoutputNodeDynShape(NodePtr &node, std::string &output_type) { if (NeedUpdateOutputByOutputTypeParm(output_type, src_node, src_index, output_data_type)) { GELOGI("Enter into process output_type schedule"); net_output_input_desc->SetDataType(output_data_type); - net_output_input_desc->SetOriginDataType(output_data_type); if (is_dynamic) { auto merge_output = src_op_desc->MutableOutputDesc(0); GE_CHECK_NOTNULL(merge_output); merge_output->SetDataType(output_data_type); - merge_output->SetOriginDataType(output_data_type); for (uint32_t i = 0; i < src_node->GetAllInDataAnchorsSize(); ++i) { auto merge_input = src_op_desc->MutableInputDesc(i); GE_CHECK_NOTNULL(merge_input); merge_input->SetDataType(output_data_type); - merge_input->SetOriginDataType(output_data_type); } } continue; @@ -1337,7 +1390,6 @@ Status ProcessNetoutputNodeDynShape(NodePtr &node, std::string &output_type) { } return SUCCESS; } - } // namespace GraphPrepare::GraphPrepare() : compute_graph_(nullptr) {} @@ -1431,6 +1483,8 @@ Status GraphPrepare::Init(const ge::Graph &graph, uint64_t session_id) { if (compute_graph_ != nullptr) { compute_graph_->SetSessionID(session_id); } + session_id_ = session_id; + Status ret = CheckGraph(); if (ret != SUCCESS) { GELOGE(ret, "RunGraph graph check fail, ret:%u", ret); @@ -1442,7 +1496,6 @@ Status GraphPrepare::Init(const ge::Graph &graph, uint64_t session_id) { GELOGE(ret, "RunGraph check ref op fail, ret:%u", ret); return ret; } - return SUCCESS; } @@ -1467,13 +1520,13 @@ Status GraphPrepare::CheckGraph() { } Status GraphPrepare::CheckRefInputNode(const NodePtr &node, const std::string &input_name, - const std::unordered_set &ref_nodes) { + const std::set &ref_nodes) { // Acceptable input types should be ref node, variable or Switch operator, which is issued by ME for dynamic - // lossscale and would be optimized in SwitchOpPass. Since ME dont differentiate between RefSwitch and Switch, - // and only issue Switch. - static std::unordered_set acceptable_types = {ge::VARIABLE, ge::VARIABLEV2, ge::VARHANDLEOP, - ge::REFSWITCH, ge::REFMERGE, ge::REFENTER, - ge::REFNEXTITERATION, ge::REFEXIT, ge::SWITCH}; + // lossscale and would be optimized in SwitchToStreamSwitchPass. + // Since ME dont differentiate between RefSwitch and Switch, and only issue Switch. + static std::set acceptable_types = {ge::VARIABLE, ge::VARIABLEV2, ge::VARHANDLEOP, + ge::REFSWITCH, ge::REFMERGE, ge::REFENTER, + ge::REFNEXTITERATION, ge::REFEXIT, ge::SWITCH}; GE_CHECK_NOTNULL(node); const auto &op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -1499,7 +1552,6 @@ Status GraphPrepare::CheckRefInputNode(const NodePtr &node, const std::string &i } } bool is_acceptable = (acceptable_types.find(input_type) != acceptable_types.end()); - if (!is_acceptable) { GELOGE(PARAM_INVALID, "The ref input of ref node %s[%s] must be ref node or variable, but %s[%s]isn't.", node->GetName().c_str(), node->GetType().c_str(), input_op_desc->GetName().c_str(), @@ -1512,7 +1564,7 @@ Status GraphPrepare::CheckRefInputNode(const NodePtr &node, const std::string &i Status GraphPrepare::CheckRefOp() { GE_CHECK_NOTNULL(compute_graph_); - std::unordered_set ref_nodes; + std::set ref_nodes; for (const NodePtr &node : compute_graph_->GetDirectNode()) { if (node == nullptr) { GELOGE(PARAM_INVALID, "param [node] must not be null."); @@ -1524,20 +1576,15 @@ Status GraphPrepare::CheckRefOp() { return PARAM_INVALID; } - auto input_names = op_desc->GetAllInputNames(); + auto input_name_index = op_desc->GetAllInputName(); auto outputs = op_desc->GetAllOutputName(); - std::unordered_set all_output_name; - - for (auto &output : outputs) { - all_output_name.insert(output.first); - } - for (const auto &input_name : input_names) { - if (all_output_name.find(input_name) != all_output_name.end()) { - if (CheckRefInputNode(node, input_name, ref_nodes) != SUCCESS) { + for (const auto &name_index : input_name_index) { + if (op_desc->GetOutputIndexByName(name_index.first) != -1) { + if (CheckRefInputNode(node, name_index.first, ref_nodes) != SUCCESS) { GELOGE(PARAM_INVALID, "CheckRefInputNode failed."); return PARAM_INVALID; } - (void)ref_nodes.insert(node); + (void)ref_nodes.insert(node); // no need to check value } } } @@ -1548,7 +1595,7 @@ Status GraphPrepare::SetRtContext(rtContext_t rt_context, rtCtxMode_t mode) { GELOGI("set rt_context %d, device id:%u.", static_cast(mode), ge::GetContext().DeviceId()); GE_CHK_RT_RET(rtCtxCreate(&rt_context, mode, ge::GetContext().DeviceId())); GE_CHK_RT_RET(rtCtxSetCurrent(rt_context)); - RtContextUtil::GetInstance().AddrtContext(rt_context); + RtContextUtil::GetInstance().AddRtContext(session_id_, rt_context); return SUCCESS; } @@ -1566,6 +1613,8 @@ Status GraphPrepare::AdjustDataOpOutput(const NodePtr &node) { int64_t tensor_size = 0; graphStatus graph_status = TensorUtils::GetTensorMemorySizeInBytes(output, tensor_size); if (graph_status != GRAPH_SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"GetTensorMemorySizeInBytes", "opname is " + node->GetName()}); GELOGE(graph_status, "GetTensorMemorySizeInBytes failed!"); return FAILED; } @@ -1599,12 +1648,16 @@ Status GraphPrepare::UpdateInput(const std::vector &user_input) { GeTensorDesc desc(user_input[index].GetTensorDesc()); auto format = desc.GetFormat(); auto origin_format = desc.GetOriginFormat(); - bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format); - bool need_check_internal_format = (!options_.is_single_op) && is_internal; + // data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM. + bool need_check_internal_format = (!IsTansDataOpData(input_node)) && (!options_.is_single_op); if (need_check_internal_format) { - GELOGE(PARAM_INVALID, "Input format %s or origin_format %s is not support.", - TypeUtils::FormatToSerialString(format).c_str(), TypeUtils::FormatToSerialString(origin_format).c_str()); - return FAILED; + bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format); + if (is_internal) { + GELOGE(PARAM_INVALID, "Input format %s or origin_format %s is not support.", + TypeUtils::FormatToSerialString(format).c_str(), + TypeUtils::FormatToSerialString(origin_format).c_str()); + return FAILED; + } } auto data_type = desc.GetDataType(); @@ -1623,7 +1676,8 @@ Status GraphPrepare::UpdateInput(const std::vector &user_input) { GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc, size) != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "TensorUtils GetSize failed"); return FAILED); - if ((size != 0) && (shape_size != size)) { + bool size_check = (size != 0 && shape_size != size); + if (size_check) { GELOGE(PARAM_INVALID, "input data size =%ld, shape_size =%ld.", size, shape_size); return FAILED; } @@ -1771,6 +1825,55 @@ Status GraphPrepare::OptimizeAfterInfershapeByAtcParams() { return SUCCESS; } +Status GraphPrepare::UpdateDataNetOutputByStorageFormat() { + for (auto &node_ptr : compute_graph_->GetAllNodes()) { + GE_CHECK_NOTNULL(node_ptr); + if (node_ptr->GetType() == DATA) { + uint32_t index = 0; + auto op_desc = node_ptr->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + const GeTensorDescPtr input = op_desc->MutableInputDesc(index); + Format storage_format = FORMAT_RESERVED; + vector dst_shape_dims; + if (GetStorageFormatAndShape(op_desc, input, storage_format, dst_shape_dims) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Get storage format for input failed"); + return FAILED; + } + + if (storage_format == FORMAT_RESERVED) { + continue; + } + + if (ModifyDataNetOutputFormatAndShape(op_desc, index, storage_format, dst_shape_dims) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Modify format and shape for inputfailed"); + return FAILED; + } + } + + if (node_ptr->GetType() == ge::NETOUTPUT) { + auto op_desc = node_ptr->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (uint32_t index = 0; index < op_desc->GetOutputsSize(); index++) { + const GeTensorDescPtr output = op_desc->MutableOutputDesc(index); + Format storage_format = FORMAT_RESERVED; + vector dst_shape_dims; + if (GetStorageFormatAndShape(op_desc, output, storage_format, dst_shape_dims) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Get storage format from output failed"); + return FAILED; + } + if (storage_format == FORMAT_RESERVED) { + continue; + } + if (ModifyDataNetOutputFormatAndShape(op_desc, index, storage_format, dst_shape_dims) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Modify format and shape for output failed"); + return FAILED; + } + } + } + } + return SUCCESS; +} + void GraphPrepare::ProcessCCEFormat() { static const char *const parser_priority = std::getenv("PARSER_PRIORITY"); static const bool keep_cce = parser_priority != nullptr && string(parser_priority) == "cce"; @@ -1955,9 +2058,7 @@ Status GraphPrepare::PrepareDynShape(ConstGraphPtr graph, const std::vector(options_.framework_type); const Graph &const_graph = *graph; @@ -1989,7 +2090,6 @@ Status GraphPrepare::PrepareRunningFormatRefiner() { PassManager pass_manager; GE_CHK_STATUS_RET(pass_manager.AddPass("PrepareRunningFormatRefiner::VariablePrepareOpPass", new (std::nothrow) VariablePrepareOpPass)) - GE_CHK_STATUS_RET(pass_manager.AddPass("PrepareRunningFormatRefiner::SubgraphPass", new (std::nothrow) SubgraphPass)) GE_TIMESTAMP_START(pass_manager); auto ret = pass_manager.Run(compute_graph); GE_TIMESTAMP_END(pass_manager, "GraphPrepare::PrepareRunningFormatRefiner"); @@ -2053,10 +2153,6 @@ Status GraphPrepare::GenerateInfershapeGraph(ConstGraphPtr graph) { Status GraphPrepare::Prepare(ConstGraphPtr graph, const std::vector &user_input, ge::ComputeGraphPtr &compute_graph, VarAccelerateCtrl &var_acc_ctrl, uint64_t session_id) { - // train graph flag - if (options_.train_graph_flag) { - domi::GetContext().train_flag = true; - } domi::GetContext().type = static_cast(options_.framework_type); if (graph == nullptr) { @@ -2071,7 +2167,7 @@ Status GraphPrepare::Prepare(ConstGraphPtr graph, const std::vector &u } GraphOptimize graph_optimize; - if (!domi::GetContext().train_flag) { + if (!options_.train_graph_flag && !domi::GetContext().train_flag) { GE_DUMP(compute_graph_, "BeforeOriginalGraphForQuantize"); GE_TIMESTAMP_START(OptimizeOriginalGraphForQuantize); ret = graph_optimize.OptimizeOriginalGraphForQuantize(compute_graph_); @@ -2302,10 +2398,10 @@ Status GraphPrepare::PrepareOptimize() { GEPass ge_passes(compute_graph_); NamesToPass names_to_passes; EnterPass enter_pass; - PrintOpPass print_pass; names_to_passes.emplace_back("EnterPass", &enter_pass); CondPass cond_pass; names_to_passes.emplace_back("CondPass", &cond_pass); + PrintOpPass print_pass; if (options_.enable_print_op_pass) { names_to_passes.emplace_back("PrintOpPass", &print_pass); } @@ -2478,7 +2574,9 @@ Status GraphPrepare::OptimizeForPreprocess() { (void)graph_pass.AddPass("OptimizeForPreprocess::PrunePass", new PrunePass); (void)graph_pass.AddPass("OptimizeForPreprocess::NextIterationPass", new NextIterationPass); (void)graph_pass.AddPass("OptimizeForPreprocess::ControlTriggerPass", new ControlTriggerPass); - (void)graph_pass.AddPass("OptimizeForPreprocess::SwitchOpPass", new SwitchOpPass); + (void)graph_pass.AddPass("OptimizeForPreprocess::MergeToStreamMergePass", new MergeToStreamMergePass); + (void)graph_pass.AddPass("OptimizeForPreprocess::SwitchToStreamSwitchPass", new SwitchToStreamSwitchPass); + (void)graph_pass.AddPass("OptimizeForPreprocess::AttachStreamLabelPass", new AttachStreamLabelPass); (void)graph_pass.AddPass("OptimizeForPreprocess::HcclMemcpyPass", new HcclMemcpyPass); GE_IF_BOOL_EXEC(options_.train_graph_flag, (void)graph_pass.AddPass("OptimizeForPreprocess::FlowCtrlPass", new FlowCtrlPass);); @@ -2560,8 +2658,6 @@ Status GraphPrepare::NewOptimizeGraphBeforeSubGraph(VarAccelerateCtrl &var_acc_c GEPass ge_passes_for_shape(compute_graph_); NamesToPass names_to_passes_for_shape; - IdentifyReferencePass identify_reference_pass; - names_to_passes_for_shape.emplace_back("IdentifyReferencePass", &identify_reference_pass); CastRemovePass cast_remove_pass; names_to_passes_for_shape.emplace_back("CastRemovePass", &cast_remove_pass); TransposeTransDataPass transpose_transdata_pass; @@ -2693,6 +2789,12 @@ Status GraphPrepare::CheckAndUpdateInput(const std::vector &user_input return SUCCESS; } Status GraphPrepare::UpdateInputOutputByOptions() { + auto ret = UpdateDataNetOutputByStorageFormat(); + if (ret != SUCCESS) { + GELOGE(ret, "Update format acoording to storage format failed."); + return ret; + } + if (options_.train_graph_flag) { GELOGI("This is train mode, no need to do this schedule."); return SUCCESS; @@ -2736,6 +2838,21 @@ bool GraphPrepare::IsBroadCastOpData(const ge::NodePtr &var_node) { return false; } +bool GraphPrepare::IsTansDataOpData(const ge::NodePtr &var_node) { + for (auto &out_anchor : var_node->GetAllOutDataAnchors()) { + GE_RT_FALSE_CHECK_NOTNULL(out_anchor); + for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { + GE_RT_FALSE_CHECK_NOTNULL(in_anchor); + ge::NodePtr dst_node = in_anchor->GetOwnerNode(); + GE_RT_FALSE_CHECK_NOTNULL(dst_node); + if (dst_node->GetType() == TRANSDATA) { + return true; + } + } + } + return false; +} + bool GraphPrepare::ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, const map> &confirm_ops, ge::NodePtr &use_node) { GE_RT_FALSE_CHECK_NOTNULL(in_anchor); diff --git a/src/ge/graph/preprocess/graph_preprocess.h b/src/ge/graph/preprocess/graph_preprocess.h index b90caa86..bae2a885 100644 --- a/src/ge/graph/preprocess/graph_preprocess.h +++ b/src/ge/graph/preprocess/graph_preprocess.h @@ -59,8 +59,7 @@ class GraphPrepare { Status Init(const ge::Graph &graph, uint64_t session_id = 0); Status Preprocess(const std::vector &user_input); Status CheckGraph(); - Status CheckRefInputNode(const NodePtr &node, const std::string &input_name, - const std::unordered_set &ref_nodes); + Status CheckRefInputNode(const NodePtr &node, const std::string &input_name, const std::set &ref_nodes); Status CheckRefOp(); Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode); Status AdjustDataOpOutput(const NodePtr &node); @@ -69,6 +68,7 @@ class GraphPrepare { Status CheckConstOp(); Status VerifyConstOp(const NodePtr &node); Status CheckUserInput(const std::vector &user_input); + Status UpdateDataNetOutputByStorageFormat(); Status OptimizeForPreprocess(); Status PrepareOptimize(); Status InferShapeForPreprocess(); @@ -88,6 +88,8 @@ class GraphPrepare { Status UpdateInputOutputByOptions(); bool IsBroadCastOpData(const ge::NodePtr &var_node); + bool IsTansDataOpData(const ge::NodePtr &var_node); + void AdjustBroadCastOpData(const ge::NodePtr &var_node); bool IsAssignOpData(const ge::NodePtr &var_node); @@ -104,6 +106,7 @@ class GraphPrepare { ge::ComputeGraphPtr compute_graph_; GraphManagerOptions options_; + uint64_t session_id_ = 0; }; } // namespace ge #endif // GE_GRAPH_PREPROCESS_GRAPH_PREPROCESS_H_ diff --git a/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc b/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc index 22128394..f35b6d3a 100644 --- a/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc +++ b/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc @@ -389,8 +389,8 @@ Status AippOp::SetDefaultParams() { GELOGI("parse aipp params:input_format:%s, csc_switch:%d.", domi::AippOpParams::InputFormat_Name(aipp_params_->input_format()).c_str(), aipp_params_->csc_switch()); - GELOGI("parse aipp params:mean_chn_0:%d, mean_chn_1:%d, mean_chn_2:%d.", aipp_params_->mean_chn_0(), - aipp_params_->mean_chn_1(), aipp_params_->mean_chn_2()); + GELOGI("parse aipp params:mean_chn_0:%d, mean_chn_1:%d, mean_chn_2:%d, mean_chn_3:%d.", aipp_params_->mean_chn_0(), + aipp_params_->mean_chn_1(), aipp_params_->mean_chn_2(), aipp_params_->mean_chn_3()); GELOGI("parse aipp params:min_chn_0:%f, min_chn_1:%f, min_chn_2:%f.", aipp_params_->min_chn_0(), aipp_params_->min_chn_1(), aipp_params_->min_chn_2()); diff --git a/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc b/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc index 2963cd5a..8bb0c6c4 100644 --- a/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc +++ b/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc @@ -45,7 +45,7 @@ static void ConvertShape2Nhwc(Format &format, vector &shape_vec) { return; } if (format != FORMAT_NCHW) { - GELOGW("The format is not NCHW, current format is %s", TypeUtils::FormatToSerialString(format).c_str()); + GELOGW("The format is not NCHW, current format is %s.", TypeUtils::FormatToSerialString(format).c_str()); return; } vector shape_vec_tmp; @@ -245,7 +245,6 @@ Status InsertNewOpUtil::UpdatePrevNodeByAipp(NodePtr &node, std::set &s GELOGE(FAILED, "Can not get size from aipp [%s]", aipp_op_desc->GetName().c_str()); return FAILED; } - // Save the input size of aipp node, which will be used in dumping aipp node or fused aipp node (void)AttrUtils::SetInt(aipp_input, ATTR_NAME_INPUT_ORIGIN_SIZE, size); auto in_data_anchor = node->GetInDataAnchor(0); @@ -324,7 +323,8 @@ Status InsertNewOpUtil::UpdateDataBySwitchN(const NodePtr &switchn, const NodePt auto data_opdesc = data->GetOpDesc(); GE_CHECK_NOTNULL(data_opdesc); - Format old_format = output_desc->GetFormat(); + Format old_format = data_opdesc->MutableOutputDesc(0)->GetFormat(); + auto ret = data_opdesc->UpdateOutputDesc(0, *input_desc); if (ret != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "Failed to update data %s output using switchn %s", data->GetName().c_str(), @@ -465,15 +465,18 @@ Status InsertNewOpUtil::RecordAIPPInfoToData(const ComputeGraphPtr &graph) { GetInputOutputInfo(data_node, aipp_it, input, output); input_dims.emplace_back(input); output_dims.emplace_back(output); + + // When static aipp is set, need to get the model input dims which processed by aipp + GE_RETURN_IF_ERROR(SetModelInputDims(data_node, aipp_it)); } if (!AttrUtils::SetListStr(data_node->GetOpDesc(), ATTR_NAME_AIPP_INPUTS, input_dims)) { - GELOGE(FAILED, "SetListInt of %s failed.", ATTR_NAME_AIPP_INPUTS.c_str()); + GELOGE(FAILED, "SetListStr of %s failed.", ATTR_NAME_AIPP_INPUTS.c_str()); return FAILED; } if (!AttrUtils::SetListStr(data_node->GetOpDesc(), ATTR_NAME_AIPP_OUTPUTS, output_dims)) { - GELOGE(FAILED, "SetListInt of %s failed.", ATTR_NAME_AIPP_OUTPUTS.c_str()); + GELOGE(FAILED, "SetListStr of %s failed.", ATTR_NAME_AIPP_OUTPUTS.c_str()); return FAILED; } } @@ -518,4 +521,41 @@ Status InsertNewOpUtil::GetInputOutputInfo(NodePtr &data_node, NodePtr &aipp_nod data_node->GetName().c_str(), aipp_node->GetName().c_str(), input.c_str(), output.c_str()); return SUCCESS; } + +Status InsertNewOpUtil::SetModelInputDims(NodePtr &data_node, NodePtr &aipp_node) { + GE_CHECK_NOTNULL(data_node); + GE_CHECK_NOTNULL(aipp_node); + OpDescPtr data_opdesc = data_node->GetOpDesc(); + GE_CHECK_NOTNULL(data_opdesc); + OpDescPtr aipp_opdesc = aipp_node->GetOpDesc(); + GE_CHECK_NOTNULL(aipp_opdesc); + + // In dynamic bacth/hw scenario, the new model input dims only need be set once + if (data_node->GetOpDesc()->HasAttr(ATTR_NAME_INPUT_DIMS)) { + GELOGD("Data %s already has attribute %s", data_node->GetOpDesc()->GetName().c_str(), ATTR_NAME_INPUT_DIMS.c_str()); + return SUCCESS; + } + vector model_input_dims; + vector origin_input_dims; + if (AttrUtils::GetListInt(aipp_opdesc, ATTR_NAME_INPUT_DIMS, model_input_dims) && !model_input_dims.empty()) { + // When dynamic bacth/hw is set, N or HW need to be set to -1 + if (AttrUtils::GetListInt(data_opdesc, ATTR_MBATCH_ORIGIN_INPUT_DIMS, origin_input_dims) && + !origin_input_dims.empty()) { + GELOGI("In dynamic bacth/hw scenario, N or HW need to be set to -1. model_input_dims: %s, origin_input_dims: %s", + formats::JoinToString(model_input_dims).c_str(), formats::JoinToString(origin_input_dims).c_str()); + for (size_t i = 0; i < origin_input_dims.size(); ++i) { + // N or HW need to be set to -1 + if (origin_input_dims[i] < 0) { + model_input_dims[i] = origin_input_dims[i]; + } + } + } + GELOGD("After set H/W to -1, the model input dims: %s.", formats::JoinToString(model_input_dims).c_str()); + if (!AttrUtils::SetListInt(data_opdesc, ATTR_NAME_INPUT_DIMS, model_input_dims)) { + GELOGE(FAILED, "SetListInt of %s failed.", ATTR_NAME_INPUT_DIMS.c_str()); + return FAILED; + } + } + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.h b/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.h index b39b3005..93a96ca2 100644 --- a/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.h +++ b/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.h @@ -67,6 +67,7 @@ class InsertNewOpUtil { Status GetDataRelatedNode(NodePtr &node, std::map> &data_next_node_map); Status GetAllAipps(const NodePtr &node, std::vector &aipps); Status GetInputOutputInfo(NodePtr &data_node, NodePtr &aipp_node, std::string &input, std::string &output); + Status SetModelInputDims(NodePtr &data_node, NodePtr &aipp_node); }; } // namespace ge diff --git a/src/ge/graph/preprocess/multi_batch_copy_graph.cc b/src/ge/graph/preprocess/multi_batch_copy_graph.cc index e063398f..d06a493d 100644 --- a/src/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/src/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -44,6 +44,7 @@ const int kSwitchNPredIndex = 1; const int kDataOutIndex = 0; const int kDataInIndex = 0; const int kMergeDataOutIndex = 0; +const int kStaticOutput = -1; const size_t kMaxShapesCount = 100; const size_t kMinShapesCount = 2; @@ -125,8 +126,12 @@ Status CalcShape(const std::vector &batch_shape, GeShape &data_shape) { for (size_t i = 0; i < data_shape.GetDimNum(); ++i) { if (data_shape.GetDim(i) < 0) { if (batch_shape_index >= batch_shape.size()) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19012", {"function", "reason"}, + {"CalcShape", "the batch shape count " + std::to_string(batch_shape.size()) + + " does not match the data shape " + data_shape.ToString()}); GELOGE(PARAM_INVALID, - "Failed to calc tensor shape, the batch shape count %zu, doees not match the data shape %s", + "Failed to calc tensor shape, the batch shape count %zu, does not match the data shape %s", batch_shape.size(), data_shape.ToString().c_str()); return PARAM_INVALID; } @@ -134,6 +139,10 @@ Status CalcShape(const std::vector &batch_shape, GeShape &data_shape) { } } if (batch_shape_index != batch_shape.size()) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19012", {"function", "reason"}, + {"CalcShape", "the batch shape count " + std::to_string(batch_shape.size()) + " does not match the data shape " + + data_shape.ToString()}); GELOGE(PARAM_INVALID, "Failed to calc tensor shape, the batch shape count %zu, does not match the data shape %s", batch_shape.size(), data_shape.ToString().c_str()); return PARAM_INVALID; @@ -198,7 +207,7 @@ Status CheckDataShape(const std::vector &nodes) { } } if (unknown_shape_count == 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E10055"); + ErrorManager::GetInstance().ATCReportErrMessage("E10040"); GELOGE(PARAM_INVALID, "Need unknow shape data when user set --dynamic_batch_size or --dynamic_image_size, please check."); return PARAM_INVALID; @@ -278,6 +287,8 @@ Status MultiBatchGraphCopyer::CreateNewNodes() { case kNodeOutBatchBranch: ret = InsertMergeForEdgeNode(node); break; + case kNodeNotSupportNode: + break; default: GELOGE(INTERNAL_ERROR, "Unexpected status %d on node %s", static_cast(branch_status), node->GetName().c_str()); @@ -290,7 +301,13 @@ Status MultiBatchGraphCopyer::CreateNewNodes() { } return SUCCESS; } + NodeStatus MultiBatchGraphCopyer::GetNodeStatus(const NodePtr &node) { + // node with subgraph is not supported + if (!(node->GetOpDesc()->GetSubgraphInstanceNames().empty())) { + return kNodeNotSupportNode; + } + if (node->GetType() == NETOUTPUT) { return kNodeOutBatchBranch; } @@ -304,6 +321,7 @@ NodeStatus MultiBatchGraphCopyer::GetNodeStatus(const NodePtr &node) { } return kNodeOutBatchBranch; } + NodePtr MultiBatchGraphCopyer::InsertMergeNode(const NodePtr &node, int index) { if (index < 0) { // the merge node must has data inputs, if origin connection is a control @@ -476,7 +494,7 @@ Status MultiBatchGraphCopyer::CheckArguments() { return PARAM_INVALID; } if (shapes_.size() < kMinShapesCount) { - ErrorManager::GetInstance().ATCReportErrMessage("E10050", {"shapesize", "minshapesize"}, + ErrorManager::GetInstance().ATCReportErrMessage("E10035", {"shapesize", "minshapesize"}, {std::to_string(shapes_.size()), std::to_string(kMinShapesCount)}); GELOGE(PARAM_INVALID, "Input parameter[--dynamic_batch_size or --dynamic_image_size]'s " @@ -485,7 +503,7 @@ Status MultiBatchGraphCopyer::CheckArguments() { return PARAM_INVALID; } if (shapes_.size() > kMaxShapesCount) { - ErrorManager::GetInstance().ATCReportErrMessage("E10051", {"shapesize", "maxshapesize"}, + ErrorManager::GetInstance().ATCReportErrMessage("E10036", {"shapesize", "maxshapesize"}, {std::to_string(shapes_.size()), std::to_string(kMaxShapesCount)}); GELOGE(PARAM_INVALID, "Input parameter[--dynamic_batch_size or --dynamic_image_size]'s " @@ -497,7 +515,7 @@ Status MultiBatchGraphCopyer::CheckArguments() { size_t shape_size = shapes_.at(0).size(); for (auto &shape : shapes_) { if (shape_size != shape.size()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10052", {"shapesize1", "shapesize2"}, + ErrorManager::GetInstance().ATCReportErrMessage("E10037", {"shapesize1", "shapesize2"}, {std::to_string(shape_size), std::to_string(shape.size())}); GELOGE(PARAM_INVALID, "Input parameter[--dynamic_batch_size or --dynamic_image_size]'s " @@ -507,7 +525,7 @@ Status MultiBatchGraphCopyer::CheckArguments() { } for (auto dim : shape) { if (dim <= 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E10053", {"dim"}, {std::to_string(dim)}); + ErrorManager::GetInstance().ATCReportErrMessage("E10038", {"dim"}, {std::to_string(dim)}); GELOGE(PARAM_INVALID, "Invalid dim %ld, all dims must be greater than 0", dim); return PARAM_INVALID; } @@ -515,7 +533,7 @@ Status MultiBatchGraphCopyer::CheckArguments() { shapes_set.insert(shape); } if (shapes_set.size() != shapes_.size()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10054"); + ErrorManager::GetInstance().ATCReportErrMessage("E10039"); GELOGE(PARAM_INVALID, "Input parameter[--dynamic_batch_size or --dynamic_image_size] exist duplicate shapes, please check"); return PARAM_INVALID; @@ -947,15 +965,18 @@ Status GetDynamicOutputShape(ComputeGraphPtr &graph) { GELOGE(PARAM_INVALID, "Graph is null ,para is invalid"); return PARAM_INVALID; } - for (auto &node : graph->GetAllNodes()) { + for (auto &node : graph->GetDirectNode()) { if (node->GetType() == NETOUTPUT) { auto netoutput_desc = node->GetOpDesc(); auto inputnode_to_netoutput = node->GetInAllNodes(); + std::vector dynamic_output_index; for (size_t j = 0; j < inputnode_to_netoutput.size(); j++) { bool ret = false; (void)AttrUtils::GetBool(inputnode_to_netoutput.at(j)->GetOpDesc(), ATTR_INSERT_BY_MBATCH, ret); if (inputnode_to_netoutput.at(j)->GetType() == MERGE && ret) { - GELOGI("Find the merge node %s with mbatch attr", inputnode_to_netoutput.at(j)->GetName().c_str()); + GELOGI("Find the merge node %s with mbatch attr and the index is %zu", + inputnode_to_netoutput.at(j)->GetName().c_str(), j); + dynamic_output_index.emplace_back(j); for (size_t i = 0; i < inputnode_to_netoutput.at(j)->GetInNodes().size(); i++) { auto input_desc = inputnode_to_netoutput.at(j)->GetOpDesc(); auto input_tensor_desc = input_desc->GetInputDesc(i); @@ -967,6 +988,17 @@ Status GetDynamicOutputShape(ComputeGraphPtr &graph) { } } if (dynamic_output_dims.size() > 0) { + for (size_t k = 0; k < inputnode_to_netoutput.size(); k++) { + auto it = std::find(dynamic_output_index.begin(), dynamic_output_index.end(), k); + if (it != dynamic_output_index.end()) { + continue; + } + auto tensor_desc = netoutput_desc->GetInputDesc(k); + auto shape = tensor_desc.GetShape().ToString(); + std::string static_output_shape = std::to_string(kStaticOutput) + "," + std::to_string(k) + "," + shape; + GELOGI("The static output shape msg is %s", static_output_shape.c_str()); + dynamic_output_dims.emplace_back(static_output_shape); + } if (!AttrUtils::SetListStr(netoutput_desc, ATTR_NAME_DYNAMIC_OUTPUT_DIMS, dynamic_output_dims)) { GELOGE(FAILED, "Set dynamic output dims attr failed"); return FAILED; diff --git a/src/ge/graph/preprocess/multi_batch_copy_graph.h b/src/ge/graph/preprocess/multi_batch_copy_graph.h index 2500645f..bf1d53b3 100644 --- a/src/ge/graph/preprocess/multi_batch_copy_graph.h +++ b/src/ge/graph/preprocess/multi_batch_copy_graph.h @@ -33,6 +33,7 @@ enum NodeStatus { kNodeInBatchBranch, kNodeOutBatchBranch, kNodeStartNode, + kNodeNotSupportNode, }; class MultiBatchGraphCopyer { diff --git a/src/ge/host_kernels/add_kernel.cc b/src/ge/host_kernels/add_kernel.cc index 6d6a049c..afef1c37 100644 --- a/src/ge/host_kernels/add_kernel.cc +++ b/src/ge/host_kernels/add_kernel.cc @@ -133,25 +133,24 @@ Status AddKernel::BCastAdd(const OpDescPtr &op_desc_ptr, const std::vector &input) { if (op_desc_ptr == nullptr) { - GELOGE(PARAM_INVALID, "Op_desc_ptr must not be null."); + GELOGW("Op_desc_ptr must not be null."); return PARAM_INVALID; } // check how many inputs if ((input.size() != kAddInputSize) || (op_desc_ptr->GetOutputsSize() != kAddOutputSize)) { - GELOGE(PARAM_INVALID, "The number of input for add must be %zu, output number must be %zu.", kAddInputSize, - kAddOutputSize); + GELOGW("The number of input for add must be %zu, output number must be %zu.", kAddInputSize, kAddOutputSize); return PARAM_INVALID; } // input vector elements must not be null if ((input[kAddFirstInput] == nullptr) || (input[kAddSecondInput] == nullptr)) { - GELOGE(PARAM_INVALID, "Input vector elements must not be null."); + GELOGW("Input vector elements must not be null."); return PARAM_INVALID; } // Inputs must have the same datatype. DataType data_type_0 = input[kAddFirstInput]->GetTensorDesc().GetDataType(); DataType data_type_1 = input[kAddSecondInput]->GetTensorDesc().GetDataType(); if (data_type_0 != data_type_1) { - GELOGE(PARAM_INVALID, "Data type of inputs for add not matched, data_type_0:%s, data_type_1:%s", + GELOGW("Data type of inputs for add not matched, data_type_0:%s, data_type_1:%s", TypeUtils::DataTypeToSerialString(data_type_0).c_str(), TypeUtils::DataTypeToSerialString(data_type_1).c_str()); return PARAM_INVALID; @@ -192,7 +191,7 @@ Status AddKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector x2_dims; const auto &op_in_desc = op_desc_ptr->MutableInputDesc(0); GE_CHECK_NOTNULL(op_in_desc); - ; DataType data_type = op_in_desc->GetDataType(); bool result = (OpUtils::GetShapeDataFromConstTensor(input[0], data_type, x1_dims) == SUCCESS) && (OpUtils::GetShapeDataFromConstTensor(input[1], data_type, x2_dims) == SUCCESS); diff --git a/src/ge/host_kernels/concat_offset_kernel.cc b/src/ge/host_kernels/concat_offset_kernel.cc index 2e609d68..0a870949 100644 --- a/src/ge/host_kernels/concat_offset_kernel.cc +++ b/src/ge/host_kernels/concat_offset_kernel.cc @@ -41,7 +41,7 @@ Status ConcatOffsetKernel::Compute(const OpDescPtr op_desc_ptr, const vector(reinterpret_cast(input_0->GetData().data()))); // validate inputs if (static_cast(input.size()) != (N + kNumOne) || input.size() <= kConcatOffsetInputIndexOne) { - GELOGE(PARAM_INVALID, "The number of input for concat offset must be equal with %d, and must be more than one.", - (N + kNumOne)); + GELOGW("The number of input for concat offset must be equal with %d, and must be more than one.", (N + kNumOne)); return NOT_CHANGED; } @@ -59,7 +58,7 @@ Status ConcatOffsetKernel::Compute(const OpDescPtr op_desc_ptr, const vectorGetTensorDesc().GetShape(); int64_t output_size = output_shape.GetShapeSize(); if (concat_dim >= output_size) { - GELOGE(PARAM_INVALID, "Concat dim is biger than the size of output_shape."); + GELOGW("Concat dim is biger than the size of output_shape."); return NOT_CHANGED; } GELOGI("Output shape size is %ld", output_size); @@ -79,7 +78,7 @@ Status ConcatOffsetKernel::Compute(const OpDescPtr op_desc_ptr, const vectorGetOutputDesc(0); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "Failed to fold node %s, out of memeory", op_desc_ptr->GetName().c_str()); + GELOGW("Failed to fold node %s, out of memeory", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } @@ -87,7 +86,7 @@ Status ConcatOffsetKernel::Compute(const OpDescPtr op_desc_ptr, const vectorMutableTensorDesc().SetShape(output_shape); GE_IF_BOOL_EXEC(output_ptr->SetData(reinterpret_cast(buf.get()), static_cast(sizeof(DT_INT32) * output_size)) != GRAPH_SUCCESS, - GELOGE(INTERNAL_ERROR, "set data failed"); + GELOGW("set data failed"); return NOT_CHANGED); v_output.push_back(output_ptr); // caculate offset diff --git a/src/ge/host_kernels/dynamic_stitch_kernel.cc b/src/ge/host_kernels/dynamic_stitch_kernel.cc index c8a19e44..c1245535 100644 --- a/src/ge/host_kernels/dynamic_stitch_kernel.cc +++ b/src/ge/host_kernels/dynamic_stitch_kernel.cc @@ -63,11 +63,11 @@ Status DynamicStitchKernel::Compute(const OpDescPtr op_desc_ptr, const vector &input) { if (op_desc_ptr == nullptr) { - GELOGE(PARAM_INVALID, "Input op_desc is nullptr."); + GELOGW("Input op_desc is nullptr."); return PARAM_INVALID; } if (op_desc_ptr->GetOutputsSize() == 0) { - GELOGE(PARAM_INVALID, "Current output_desc is empty."); + GELOGW("Current output_desc is empty."); return PARAM_INVALID; } // validate input @@ -78,7 +78,7 @@ Status DynamicStitchKernel::ValidateParams(const OpDescPtr &op_desc_ptr, const s } for (const auto &in : input) { if (in == nullptr) { - GELOGE(PARAM_INVALID, "input is nullptr."); + GELOGW("input is nullptr."); return PARAM_INVALID; } } @@ -150,7 +150,7 @@ Status DynamicStitchKernel::GenData(const vector &input, GeTen // 2.allocate memery for output std::unique_ptr buf(new (std::nothrow) uint8_t[allowance]); if (buf == nullptr) { - GELOGE(MEMALLOC_FAILED, "new buffer failed"); + GELOGW("new buffer failed"); return INTERNAL_ERROR; } // 3.copy data from input_data along with the sequence of input_indices @@ -164,7 +164,7 @@ Status DynamicStitchKernel::GenData(const vector &input, GeTen output_ptr->MutableTensorDesc().SetShape(merged_shape); Status ret = output_ptr->SetData(buf.get(), allowance); if (ret != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "set data failed"); + GELOGW("set data failed"); return NOT_CHANGED; } return SUCCESS; diff --git a/src/ge/host_kernels/empty_kernel.cc b/src/ge/host_kernels/empty_kernel.cc index 856caf50..a5e5fbcf 100644 --- a/src/ge/host_kernels/empty_kernel.cc +++ b/src/ge/host_kernels/empty_kernel.cc @@ -38,7 +38,7 @@ const size_t kShapeMaxDims = 1; } // namespace Status EmptyKernel::EmptyCheck(const OpDescPtr &op_desc_ptr, const std::vector &input) { if (op_desc_ptr == nullptr) { - GELOGE(PARAM_INVALID, "Parameter's invalid, Input opDescPtr is nullptr."); + GELOGW("Parameter's invalid, Input opDescPtr is nullptr."); return PARAM_INVALID; } // check input size @@ -46,20 +46,19 @@ Status EmptyKernel::EmptyCheck(const OpDescPtr &op_desc_ptr, const std::vectorGetAllInputsDesc().size() != kEmptyInputsSize) || (input.size() != kEmptyInputsSize) || (op_desc_ptr->GetAllOutputsDesc().size() != kEmptyOutputsSize)); if (size_check) { - GELOGE(PARAM_INVALID, "Input/Output size error. InDesc size:%zu, OutDesc size:%zu, in size:%zu ", + GELOGW("Input/Output size error. InDesc size:%zu, OutDesc size:%zu, in size:%zu ", op_desc_ptr->GetAllInputsDesc().size(), op_desc_ptr->GetAllOutputsDesc().size(), input.size()); return PARAM_INVALID; } if (input.at(kEmptyFirstInput) == nullptr) { - GELOGE(PARAM_INVALID, "Parameter's invalid, first input is nullptr."); + GELOGW("Parameter's invalid, first input is nullptr."); return PARAM_INVALID; } ConstGeTensorPtr shape = input.at(kEmptyFirstInput); // Check if the dimension is 1-D if (shape->GetTensorDesc().GetShape().GetDimNum() > kShapeMaxDims) { - GELOGE(PARAM_INVALID, "Check if the dimension is 1-D failed, dims:%zu", - shape->GetTensorDesc().GetShape().GetDimNum()); + GELOGW("Check if the dimension is 1-D failed, dims:%zu", shape->GetTensorDesc().GetShape().GetDimNum()); return PARAM_INVALID; } return SUCCESS; @@ -84,7 +83,7 @@ Status EmptyKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector(shape, shape_vec, total_data_size); } else { - GELOGE(PARAM_INVALID, "shape type must be DT_INT32 or DT_INT64."); + GELOGW("shape type must be DT_INT32 or DT_INT64."); return NOT_CHANGED; } diff --git a/src/ge/host_kernels/expanddims_kernel.cc b/src/ge/host_kernels/expanddims_kernel.cc index 1d17ad48..15648573 100644 --- a/src/ge/host_kernels/expanddims_kernel.cc +++ b/src/ge/host_kernels/expanddims_kernel.cc @@ -66,7 +66,7 @@ Status ExpanddimsKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vec auto output_tensor_desc = op_desc_ptr->GetOutputDesc(kExpandDimsIndexZero); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "Failed to fold node %s, out of memory", op_desc_ptr->GetName().c_str()); + GELOGW("Failed to fold node %s, out of memory", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } diff --git a/src/ge/host_kernels/floordiv_kernel.cc b/src/ge/host_kernels/floordiv_kernel.cc index 4175df92..05eded80 100644 --- a/src/ge/host_kernels/floordiv_kernel.cc +++ b/src/ge/host_kernels/floordiv_kernel.cc @@ -260,7 +260,7 @@ Status FloorDivKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetOutputDesc(0); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); + GELOGW("make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } diff --git a/src/ge/host_kernels/floormod_kernel.cc b/src/ge/host_kernels/floormod_kernel.cc index a8c16c9d..7ad746de 100644 --- a/src/ge/host_kernels/floormod_kernel.cc +++ b/src/ge/host_kernels/floormod_kernel.cc @@ -122,7 +122,7 @@ Status FloorModKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector(op_desc_ptr->GetOutputDesc(kFloorModFirstOutput)); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); + GELOGW("make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } diff --git a/src/ge/host_kernels/gather_v2_kernel.cc b/src/ge/host_kernels/gather_v2_kernel.cc index c8cc3006..7413395a 100644 --- a/src/ge/host_kernels/gather_v2_kernel.cc +++ b/src/ge/host_kernels/gather_v2_kernel.cc @@ -274,7 +274,7 @@ Status GatherV2Kernel::SaveIndicesByDataType(ConstGeTensorPtr indices_tensor_ptr auto indices_ptr = const_cast(reinterpret_cast(indices_tensor_ptr->GetData().data())); for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) { if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) { - GELOGE(NOT_CHANGED, "indices %ld value is not in range [0, %ld)", i, x_shape.GetDim(axis)); + GELOGW("indices %ld value is not in range [0, %ld)", i, x_shape.GetDim(axis)); return NOT_CHANGED; } indicates_.push_back(*(indices_ptr + i)); @@ -284,7 +284,7 @@ Status GatherV2Kernel::SaveIndicesByDataType(ConstGeTensorPtr indices_tensor_ptr auto indices_ptr = const_cast(reinterpret_cast(indices_tensor_ptr->GetData().data())); for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) { if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) { - GELOGE(NOT_CHANGED, "indices %ld value is not in range [0, %ld)", i, x_shape.GetDim(axis)); + GELOGW("indices %ld value is not in range [0, %ld)", i, x_shape.GetDim(axis)); return NOT_CHANGED; } indicates_.push_back(*(indices_ptr + i)); @@ -296,19 +296,19 @@ Status GatherV2Kernel::SaveIndicesByDataType(ConstGeTensorPtr indices_tensor_ptr Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vector &input, vector &v_output) const { if (op_desc_ptr == nullptr) { - GELOGE(NOT_CHANGED, "input opdesc is nullptr."); + GELOGW("input opdesc is nullptr."); return NOT_CHANGED; } if (input.size() != kGatherV2InpotNum) { - GELOGE(NOT_CHANGED, "The number of input for GatherV2 must be %zu.", kGatherV2InpotNum); + GELOGW("The number of input for GatherV2 must be %zu.", kGatherV2InpotNum); return NOT_CHANGED; } bool is_null = (input[kGatherV2InputIndexZero] == nullptr || input[kGatherV2InputIndexOne] == nullptr || input[kGatherV2InputIndexTwo] == nullptr); if (is_null) { - GELOGE(NOT_CHANGED, "some input is nullptr."); + GELOGW("some input is nullptr."); return NOT_CHANGED; } ConstGeTensorPtr tensor0 = input.at(kGatherV2InputIndexZero); @@ -318,7 +318,7 @@ Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vectorGetData().size() == 0) || (tensor1->GetData().size() == 0) || (tensor2->GetData().size() == 0)); if (size_is_zero) { - GELOGE(NOT_CHANGED, "some input size is zero."); + GELOGW("some input size is zero."); return NOT_CHANGED; } @@ -326,13 +326,13 @@ Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vectorGetTensorDesc().GetShape(); // axis must be scalar if (axis_shape.GetDimNum() != 0) { - GELOGE(NOT_CHANGED, "axis must be scalar but its shape is %zu", axis_shape.GetDimNum()); + GELOGW("axis must be scalar but its shape is %zu", axis_shape.GetDimNum()); return NOT_CHANGED; } auto axis_data_type = tensor2->GetTensorDesc().GetDataType(); bool is_valid_axis_data_type = axis_data_type == DT_INT32 || axis_data_type == DT_INT64; if (!is_valid_axis_data_type) { - GELOGE(NOT_CHANGED, "axis datatype must be DT_INT32 or DT_INT64"); + GELOGW("axis datatype must be DT_INT32 or DT_INT64"); return NOT_CHANGED; } @@ -340,11 +340,11 @@ Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vectorGetTensorDesc().GetDataType(); bool is_valid_indices_data_type = indices_data_type == DT_INT32 || indices_data_type == DT_INT64; if (!is_valid_indices_data_type) { - GELOGE(NOT_CHANGED, "indices datatype must be DT_INT32 or DT_INT64"); + GELOGW("indices datatype must be DT_INT32 or DT_INT64"); return NOT_CHANGED; } if (indices_shape.GetDimNum() > kMaxIndicatesDims) { - GELOGE(NOT_CHANGED, "indices input only support 0 or 1 dims"); + GELOGW("indices input only support 0 or 1 dims"); return NOT_CHANGED; } return SUCCESS; @@ -372,7 +372,7 @@ Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vectorGetName().c_str()); @@ -390,13 +390,13 @@ Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector= 0 ? axis : axis + x_shape.GetDimNum(); // check axis value if (axis < 0 || (axis + 1) > static_cast(x_shape.GetDimNum())) { - GELOGE(NOT_CHANGED, "axis is invalid"); + GELOGW("axis is invalid"); return NOT_CHANGED; } auto indices_data_type = tensor1->GetTensorDesc().GetDataType(); ret = SaveIndicesByDataType(tensor1, x_shape, indices_shape, indices_data_type, static_cast(axis)); if (ret != SUCCESS) { - GELOGE(NOT_CHANGED, "Save indeices by data type failed!"); + GELOGW("Save indeices by data type failed!"); return ret; } @@ -420,7 +420,7 @@ Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector(op_desc_ptr->GetOutputDesc(0)); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); + GELOGW("make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } output_ptr->MutableTensorDesc().SetShape(GeShape(y_shape)); diff --git a/src/ge/host_kernels/identity_kernel.cc b/src/ge/host_kernels/identity_kernel.cc new file mode 100644 index 00000000..16bd3138 --- /dev/null +++ b/src/ge/host_kernels/identity_kernel.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "identity_kernel.h" +#include "inc/kernel_factory.h" + +namespace { +constexpr uint32_t kInputDescIndex = 0; +constexpr uint32_t kOutputDescIndex = 0; +} // namespace + +namespace ge { +Status IdentityKernel::Compute(const ge::OpDescPtr op_desc, const std::vector &input, + std::vector &v_output) { + if (op_desc == nullptr) { + GELOGE(PARAM_INVALID, "IdentityKernel op_desc is null."); + return NOT_CHANGED; + } + if (input.empty()) { + GELOGE(PARAM_INVALID, "Node [%s] inputs is empty.", op_desc->GetName().c_str()); + return NOT_CHANGED; + } + if (op_desc->GetOutputsSize() < 1) { + GELOGE(PARAM_INVALID, "Node [%s] output is empty.", op_desc->GetName().c_str()); + return NOT_CHANGED; + } + GELOGD("IdentityKernel in: node[%s]", op_desc->GetName().c_str()); + + auto out_tensor_desc = op_desc->GetOutputDesc(kOutputDescIndex); + GeTensorPtr output_ptr = MakeShared(out_tensor_desc); + if (output_ptr == nullptr) { + GELOGE(OUT_OF_MEMORY, "Node [%s] make shared failed.", op_desc->GetName().c_str()); + return OUT_OF_MEMORY; + } + auto input_tensor_ptr = input.at(kInputDescIndex); + if (input_tensor_ptr == nullptr) { + GELOGE(PARAM_INVALID, "Node [%s] get input failed.", op_desc->GetName().c_str()); + return NOT_CHANGED; + } + if (output_ptr->SetData(input_tensor_ptr->GetData()) != GRAPH_SUCCESS) { + GELOGW("Compute: SetData failed"); + return NOT_CHANGED; + } + v_output.emplace_back(output_ptr); + GELOGD("IdentityKernel success: node[%s]", op_desc->GetName().c_str()); + + return SUCCESS; +} +REGISTER_KERNEL(IDENTITY, IdentityKernel); +} // namespace ge diff --git a/src/ge/host_kernels/identity_kernel.h b/src/ge/host_kernels/identity_kernel.h new file mode 100644 index 00000000..2164d880 --- /dev/null +++ b/src/ge/host_kernels/identity_kernel.h @@ -0,0 +1,31 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_FOLDING_KERNEL_IDENTITY_KERNEL_H_ +#define GE_GRAPH_PASSES_FOLDING_KERNEL_IDENTITY_KERNEL_H_ + +#include "inc/kernel.h" +#include + +namespace ge { +class IdentityKernel : public Kernel { + public: + Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector &input, + std::vector &v_output) override; +}; +} // namespace ge + +#endif // GE_GRAPH_PASSES_FOLDING_KERNEL_IDENTITY_KERNEL_H_ diff --git a/src/ge/host_kernels/pack_kernel.cc b/src/ge/host_kernels/pack_kernel.cc index f3f64a6c..9b62a582 100644 --- a/src/ge/host_kernels/pack_kernel.cc +++ b/src/ge/host_kernels/pack_kernel.cc @@ -63,7 +63,7 @@ Status PackKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vector &input) { if (op_desc_ptr == nullptr) { - GELOGE(PARAM_INVALID, "input opdesc is nullptr."); + GELOGW("input opdesc is nullptr."); return PARAM_INVALID; } if (!(AttrUtils::GetInt(op_desc_ptr, PACK_ATTR_NAME_NUM, n_))) { @@ -71,16 +71,15 @@ Status PackKernel::ValidateKernelParams(const ge::OpDescPtr &op_desc_ptr, GELOGD("Attr %s is not set, default value %ld is used.", PACK_ATTR_NAME_NUM.c_str(), n_); } if (!(AttrUtils::GetInt(op_desc_ptr, ATTR_NAME_AXIS, axis_))) { - GELOGE(PARAM_INVALID, "Attr %s is not exist.", ATTR_NAME_AXIS.c_str()); + GELOGW("Attr %s is not exist.", ATTR_NAME_AXIS.c_str()); return PARAM_INVALID; } if (input.empty()) { - GELOGE(PARAM_INVALID, "The number of input for Pack should be %ld, in fact it is %zu ", n_, input.size()); + GELOGW("The number of input for Pack should be %ld, in fact it is %zu ", n_, input.size()); return NOT_CHANGED; } if (input.size() != static_cast(n_)) { - GELOGE(PARAM_INVALID, "The number of input for Pack should be %d, in fact it is %ld ", static_cast(n_), - input.size()); + GELOGW("The number of input for Pack should be %d, in fact it is %ld ", static_cast(n_), input.size()); return PARAM_INVALID; } data_type_ = op_desc_ptr->GetInputDesc(0).GetDataType(); diff --git a/src/ge/host_kernels/permute_kernel.cc b/src/ge/host_kernels/permute_kernel.cc index 8263d19f..24bed54d 100644 --- a/src/ge/host_kernels/permute_kernel.cc +++ b/src/ge/host_kernels/permute_kernel.cc @@ -110,14 +110,14 @@ Status PermuteKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetData().data(); formats::TransResult trans_result; auto ret = formats::TransposeWithShapeCheck(src_data, src_shape, data_shape, src_data_type, perm_list, trans_result); if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to Transpose from %s to %s, shape %s to %s, perm_list %s, data type %s", + GELOGW("Failed to Transpose from %s to %s, shape %s to %s, perm_list %s, data type %s", TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(data_format).c_str(), formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(data_shape).c_str(), formats::ShapeToString(perm_list).c_str(), TypeUtils::DataTypeToSerialString(src_data_type).c_str()); diff --git a/src/ge/host_kernels/rank_kernel.cc b/src/ge/host_kernels/rank_kernel.cc index faaf16b8..c8763aef 100644 --- a/src/ge/host_kernels/rank_kernel.cc +++ b/src/ge/host_kernels/rank_kernel.cc @@ -49,7 +49,7 @@ Status RankKernel::Compute(const NodePtr &node, std::vector &v_outp auto ndims = input_shape->GetShape().GetDimNum(); GeTensorDesc tensor_desc(op_desc->GetOutputDesc(0)); GeTensorPtr output_ptr; - output_ptr = MakeShared(tensor_desc, reinterpret_cast(&ndims), sizeof(ndims)); + output_ptr = MakeShared(tensor_desc, reinterpret_cast(&ndims), GetSizeByDataType(DT_INT32)); if (output_ptr == nullptr) { GELOGE(MEMALLOC_FAILED, "make_shared ge::GeTensor failed"); return MEMALLOC_FAILED; diff --git a/src/ge/host_kernels/reduce_prod_kernel.cc b/src/ge/host_kernels/reduce_prod_kernel.cc index 479b50ab..739d4b9f 100644 --- a/src/ge/host_kernels/reduce_prod_kernel.cc +++ b/src/ge/host_kernels/reduce_prod_kernel.cc @@ -51,7 +51,7 @@ Status ReduceProdKernel::ReduceProdCheck(const ge::OpDescPtr &op_desc_ptr, op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } - GELOGE(PARAM_INVALID, "Unexpected ReduceProd node, node input size: %zu, node name: %s", input.size(), + GELOGW("Unexpected ReduceProd node, node input size: %zu, node name: %s", input.size(), op_desc_ptr->GetName().c_str()); return PARAM_INVALID; } @@ -60,13 +60,13 @@ Status ReduceProdKernel::ReduceProdCheck(const ge::OpDescPtr &op_desc_ptr, GE_CHECK_NOTNULL(data_tensor); GE_CHECK_NOTNULL(axis_tensor); if (axis_tensor->GetTensorDesc().GetShape().GetDimNum() > kReduceProdMaxAxisRank) { - GELOGE(PARAM_INVALID, "Axis must be at most rank 1, node node: %s", op_desc_ptr->GetName().c_str()); + GELOGW("Axis must be at most rank 1, node node: %s", op_desc_ptr->GetName().c_str()); return PARAM_INVALID; } DataType data_type = data_tensor->GetTensorDesc().GetDataType(); if (kReduceProdSupportedType.find(data_type) == kReduceProdSupportedType.end()) { - GELOGE(PARAM_INVALID, "ReduceProdKernel data type %s not support, node name: %s", + GELOGW("ReduceProdKernel data type %s not support, node name: %s", TypeUtils::DataTypeToSerialString(data_type).c_str(), op_desc_ptr->GetName().c_str()); return PARAM_INVALID; } @@ -83,7 +83,7 @@ Status ReduceProdKernel::AxisCal(const std::vector &input) int32_t *axis = const_cast(reinterpret_cast(axis_tensor->GetData().GetData())); GE_CHECK_NOTNULL(axis); if (static_cast(*axis) >= data_dim_size) { - GELOGE(PARAM_INVALID, "axis is out of rank of data_dims, axis is %d.", *axis); + GELOGW("axis is out of rank of data_dims, axis is %d.", *axis); return PARAM_INVALID; } axis_dim_ = data_dims[static_cast(*axis)]; @@ -98,13 +98,13 @@ Status ReduceProdKernel::AxisCal(const std::vector &input) // data_dims is the vector of dims, element in data_dims isn't negative. if (axis_appear) { if (data_dims[i] != 0 && end_dim_ > (INT64_MAX / data_dims[i])) { - GELOGE(INTERNAL_ERROR, "Product is overflow. multiplier 1: %ld. multiplier 2: %ld.", end_dim_, data_dims[i]); + GELOGW("Product is overflow. multiplier 1: %ld. multiplier 2: %ld.", end_dim_, data_dims[i]); return INTERNAL_ERROR; } end_dim_ *= data_dims[i]; } else { if (data_dims[i] != 0 && head_dim_ > (INT64_MAX / data_dims[i])) { - GELOGE(INTERNAL_ERROR, "Product is overflow. multiplier 1: %ld. multiplier 2: %ld.", head_dim_, data_dims[i]); + GELOGW("Product is overflow. multiplier 1: %ld. multiplier 2: %ld.", head_dim_, data_dims[i]); return INTERNAL_ERROR; } head_dim_ *= data_dims[i]; @@ -122,7 +122,7 @@ Status ReduceProdKernel::DataCal(const std::vector &input, size_t data_num = data_tensor->GetData().size() / sizeof(int32_t); unique_ptr buf(new (std::nothrow) int32_t[data_num]()); if (buf == nullptr) { - GELOGE(MEMALLOC_FAILED, "new buf failed"); + GELOGW("new buf failed"); return INTERNAL_ERROR; } @@ -190,12 +190,12 @@ Status ReduceProdKernel::ComputeNoAxis(const ge::OpDescPtr &op_desc_ptr, const s ConstGeTensorPtr data_tensor = input.at(kReduceProdDataIndex); GE_CHECK_NOTNULL(data_tensor); if (data_tensor->GetData().size() == 0) { - GELOGE(PARAM_INVALID, "ReduceProdKernel data size of inputs is 0, node node: %s", op_desc_ptr->GetName().c_str()); + GELOGW("ReduceProdKernel data size of inputs is 0, node node: %s", op_desc_ptr->GetName().c_str()); return PARAM_INVALID; } DataType data_type = data_tensor->GetTensorDesc().GetDataType(); if (kReduceProdSupportedType.find(data_type) == kReduceProdSupportedType.end()) { - GELOGE(PARAM_INVALID, "ReduceProdKernel data type %s not support, node name: %s", + GELOGW("ReduceProdKernel data type %s not support, node name: %s", TypeUtils::DataTypeToSerialString(data_type).c_str(), op_desc_ptr->GetName().c_str()); return PARAM_INVALID; } @@ -206,7 +206,7 @@ Status ReduceProdKernel::ComputeNoAxis(const ge::OpDescPtr &op_desc_ptr, const s size_t data_num = data_tensor->GetData().size() / sizeof(int32_t); unique_ptr buf(new (std::nothrow) int32_t[data_num]()); if (buf == nullptr) { - GELOGE(MEMALLOC_FAILED, "new buf failed"); + GELOGW("new buf failed"); return INTERNAL_ERROR; } @@ -235,7 +235,7 @@ Status ReduceProdKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vec GELOGI("ReduceProdKernel in."); Status ret = ReduceProdCheck(op_desc_ptr, input); if (ret != SUCCESS && ret != NOT_CHANGED) { - GELOGE(PARAM_INVALID, "ReduceProdKernel input is invalid, failed to fold node."); + GELOGW("ReduceProdKernel input is invalid, failed to fold node."); return NOT_CHANGED; } @@ -243,7 +243,7 @@ Status ReduceProdKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vec auto output_tensor_desc = op_desc_ptr->GetOutputDesc(0); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); + GELOGW("make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } diff --git a/src/ge/host_kernels/reformat_kernel.cc b/src/ge/host_kernels/reformat_kernel.cc index 33a13599..c2dd1e17 100644 --- a/src/ge/host_kernels/reformat_kernel.cc +++ b/src/ge/host_kernels/reformat_kernel.cc @@ -56,7 +56,7 @@ Status ReFormatKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetTensorDesc().GetShape()).c_str()); return NOT_CHANGED; } GeTensorPtr output_ptr = MakeShared(op_desc_ptr->GetOutputDesc(kReformatFirstOutput)); if (output_ptr == nullptr) { - GELOGE(INTERNAL_ERROR, "Create shared ptr for GeTensor failed"); + GELOGW("Create shared ptr for GeTensor failed"); return NOT_CHANGED; } - GE_IF_BOOL_EXEC(output_ptr->SetData(input.at(0)->GetData()) != GRAPH_SUCCESS, - GELOGE(INTERNAL_ERROR, "set data failed"); + GE_IF_BOOL_EXEC(output_ptr->SetData(input.at(0)->GetData()) != GRAPH_SUCCESS, GELOGW("set data failed"); return NOT_CHANGED); v_output.emplace_back(output_ptr); GELOGD("ReFormatKernel success."); diff --git a/src/ge/host_kernels/reshape_kernel.cc b/src/ge/host_kernels/reshape_kernel.cc index 906624d2..dc7e4bb8 100644 --- a/src/ge/host_kernels/reshape_kernel.cc +++ b/src/ge/host_kernels/reshape_kernel.cc @@ -67,7 +67,7 @@ Status ReshapeKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vector auto output_tensor_desc = op_desc_ptr->GetOutputDesc(kOutputDescFirstIndex); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "Failed to fold node %s, out of memory", op_desc_ptr->GetName().c_str()); + GELOGW("Failed to fold node %s, out of memory", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } diff --git a/src/ge/host_kernels/rsqrt_kernel.cc b/src/ge/host_kernels/rsqrt_kernel.cc index 3e14fd5f..56972d23 100644 --- a/src/ge/host_kernels/rsqrt_kernel.cc +++ b/src/ge/host_kernels/rsqrt_kernel.cc @@ -64,7 +64,7 @@ Status RsqrtKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector 0) { unique_ptr buf(new (std::nothrow) float[data_count]()); if (buf == nullptr) { - GELOGE(MEMALLOC_FAILED, "new buf failed"); + GELOGW("new buf failed"); return NOT_CHANGED; } @@ -81,13 +81,13 @@ Status RsqrtKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetOutputDesc(0); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "MakeShared GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); + GELOGW("MakeShared GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } output_ptr->MutableTensorDesc().SetDataType(DT_FLOAT); GE_IF_BOOL_EXEC(output_ptr->SetData(reinterpret_cast(buf.get()), data_size) != GRAPH_SUCCESS, - GELOGE(INTERNAL_ERROR, "set data failed"); + GELOGW("set data failed"); return NOT_CHANGED); output_ptr->MutableTensorDesc().SetShape(x_shape); v_output.push_back(output_ptr); diff --git a/src/ge/host_kernels/slice_d_kernel.cc b/src/ge/host_kernels/slice_d_kernel.cc index ad0a1675..3b8fd0a0 100644 --- a/src/ge/host_kernels/slice_d_kernel.cc +++ b/src/ge/host_kernels/slice_d_kernel.cc @@ -129,7 +129,7 @@ Status SliceDKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetOutputDesc(0); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "Failed to fold node %s, out of memory", op_desc_ptr->GetName().c_str()); + GELOGW("Failed to fold node %s, out of memory", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } @@ -143,8 +143,14 @@ Status SliceDKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector(const_cast(x_tensor->GetData().data())); int64_t x_data_size = x_tensor->GetTensorDesc().GetShape().GetShapeSize(); - Status ret = OpUtils::SetOutputSliceData(data, x_data_size, x_data_type, x_dims, begin_list, size_list, - output_ptr.get(), stride_list); + + Status ret = CheckOutputDims(size_list, op_desc_ptr); + if (ret != SUCCESS) { + return ret; + } + + ret = OpUtils::SetOutputSliceData(data, x_data_size, x_data_type, x_dims, begin_list, size_list, output_ptr.get(), + stride_list); if (ret != SUCCESS) { GELOGW("Set output data of SliceD failed."); return NOT_CHANGED; @@ -155,5 +161,16 @@ Status SliceDKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector &output_dims, const OpDescPtr attr) { + // check dim not all less than 0 + for (auto dim : output_dims) { + if (dim > 0) { + return SUCCESS; + } + } + GELOGW("all output dim <=0, can't be processed. op_name : %s", attr->GetName().c_str()); + return NOT_CHANGED; +} + REGISTER_KERNEL(SLICED, SliceDKernel); } // namespace ge diff --git a/src/ge/host_kernels/slice_d_kernel.h b/src/ge/host_kernels/slice_d_kernel.h index 9fe35352..90ef9b8b 100644 --- a/src/ge/host_kernels/slice_d_kernel.h +++ b/src/ge/host_kernels/slice_d_kernel.h @@ -29,6 +29,7 @@ class SliceDKernel : public Kernel { private: Status SliceDCheck(const OpDescPtr &op_desc_ptr, const std::vector &input, std::vector &begin_list, std::vector &size_list); + Status CheckOutputDims(const std::vector &output_dims, const OpDescPtr attr); }; } // namespace ge diff --git a/src/ge/host_kernels/slice_kernel.cc b/src/ge/host_kernels/slice_kernel.cc index 1d7d90c2..5f72fc49 100644 --- a/src/ge/host_kernels/slice_kernel.cc +++ b/src/ge/host_kernels/slice_kernel.cc @@ -21,8 +21,8 @@ #include "common/types.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" -#include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" +#include "host_kernels/kernel_utils.h" #include "inc/kernel_factory.h" namespace ge { diff --git a/src/ge/host_kernels/ssd_prior_box_kernel.cc b/src/ge/host_kernels/ssd_prior_box_kernel.cc index c874d732..9de5a08d 100644 --- a/src/ge/host_kernels/ssd_prior_box_kernel.cc +++ b/src/ge/host_kernels/ssd_prior_box_kernel.cc @@ -365,7 +365,7 @@ Status SsdPriorboxKernel::Compute(const NodePtr &node, std::vector // make TensorDesc GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(INTERNAL_ERROR, "Create shared ptr for GeTensor failed"); + GELOGW("Create shared ptr for GeTensor failed"); return NOT_CHANGED; } GE_IF_BOOL_EXEC(output_ptr->SetData(reinterpret_cast(output_data.get()), diff --git a/src/ge/host_kernels/strided_slice_kernel.cc b/src/ge/host_kernels/strided_slice_kernel.cc index 0d70a36a..6a9a558c 100644 --- a/src/ge/host_kernels/strided_slice_kernel.cc +++ b/src/ge/host_kernels/strided_slice_kernel.cc @@ -46,31 +46,31 @@ Status StridedSliceKernel::CheckAndGetAttr(const OpDescPtr &attr, const std::vec int64_t shrink_axis_mask = 0; if (attr == nullptr) { - GELOGE(PARAM_INVALID, "input opdescptr is nullptr."); + GELOGW("input opdescptr is nullptr."); return PARAM_INVALID; } if (input.size() != kStridedSliceInputSize) { - GELOGE(PARAM_INVALID, "The number of input for strided slice must be %zu.", kStridedSliceInputSize); + GELOGW("The number of input for strided slice must be %zu.", kStridedSliceInputSize); return PARAM_INVALID; } if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_BEGIN_MASK, begin_mask)) { - GELOGE(PARAM_INVALID, "get begin_mask attr failed."); + GELOGW("get begin_mask attr failed."); return PARAM_INVALID; } if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_END_MASK, end_mask)) { - GELOGE(PARAM_INVALID, "get end_mask attr failed."); + GELOGW("get end_mask attr failed."); return PARAM_INVALID; } if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_ELLIPSIS_MASK, ellipsis_mask)) { - GELOGE(PARAM_INVALID, "get ellipsis_mask attr failed."); + GELOGW("get ellipsis_mask attr failed."); return PARAM_INVALID; } if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_NEW_AXIS_MASK, new_axis_mask)) { - GELOGE(PARAM_INVALID, "get new_axis_mask attr failed."); + GELOGW("get new_axis_mask attr failed."); return PARAM_INVALID; } if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK, shrink_axis_mask)) { - GELOGE(PARAM_INVALID, "get shrink_axis_mask attr failed."); + GELOGW("get shrink_axis_mask attr failed."); return PARAM_INVALID; } if ((ellipsis_mask != 0) || (new_axis_mask != 0)) { @@ -98,7 +98,7 @@ Status StridedSliceKernel::CheckAndGetAttr(const OpDescPtr &attr, const std::vec ConstGeTensorPtr weight2 = input[kStridedSliceInputIndex2]; ConstGeTensorPtr weight3 = input[kStridedSliceInputIndex3]; if (CheckWeight(weight0, weight1, weight2, weight3) != SUCCESS) { - GELOGE(PARAM_INVALID, "Check And Get Attr failed."); + GELOGW("Check And Get Attr failed."); return PARAM_INVALID; } @@ -168,6 +168,17 @@ void StridedSliceKernel::GetOutputDims(uint32_t dims_size, const std::vector &output_dims, const OpDescPtr attr) { + // check dim not all less than 0 + for (auto dim : output_dims) { + if (dim > 0) { + return SUCCESS; + } + } + GELOGW("all output dim <=0, can't be processed. op_name : %s", attr->GetName().c_str()); + return NOT_CHANGED; +} + Status StridedSliceKernel::Compute(const ge::OpDescPtr attr, const std::vector &input, vector &v_output) { GELOGI("StridedSliceKernel in."); @@ -191,7 +202,7 @@ Status StridedSliceKernel::Compute(const ge::OpDescPtr attr, const std::vector(weight2->GetData().data()); const int32_t *stride = reinterpret_cast(weight3->GetData().data()); if ((begin == nullptr) || (end == nullptr) || (stride == nullptr)) { - GELOGE(PARAM_INVALID, "input weight tensor is nullptr."); + GELOGW("input weight tensor is nullptr."); return NOT_CHANGED; } @@ -237,16 +248,22 @@ Status StridedSliceKernel::Compute(const ge::OpDescPtr attr, const std::vectorGetOutputDesc(0); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "MakeShared GeTensor failed, node name %s.", attr->GetName().c_str()); + GELOGW("MakeShared GeTensor failed, node name %s.", attr->GetName().c_str()); return NOT_CHANGED; } void *data = reinterpret_cast(const_cast(weight0->GetData().data())); GE_CHECK_NOTNULL(data); + + ret = CheckOutputDims(output_dims, attr); + if (ret != SUCCESS) { + return ret; + } + ret = OpUtils::SetOutputSliceData(data, static_cast(data_size), args.data_type, input_dims, begin_vec, output_dims, output_ptr.get(), stride_vec); if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "SetOutputSliceData failed."); + GELOGW("SetOutputSliceData failed."); return NOT_CHANGED; } diff --git a/src/ge/host_kernels/strided_slice_kernel.h b/src/ge/host_kernels/strided_slice_kernel.h index e569b2d0..0ba3afbd 100644 --- a/src/ge/host_kernels/strided_slice_kernel.h +++ b/src/ge/host_kernels/strided_slice_kernel.h @@ -44,6 +44,7 @@ class StridedSliceKernel : public Kernel { int32_t &end_i, int32_t &dim_i) const; void GetOutputDims(uint32_t dims_size, const std::vector &output_dims, const Attr &args, vector &v_dims); + Status CheckOutputDims(const std::vector &output_dims, const OpDescPtr attr); }; } // namespace ge #endif // GE_GRAPH_PASSES_FOLDING_KERNEL_STRIDED_SLICE_KERNEL_H_ diff --git a/src/ge/host_kernels/sub_kernel.cc b/src/ge/host_kernels/sub_kernel.cc index ed1e5808..70a14c9f 100644 --- a/src/ge/host_kernels/sub_kernel.cc +++ b/src/ge/host_kernels/sub_kernel.cc @@ -162,7 +162,7 @@ Status SubKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vectorGetOutputDesc(kSubFirstOutput); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); + GELOGW("make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } diff --git a/src/ge/host_kernels/transdata_kernel.cc b/src/ge/host_kernels/transdata_kernel.cc index 5fe44fe4..c5c9da6e 100644 --- a/src/ge/host_kernels/transdata_kernel.cc +++ b/src/ge/host_kernels/transdata_kernel.cc @@ -113,7 +113,7 @@ Status TransdataKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetData().data(); formats::TransResult trans_result; auto ret = formats::TransposeWithShapeCheck(src_data, src_shape, data_shape, src_data_type, perm_list, trans_result); if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to Transpose from %s to %s, shape %s to %s, perm_list %s, data type %s", + GELOGW("Failed to Transpose from %s to %s, shape %s to %s, perm_list %s, data type %s", TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(data_format).c_str(), formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(data_shape).c_str(), formats::ShapeToString(perm_list).c_str(), TypeUtils::DataTypeToSerialString(src_data_type).c_str()); diff --git a/src/ge/hybrid/common/npu_memory_allocator.cc b/src/ge/hybrid/common/npu_memory_allocator.cc index f432318b..1908725f 100644 --- a/src/ge/hybrid/common/npu_memory_allocator.cc +++ b/src/ge/hybrid/common/npu_memory_allocator.cc @@ -25,6 +25,11 @@ namespace hybrid { std::map> NpuMemoryAllocator::allocators_; std::mutex NpuMemoryAllocator::mu_; +AllocationAttr::AllocationAttr(int padding, void *try_reuse_addr) + : padding_(padding), try_reuse_addr_(try_reuse_addr) {} +AllocationAttr::AllocationAttr(int padding) : AllocationAttr(padding, nullptr) {} +AllocationAttr::AllocationAttr(void *try_reuse_addr) : AllocationAttr(0, try_reuse_addr) {} + NpuMemoryAllocator *NpuMemoryAllocator::GetAllocator() { int32_t device_id = 0; if (rtGetDevice(&device_id) != RT_ERROR_NONE) { @@ -38,15 +43,26 @@ NpuMemoryAllocator *NpuMemoryAllocator::GetAllocator() { NpuMemoryAllocator::NpuMemoryAllocator(uint32_t device_id) : device_id_(device_id) {} -void *NpuMemoryAllocator::Allocate(std::size_t size, void *try_reuse_addr) { - void *buffer = - MemManager::CachingInstance(RT_MEMORY_HBM).Malloc(size, reinterpret_cast(try_reuse_addr), device_id_); +void *NpuMemoryAllocator::Allocate(std::size_t size, AllocationAttr *attr) { + void *try_reuse_addr = nullptr; + size_t allocate_size = size; + if (attr != nullptr) { + try_reuse_addr = attr->try_reuse_addr_; + if (attr->padding_ != 0) { + // padding up to multiple of attr->padding, and add extra attr->padding_ + allocate_size = (size + 2 * attr->padding_ - 1) / attr->padding_ * attr->padding_; + GELOGD("Padding size %ld by %d. final size = %zu.", size, attr->padding_, allocate_size); + } + } + + void *buffer = MemManager::CachingInstance(RT_MEMORY_HBM) + .Malloc(allocate_size, reinterpret_cast(try_reuse_addr), device_id_); if (buffer == nullptr) { - GELOGE(MEMALLOC_FAILED, "Failed to malloc memory, device_id = %u, size = %zu", device_id_, size); + GELOGE(MEMALLOC_FAILED, "Failed to malloc memory, device_id = %u, size = %zu", device_id_, allocate_size); return nullptr; } - GELOGI("Allocating buffer of size %u successfully. device_id = %u, address = %p", size, device_id_, buffer); + GELOGI("Allocating buffer of size %zu successfully. device_id = %u, address = %p", allocate_size, device_id_, buffer); return buffer; } diff --git a/src/ge/hybrid/common/npu_memory_allocator.h b/src/ge/hybrid/common/npu_memory_allocator.h index 8cfeafa6..a9744540 100644 --- a/src/ge/hybrid/common/npu_memory_allocator.h +++ b/src/ge/hybrid/common/npu_memory_allocator.h @@ -26,16 +26,35 @@ namespace ge { namespace hybrid { +class AllocationAttr { + public: + explicit AllocationAttr(int padding); + explicit AllocationAttr(void *try_reuse_addr); + AllocationAttr(int padding, void *try_reuse_addr); + ~AllocationAttr() = default; + + private: + friend class NpuMemoryAllocator; + int padding_ = 0; + void *try_reuse_addr_ = nullptr; +}; + class NpuMemoryAllocator { public: ~NpuMemoryAllocator() = default; static NpuMemoryAllocator *GetAllocator(uint32_t device_id); static NpuMemoryAllocator *GetAllocator(); static void DestroyAllocator(); + static AllocationAttr *AttrWithDefaultPadding() { + static AllocationAttr attr(kDefaultPadding, nullptr); + return &attr; + } - void *Allocate(std::size_t size, void *try_reuse_addr = nullptr); + void *Allocate(std::size_t size, AllocationAttr *attr = nullptr); void Deallocate(void *data); + static constexpr int kDefaultPadding = 32; + private: explicit NpuMemoryAllocator(uint32_t device_id); uint32_t device_id_; diff --git a/src/ge/hybrid/common/tensor_value.cc b/src/ge/hybrid/common/tensor_value.cc index 9544e03a..929d3c87 100644 --- a/src/ge/hybrid/common/tensor_value.cc +++ b/src/ge/hybrid/common/tensor_value.cc @@ -24,7 +24,7 @@ namespace hybrid { TensorBuffer::TensorBuffer(NpuMemoryAllocator *allocator, void *buffer, size_t size) : allocator_(allocator), buffer_(buffer), size_(size) {} -std::unique_ptr TensorBuffer::Create(NpuMemoryAllocator *allocator, size_t size) { +std::unique_ptr TensorBuffer::Create(NpuMemoryAllocator *allocator, size_t size, AllocationAttr *attr) { void *buffer = nullptr; if (size == 0) { GELOGD("size is 0"); @@ -36,7 +36,7 @@ std::unique_ptr TensorBuffer::Create(NpuMemoryAllocator *allocator return nullptr; } - buffer = allocator->Allocate(size); + buffer = allocator->Allocate(size, attr); if (buffer == nullptr) { GELOGE(MEMALLOC_FAILED, "Failed to allocate memory. size = %zu", size); return nullptr; diff --git a/src/ge/hybrid/common/tensor_value.h b/src/ge/hybrid/common/tensor_value.h index 18e67534..db8df9e5 100644 --- a/src/ge/hybrid/common/tensor_value.h +++ b/src/ge/hybrid/common/tensor_value.h @@ -24,10 +24,12 @@ namespace ge { namespace hybrid { class NpuMemoryAllocator; +class AllocationAttr; class TensorBuffer { public: - static std::unique_ptr Create(NpuMemoryAllocator *allocator, size_t size); + static std::unique_ptr Create(NpuMemoryAllocator *allocator, size_t size, + AllocationAttr *attr = nullptr); static std::unique_ptr Create(void *buffer, size_t size); diff --git a/src/ge/hybrid/executor/hybrid_execution_context.cc b/src/ge/hybrid/executor/hybrid_execution_context.cc index bb8e0195..8144ba52 100644 --- a/src/ge/hybrid/executor/hybrid_execution_context.cc +++ b/src/ge/hybrid/executor/hybrid_execution_context.cc @@ -17,34 +17,5 @@ #include "hybrid_execution_context.h" namespace ge { -namespace hybrid { -NodeStatePtr GraphExecutionContext::GetOrCreateNodeState(const NodePtr &node) { - auto &node_state = node_states[node]; - if (node_state == nullptr) { - const NodeItem *node_item = model->GetNodeItem(node); - if (node_item == nullptr) { - return nullptr; - } - node_state.reset(new (std::nothrow) NodeState(*node_item)); - } - - return node_state; -} - -void GraphExecutionContext::OnError(Status error_code) { - GELOGE(error_code, "Error occurred while executing model"); - { - std::lock_guard lk(mu_); - this->status = error_code; - } - - compile_queue.Stop(); - execution_queue.Stop(); -} - -Status GraphExecutionContext::GetStatus() { - std::lock_guard lk(mu_); - return status; -} -} // namespace hybrid +namespace hybrid {} // namespace hybrid } // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/executor/hybrid_execution_context.h b/src/ge/hybrid/executor/hybrid_execution_context.h index 07a6fabf..96722fa9 100644 --- a/src/ge/hybrid/executor/hybrid_execution_context.h +++ b/src/ge/hybrid/executor/hybrid_execution_context.h @@ -20,6 +20,7 @@ #include #include #include "common/blocking_queue.h" +#include "framework/common/debug/ge_log.h" #include "hybrid/common/npu_memory_allocator.h" #include "hybrid/common/tensor_value.h" #include "hybrid/executor/hybrid_profiler.h" @@ -33,34 +34,26 @@ namespace hybrid { struct GraphExecutionContext { uint64_t session_id = 0; const HybridModel *model = nullptr; - NodeDoneManager cv_manager; - BlockingQueue compile_queue; - BlockingQueue execution_queue; - std::vector all_inputs; - std::vector all_outputs; - std::unordered_map node_states; rtStream_t stream = nullptr; + rtContext_t rt_context = nullptr; + rtContext_t rt_gen_context = nullptr; std::unique_ptr callback_manager; NpuMemoryAllocator *allocator = nullptr; mutable std::unique_ptr profiler; bool trace_enabled = false; - int profiling_level = 0; + long profiling_level = 0; bool dump_enabled = false; - Status status = SUCCESS; - std::mutex mu_; - - NodeStatePtr GetOrCreateNodeState(const NodePtr &node); - void OnError(Status status); - Status GetStatus(); + long iteration = 0; }; -#define RECORD_PROFILING_EVENT(context, event_type, fmt, category, node_name, ...) \ +#define RECORD_PROFILING_EVENT(context, evt_type, fmt, category, node_name, ...) \ do { \ if ((context)->profiler != nullptr) { \ if (node_name != nullptr) { \ - context->profiler->RecordEvent(event_type, "[%s] [%s] " fmt, node_name, category, ##__VA_ARGS__); \ + context->profiler->RecordEvent(evt_type, "tid:%lu [%s] [%s] " fmt, GetTid(), node_name, category, \ + ##__VA_ARGS__); \ } else { \ - context->profiler->RecordEvent(event_type, "[%s] " fmt, category, ##__VA_ARGS__); \ + context->profiler->RecordEvent(evt_type, "tid:%lu [%s] " fmt, GetTid(), category, ##__VA_ARGS__); \ } \ } \ } while (0) @@ -79,7 +72,6 @@ struct GraphExecutionContext { #define RECORD_CALLBACK_EVENT(context, name, fmt, ...) \ RECORD_PROFILING_EVENT((context), HybridProfiler::CALLBACK, fmt, "Callback", name, ##__VA_ARGS__) - } // namespace hybrid } // namespace ge #endif // GE_HYBRID_EXECUTOR_HYBRID_EXECUTION_CONTEXT_H_ diff --git a/src/ge/hybrid/executor/hybrid_model_async_executor.cc b/src/ge/hybrid/executor/hybrid_model_async_executor.cc index bd5d77f7..7f650017 100644 --- a/src/ge/hybrid/executor/hybrid_model_async_executor.cc +++ b/src/ge/hybrid/executor/hybrid_model_async_executor.cc @@ -77,19 +77,18 @@ Status HybridModelAsyncExecutor::Init() { GE_CHECK_NOTNULL(data_inputer_); GE_CHK_RT_RET(rtStreamCreate(&stream_, RT_STREAM_PRIORITY_DEFAULT)); - engine_ = std::unique_ptr(new (std::nothrow) HybridModelExecutor(model_, device_id_, stream_)); - GE_CHECK_NOTNULL(engine_); - GE_CHK_STATUS_RET(engine_->Init(), "Failed to init hybrid engine"); - + executor_ = std::unique_ptr(new (std::nothrow) HybridModelExecutor(model_, device_id_, stream_)); + GE_CHECK_NOTNULL(executor_); + GE_CHK_STATUS_RET(executor_->Init(), "Failed to init hybrid engine"); GE_CHK_STATUS_RET(InitInputTensors(), "Failed to init input tensors"); return SUCCESS; } Status HybridModelAsyncExecutor::PreRun(InputData ¤t_data) { GE_CHK_STATUS_RET(SyncVarData(), "Failed to sync var data"); - RECORD_MODEL_EXECUTION_EVENT(engine_->GetContext(), "[SyncVarData] End"); + RECORD_MODEL_EXECUTION_EVENT(executor_->GetContext(), "[SyncVarData] End"); GE_CHK_STATUS_RET(CopyInputData(current_data), "Failed to copy input data to model"); - RECORD_MODEL_EXECUTION_EVENT(engine_->GetContext(), "[CopyInputData] End"); + RECORD_MODEL_EXECUTION_EVENT(executor_->GetContext(), "[CopyInputData] End"); return SUCCESS; } @@ -119,21 +118,21 @@ Status HybridModelAsyncExecutor::RunInternal() { args.inputs[it.first] = it.second; } - RECORD_MODEL_EXECUTION_EVENT(engine_->GetContext(), "[RunInternal] [iteration = %d] Start", iterator_count_); + RECORD_MODEL_EXECUTION_EVENT(executor_->GetContext(), "[RunInternal] [iteration = %d] Start", iterator_count_); ret = PreRun(current_data); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - ret != SUCCESS, (void)HandleResult(ret, current_data.index, args.outputs, data_wrapper->GetOutput()); + ret != SUCCESS, (void)HandleResult(ret, current_data.index, args, data_wrapper->GetOutput()); CsaInteract::GetInstance().StoreInternalErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC); continue, "PreRun failed."); // [No need to check value] - ret = engine_->Execute(args); - ret = HandleResult(ret, current_data.index, args.outputs, data_wrapper->GetOutput()); + ret = executor_->Execute(args); + ret = HandleResult(ret, current_data.index, args, data_wrapper->GetOutput()); if (ret != SUCCESS) { CsaInteract::GetInstance().StoreInternalErrorCode(ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); continue; } - RECORD_MODEL_EXECUTION_EVENT(engine_->GetContext(), "[RunInternal] [iteration = %d] End", iterator_count_); + RECORD_MODEL_EXECUTION_EVENT(executor_->GetContext(), "[RunInternal] [iteration = %d] End", iterator_count_); iterator_count_++; GELOGI("run iterator count is %lu", iterator_count_); } @@ -143,8 +142,8 @@ Status HybridModelAsyncExecutor::RunInternal() { return SUCCESS; } -Status HybridModelAsyncExecutor::HandleResult(Status exec_ret, uint32_t data_id, - const std::vector &output_tensors, OutputData *output_data) { +Status HybridModelAsyncExecutor::HandleResult(Status exec_ret, uint32_t data_id, HybridModelExecutor::ExecuteArgs &args, + OutputData *output_data) { GELOGD("Start to handle result. model id = %u, data index = %u, execution ret = %u", model_id_, data_id, exec_ret); std::vector output_tensor_info_list; if (exec_ret == END_OF_SEQUENCE) { @@ -158,7 +157,7 @@ Status HybridModelAsyncExecutor::HandleResult(Status exec_ret, uint32_t data_id, } GE_CHECK_NOTNULL(output_data); - auto ret = CopyOutputs(output_tensors, output_data, output_tensor_info_list); + auto ret = CopyOutputs(args, output_data, output_tensor_info_list); if (ret != SUCCESS) { OnComputeDone(data_id, INTERNAL_ERROR, output_tensor_info_list); return INTERNAL_ERROR; @@ -215,9 +214,8 @@ Status HybridModelAsyncExecutor::CopyInputData(const InputData ¤t_data) { Status HybridModelAsyncExecutor::InitInputTensors() { auto allocator = NpuMemoryAllocator::GetAllocator(device_id_); GE_CHECK_NOTNULL(allocator); - for (const auto &it : model_->input_nodes_) { - auto input_index = it.first; - auto input_node = it.second; + int input_index = 0; + for (const auto &input_node : model_->GetRootGraphItem()->GetInputNodes()) { GELOGD("Init input[%u], node = %s", input_index, input_node->NodeName().c_str()); auto output_desc = input_node->op_desc->GetOutputDescPtr(kDataOutputIndex); GE_CHECK_NOTNULL(output_desc); @@ -235,6 +233,7 @@ Status HybridModelAsyncExecutor::InitInputTensors() { TensorValue tensor(shared_ptr(buffer.release())); tensor.SetName("Input_" + input_node->NodeName()); input_tensors_.emplace(input_index, tensor); + input_index += 1; } return SUCCESS; @@ -250,35 +249,33 @@ Status HybridModelAsyncExecutor::OnComputeDone(uint32_t data_index, uint32_t res return result_code; } -Status HybridModelAsyncExecutor::CopyOutputs(const std::vector &output_tensors, OutputData *output_data, +Status HybridModelAsyncExecutor::CopyOutputs(HybridModelExecutor::ExecuteArgs &args, OutputData *output_data, std::vector &outputs) { // copy output data from op to designated position - NodeItem *net_output_node = model_->net_output_node_; - GE_CHECK_NOTNULL(net_output_node); - auto all_input_desc = net_output_node->op_desc->GetAllInputsDescPtr(); - - if (all_input_desc.size() != output_tensors.size()) { + std::vector &output_tensor_desc_list = args.output_desc; + std::vector &output_tensors = args.outputs; + if (output_tensor_desc_list.size() != output_tensors.size()) { GELOGE(INTERNAL_ERROR, "Output sizes mismatch. From op_desc = %zu, and from output tensors = %zu", - all_input_desc.size(), output_tensors.size()); + output_tensor_desc_list.size(), output_tensors.size()); return INTERNAL_ERROR; } - GELOGD("Number of outputs = %zu", all_input_desc.size()); + GELOGD("Number of outputs = %zu", output_tensor_desc_list.size()); for (size_t i = 0; i < output_tensors.size(); ++i) { GELOGD("Start to process output[%zu]", i); auto &output_tensor = output_tensors[i]; - auto &tensor_desc = all_input_desc.at(i); + auto &tensor_desc = output_tensor_desc_list.at(i); GE_CHECK_NOTNULL(tensor_desc); int64_t output_size = -1; - GE_CHK_GRAPH_STATUS_RET(TensorUtils::CalcTensorMemSize(tensor_desc->MutableShape(), tensor_desc->GetFormat(), + GE_CHK_GRAPH_STATUS_RET(TensorUtils::CalcTensorMemSize(tensor_desc->GetShape(), tensor_desc->GetFormat(), tensor_desc->GetDataType(), output_size), "Failed to calc tensor size for output[%zu]. shape = [%s], type = %s, format = %s", i, - tensor_desc->MutableShape().ToString().c_str(), + tensor_desc->GetShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()).c_str(), TypeUtils::FormatToSerialString(tensor_desc->GetFormat()).c_str()); GELOGD("Got tensor size for output[%zu] successfully. shape = [%s], type = %s, format = %s, size = %ld", i, - tensor_desc->MutableShape().ToString().c_str(), + tensor_desc->GetShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()).c_str(), TypeUtils::FormatToSerialString(tensor_desc->GetFormat()).c_str(), output_size); @@ -286,7 +283,7 @@ Status HybridModelAsyncExecutor::CopyOutputs(const std::vector &out GE_CHECK_LE(output_size, UINT32_MAX); if (output_tensor.GetSize() < static_cast(output_size)) { GELOGE(INTERNAL_ERROR, "output[%zu] tensor size(%zu) is not enough for output shape [%s]", i, - output_tensor.GetSize(), tensor_desc->MutableShape().ToString().c_str()); + output_tensor.GetSize(), tensor_desc->GetShape().ToString().c_str()); return INTERNAL_ERROR; } @@ -302,7 +299,7 @@ Status HybridModelAsyncExecutor::CopyOutputs(const std::vector &out output.data = std::move(data_buf); output_data->blobs.emplace_back(data_buf.get(), static_cast(output_size), false); } else { - GELOGW("Output[%zu] is empty. shape = [%s]", i, tensor_desc->MutableShape().ToString().c_str()); + GELOGW("Output[%zu] is empty. shape = [%s]", i, tensor_desc->GetShape().ToString().c_str()); output.data = nullptr; output_data->blobs.emplace_back(nullptr, 0U, false); } @@ -310,7 +307,53 @@ Status HybridModelAsyncExecutor::CopyOutputs(const std::vector &out outputs.emplace_back(std::move(output)); GELOGD("Output[%zu] added, type = %s, shape = [%s], size = %ld", i, TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()).c_str(), - tensor_desc->MutableShape().ToString().c_str(), output_size); + tensor_desc->GetShape().ToString().c_str(), output_size); + } + + return SUCCESS; +} + +Status HybridModelAsyncExecutor::Execute(const vector &inputs, vector &outputs) { + GELOGD("Start to execute model."); + // prepare inputs + InputData input_data; + for (auto &tensor : inputs) { + DataBuffer buffer; + buffer.data = const_cast(tensor.GetData().GetData()); + buffer.length = tensor.GetData().size(); + input_data.blobs.emplace_back(buffer); + } + GE_CHK_STATUS_RET(CopyInputData(input_data), "Failed to copy input data to model"); + GELOGD("Done copying input data successfully."); + + HybridModelExecutor::ExecuteArgs args; + args.inputs.resize(input_tensors_.size()); + args.input_desc.resize(input_tensors_.size()); + for (auto &it : input_tensors_) { + args.inputs[it.first] = it.second; + args.input_desc[it.first] = MakeShared(inputs[it.first].GetTensorDesc()); + } + + GE_CHK_STATUS_RET(executor_->Execute(args), "Failed to execute model."); + + std::vector output_tensor_info_list; + OutputData output_data; + GE_CHK_STATUS_RET(CopyOutputs(args, &output_data, output_tensor_info_list), "Failed to copy outputs."); + GELOGD("Done copying output data successfully. output count = %zu", output_tensor_info_list.size()); + + int out_index = 0; + outputs.resize(output_tensor_info_list.size()); + for (auto &out_tensor_info : output_tensor_info_list) { + auto &ge_tensor = outputs[out_index]; + if (out_tensor_info.length > 0) { + GE_CHK_GRAPH_STATUS_RET(ge_tensor.SetData(out_tensor_info.data.get(), out_tensor_info.length), + "Failed to set output[%d].", out_index); + } + + ge_tensor.MutableTensorDesc() = *args.output_desc[out_index]; + GELOGD("Set output[%d], tensor size = %ld, shape = [%s]", out_index, out_tensor_info.length, + ge_tensor.MutableTensorDesc().MutableShape().ToString().c_str()); + ++out_index; } return SUCCESS; diff --git a/src/ge/hybrid/executor/hybrid_model_async_executor.h b/src/ge/hybrid/executor/hybrid_model_async_executor.h index cb440ba7..195f79a9 100644 --- a/src/ge/hybrid/executor/hybrid_model_async_executor.h +++ b/src/ge/hybrid/executor/hybrid_model_async_executor.h @@ -35,6 +35,8 @@ class HybridModelAsyncExecutor { Status Init(); + Status Execute(const vector &inputs, vector &outputs); + Status Start(const std::shared_ptr &listener); void SetDeviceId(uint32_t device_id); @@ -52,10 +54,10 @@ class HybridModelAsyncExecutor { Status SyncVarData(); - Status HandleResult(Status exec_ret, uint32_t data_id, const std::vector &output_tensors, + Status HandleResult(Status exec_ret, uint32_t data_id, HybridModelExecutor::ExecuteArgs &args, OutputData *output_data); - Status CopyOutputs(const std::vector &output_tensors, OutputData *output_data, + Status CopyOutputs(HybridModelExecutor::ExecuteArgs &args, OutputData *output_data, std::vector &outputs); Status OnComputeDone(uint32_t data_index, uint32_t result_code, std::vector &outputs); @@ -70,7 +72,7 @@ class HybridModelAsyncExecutor { uint32_t model_id_ = 0U; std::atomic_bool run_flag_; std::unique_ptr data_inputer_; - std::unique_ptr engine_; + std::unique_ptr executor_; std::future future_; uint64_t iterator_count_ = 0; diff --git a/src/ge/hybrid/executor/hybrid_model_executor.cc b/src/ge/hybrid/executor/hybrid_model_executor.cc index 856b4483..d62d7be3 100644 --- a/src/ge/hybrid/executor/hybrid_model_executor.cc +++ b/src/ge/hybrid/executor/hybrid_model_executor.cc @@ -26,17 +26,17 @@ HybridModelExecutor::HybridModelExecutor(HybridModel *model, uint32_t device_id, Status HybridModelExecutor::Init() { GELOGD("Start to init HybridGraphEngine."); GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); - infer_shape_engine_.reset(new (std::nothrow) ShapeInferenceEngine(&context_)); - compile_engine_.reset(new (std::nothrow) TaskCompileEngine(&context_)); - execute_engine_.reset(new (std::nothrow) ExecutionEngine(&context_, context_.callback_manager.get())); - GE_CHK_STATUS_RET_NOLOG(compile_engine_->Init()); GELOGD("HybridGraphEngine initialized successfully."); return SUCCESS; } Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { GELOGD("Start to execute model."); - auto ret = ExecuteGraphInternal(args); + auto root_graph_item = model_->GetRootGraphItem(); + GE_CHECK_NOTNULL(root_graph_item); + + SubgraphExecutor executor(model_->GetRootGraphItem(), &context_); + auto ret = ExecuteGraphInternal(executor, args); Cleanup(); RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End"); GE_CHK_STATUS_RET(ret, "Failed to execute model"); @@ -46,24 +46,22 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { context_.profiler->Reset(); } + context_.iteration += 1; return SUCCESS; } -Status HybridModelExecutor::ExecuteGraphInternal(HybridModelExecutor::ExecuteArgs &args) { +Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, HybridModelExecutor::ExecuteArgs &args) { RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start"); GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_)); RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End"); - GE_CHK_STATUS_RET_NOLOG(InitInputsAndOutputs(args, context_)); - RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitInputsAndOutputs] End"); - GE_CHK_STATUS_RET_NOLOG(compile_engine_->Start(pool_)); - RECORD_MODEL_EXECUTION_EVENT(&context_, "[CompileProcess] Started"); - GE_CHK_STATUS_RET_NOLOG(infer_shape_engine_->Start(pool_)); - RECORD_MODEL_EXECUTION_EVENT(&context_, "[InferShapeProcess] Started"); - GE_CHK_STATUS_RET(execute_engine_->Start(), "Run execution engine failed."); - RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecutionProcess] End"); - GE_CHK_STATUS_RET_NOLOG(Synchronize()); + + GE_CHK_STATUS_RET(executor.ExecuteAsync(args.inputs, args.input_desc), "Failed to execute partitioned call."); + RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End"); + + GE_CHK_STATUS_RET(executor.Synchronize(), "Failed to sync root graph."); RECORD_MODEL_EXECUTION_EVENT(&context_, "[Synchronize] End"); - GE_CHK_STATUS_RET_NOLOG(GetOutput(args)); + + GE_CHK_STATUS_RET(executor.GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End"); return SUCCESS; } @@ -71,18 +69,16 @@ Status HybridModelExecutor::ExecuteGraphInternal(HybridModelExecutor::ExecuteArg Status HybridModelExecutor::Cleanup() { GELOGD("Start to cleanup."); context_.callback_manager->Destroy(); - context_.cv_manager.Reset(); - context_.node_states.clear(); - context_.all_inputs.clear(); - context_.all_outputs.clear(); - context_.compile_queue.Clear(); - context_.execution_queue.Clear(); RuntimeInferenceContext::DestroyContext(to_string(context_.session_id)); GELOGD("Cleanup successfully."); return SUCCESS; } Status HybridModelExecutor::InitExecutionContext() { + GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); + GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); + GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); + context_.stream = stream_; context_.model = model_; context_.session_id = ::ge::GetContext().SessionId(); @@ -94,78 +90,15 @@ Status HybridModelExecutor::InitExecutionContext() { if (IsLogEnable(GE_MODULE_NAME, DLOG_DEBUG)) { context_.trace_enabled = true; } - return SUCCESS; } Status HybridModelExecutor::ResetExecutionContext(GraphExecutionContext &context) { - auto &model = *context.model; - context.all_inputs.resize(model.TotalInputs()); - context.all_outputs.resize(model.TotalOutputs()); - context.compile_queue.Restart(); - context.execution_queue.Restart(); GE_CHK_STATUS_RET_NOLOG(context.callback_manager->Init()); - - for (auto const_node : model.GetConstNodes()) { - auto weight_tensor = model.GetWeight(const_node); - GE_CHECK_NOTNULL(weight_tensor); - for (auto &dst_aid_and_nid : const_node->outputs[0]) { - auto *dst_node_item = dst_aid_and_nid.second; - auto input_offset = dst_node_item->input_start + dst_aid_and_nid.first; - context.all_inputs[input_offset] = *weight_tensor; - } - } - string ctx_id = std::to_string(context.session_id); RuntimeInferenceContext::DestroyContext(ctx_id); GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::CreateContext(ctx_id), "Failed to Destroy RuntimeInferenceContext"); return SUCCESS; } - -Status HybridModelExecutor::InitInputsAndOutputs(HybridModelExecutor::ExecuteArgs &args, - GraphExecutionContext &context) { - for (const auto &it : model_->GetInputNodes()) { - uint32_t input_index = it.first; - if (input_index >= args.inputs.size()) { - GELOGE(PARAM_INVALID, "Not enough inputs. NumInputs = %zu, but input index = %u", args.inputs.size(), - input_index); - return PARAM_INVALID; - } - - auto node_item = it.second; - auto &input_tensor = args.inputs[input_index]; - GELOGD("Set input tensor[%u] to inputs with index = %d, addr = %p, size = %zu", input_index, node_item->input_start, - input_tensor.GetData(), input_tensor.GetSize()); - context.all_inputs[node_item->input_start] = input_tensor; - } - - for (size_t i = 0; i < model_->GetOutputOffsets().size(); ++i) { - auto offset = model_->GetOutputOffsets()[i]; - if (i < args.outputs.size() && args.outputs[i].GetData() != nullptr) { - GELOGD("Use user allocated output memory. output index = %zu, output offset = %d", i, offset); - context.all_outputs[offset] = args.outputs[i]; - } - } - - return SUCCESS; -} - -Status HybridModelExecutor::Synchronize() { - GE_CHK_RT_RET(rtStreamSynchronize(stream_)); - return SUCCESS; -} - -Status HybridModelExecutor::GetOutput(HybridModelExecutor::ExecuteArgs &args) { - auto &net_output_input_offsets = model_->GetNetOutputInputOffsets(); - auto num_outputs = net_output_input_offsets.size(); - args.outputs.resize(num_outputs); - for (size_t i = 0; i < num_outputs; ++i) { - auto offset = net_output_input_offsets[i]; - GELOGI("Get output[%zu] from offset %d", i, offset); - args.outputs[i] = context_.all_inputs[offset]; - } - - return SUCCESS; -} } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/executor/hybrid_model_executor.h b/src/ge/hybrid/executor/hybrid_model_executor.h index 2bda6331..9996dbe0 100644 --- a/src/ge/hybrid/executor/hybrid_model_executor.h +++ b/src/ge/hybrid/executor/hybrid_model_executor.h @@ -20,9 +20,7 @@ #include "graph/load/new_model_manager/data_inputer.h" #include "hybrid/executor/hybrid_execution_context.h" #include "hybrid/executor/rt_callback_manager.h" -#include "hybrid/executor/worker/execution_engine.h" -#include "hybrid/executor/worker/shape_inference_engine.h" -#include "hybrid/executor/worker/task_compile_engine.h" +#include "hybrid/executor/subgraph_executor.h" namespace ge { namespace hybrid { @@ -30,7 +28,9 @@ class HybridModelExecutor { public: struct ExecuteArgs { std::vector inputs; + std::vector input_desc; std::vector outputs; + std::vector output_desc; }; HybridModelExecutor(HybridModel *model, uint32_t device_id, rtStream_t stream); @@ -44,24 +44,15 @@ class HybridModelExecutor { Status Execute(ExecuteArgs &args); private: - Status ExecuteGraphInternal(ExecuteArgs &args); + Status ExecuteGraphInternal(SubgraphExecutor &executor, ExecuteArgs &args); Status Cleanup(); Status InitExecutionContext(); static Status ResetExecutionContext(GraphExecutionContext &context); - Status InitInputsAndOutputs(ExecuteArgs &args, GraphExecutionContext &context); - Status GetOutput(ExecuteArgs &args); - - Status Synchronize(); - - ThreadPool pool_; HybridModel *model_; uint32_t device_id_; rtStream_t stream_; GraphExecutionContext context_; - std::unique_ptr infer_shape_engine_; - std::unique_ptr compile_engine_; - std::unique_ptr execute_engine_; }; } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/executor/hybrid_profiler.cc b/src/ge/hybrid/executor/hybrid_profiler.cc index 1081a144..4c70e043 100644 --- a/src/ge/hybrid/executor/hybrid_profiler.cc +++ b/src/ge/hybrid/executor/hybrid_profiler.cc @@ -59,11 +59,10 @@ void HybridProfiler::Dump(std::ostream &output_stream) { auto first_evt = events_[0]; auto start = first_evt.timestamp; - output_stream << "Start " << first_evt.desc << std::endl; std::vector prev_timestamps; prev_timestamps.resize(kMaxEventTypes, start); - for (int i = 1; i < counter_; ++i) { + for (int i = 0; i < counter_; ++i) { auto &evt = events_[i]; auto elapsed = std::chrono::duration_cast(evt.timestamp - start).count(); auto &prev_ts = prev_timestamps[evt.event_type]; diff --git a/src/ge/hybrid/executor/node_done_manager.cc b/src/ge/hybrid/executor/node_done_manager.cc index dfeddb5b..3ec45339 100644 --- a/src/ge/hybrid/executor/node_done_manager.cc +++ b/src/ge/hybrid/executor/node_done_manager.cc @@ -15,35 +15,49 @@ */ #include "hybrid/executor/node_done_manager.h" +#include #include "framework/common/debug/ge_log.h" namespace ge { namespace hybrid { +namespace { +constexpr int kDefaultWaitTimeoutInSec = 10; +} bool NodeDoneManager::Cond::Await() { - std::unique_lock lk(mu_); - cv_.wait(lk, [&]() { return is_released_ || is_cancelled_; }); + std::unique_lock lk(cond_mu_); + if (!cv_.wait_for(lk, std::chrono::seconds(kDefaultWaitTimeoutInSec), + [&]() { return is_released_ || is_cancelled_; })) { + GELOGE(INTERNAL_ERROR, "Wait timed out."); + return false; + } + return is_released_; } void NodeDoneManager::Cond::Release() { - std::unique_lock lk(mu_); + std::unique_lock lk(cond_mu_); is_released_ = true; cv_.notify_all(); } void NodeDoneManager::Cond::Cancel() { - std::unique_lock lk(mu_); + std::unique_lock lk(cond_mu_); is_cancelled_ = true; cv_.notify_all(); } bool NodeDoneManager::Cond::IsRelease() { - std::unique_lock lk(mu_); + std::unique_lock lk(cond_mu_); return is_released_; } NodeDoneManager::Cond *NodeDoneManager::GetSubject(const NodePtr &node) { std::lock_guard lk(mu_); + if (destroyed_) { + GELOGD("Already destroyed."); + return nullptr; + } + auto it = subjects_.find(node); if (it == subjects_.end()) { return &subjects_[node]; @@ -52,8 +66,10 @@ NodeDoneManager::Cond *NodeDoneManager::GetSubject(const NodePtr &node) { return &it->second; } -void NodeDoneManager::Reset() { +void NodeDoneManager::Destroy() { + GELOGD("Start to reset NodeDoneManager."); std::lock_guard lk(mu_); + GELOGD("Cond size = %zu.", subjects_.size()); for (auto &sub : subjects_) { if (!sub.second.IsRelease()) { sub.second.Cancel(); @@ -62,15 +78,24 @@ void NodeDoneManager::Reset() { } subjects_.clear(); + destroyed_ = true; + GELOGD("Done resetting NodeDoneManager successfully."); } void NodeDoneManager::NodeDone(const NodePtr &node) { - GetSubject(node)->Release(); - GELOGD("[%s] Node released.", node->GetName().c_str()); + auto sub = GetSubject(node); + if (sub != nullptr) { + sub->Release(); + GELOGD("[%s] Node released.", node->GetName().c_str()); + } } bool NodeDoneManager::Await(const NodePtr &node) { auto sub = GetSubject(node); + if (sub == nullptr) { + return false; + } + GELOGD("[%s] Await start. is_released = %s", node->GetName().c_str(), sub->IsRelease() ? "true" : "false"); bool ret = sub->Await(); GELOGD("[%s] Await ended. is_released = %s", node->GetName().c_str(), sub->IsRelease() ? "true" : "false"); diff --git a/src/ge/hybrid/executor/node_done_manager.h b/src/ge/hybrid/executor/node_done_manager.h index ccf263d1..f1fdfbec 100644 --- a/src/ge/hybrid/executor/node_done_manager.h +++ b/src/ge/hybrid/executor/node_done_manager.h @@ -31,7 +31,7 @@ class NodeDoneManager { bool Await(const NodePtr &node); - void Reset(); + void Destroy(); private: class Cond { @@ -42,7 +42,7 @@ class NodeDoneManager { bool Await(); private: - std::mutex mu_; + std::mutex cond_mu_; std::condition_variable cv_; bool is_released_ = false; bool is_cancelled_ = false; @@ -51,6 +51,7 @@ class NodeDoneManager { Cond *GetSubject(const NodePtr &node); std::mutex mu_; std::unordered_map subjects_; + bool destroyed_ = false; }; } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/executor/node_state.cc b/src/ge/hybrid/executor/node_state.cc index 6895f158..5368597d 100644 --- a/src/ge/hybrid/executor/node_state.cc +++ b/src/ge/hybrid/executor/node_state.cc @@ -15,13 +15,133 @@ */ #include "hybrid/executor/node_state.h" +#include +#include "framework/common/debug/log.h" #include "graph/compute_graph.h" +#include "hybrid_execution_context.h" +#include "subgraph_context.h" namespace ge { namespace hybrid { -NodeState::NodeState(const NodeItem &node_item) { - this->node_item = &node_item; - this->op_desc = node_item.node->GetOpDesc(); +namespace { +constexpr auto kMaxWaitTimeInSec = 10; +} +ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item(node_item) { + this->num_pending_shapes_ = node_item.num_inputs - node_item.num_static_input_shapes; + GELOGD("[%s] ShapeInferenceState created, pending shape count = %d", node_item.NodeName().c_str(), + this->num_pending_shapes_); +} + +void ShapeInferenceState::UpdateInputShape(uint32_t idx, const GeShape &ori_shape, const GeShape &shape) { + if (node_item.is_input_shape_static[idx]) { + GELOGD("[%s] Trying to update static shape, idx = %u. old shape = [%s], new shape = [%s]", + node_item.NodeName().c_str(), idx, node_item.op_desc->MutableInputDesc(idx)->GetShape().ToString().c_str(), + shape.ToString().c_str()); + return; + } + + GELOGD("[%s] Update input shape [%u] with Shape: [%s] and OriginalShape: [%s]", node_item.NodeName().c_str(), idx, + shape.ToString().c_str(), ori_shape.ToString().c_str()); + + std::lock_guard lk(mu_); + node_item.op_desc->MutableInputDesc(idx)->SetShape(shape); + node_item.op_desc->MutableInputDesc(idx)->SetOriginShape(ori_shape); + if (--num_pending_shapes_ == 0) { + ready_cv_.notify_all(); + } +} + +void ShapeInferenceState::UpdateInputShapeFuture(uint32_t idx, ShapeFuture &&future) { + if (node_item.is_input_shape_static[idx]) { + GELOGD("[%s] Trying to update constant shape, idx = %u", node_item.NodeName().c_str(), idx); + return; + } + + GELOGD("[%s] Update input shape [%u] with ShapeFuture.", node_item.NodeName().c_str(), idx); + std::lock_guard lk(mu_); + shape_futures.emplace_back(idx, std::move(future)); + if (--num_pending_shapes_ == 0) { + ready_cv_.notify_all(); + } +} + +Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &context) { + std::unique_lock lk(mu_); + if (num_pending_shapes_ > 0) { + GELOGD("[%s] Await pending shape or shape future start.", node_item.NodeName().c_str()); + if (!ready_cv_.wait_for(lk, std::chrono::seconds(kMaxWaitTimeInSec), [&]() { return num_pending_shapes_ == 0; })) { + GELOGE(INTERNAL_ERROR, "[%s] Wait for shape timeout.", node_item.NodeName().c_str()); + return INTERNAL_ERROR; + } + GELOGD("[%s] Await pending shape or shape future end.", node_item.NodeName().c_str()); + } + + for (auto &p : shape_futures) { + auto idx = p.first; + auto &future = p.second; + GeShape shape; + GeShape ori_shape; + RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] Start", idx); + GE_CHK_STATUS_RET(future.Get(ori_shape, shape), "[%s] Get shape failed. index = %u", node_item.NodeName().c_str(), + idx); + RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] End", idx); + + GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s]", node_item.NodeName().c_str(), idx, + shape.ToString().c_str(), ori_shape.ToString().c_str()); + node_item.op_desc->MutableInputDesc(idx)->SetShape(std::move(shape)); + node_item.op_desc->MutableInputDesc(idx)->SetOriginShape(ori_shape); + } + + return SUCCESS; +} + +ShapeFuture::ShapeFuture(NodePtr src_node, uint32_t src_index, SubgraphContext *subgraph_context) + : src_node_(std::move(src_node)), src_index_(src_index), subgraph_context_(subgraph_context) {} + +NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context) + : node_item_(&node_item), shape_inference_state_(node_item), subgraph_context_(subgraph_context) { + this->op_desc_ = node_item.node->GetOpDesc(); +} + +Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { + for (auto &src_node : node_item_->dependents_for_execution) { + GELOGI("[%s] Start to wait for data dependent node: [%s]", node_item_->NodeName().c_str(), + src_node->GetName().c_str()); + RECORD_EXECUTION_EVENT(&context, node_item_->NodeName().c_str(), "[AwaitNodeDone] [%s] Start", + src_node->GetName().c_str()); + if (!subgraph_context_->Await(src_node)) { + GELOGE(INTERNAL_ERROR, "[%s] Await node [%s] failed.", GetName().c_str(), src_node->GetName().c_str()); + return INTERNAL_ERROR; + } + + RECORD_EXECUTION_EVENT(&context, node_item_->NodeName().c_str(), "[AwaitNodeDone] [%s] End", + src_node->GetName().c_str()); + GELOGI("[%s] Done waiting node.", src_node->GetName().c_str()); + } + + return SUCCESS; +} + +Status NodeState::WaitForPrepareDone() { + if (prepare_future_.valid()) { + GELOGD("[%s] Start to wait for prepare future.", GetName().c_str()); + GE_CHK_STATUS_RET(prepare_future_.get(), "[%s] PreRun failed.", GetName().c_str()); + } + + return SUCCESS; +} + +Status ShapeFuture::Get(GeShape &ori_shape, GeShape &shape) { + GELOGI("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); + if (!subgraph_context_->Await(src_node_)) { + GELOGE(INTERNAL_ERROR, "cancelled"); + return INTERNAL_ERROR; + } + + shape = src_node_->GetOpDesc()->MutableOutputDesc(src_index_)->MutableShape(); + ori_shape = src_node_->GetOpDesc()->MutableOutputDesc(src_index_)->GetOriginShape(); + GELOGI("Get shape from %s:%u. shape = [%s]", src_node_->GetName().c_str(), src_index_, shape.ToString().c_str()); + return SUCCESS; } } // namespace hybrid } // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/executor/node_state.h b/src/ge/hybrid/executor/node_state.h index b2811bcb..73e0f75c 100644 --- a/src/ge/hybrid/executor/node_state.h +++ b/src/ge/hybrid/executor/node_state.h @@ -17,38 +17,83 @@ #ifndef GE_HYBRID_EXECUTOR_NODE_STATE_H_ #define GE_HYBRID_EXECUTOR_NODE_STATE_H_ +#include +#include +#include +#include "external/ge/ge_api_error_codes.h" #include "hybrid/model/node_item.h" +#include "node_done_manager.h" namespace ge { namespace hybrid { - class NodeTask; +class GraphExecutionContext; +class SubgraphContext; -// 存放一些会å˜åŒ–的信æ¯... -class NodeState { +class ShapeFuture { public: - NodeState() = default; - explicit NodeState(const NodeItem &node_item); + ShapeFuture(NodePtr src_node, uint32_t src_index, SubgraphContext *subgraph_context); + ~ShapeFuture() = default; + Status Get(GeShape &ori_shape, GeShape &shape); + + private: + NodePtr src_node_; + uint32_t src_index_; + SubgraphContext *subgraph_context_; +}; + +struct ShapeInferenceState { + explicit ShapeInferenceState(const NodeItem &node_item); + + void UpdateInputShape(uint32_t idx, const GeShape &ori_shape, const GeShape &shape); + + void UpdateInputShapeFuture(uint32_t idx, ShapeFuture &&future); + + Status AwaitShapesReady(const GraphExecutionContext &context); + + const NodeItem &node_item; + + private: + std::vector> shape_futures; + int num_pending_shapes_ = 0; + std::condition_variable ready_cv_; + std::mutex mu_; +}; + +// saving sth. dynamic during execution +struct NodeState { + public: + NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context); ~NodeState() = default; - inline int NodeId() const { return node_item->node_id; } + OpDesc *GetOpDesc() const { return op_desc_.get(); } + + inline const NodeItem *GetNodeItem() const { return node_item_; } + + inline const string &GetName() const { return node_item_->NodeName(); } + + inline const string &GetType() const { return node_item_->NodeType(); } - inline Node *GetNode() const { return node_item->node.get(); } + ShapeInferenceState &GetShapeInferenceState() { return shape_inference_state_; } - OpDesc *GetOpDesc() const { return op_desc.get(); } + const shared_ptr &GetKernelTask() const { return kernel_task_; } - inline const NodeItem *GetNodeItem() const { return node_item; } + void SetKernelTask(const shared_ptr &kernel_task) { kernel_task_ = kernel_task; } - inline const string &GetName() const { return node_item->NodeName(); } + Status WaitForPrepareDone(); - inline const string &GetType() const { return node_item->NodeType(); } + void SetPrepareFuture(std::future &&prepare_future) { this->prepare_future_ = std::move(prepare_future); } - // private: - const NodeItem *node_item = nullptr; - std::shared_ptr kernel_task = nullptr; + Status AwaitInputTensors(GraphExecutionContext &context) const; - bool is_compiled = false; - OpDescPtr op_desc; + private: + const NodeItem *node_item_ = nullptr; + std::shared_ptr kernel_task_ = nullptr; + std::future prepare_future_; + OpDescPtr op_desc_; + ShapeInferenceState shape_inference_state_; + SubgraphContext *subgraph_context_; + std::mutex mu_; }; using NodeStatePtr = std::shared_ptr; diff --git a/src/ge/hybrid/executor/rt_callback_manager.cc b/src/ge/hybrid/executor/rt_callback_manager.cc index 6be8da31..c1c98f73 100644 --- a/src/ge/hybrid/executor/rt_callback_manager.cc +++ b/src/ge/hybrid/executor/rt_callback_manager.cc @@ -42,7 +42,6 @@ Status CallbackManager::Init() { rtContext_t ctx = nullptr; GE_CHK_RT_RET(rtCtxGetCurrent(&ctx)); ret_future_ = std::async([&](rtContext_t context) -> Status { return CallbackProcess(context); }, ctx); - if (!ret_future_.valid()) { GELOGE(INTERNAL_ERROR, "Failed to init callback manager."); return INTERNAL_ERROR; diff --git a/src/ge/hybrid/executor/subgraph_context.cc b/src/ge/hybrid/executor/subgraph_context.cc new file mode 100644 index 00000000..d5d9075d --- /dev/null +++ b/src/ge/hybrid/executor/subgraph_context.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "subgraph_context.h" + +#include "common/debug/log.h" + +namespace ge { +namespace hybrid { +SubgraphContext::SubgraphContext(const GraphItem *graph_item) : graph_item_(graph_item) {} + +Status SubgraphContext::Init() { + GE_CHECK_NOTNULL(graph_item_); + GELOGD("[%s] Start to init subgraph context. total inputs = %d, total outputs = %d", graph_item_->GetName().c_str(), + graph_item_->TotalInputs(), graph_item_->TotalOutputs()); + all_inputs_.resize(graph_item_->TotalInputs()); + all_outputs_.resize(graph_item_->TotalOutputs()); + + return SUCCESS; +} + +NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { + std::lock_guard lk(mu_); + auto &node_state = node_states_[node_item]; + if (node_state == nullptr) { + node_state.reset(new (std::nothrow) NodeState(*node_item, this)); + } + + return node_state; +} + +Status SubgraphContext::SetInput(int index, const TensorValue &tensor) { + if (static_cast(index) >= all_inputs_.size()) { + GELOGE(INTERNAL_ERROR, "output index output range. all input num = %zu, input index = %d", all_inputs_.size(), + index); + return INTERNAL_ERROR; + } + + all_inputs_[index] = tensor; + return SUCCESS; +} + +Status SubgraphContext::SetInput(const NodeItem &node_item, int input_index, const TensorValue &tensor) { + auto index = node_item.input_start + input_index; + return SetInput(index, tensor); +} + +Status SubgraphContext::SetOutput(const NodeItem &node_item, int output_index, const TensorValue &tensor) { + auto index = node_item.output_start + output_index; + if (output_index >= node_item.num_outputs || static_cast(index) >= all_outputs_.size()) { + GELOGE(INTERNAL_ERROR, "output index output range. all output num = %zu, node_item = %s, output index = %d", + all_outputs_.size(), node_item.DebugString().c_str(), output_index); + return INTERNAL_ERROR; + } + + all_outputs_[index] = tensor; + return SUCCESS; +} + +Status SubgraphContext::GetInput(int index, TensorValue &tensor) { + GE_CHECK_GE(all_inputs_.size(), index + 1U); + tensor = all_inputs_[index]; + return SUCCESS; +} + +Status SubgraphContext::GetOutputs(std::vector &outputs) { + if (graph_item_->IsDynamic()) { + GELOGD("[%s] graph is dynamic, get outputs from net output input tensors", graph_item_->GetName().c_str()); + // get from net output inputs + auto output_node = graph_item_->GetOutputNode(); + GE_CHECK_NOTNULL(output_node); + for (int i = 0; i < output_node->num_inputs; ++i) { + TensorValue tensor; + GE_CHK_STATUS_RET_NOLOG(GetInput(output_node->input_start + i, tensor)); + GELOGD("[%s] Adding output tensor by input index [%d], tensor = %s", graph_item_->GetName().c_str(), + output_node->input_start + i, tensor.DebugString().c_str()); + outputs.emplace_back(std::move(tensor)); + } + } else { + GELOGD("[%s] graph is non-dynamic, get outputs from subgraph outputs", graph_item_->GetName().c_str()); + for (auto &tensor : all_outputs_) { + GELOGD("[%s] Adding output tensor: %s", graph_item_->GetName().c_str(), tensor.DebugString().c_str()); + outputs.emplace_back(tensor); + } + } + + return SUCCESS; +} + +bool SubgraphContext::Await(const NodePtr &node) { return node_done_manager_.Await(node); } + +void SubgraphContext::OnError(Status error) { + GELOGE(error, "[%s] Error occurred while executing graph.", graph_item_->GetName().c_str()); + node_done_manager_.Destroy(); +} + +void SubgraphContext::NodeDone(const NodePtr &node) { node_done_manager_.NodeDone(node); } +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/executor/subgraph_context.h b/src/ge/hybrid/executor/subgraph_context.h new file mode 100644 index 00000000..fd934d80 --- /dev/null +++ b/src/ge/hybrid/executor/subgraph_context.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_EXECUTOR_ITERATION_CONTEXT_H_ +#define GE_HYBRID_EXECUTOR_ITERATION_CONTEXT_H_ + +#include + +#include "hybrid/common/tensor_value.h" +#include "hybrid/executor/node_state.h" +#include "hybrid/executor/node_done_manager.h" +#include "hybrid/model/graph_item.h" +#include "hybrid/model/node_item.h" + +namespace ge { +namespace hybrid { +class SubgraphContext { + public: + explicit SubgraphContext(const GraphItem *graph_item); + ~SubgraphContext() = default; + + Status Init(); + NodeStatePtr GetOrCreateNodeState(const NodeItem *node_item); + + void OnError(Status error); + + Status SetInput(const NodeItem &node_item, int input_index, const TensorValue &tensor); + Status SetOutput(const NodeItem &node_item, int output_index, const TensorValue &tensor); + Status SetInput(int index, const TensorValue &tensor); + Status GetInput(int index, TensorValue &tensor); + Status GetOutputs(std::vector &outputs); + + bool Await(const NodePtr &node); + void NodeDone(const NodePtr &node); + + private: + friend class TaskContext; + const GraphItem *graph_item_; + std::mutex mu_; + std::vector all_inputs_; + std::vector all_outputs_; + NodeDoneManager node_done_manager_; + std::unordered_map node_states_; +}; +} // namespace hybrid +} // namespace ge + +#endif // GE_HYBRID_EXECUTOR_ITERATION_CONTEXT_H_ diff --git a/src/ge/hybrid/executor/subgraph_executor.cc b/src/ge/hybrid/executor/subgraph_executor.cc new file mode 100644 index 00000000..3d699970 --- /dev/null +++ b/src/ge/hybrid/executor/subgraph_executor.cc @@ -0,0 +1,373 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/executor/subgraph_executor.h" +#include "hybrid/executor/worker/task_compile_engine.h" +#include "hybrid/executor/worker/execution_engine.h" +#include "hybrid/node_executor/node_executor.h" + +namespace ge { +namespace hybrid { +namespace { +constexpr int kDefaultThreadNum = 4; +constexpr int kDataInputIndex = 0; +} // namespace + +SubgraphExecutor::SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape) + : graph_item_(graph_item), + context_(context), + force_infer_shape_(force_infer_shape), + pre_run_pool_(kDefaultThreadNum) {} + +SubgraphExecutor::~SubgraphExecutor() { GELOGD("[%s] SubgraphExecutor destroyed.", graph_item_->GetName().c_str()); } + +Status SubgraphExecutor::Init(const std::vector &inputs, + const std::vector &input_desc) { + subgraph_context_.reset(new (std::nothrow) SubgraphContext(graph_item_)); + GE_CHECK_NOTNULL(subgraph_context_); + GE_CHK_STATUS_RET(subgraph_context_->Init(), "[%s] Failed to init subgraph context.", graph_item_->GetName().c_str()); + + shape_inference_engine_.reset(new (std::nothrow) ShapeInferenceEngine(context_, subgraph_context_.get())); + GE_CHECK_NOTNULL(shape_inference_engine_); + + if (graph_item_->IsDynamic()) { + GE_CHK_STATUS_RET(InitInputsForUnknownShape(inputs, input_desc), "[%s] Failed to set inputs.", + graph_item_->GetName().c_str()); + } else { + GE_CHK_STATUS_RET(InitInputsForKnownShape(inputs), + "[%s] Failed to init subgraph executor for known shape subgraph.", + graph_item_->GetName().c_str()); + } + + return SUCCESS; +} + +Status SubgraphExecutor::InitInputsForUnknownShape(const std::vector &inputs, + const std::vector &input_desc) { + // Number of inputs of parent node should be greater or equal than that of subgraph + auto input_nodes = graph_item_->GetInputNodes(); + if (inputs.size() < input_nodes.size()) { + GELOGE(INTERNAL_ERROR, "[%s] Number of inputs [%zu] is not sufficient for subgraph which needs [%zu] inputs.", + graph_item_->GetName().c_str(), inputs.size(), input_nodes.size()); + return INTERNAL_ERROR; + } + + for (size_t i = 0; i < input_nodes.size(); ++i) { + auto &input_node = input_nodes[i]; + if (input_node == nullptr) { + GELOGD("[%s] Input[%zu] is not needed by subgraph, skip it.", graph_item_->GetName().c_str(), i); + continue; + } + + auto &input_tensor = inputs[i]; + GELOGD("[%s] Set input tensor[%zu] to inputs with index = %d, tensor = %s", graph_item_->GetName().c_str(), i, + input_node->input_start, input_tensor.DebugString().c_str()); + + GE_CHK_STATUS_RET(subgraph_context_->SetInput(*input_node, kDataInputIndex, input_tensor), + "[%s] Failed to set input tensor[%zu]", graph_item_->GetName().c_str(), i); + + if (force_infer_shape_ || input_node->is_dynamic) { + GELOGD("[%s] Start to update input[%zu] for subgraph data node.", graph_item_->GetName().c_str(), i); + GE_CHECK_LE(i + 1, input_desc.size()); + const auto &tensor_desc = input_desc[i]; + auto node_state = subgraph_context_->GetOrCreateNodeState(input_node); + GE_CHECK_NOTNULL(node_state); + node_state->GetShapeInferenceState().UpdateInputShape(0, tensor_desc->GetOriginShape(), tensor_desc->GetShape()); + } + } + + GELOGD("[%s] Done setting inputs.", graph_item_->GetName().c_str()); + return SUCCESS; +} + +Status SubgraphExecutor::InitInputsForKnownShape(const std::vector &inputs) { + auto &input_index_mapping = graph_item_->GetInputIndexMapping(); + for (size_t i = 0; i < input_index_mapping.size(); ++i) { + auto &parent_input_index = input_index_mapping[i]; + if (static_cast(parent_input_index) >= inputs.size()) { + GELOGE(INTERNAL_ERROR, + "[%s] Number of inputs [%zu] is not sufficient for subgraph which needs at lease [%d] inputs", + graph_item_->GetName().c_str(), inputs.size(), parent_input_index + 1); + + return INTERNAL_ERROR; + } + + auto &input_tensor = inputs[parent_input_index]; + subgraph_context_->SetInput(i, input_tensor); + GELOGD("[%s] Set input tensor[%zu] with inputs with index = %d, tensor = %s", graph_item_->GetName().c_str(), i, + parent_input_index, input_tensor.DebugString().c_str()); + } + + return SUCCESS; +} + +Status SubgraphExecutor::ExecuteAsync(const std::vector &inputs, + const std::vector &input_desc) { + GELOGD("[%s] is dynamic = %s", graph_item_->GetName().c_str(), graph_item_->IsDynamic() ? "true" : "false"); + GE_CHK_STATUS_RET(Init(inputs, input_desc), "[%s] Failed to init executor.", graph_item_->GetName().c_str()); + + if (!graph_item_->IsDynamic()) { + return ExecuteAsyncForKnownShape(inputs); + } + + GE_CHK_STATUS_RET(ScheduleTasks(), "[%s] Failed to execute tasks.", graph_item_->GetName().c_str()); + GELOGD("[%s] Done executing subgraph successfully.", graph_item_->GetName().c_str()); + return SUCCESS; +} + +Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vector &inputs) { + GELOGD("[%s] subgraph is not dynamic.", graph_item_->GetName().c_str()); + if (graph_item_->GetAllNodes().size() != 1) { + GELOGE(INTERNAL_ERROR, "[%s] Invalid known shape subgraph. node size = %zu", graph_item_->GetName().c_str(), + graph_item_->GetAllNodes().size()); + return INTERNAL_ERROR; + } + + auto node_item = graph_item_->GetAllNodes()[0]; + GE_CHECK_NOTNULL(node_item); + auto node_state = subgraph_context_->GetOrCreateNodeState(node_item); + GE_CHECK_NOTNULL(node_state); + node_state->SetKernelTask(node_item->kernel_task); + + known_shape_task_context_ = TaskContext::Create(*node_item, context_, subgraph_context_.get()); + GE_CHECK_NOTNULL(known_shape_task_context_); + + GE_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_), + "[%s] Failed to execute node [%s] for known subgraph.", graph_item_->GetName().c_str(), + known_shape_task_context_->GetNodeName()); + + GELOGD("[%s] Done execute non-dynamic subgraph successfully.", graph_item_->GetName().c_str()); + return SUCCESS; +} + +Status SubgraphExecutor::ExecuteAsync(TaskContext &task_context) { + std::vector inputs; + std::vector input_desc; + for (int i = 0; i < task_context.NumInputs(); ++i) { + auto tensor = task_context.GetInput(i); + GE_CHECK_NOTNULL(tensor); + inputs.emplace_back(*tensor); + input_desc.emplace_back(task_context.GetInputDesc(i)); + } + + GE_CHK_STATUS_RET(ExecuteAsync(inputs, input_desc), "[%s] Failed to execute subgraph.", + graph_item_->GetName().c_str()); + + GE_CHK_STATUS_RET(SetOutputsToParentNode(task_context), "[%s] Failed to set output shapes to parent node.", + graph_item_->GetName().c_str()); + return SUCCESS; +} + +Status SubgraphExecutor::PrepareNodes() { + GELOGD("[%s] Start to prepare nodes. force infer shape = %s.", graph_item_->GetName().c_str(), + force_infer_shape_ ? "true" : "false"); + auto &all_nodes = graph_item_->GetAllNodes(); + for (size_t i = 0; i < all_nodes.size(); ++i) { + auto &node_item = *all_nodes[i]; + // for while op + if (force_infer_shape_ && !node_item.is_dynamic) { + GELOGD("[%s] Force infer shape is set, updating node to dynamic.", node_item.NodeName().c_str()); + auto &mutable_node_item = const_cast(node_item); + mutable_node_item.SetToDynamic(); + } + + GELOGD("[%s] Start to prepare node [%s].", graph_item_->GetName().c_str(), node_item.NodeName().c_str()); + auto node_state = subgraph_context_->GetOrCreateNodeState(&node_item); + GE_CHECK_NOTNULL(node_state); + auto p_node_state = node_state.get(); + + if (node_item.node_type == NETOUTPUT) { + // Wait for all inputs become valid + // after PrepareNodes returned. all output tensors and shapes are valid + GE_CHK_STATUS_RET_NOLOG(p_node_state->GetShapeInferenceState().AwaitShapesReady(*context_)); + GE_CHK_STATUS_RET_NOLOG(p_node_state->AwaitInputTensors(*context_)); + continue; + } + + // only do shape inference and compilation for nodes with dynamic shapes. + if (node_item.is_dynamic) { + auto prepare_future = pre_run_pool_.commit([this, p_node_state]() -> Status { + GE_CHK_STATUS_RET_NOLOG(InferShape(shape_inference_engine_.get(), *p_node_state)); + return PrepareForExecution(context_, *p_node_state); + }); + + p_node_state->SetPrepareFuture(std::move(prepare_future)); + } else { + GELOGD("[%s] Skipping shape inference and compilation for node with static shape.", node_item.NodeName().c_str()); + if (node_item.kernel_task == nullptr) { + GELOGW("[%s] Node of static shape got no task.", node_item.NodeName().c_str()); + GE_CHK_STATUS_RET(TaskCompileEngine::Compile(*p_node_state, context_), "[%s] Failed to create task.", + p_node_state->GetName().c_str()); + } else { + node_state->SetKernelTask(node_item.kernel_task); + } + } + + if (!ready_queue_.Push(p_node_state)) { + GELOGE(INTERNAL_ERROR, "[%s] Error occurs while launching tasks. quit from preparing nodes.", + graph_item_->GetName().c_str()); + return INTERNAL_ERROR; + } + + GELOGD("[%s] Push node [%s] to queue.", graph_item_->GetName().c_str(), node_item.NodeName().c_str()); + } + + return SUCCESS; +} + +Status SubgraphExecutor::InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state) { + const auto &node_item = *node_state.GetNodeItem(); + GE_CHK_STATUS_RET(shape_inference_engine->InferShape(node_state), "[%s] Failed to InferShape.", + node_state.GetName().c_str()); + GE_CHK_STATUS_RET(shape_inference_engine->PropagateOutputShapes(node_item), "[%s] Failed to PropagateOutputShapes.", + node_state.GetName().c_str()); + return SUCCESS; +} + +Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeState &node_state) { + auto &node_item = *node_state.GetNodeItem(); + if (node_item.kernel_task == nullptr) { + GE_CHK_STATUS_RET(TaskCompileEngine::Compile(node_state, ctx), "Failed to create task for node[%s]", + node_state.GetName().c_str()); + } else { + node_state.SetKernelTask(node_item.kernel_task); + } + + GELOGD("[%s] Start to invoke CalcOpRunningParam.", node_item.NodeName().c_str()); + RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[CalcOpRunningParam] Start"); + GE_CHK_STATUS_RET(NodeExecutorManager::GetInstance().CalcOpRunningParam(*node_item.node), + "[%s] Failed to invoke CalcOpRunningParam.", node_item.NodeName().c_str()); + RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[CalcOpRunningParam] End"); + GELOGD("[%s] Done invoking CalcOpRunningParam successfully.", node_item.NodeName().c_str()); + return SUCCESS; +} + +Status SubgraphExecutor::LaunchTasks() { + while (true) { + NodeState *node_state = nullptr; + if (!ready_queue_.Pop(node_state)) { + GELOGE(INTERNAL_ERROR, "[%s] Failed to pop node.", graph_item_->GetName().c_str()); + return INTERNAL_ERROR; + } + + if (node_state == nullptr) { + GELOGD("[%s] Got EOF from queue.", graph_item_->GetName().c_str()); + return SUCCESS; + } + + GE_CHK_STATUS_RET_NOLOG(node_state->WaitForPrepareDone()); + + GELOGD("[%s] Start to execute.", node_state->GetName().c_str()); + auto task_context = TaskContext::Create(*node_state->GetNodeItem(), context_, subgraph_context_.get()); + GE_CHECK_NOTNULL(task_context); + task_context->SetForceInferShape(force_infer_shape_); + auto shared_task_context = std::shared_ptr(task_context.release()); + GE_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, shared_task_context, *context_), + "[%s] Execute node failed.", node_state->GetName().c_str()); + + GELOGD("[%s] Done executing node successfully.", node_state->GetName().c_str()); + } +} + +Status SubgraphExecutor::ScheduleTasks() { + GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str()); + auto prepare_future = std::async([&]() -> Status { + auto ret = PrepareNodes(); + ready_queue_.Push(nullptr); + return ret; + }); + + GELOGD("[%s] Start to execute subgraph.", graph_item_->GetName().c_str()); + auto ret = LaunchTasks(); + if (ret != SUCCESS) { + GELOGE(ret, "[%s] Failed to execute subgraph.", graph_item_->GetName().c_str()); + subgraph_context_->OnError(ret); + ready_queue_.Stop(); + prepare_future.wait(); + return ret; + } + + GE_CHK_STATUS_RET(prepare_future.get(), "[%s] Error occurred in task preparation.", graph_item_->GetName().c_str()); + + GELOGD("[%s] Done launching all tasks successfully.", graph_item_->GetName().c_str()); + return SUCCESS; +} + +Status SubgraphExecutor::GetOutputs(vector &outputs) { return subgraph_context_->GetOutputs(outputs); } + +Status SubgraphExecutor::GetOutputs(vector &outputs, std::vector &output_desc) { + GE_CHK_STATUS_RET(GetOutputs(outputs), "[%s] Failed to get output tensors.", graph_item_->GetName().c_str()); + + // copy output data from op to designated position + std::vector output_tensor_desc_list; + GE_CHK_STATUS_RET(graph_item_->GetOutputDescList(output_desc), "[%s] Failed to get output tensor desc.", + graph_item_->GetName().c_str()); + return SUCCESS; +} + +Status SubgraphExecutor::Synchronize() { + GELOGD("[%s] Synchronize start.", graph_item_->GetName().c_str()); + GE_CHK_RT_RET(rtStreamSynchronize(context_->stream)); + GELOGD("[%s] Done synchronizing successfully.", graph_item_->GetName().c_str()); + return SUCCESS; +} + +Status SubgraphExecutor::SetOutputsToParentNode(TaskContext &task_context) { + // get output tensors and tensor desc list + std::vector outputs; + std::vector output_desc_list; + GE_CHK_STATUS_RET(subgraph_context_->GetOutputs(outputs), "[%s] Failed to get output tensors.", + graph_item_->GetName().c_str()); + GE_CHK_STATUS_RET(graph_item_->GetOutputDescList(output_desc_list), "[%s] Failed to get output tensor desc.", + graph_item_->GetName().c_str()); + + if (outputs.size() != output_desc_list.size()) { + GELOGE(INTERNAL_ERROR, "[%s] num output tensors = %zu, num output tensor desc = %zu", + graph_item_->GetName().c_str(), outputs.size(), output_desc_list.size()); + return INTERNAL_ERROR; + } + + // mapping to parent task context + for (size_t i = 0; i < outputs.size(); ++i) { + int parent_output_index = graph_item_->GetParentOutputIndex(i); + GE_CHECK_GE(parent_output_index, 0); + // update tensor + GELOGD("[%s] Updating output[%zu] to parent output[%d]", graph_item_->GetName().c_str(), i, parent_output_index); + + GELOGD("[%s] Updating output tensor, index = %d, tensor = %s", graph_item_->GetName().c_str(), parent_output_index, + outputs[i].DebugString().c_str()); + task_context.SetOutput(parent_output_index, outputs[i]); + + // updating shapes. dynamic format/dtype is not supported. + // It should be noted that even the subgraph is of known shape, it is also necessary to update parent output desc, + // for instance, IfOp may have two known-shaped subgraphs of different output shapes + const auto &output_desc = output_desc_list[i]; + auto parent_output_desc = task_context.MutableOutputDesc(parent_output_index); + GE_CHECK_NOTNULL(parent_output_desc); + GELOGD("[%s] Updating output shape[%d] from [%s] to [%s]", graph_item_->GetName().c_str(), parent_output_index, + parent_output_desc->MutableShape().ToString().c_str(), output_desc->GetShape().ToString().c_str()); + parent_output_desc->SetShape(output_desc->GetShape()); + + GELOGD("[%s] Updating output original shape[%d] from [%s] to [%s]", graph_item_->GetName().c_str(), + parent_output_index, parent_output_desc->GetOriginShape().ToString().c_str(), + output_desc->GetOriginShape().ToString().c_str()); + parent_output_desc->SetOriginShape(output_desc->GetOriginShape()); + } + + return SUCCESS; +} +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/executor/subgraph_executor.h b/src/ge/hybrid/executor/subgraph_executor.h new file mode 100644 index 00000000..7cdb2070 --- /dev/null +++ b/src/ge/hybrid/executor/subgraph_executor.h @@ -0,0 +1,101 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_EXECUTOR_EXECUTOR_SUBGRAPH_EXECUTOR_H_ +#define GE_HYBRID_EXECUTOR_EXECUTOR_SUBGRAPH_EXECUTOR_H_ + +#include + +#include "common/blocking_queue.h" +#include "common/thread_pool.h" +#include "hybrid/executor/subgraph_context.h" +#include "hybrid/executor/node_state.h" +#include "hybrid/executor/hybrid_execution_context.h" +#include "hybrid/executor/worker/shape_inference_engine.h" +#include "hybrid/model/graph_item.h" +#include "hybrid/node_executor/task_context.h" + +namespace ge { +namespace hybrid { +// Executor for executing a subgraph +class SubgraphExecutor { + public: + SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape = false); + ~SubgraphExecutor(); + + /** + * Execute subgraph async, output tensor address(not data) and output tensor descriptions are + * valid after this method returned + * @param inputs input tensors + * @param input_desc input tensor descriptions + * @return SUCCESS on success, error code otherwise + */ + Status ExecuteAsync(const std::vector &inputs, const std::vector &input_desc); + + /** + * Execute subgraph async, output tensor address(not data) and output tensor descriptions are + * valid after this method returned + * @param task_context instance of TaskContext + * @return SUCCESS on success, error code otherwise + */ + Status ExecuteAsync(TaskContext &task_context); + + /** + * Synchronize all tasks in the subgraph. output tensor data are valid after this method returned + * @return SUCCESS on success, error code otherwise + */ + Status Synchronize(); + + /** + * Get output tensors + * @param outputs output tensors + * @return SUCCESS on success, error code otherwise + */ + Status GetOutputs(std::vector &outputs); + + /** + * Get output tensors and output tensor descriptions + * @param outputs output tensors + * @param output_desc output tensor descriptions + * @return SUCCESS on success, error code otherwise + */ + Status GetOutputs(std::vector &outputs, std::vector &output_desc); + + private: + static Status PrepareForExecution(GraphExecutionContext *ctx, NodeState &node_state); + static Status InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state); + Status Init(const std::vector &inputs, const std::vector &input_desc); + Status InitInputsForUnknownShape(const std::vector &inputs, + const std::vector &input_desc); + Status InitInputsForKnownShape(const std::vector &inputs); + Status ExecuteAsyncForKnownShape(const std::vector &inputs); + Status ScheduleTasks(); + Status PrepareNodes(); + Status LaunchTasks(); + Status SetOutputsToParentNode(TaskContext &task_context); + + const GraphItem *graph_item_; + GraphExecutionContext *context_; + std::unique_ptr subgraph_context_; + bool force_infer_shape_; + ThreadPool pre_run_pool_; + BlockingQueue ready_queue_; + std::unique_ptr shape_inference_engine_; + std::shared_ptr known_shape_task_context_; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_EXECUTOR_EXECUTOR_SUBGRAPH_EXECUTOR_H_ diff --git a/src/ge/hybrid/executor/worker/execution_engine.cc b/src/ge/hybrid/executor/worker/execution_engine.cc index 9e656139..20da6378 100644 --- a/src/ge/hybrid/executor/worker/execution_engine.cc +++ b/src/ge/hybrid/executor/worker/execution_engine.cc @@ -15,7 +15,6 @@ */ #include "hybrid/executor/worker/execution_engine.h" -#include #include "graph/runtime_inference_context.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_adapter.h" @@ -23,9 +22,38 @@ namespace ge { namespace hybrid { +namespace { +constexpr int64_t kMaxPadding = 63; + +Status LogInputs(const NodeItem &node_item, const TaskContext &task_context) { + for (auto i = 0; i < task_context.NumInputs(); ++i) { + const auto &input_tensor = task_context.GetInput(i); + GE_CHECK_NOTNULL(input_tensor); + const auto &tensor_desc = node_item.op_desc->MutableInputDesc(i); + GE_CHECK_NOTNULL(tensor_desc); + GELOGD("[%s] Print task args. input[%d] = %s, shape = [%s]", node_item.NodeName().c_str(), i, + input_tensor->DebugString().c_str(), tensor_desc->MutableShape().ToString().c_str()); + } + + return SUCCESS; +} + +Status LogOutputs(const NodeItem &node_item, const TaskContext &task_context) { + for (auto i = 0; i < task_context.NumOutputs(); ++i) { + const auto &output_tensor = task_context.GetOutput(i); + GE_CHECK_NOTNULL(output_tensor); + const auto &tensor_desc = node_item.op_desc->MutableOutputDesc(i); + GE_CHECK_NOTNULL(tensor_desc); + GELOGD("[%s] Print task args. output[%d] = %s, shape = [%s]", node_item.NodeName().c_str(), i, + output_tensor->DebugString().c_str(), tensor_desc->MutableShape().ToString().c_str()); + } + + return SUCCESS; +} +} // namespace class NodeDoneCallback { public: - NodeDoneCallback(GraphExecutionContext *graph_context, std::shared_ptr &task_context); + NodeDoneCallback(GraphExecutionContext *graph_context, std::shared_ptr task_context); ~NodeDoneCallback() = default; Status OnNodeDone(); @@ -35,8 +63,8 @@ class NodeDoneCallback { std::shared_ptr context_; }; -NodeDoneCallback::NodeDoneCallback(GraphExecutionContext *graph_context, std::shared_ptr &task_context) - : graph_context_(graph_context), context_(task_context) {} +NodeDoneCallback::NodeDoneCallback(GraphExecutionContext *graph_context, std::shared_ptr task_context) + : graph_context_(graph_context), context_(std::move(task_context)) {} Status NodeDoneCallback::PrepareConstInputs(const NodeItem &node_item) { for (auto output_idx : node_item.to_const_output_id_list) { @@ -46,17 +74,28 @@ Status NodeDoneCallback::PrepareConstInputs(const NodeItem &node_item) { auto output_tensor = context_->GetOutput(output_idx); GE_CHECK_NOTNULL(output_tensor); - vector host_buffer(output_tensor->GetSize()); - GELOGD("[%s] To cache output[%d] to host, size = %zu", node_item.NodeName().c_str(), output_idx, - output_tensor->GetSize()); - GE_CHK_RT_RET(rtMemcpy(host_buffer.data(), host_buffer.size(), output_tensor->GetData(), output_tensor->GetSize(), - RT_MEMCPY_HOST_TO_DEVICE)); Tensor tensor; - tensor.SetData(host_buffer); auto ge_tensor_desc = node_item.op_desc->MutableOutputDesc(output_idx); GE_CHECK_NOTNULL(ge_tensor_desc); tensor.SetTensorDesc(TensorAdapter::GeTensorDesc2TensorDesc(*ge_tensor_desc)); + int64_t tensor_size; + GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorSizeInBytes(*ge_tensor_desc, tensor_size), + "Failed to invoke GetTensorSizeInBytes"); + + if (output_tensor->GetSize() < static_cast(tensor_size)) { + GELOGE(INTERNAL_ERROR, "[%s] Tensor size is not enough. output index = %d, required size = %zu, tensor = %s", + node_item.NodeName().c_str(), output_idx, tensor_size, output_tensor->DebugString().c_str()); + return INTERNAL_ERROR; + } + + vector host_buffer(tensor_size); + GELOGD("[%s] To cache output[%d] to host, size = %zu", node_item.NodeName().c_str(), output_idx, + output_tensor->GetSize()); + GE_CHK_RT_RET( + rtMemcpy(host_buffer.data(), tensor_size, output_tensor->GetData(), tensor_size, RT_MEMCPY_DEVICE_TO_HOST)); + tensor.SetData(host_buffer); + string session_id = std::to_string(context_->GetSessionId()); RuntimeInferenceContext *runtime_infer_ctx = nullptr; GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx), @@ -87,115 +126,118 @@ Status NodeDoneCallback::OnNodeDone() { GE_CHK_STATUS_RET_NOLOG(PrepareConstInputs(node_item)); // PropagateOutputs for type == DEPEND_COMPUTE if (node_item.shape_inference_type == DEPEND_COMPUTE) { + if (graph_context_->trace_enabled) { + (void)LogOutputs(node_item, *context_); + } + GE_CHK_STATUS_RET(context_->PropagateOutputs(), "[%s] Failed to propagate outputs failed", node_item.NodeName().c_str()); RECORD_CALLBACK_EVENT(graph_context_, context_->GetNodeName(), "[PropagateOutputs] End"); } - // release + // release condition variable if (node_item.has_observer) { GELOGI("[%s] Notify observer. node_id = %d", node_item.NodeName().c_str(), node_item.node_id); - graph_context_->cv_manager.NodeDone(node_item.node); + context_->NodeDone(); } RECORD_CALLBACK_EVENT(graph_context_, context_->GetNodeName(), "[Callback] End"); return SUCCESS; } -ExecutionEngine::ExecutionEngine(GraphExecutionContext *context, CallbackManager *callback_manager) - : context_(context), callback_manager_(callback_manager) {} - -Status ExecutionEngine::Start() { - GE_CHK_STATUS_RET_NOLOG(ExecutionProcess()); - return SUCCESS; -} - -Status ExecutionEngine::ExecutionProcess() { - GELOGI("ExecutorEngine worker started"); - auto &ready_queue = context_->execution_queue; - while (true) { - NodeStatePtr node_state = nullptr; - if (!ready_queue.Pop(node_state)) { - GELOGE(FAILED, "Pop task failed"); - return FAILED; - } - - // EOF - if (node_state == nullptr) { - break; +Status ExecutionEngine::ExecuteAsync(NodeState &node_state, const std::shared_ptr &task_context, + GraphExecutionContext &execution_context) { + GELOGI("[%s] Node is ready for execution", task_context->GetNodeName()); + RECORD_EXECUTION_EVENT(&execution_context, task_context->GetNodeName(), "Start"); + auto cb = std::shared_ptr(new (std::nothrow) NodeDoneCallback(&execution_context, task_context)); + GE_CHECK_NOTNULL(cb); + auto callback = [&, cb]() { + auto ret = cb->OnNodeDone(); + if (ret != SUCCESS) { + task_context->OnError(ret); } + }; - RECORD_EXECUTION_EVENT(context_, node_state->GetName().c_str(), "Start"); - GELOGI("[%s] Node is ready for execution", node_state->GetName().c_str()); - auto *node_item = node_state->node_item; - auto task_context = TaskContext::Create(*node_item, context_); - GE_CHECK_NOTNULL(task_context); - auto shared_task_context = shared_ptr(task_context.release()); - - auto cb = std::shared_ptr(new (std::nothrow) NodeDoneCallback(context_, shared_task_context)); - GE_CHECK_NOTNULL(cb); - auto callback = [&, cb]() { - auto ret = cb->OnNodeDone(); - if (ret != SUCCESS) { - context_->OnError(ret); - } - }; - - GE_CHK_STATUS_RET_NOLOG(ExecuteAsync(*node_state, *shared_task_context, callback)); - GE_CHK_STATUS_RET_NOLOG(PropagateOutputs(*node_item, *shared_task_context)); - } - - GELOGI("ExecutorEngine worker ended."); + GE_CHK_STATUS_RET_NOLOG(DoExecuteAsync(node_state, *task_context, execution_context, callback)); + GE_CHK_STATUS_RET_NOLOG(PropagateOutputs(*node_state.GetNodeItem(), *task_context, execution_context)); return SUCCESS; } -Status ExecutionEngine::ExecuteAsync(NodeState &node_state, TaskContext &task_context, - const std::function &callback) { - const auto &task = node_state.kernel_task; +Status ExecutionEngine::DoExecuteAsync(NodeState &node_state, TaskContext &task_context, GraphExecutionContext &context, + const std::function &callback) { + const auto &task = node_state.GetKernelTask(); if (task == nullptr) { GELOGE(INTERNAL_ERROR, "[%s] NodeTask is null.", node_state.GetName().c_str()); return INTERNAL_ERROR; } - RECORD_EXECUTION_EVENT(context_, task_context.GetNodeName(), "[PrepareTask] Start"); - auto executor = node_state.node_item->node_executor; + // Wait for dependent nodes(DEPEND_COMPUTE), so that the input tensors are valid. + RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[AwaitDependents] Start"); + GE_CHK_STATUS_RET(node_state.AwaitInputTensors(context), "[%s] Failed to wait for dependent nodes.", + node_state.GetName().c_str()); + + const auto &node_item = *node_state.GetNodeItem(); + auto executor = node_item.node_executor; + GE_CHECK_NOTNULL(executor); + RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start"); GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[%s] Failed to prepare task", node_state.GetName().c_str()); - RECORD_EXECUTION_EVENT(context_, task_context.GetNodeName(), "[PrepareTask] End"); + RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End"); GELOGD("[%s] Done task preparation successfully.", node_state.GetName().c_str()); - if (context_->trace_enabled) { - for (auto i = 0; i < task_context.NumInputs(); ++i) { - const auto &input_tensor = task_context.GetInput(i); - GE_CHECK_NOTNULL(input_tensor); - GELOGD("[%s] Tensor of input[%d] = %s", node_state.GetName().c_str(), i, input_tensor->DebugString().c_str()); - } - - for (auto i = 0; i < task_context.NumOutputs(); ++i) { - const auto &output_tensor = task_context.GetOutput(i); - GE_CHECK_NOTNULL(output_tensor); - GELOGD("[%s] Tensor of output[%d] = %s", node_state.GetName().c_str(), i, output_tensor->DebugString().c_str()); + if (context.trace_enabled) { + LogInputs(node_item, task_context); + if (node_item.shape_inference_type != DEPEND_COMPUTE) { + LogOutputs(node_item, task_context); } } - RECORD_EXECUTION_EVENT(context_, task_context.GetNodeName(), "[ExecuteTask] Start"); + GE_CHK_STATUS_RET(ValidateInputTensors(node_state, task_context), "Failed to validate input tensors."); + RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[ValidateInputTensors] End"); + GE_CHK_STATUS_RET(executor->ExecuteTask(*task, task_context, callback), "[%s] Failed to execute task", node_state.GetName().c_str()); - RECORD_EXECUTION_EVENT(context_, task_context.GetNodeName(), "[ExecuteTask] End"); + RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[ExecuteTask] End"); GELOGD("[%s] Done task launch successfully.", node_state.GetName().c_str()); return SUCCESS; } -Status ExecutionEngine::PropagateOutputs(const NodeItem &node_item, TaskContext &task_context) { +Status ExecutionEngine::ValidateInputTensors(const NodeState &node_state, const TaskContext &task_context) { + for (auto i = 0; i < task_context.NumInputs(); ++i) { + const auto &input_tensor = task_context.GetInput(i); + GE_CHECK_NOTNULL(input_tensor); + const auto &tensor_desc = node_state.GetOpDesc()->MutableInputDesc(i); + GE_CHECK_NOTNULL(tensor_desc); + int64_t expected_size; + GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, expected_size)); + GELOGD("[%s] Input[%d] expects [%ld] bytes.", task_context.GetNodeName(), i, expected_size); + auto size_diff = expected_size - static_cast(input_tensor->GetSize()); + if (size_diff > 0) { + if (size_diff <= kMaxPadding) { + GELOGW("[%s] Input[%d]: tensor size mismatches. expected: %ld, but given %zu", task_context.GetNodeName(), i, + expected_size, input_tensor->GetSize()); + } else { + GELOGE(INTERNAL_ERROR, "[%s] Input[%d]: tensor size mismatches. expected: %ld, but given %zu", + task_context.GetNodeName(), i, expected_size, input_tensor->GetSize()); + return INTERNAL_ERROR; + } + } + } + + return SUCCESS; +} + +Status ExecutionEngine::PropagateOutputs(const NodeItem &node_item, TaskContext &task_context, + GraphExecutionContext &context) { if (node_item.shape_inference_type != DEPEND_COMPUTE) { GE_CHK_STATUS_RET(task_context.PropagateOutputs(), "[%s] Failed to propagate outputs.", node_item.NodeName().c_str()); - RECORD_EXECUTION_EVENT(context_, task_context.GetNodeName(), "[PropagateOutputs] End"); + RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PropagateOutputs] End"); + GELOGD("[%s] Done propagating outputs successfully.", node_item.NodeName().c_str()); } - GELOGD("[%s] Done propagating outputs successfully.", node_item.NodeName().c_str()); return SUCCESS; } } // namespace hybrid diff --git a/src/ge/hybrid/executor/worker/execution_engine.h b/src/ge/hybrid/executor/worker/execution_engine.h index f5f317af..56f1557d 100644 --- a/src/ge/hybrid/executor/worker/execution_engine.h +++ b/src/ge/hybrid/executor/worker/execution_engine.h @@ -17,30 +17,21 @@ #ifndef GE_HYBRID_EXECUTOR_EXECUTOR_EXECUTION_ENGINE_H_ #define GE_HYBRID_EXECUTOR_EXECUTOR_EXECUTION_ENGINE_H_ -#include "common/thread_pool.h" -#include "hybrid/common/npu_memory_allocator.h" #include "hybrid/executor/hybrid_execution_context.h" -#include "hybrid/executor/rt_callback_manager.h" #include "hybrid/node_executor/task_context.h" namespace ge { namespace hybrid { class ExecutionEngine { public: - explicit ExecutionEngine(GraphExecutionContext *context, CallbackManager *callback_manager); - ~ExecutionEngine() = default; - - Status Start(); + static Status ExecuteAsync(NodeState &node_state, const std::shared_ptr &task_context, + GraphExecutionContext &execution_context); private: - Status PropagateOutputs(const NodeItem &node_item, TaskContext &task_context); - - Status ExecutionProcess(); - - Status ExecuteAsync(NodeState &node_state, TaskContext &task_context, const std::function &callback); - - GraphExecutionContext *context_; - CallbackManager *callback_manager_; + static Status ValidateInputTensors(const NodeState &node_state, const TaskContext &task_context); + static Status PropagateOutputs(const NodeItem &node_item, TaskContext &task_context, GraphExecutionContext &context); + static Status DoExecuteAsync(NodeState &node_state, TaskContext &task_context, GraphExecutionContext &context, + const std::function &callback); }; } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/executor/worker/shape_inference_engine.cc b/src/ge/hybrid/executor/worker/shape_inference_engine.cc index 90082fff..f600e94a 100644 --- a/src/ge/hybrid/executor/worker/shape_inference_engine.cc +++ b/src/ge/hybrid/executor/worker/shape_inference_engine.cc @@ -15,117 +15,27 @@ */ #include "hybrid/executor/worker/shape_inference_engine.h" - #include "graph/shape_refiner.h" -#include "graph/runtime_inference_context.h" #include "graph/utils/node_utils.h" #include "hybrid/node_executor/node_executor.h" namespace ge { namespace hybrid { +ShapeInferenceEngine::ShapeInferenceEngine(GraphExecutionContext *execution_context, SubgraphContext *subgraph_context) + : execution_context_(execution_context), subgraph_context_(subgraph_context) {} -ShapeInferenceEngine::ShapeInferenceEngine(GraphExecutionContext *context) : context_(context) {} - -Status ShapeInferenceEngine::Start(ThreadPool &pool) { - GELOGI("RuntimeShapeInferenceEngine start."); - pool.commit([&]() { - auto ret = this->InferShapeProcess(); - InferenceDone(ret); - }); - - return SUCCESS; -} - -Status ShapeInferenceEngine::InferShapeProcess() { - GELOGI("RuntimeShapeInferenceEngine worker start."); - const auto &root_nodes = context_->model->RootNodes(); - auto &complete_queue = context_->compile_queue; - std::queue ready_nodes; - for (auto &node_item : root_nodes) { - auto infer_state = GetOrCreateEntry(*node_item); - GE_CHECK_NOTNULL(infer_state); - ready_nodes.emplace(infer_state); - } - - while (!ready_nodes.empty()) { - InferenceState *infer_state = ready_nodes.front(); - ready_nodes.pop(); - auto node_item = infer_state->node_item; - // even for non-dynamic shape node, it is still necessary to wait for pending shapes if got any. - // which indicates that the parent node is of type 4, in which case the inputs will be valid only - // when computing is done. - GE_CHK_STATUS_RET(infer_state->AwaitShapeFutures(context_), "Await shape failed."); - GELOGI("[%s] Node is ready for shape inference.", node_item.NodeName().c_str()); - if (node_item.is_dynamic) { - // may block - RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "Start"); - GELOGI("[%s] Start to invoke InferShape", node_item.NodeName().c_str()); - auto ret = InferShape(*infer_state); - if (ret != SUCCESS) { - return ret; - } - - RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[CalcOpRunningParam] Start"); - GE_CHK_STATUS_RET(NodeExecutorManager::GetInstance().CalcOpRunningParam(*node_item.node), - "[%s] Failed to invoke CalcOpRunningParam.", node_item.NodeName().c_str()); - RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[CalcOpRunningParam] End"); - } else { - GELOGD("[%s] Skip static shape node", node_item.NodeName().c_str()); - } - - if (node_item.node_type != NETOUTPUT) { - GELOGI("[%s] Push to compile queue", node_item.NodeName().c_str()); - // may block if full - auto node_state = context_->GetOrCreateNodeState(node_item.node); - complete_queue.Push(node_state); - } - - // Propagate - RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[PropagateOutputShapes] Start"); - PropagateOutputShapes(*infer_state, ready_nodes); - RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[PropagateOutputShapes] End"); - } - - return SUCCESS; -} - -void ShapeInferenceEngine::InferenceDone(Status status) { - if (status != SUCCESS) { - GELOGE(status, "Error occurred while shape inference"); - context_->OnError(status); - } else { - context_->compile_queue.Push(nullptr); - } - inference_states_.clear(); - GELOGI("RuntimeShapeInferenceEngine worker END"); -} - -Status ShapeInferenceEngine::InferShape(InferenceState &entry) { - // input shapes are ready, wait for dependent data if has any - const auto &node_item = entry.node_item; - if (!node_item.dependent_node_list.empty()) { - for (auto &src_node : node_item.dependent_node_list) { - auto *src_node_item = context_->model->GetNodeItem(src_node); - GELOGI("[%s] Start to wait for data dependent node: %s", node_item.NodeName().c_str(), - src_node_item->NodeName().c_str()); - RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[AwaitNodeDone] [%s] Start", - src_node->GetName().c_str()); - if (!context_->cv_manager.Await(src_node)) { - GELOGE(INTERNAL_ERROR, "[%s] Await node failed.", src_node_item->NodeName().c_str()); - return INTERNAL_ERROR; - } - - RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[AwaitNodeDone] [%s] End", - src_node->GetName().c_str()); - GELOGI("[%s] Done waiting node.", src_node_item->NodeName().c_str()); - } - } +Status ShapeInferenceEngine::InferShape(NodeState &node_state) { + // Wait for all input shape become valid + GE_CHK_STATUS_RET_NOLOG(node_state.GetShapeInferenceState().AwaitShapesReady(*execution_context_)); + auto &node_item = *node_state.GetNodeItem(); + // Skip shape inference for node of type DEPEND_COMPUTE if (node_item.shape_inference_type == DEPEND_COMPUTE) { - GELOGD("[%s] Skip node with unknown shape type DEPEND_COMPUTE", node_item.NodeName().c_str()); + GELOGD("[%s] Skipping node with unknown shape type DEPEND_COMPUTE", node_item.NodeName().c_str()); return SUCCESS; } + // Clear shape range in case shape inference func forgot to do it if (node_item.shape_inference_type == DEPEND_SHAPE_RANGE) { // in case InferFunc forgot to reset output shape for (auto &output_desc : node_item.op_desc->GetAllOutputsDescPtr()) { @@ -133,13 +43,16 @@ Status ShapeInferenceEngine::InferShape(InferenceState &entry) { } } - // do shape inference - RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[InferShape] Start"); + // Wait for "const input nodes" if node's shape inference function requires any. + GE_CHK_STATUS_RET_NOLOG(AwaitDependentNodes(node_state)); + + // Do shape inference GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str()); + RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndType(node_item.node), "Invoke InferShapeAndType failed."); - RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[InferShape] End"); + RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End"); - // Check shape + // Check again to make sure shape is valid after shape inference if (node_item.shape_inference_type != DEPEND_SHAPE_RANGE) { bool is_unknown_shape = false; GE_CHK_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node_item.node, is_unknown_shape), @@ -149,12 +62,33 @@ Status ShapeInferenceEngine::InferShape(InferenceState &entry) { node_item.NodeName().c_str()); } + GELOGD("[%s] [HybridTrace] After shape inference. Node = %s", node_item.NodeName().c_str(), + node_item.DebugString().c_str()); + GELOGD("[%s] InferShapeAndType finished successfully.", node_item.NodeName().c_str()); return SUCCESS; } -void ShapeInferenceEngine::PropagateOutputShapes(InferenceState &entry, std::queue &queue) { - auto &node_item = entry.node_item; +Status ShapeInferenceEngine::AwaitDependentNodes(NodeState &node_state) { + auto &node_item = *node_state.GetNodeItem(); + for (auto &src_node : node_item.dependents_for_shape_inference) { + GELOGI("[%s] Start to wait for data dependent node: %s", node_item.NodeName().c_str(), src_node->GetName().c_str()); + RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[AwaitNodeDone] [%s] Start", + src_node->GetName().c_str()); + if (!subgraph_context_->Await(src_node)) { + GELOGE(INTERNAL_ERROR, "[%s] Await node failed.", src_node->GetName().c_str()); + return INTERNAL_ERROR; + } + + RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[AwaitNodeDone] [%s] End", + src_node->GetName().c_str()); + GELOGI("[%s] Done waiting node.", src_node->GetName().c_str()); + } + + return SUCCESS; +} + +Status ShapeInferenceEngine::PropagateOutputShapes(const NodeItem &node_item) { // output shape will not be valid until compute is done. bool shape_is_future = node_item.shape_inference_type == DEPEND_SHAPE_RANGE || node_item.shape_inference_type == DEPEND_COMPUTE; @@ -171,88 +105,25 @@ void ShapeInferenceEngine::PropagateOutputShapes(InferenceState &entry, std::que // propagate output to all sub-inputs for (auto &dst_input_index_and_node : output_nodes) { auto &dst_node_item = dst_input_index_and_node.second; - auto inference_state = GetOrCreateEntry(*dst_node_item); + auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item); + GE_CHECK_NOTNULL(dst_node_state); + GELOGI("[%s] Update dst node [%s], input index = %d", node_item.NodeName().c_str(), dst_node_item->NodeName().c_str(), dst_input_index_and_node.first); - // in case type 3/4, shape will be valid after computing is done + // in case type 3 and 4, shape will be valid after computing is done if (shape_is_future) { - ShapeFuture future(node_item.node, i, &context_->cv_manager); - inference_state->UpdateInputShapeFuture(dst_input_index_and_node.first, std::move(future)); + ShapeFuture future(node_item.node, i, subgraph_context_); + dst_node_state->GetShapeInferenceState().UpdateInputShapeFuture(dst_input_index_and_node.first, + std::move(future)); } else { - inference_state->UpdateInputShape(dst_input_index_and_node.first, ori_shape, shape); - } - - if (inference_state->IsInputShapesReady()) { - GELOGI("[%s] Node input shape is ready, add to queue.", inference_state->node_item.NodeName().c_str()); - queue.emplace(inference_state); + dst_node_state->GetShapeInferenceState().UpdateInputShape(dst_input_index_and_node.first, ori_shape, shape); } } } GELOGD("[%s] Propagating output shapes finished successfully.", node_item.NodeName().c_str()); -} - -ShapeInferenceEngine::InferenceState *ShapeInferenceEngine::GetOrCreateEntry(const NodeItem &node_item) { - auto &node_state = inference_states_[node_item.node_id]; - if (node_state == nullptr) { - node_state.reset(new (std::nothrow) InferenceState(node_item)); - } - - return node_state.get(); -} - -ShapeInferenceEngine::InferenceState::InferenceState(const NodeItem &node_item) : node_item(node_item) { - this->num_pending_shapes = node_item.num_inputs; -} - -void ShapeInferenceEngine::InferenceState::UpdateInputShape(uint32_t idx, const GeShape &ori_shape, - const GeShape &shape) { - if (node_item.const_input_shapes.count(idx) != 0) { - GELOGD("[%s] Trying to update constant shape, idx = %u. old shape = [%s], new shape = [%s]", - node_item.NodeName().c_str(), idx, node_item.op_desc->MutableInputDesc(idx)->GetShape().ToString().c_str(), - shape.ToString().c_str()); - } - - GELOGD("[%s] Update input shape [%u] with Shape: [%s] and OriginalShape: [%s]", node_item.NodeName().c_str(), idx, - shape.ToString().c_str(), ori_shape.ToString().c_str()); - num_pending_shapes -= 1; - node_item.op_desc->MutableInputDesc(idx)->SetShape(shape); - node_item.op_desc->MutableInputDesc(idx)->SetOriginShape(ori_shape); -} - -void ShapeInferenceEngine::InferenceState::UpdateInputShapeFuture(uint32_t idx, ShapeFuture &&future) { - if (node_item.const_input_shapes.count(idx) != 0) { - GELOGE(INTERNAL_ERROR, "[%s] Trying to update constant shape, idx = %u", node_item.NodeName().c_str(), idx); - return; - } - - GELOGD("[%s] Update input shape [%u] with ShapeFuture.", node_item.NodeName().c_str(), idx); - num_pending_shapes -= 1; - shape_futures.emplace_back(idx, std::move(future)); -} - -Status ShapeInferenceEngine::InferenceState::AwaitShapeFutures(GraphExecutionContext *context) { - for (auto &p : shape_futures) { - auto idx = p.first; - auto &future = p.second; - GeShape shape; - GeShape ori_shape; - RECORD_SHAPE_INFERENCE_EVENT(context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] Start", idx); - GE_CHK_STATUS_RET(future.Get(ori_shape, shape), "[%s] Get shape failed. index = %u", node_item.NodeName().c_str(), - idx); - RECORD_SHAPE_INFERENCE_EVENT(context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] End", idx); - - GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s]", node_item.NodeName().c_str(), idx, - shape.ToString().c_str(), ori_shape.ToString().c_str()); - node_item.op_desc->MutableInputDesc(idx)->SetShape(std::move(shape)); - node_item.op_desc->MutableInputDesc(idx)->SetOriginShape(ori_shape); - } - return SUCCESS; } - -ShapeInferenceEngine::ShapeFuture::ShapeFuture(NodePtr src_node, uint32_t src_index, NodeDoneManager *node_done_manager) - : src_node_(std::move(src_node)), src_index_(src_index), node_done_manager_(node_done_manager) {} } // namespace hybrid -} // namespace ge \ No newline at end of file +} // namespace ge diff --git a/src/ge/hybrid/executor/worker/shape_inference_engine.h b/src/ge/hybrid/executor/worker/shape_inference_engine.h index b1e1c879..972f8ee1 100644 --- a/src/ge/hybrid/executor/worker/shape_inference_engine.h +++ b/src/ge/hybrid/executor/worker/shape_inference_engine.h @@ -17,75 +17,25 @@ #ifndef GE_HYBRID_EXECUTOR_INFERSHAPE_SHAPE_INFERENCE_ENGINE_H_ #define GE_HYBRID_EXECUTOR_INFERSHAPE_SHAPE_INFERENCE_ENGINE_H_ -#include -#include -#include -#include "common/thread_pool.h" #include "hybrid/executor/hybrid_execution_context.h" +#include "hybrid/executor/subgraph_context.h" namespace ge { namespace hybrid { class ShapeInferenceEngine { public: - explicit ShapeInferenceEngine(GraphExecutionContext *context); - + ShapeInferenceEngine(GraphExecutionContext *execution_context, SubgraphContext *subgraph_context); ~ShapeInferenceEngine() = default; - Status Start(ThreadPool &pool); - - private: - class ShapeFuture { - public: - ShapeFuture(NodePtr src_node, uint32_t src_index, NodeDoneManager *node_done_manager); - ~ShapeFuture() = default; - Status Get(GeShape &ori_shape, GeShape &shape) { - GELOGI("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); - if (!node_done_manager_->Await(src_node_)) { - GELOGE(INTERNAL_ERROR, "cancelled"); - return INTERNAL_ERROR; - } - - shape = src_node_->GetOpDesc()->MutableOutputDesc(src_index_)->MutableShape(); - ori_shape = src_node_->GetOpDesc()->MutableOutputDesc(src_index_)->GetOriginShape(); - GELOGI("Get shape from %s:%u. shape = [%s]", src_node_->GetName().c_str(), src_index_, shape.ToString().c_str()); - return SUCCESS; - } - - private: - NodePtr src_node_; - uint32_t src_index_; - NodeDoneManager *node_done_manager_; - }; - - struct InferenceState { - explicit InferenceState(const NodeItem &node_item); - inline bool IsInputShapesReady() const { return num_pending_shapes == 0; } - - void UpdateInputShape(uint32_t idx, const GeShape &ori_shape, const GeShape &shape); - - Status AwaitShapeFutures(GraphExecutionContext *context); + Status InferShape(NodeState &node_state); - void UpdateInputShapeFuture(uint32_t idx, ShapeFuture &&future); + Status PropagateOutputShapes(const NodeItem &node_item); - const NodeItem &node_item; - - private: - std::vector> shape_futures; - int num_pending_shapes = 0; - }; - - InferenceState *GetOrCreateEntry(const NodeItem &node_item); - - Status InferShapeProcess(); - - void InferenceDone(Status status); - - Status InferShape(InferenceState &entry); - - void PropagateOutputShapes(InferenceState &entry, std::queue &queue); + private: + Status AwaitDependentNodes(NodeState &node_state); - GraphExecutionContext *context_; - std::unordered_map> inference_states_; + GraphExecutionContext *execution_context_; + SubgraphContext *subgraph_context_; }; } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/executor/worker/task_compile_engine.cc b/src/ge/hybrid/executor/worker/task_compile_engine.cc index f6434ffa..57b19f5f 100644 --- a/src/ge/hybrid/executor/worker/task_compile_engine.cc +++ b/src/ge/hybrid/executor/worker/task_compile_engine.cc @@ -16,172 +16,22 @@ #include "hybrid/executor/worker/task_compile_engine.h" #include "init/gelib.h" -#include "framework/common/debug/log.h" #include "hybrid/node_executor/node_executor.h" namespace ge { namespace hybrid { -namespace { -uint32_t kDefaultWorkerCnt = 4; -uint32_t kDefaultDeviceId = 0; -} // namespace -TaskCompileEngine::TaskCompileEngine(GraphExecutionContext *context) : context_(context), pool_(kDefaultWorkerCnt) {} - -TaskCompileEngine::~TaskCompileEngine() { - if (rt_context_ != nullptr) { - GELOGD("To destroy compile context: %p.", rt_context_); - GE_CHK_RT(rtCtxDestroy(rt_context_)); - } -} - -Status TaskCompileEngine::Init() { - GELOGD("Start to init CompileEngine"); - rtContext_t current_ctx = nullptr; - GE_CHK_RT(rtCtxGetCurrent(¤t_ctx)); - GE_CHK_RT_RET(rtCtxCreate(&rt_context_, RT_CTX_GEN_MODE, kDefaultDeviceId)); - GELOGD("Context created for compiling. ctx = %p", rt_context_); - GE_CHK_RT_RET(rtCtxSetCurrent(current_ctx)); - return SUCCESS; -} - -void TaskCompileEngine::Reset() { - complete_queue_.Push(nullptr); // ensure iteration can stop - unique_ptr entry; - while (true) { - complete_queue_.Pop(entry); - if (entry == nullptr) { - break; - } - - if (entry->future != nullptr) { - entry->future->wait(); - } - } - - complete_queue_.Clear(); -} - -Status TaskCompileEngine::Start(ThreadPool &pool) { - pool.commit([&]() { (void)this->CompileProcess(); }); - - worker_future_ = pool_.commit([&]() -> Status { return this->DistributeCompiledTasks(); }); - - if (!worker_future_.valid()) { - GELOGE(INTERNAL_ERROR, "Failed to start worker thread"); - return INTERNAL_ERROR; - } - - return SUCCESS; -} - -Status TaskCompileEngine::CompileProcess() { - auto &compile_queue = context_->compile_queue; - while (true) { - NodeStatePtr node_state; - // Stop() will not be invoked, Pop won't failed - (void)compile_queue.Pop(node_state); - - // EOF - if (node_state == nullptr) { - GELOGD("Got EOF"); - complete_queue_.Push(unique_ptr()); - break; - } - - auto entry = unique_ptr(new (std::nothrow) ResultQueueEntry()); - GE_CHECK_NOTNULL(entry); - entry->node_state = node_state; - - auto node_item = *node_state->node_item; - if (node_item.kernel_task != nullptr) { - GELOGD("use precompiled task. node name = %s", node_item.NodeName().c_str()); - node_state->kernel_task = node_item.kernel_task; - complete_queue_.Push(std::move(entry)); - continue; - } - - auto ret = CompileAsync(*node_state->node_item, *entry); - if (ret == SUCCESS) { - complete_queue_.Push(std::move(entry)); - continue; - } - - // On Error - worker_future_.wait(); - Reset(); - return CompileDone(ret); - } - - Status ret = worker_future_.get(); - Reset(); - return CompileDone(ret); -} - -Status TaskCompileEngine::CompileDone(Status status) { - if (status != SUCCESS) { - GELOGE(status, "Error occurred while compiling node."); - context_->OnError(status); - } else { - context_->execution_queue.Push(nullptr); - } - GELOGI("CompileEngine worker END. ret = %u", status); - return status; -} - -Status TaskCompileEngine::DoCompile(const NodeItem &node_item, NodeState &node_state) { - RECORD_COMPILE_EVENT(context_, node_state.GetName().c_str(), "Start"); - GE_CHK_RT_RET(rtCtxSetCurrent(rt_context_)); - auto ret = node_item.node_executor->CompileTask(*context_->model, node_item.node, node_state.kernel_task); - RECORD_COMPILE_EVENT(context_, node_state.GetName().c_str(), "End"); +Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) { + const auto &node_item = *node_state.GetNodeItem(); + RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "Start"); + GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context)); + + shared_ptr kernel_task; + auto ret = node_item.node_executor->CompileTask(*context->model, node_item.node, kernel_task); + RECORD_COMPILE_EVENT(context, node_state.GetName().c_str(), "End"); GE_CHK_STATUS_RET(ret, "Failed to create task for node: %s", node_item.NodeName().c_str()); + node_state.SetKernelTask(kernel_task); GELOGI("Compiling node %s successfully", node_state.GetName().c_str()); return SUCCESS; } - -Status TaskCompileEngine::CompileAsync(const NodeItem &node_item, ResultQueueEntry &entry) { - auto node_state = entry.node_state; - auto f = pool_.commit([this, node_item, node_state]() -> Status { return DoCompile(node_item, *node_state); }); - - if (!f.valid()) { - GELOGE(INTERNAL_ERROR, "Failed to commit compile task"); - return INTERNAL_ERROR; - } - - entry.future = unique_ptr>(new (std::nothrow) std::future(std::move(f))); - GE_CHECK_NOTNULL(entry.future); - return SUCCESS; -} - -Status TaskCompileEngine::DistributeCompiledTasks() { - GELOGD("DistributeCompiledTasks start."); - auto &execute_queue = context_->execution_queue; - unique_ptr entry; - bool ret = SUCCESS; - while (true) { - if (!complete_queue_.Pop(entry)) { - GELOGE(INTERNAL_ERROR, "Failed to pop item from queue"); - ret = INTERNAL_ERROR; - break; - } - - // EOF - if (entry == nullptr) { - break; - } - - // if has compile future - if (entry->future != nullptr) { - ret = entry->future->get(); - if (ret != SUCCESS) { - break; - } - } - - execute_queue.Push(entry->node_state); - } - - GELOGD("DistributeCompiledTasks out. ret = %u.", ret); - return ret; -} } // namespace hybrid -} // namespace ge +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/executor/worker/task_compile_engine.h b/src/ge/hybrid/executor/worker/task_compile_engine.h index 828a1d8c..a677cb2e 100644 --- a/src/ge/hybrid/executor/worker/task_compile_engine.h +++ b/src/ge/hybrid/executor/worker/task_compile_engine.h @@ -17,44 +17,13 @@ #ifndef GE_HYBRID_EXECUTOR_COMPILE_TASK_COMPILE_ENGINE_H_ #define GE_HYBRID_EXECUTOR_COMPILE_TASK_COMPILE_ENGINE_H_ -#include -#include -#include "common/thread_pool.h" #include "hybrid/executor/hybrid_execution_context.h" namespace ge { namespace hybrid { class TaskCompileEngine { public: - explicit TaskCompileEngine(GraphExecutionContext *context); - - ~TaskCompileEngine(); - - Status Init(); - - Status Start(ThreadPool &pool); - - private: - struct ResultQueueEntry { - NodeStatePtr node_state; - std::unique_ptr> future; - }; - - Status CompileProcess(); - - Status CompileDone(Status status); - - private: - Status DoCompile(const NodeItem &node_item, NodeState &node_state); - Status CompileAsync(const NodeItem &node_item, ResultQueueEntry &entry); - Status DistributeCompiledTasks(); - void Reset(); - - rtContext_t rt_context_ = nullptr; - GraphExecutionContext *context_; - BlockingQueue> complete_queue_; - ThreadPool pool_; - std::future worker_future_; + static Status Compile(NodeState &node_state, GraphExecutionContext *context); }; } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/hybrid_davinci_model.cc b/src/ge/hybrid/hybrid_davinci_model.cc index 58c7d0e3..0454fa72 100644 --- a/src/ge/hybrid/hybrid_davinci_model.cc +++ b/src/ge/hybrid/hybrid_davinci_model.cc @@ -18,6 +18,7 @@ #include "hybrid_davinci_model.h" #include "hybrid/model/hybrid_model.h" #include "hybrid/executor/hybrid_model_async_executor.h" +#include "hybrid/node_executor/node_executor.h" namespace ge { namespace hybrid { @@ -25,14 +26,19 @@ class HybridDavinciModel::Impl { public: explicit Impl(GeRootModelPtr ge_model) : model_(std::move(ge_model)), executor_(&model_) {} - ~Impl() = default; + ~Impl() { NodeExecutorManager::GetInstance().FinalizeExecutors(); } Status Init() { + GE_CHK_STATUS_RET(NodeExecutorManager::GetInstance().EnsureInitialized(), "Failed to initialize executors"); GE_CHK_STATUS_RET(model_.Init(), "Failed to init model.") GE_CHK_STATUS_RET(executor_.Init(), "Failed to init model executor.") return SUCCESS; } + Status Execute(const vector &inputs, vector &outputs) { + return executor_.Execute(inputs, outputs); + } + Status ModelRunStart() { return executor_.Start(listener_); } Status ModelRunStop() { return executor_.Stop(); } @@ -76,6 +82,11 @@ Status HybridDavinciModel::Init() { return impl_->Init(); } +Status HybridDavinciModel::Execute(const vector &inputs, vector &outputs) { + GE_CHECK_NOTNULL(impl_); + return impl_->Execute(inputs, outputs); +} + Status HybridDavinciModel::ModelRunStart() { GE_CHECK_NOTNULL(impl_); return impl_->ModelRunStart(); diff --git a/src/ge/hybrid/hybrid_davinci_model.h b/src/ge/hybrid/hybrid_davinci_model.h index 866b756b..c286a222 100644 --- a/src/ge/hybrid/hybrid_davinci_model.h +++ b/src/ge/hybrid/hybrid_davinci_model.h @@ -37,6 +37,8 @@ class HybridDavinciModel { Status Init(); + Status Execute(const vector &inputs, vector &outputs); + Status ModelRunStart(); Status ModelRunStop(); diff --git a/src/ge/hybrid/hybrid_davinci_model_stub.cc b/src/ge/hybrid/hybrid_davinci_model_stub.cc index bca118f8..7bde98a3 100644 --- a/src/ge/hybrid/hybrid_davinci_model_stub.cc +++ b/src/ge/hybrid/hybrid_davinci_model_stub.cc @@ -26,6 +26,8 @@ std::unique_ptr HybridDavinciModel::Create(const GeRootModel Status HybridDavinciModel::Init() { return UNSUPPORTED; } +Status HybridDavinciModel::Execute(const vector &inputs, vector &outputs) { return UNSUPPORTED; } + Status HybridDavinciModel::ModelRunStart() { return UNSUPPORTED; } Status HybridDavinciModel::ModelRunStop() { return UNSUPPORTED; } diff --git a/src/ge/hybrid/model/graph_item.cc b/src/ge/hybrid/model/graph_item.cc new file mode 100644 index 00000000..528fc4ee --- /dev/null +++ b/src/ge/hybrid/model/graph_item.cc @@ -0,0 +1,62 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "framework/common/util.h" +#include "graph_item.h" + +namespace ge { +namespace hybrid { +namespace { +constexpr int kInvalidIndex = -1; +} // namespace +GraphItem::~GraphItem() { GELOGD("[%s] GraphItem destroyed.", name_.c_str()); } + +const vector &hybrid::GraphItem::GetAllNodes() const { return node_items_; } + +const vector &GraphItem::GetInputNodes() const { return input_nodes_; } + +Status GraphItem::GetOutputDescList(vector &output_desc_list) const { + if (is_dynamic_) { + for (auto &node_and_idx : output_edges_) { + const auto &tensor_desc = node_and_idx.first->op_desc->MutableOutputDesc(node_and_idx.second); + GE_CHECK_NOTNULL(tensor_desc); + output_desc_list.emplace_back(tensor_desc); + } + } else { + auto all_output_desc = output_node_->op_desc->GetAllOutputsDescPtr(); + for (auto &tensor_desc : output_node_->op_desc->GetAllOutputsDescPtr()) { + output_desc_list.emplace_back(tensor_desc); + } + } + + return SUCCESS; +} + +bool GraphItem::IsDynamic() const { return is_dynamic_; } + +const vector &GraphItem::GetInputIndexMapping() const { return input_index_mapping_; } + +int GraphItem::GetParentOutputIndex(size_t index) const { + if (index >= output_index_mapping_.size()) { + return kInvalidIndex; + } + + return output_index_mapping_[index]; +} + +const NodeItem *GraphItem::GetOutputNode() const { return output_node_; } +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/model/graph_item.h b/src/ge/hybrid/model/graph_item.h new file mode 100644 index 00000000..cb0fbbed --- /dev/null +++ b/src/ge/hybrid/model/graph_item.h @@ -0,0 +1,64 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_MODEL_SUBGRAPH_ITEM_H_ +#define GE_HYBRID_MODEL_SUBGRAPH_ITEM_H_ + +#include "external/ge/ge_api_error_codes.h" +#include "hybrid/model/node_item.h" + +namespace ge { +namespace hybrid { +class GraphItem { + public: + GraphItem() = default; + ~GraphItem(); + const vector &GetAllNodes() const; + const vector &GetInputNodes() const; + Status GetOutputDescList(std::vector &output_desc_list) const; + + int TotalInputs() const { return total_inputs_; } + + int TotalOutputs() const { return total_outputs_; } + + const std::string &GetName() const { return name_; } + + void SetName(const string &name) { name_ = name; } + + const NodeItem *GetOutputNode() const; + + bool IsDynamic() const; + int GetParentOutputIndex(size_t index) const; + const vector &GetInputIndexMapping() const; + + private: + friend class HybridModelBuilder; + std::string name_; + std::vector node_items_; + std::vector input_nodes_; + const NodeItem *output_node_ = nullptr; + // + std::vector> output_edges_; + int total_inputs_ = 0; + int total_outputs_ = 0; + + bool is_dynamic_ = true; + std::vector input_index_mapping_; + std::vector output_index_mapping_; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_MODEL_SUBGRAPH_ITEM_H_ diff --git a/src/ge/hybrid/model/hybrid_model.cc b/src/ge/hybrid/model/hybrid_model.cc index e3726aec..0cb81aa3 100644 --- a/src/ge/hybrid/model/hybrid_model.cc +++ b/src/ge/hybrid/model/hybrid_model.cc @@ -29,6 +29,8 @@ namespace ge { namespace hybrid { HybridModel::HybridModel(GeRootModelPtr ge_model) : ge_root_model_(std::move(ge_model)) {} +HybridModel::~HybridModel() { GELOGD("[%s] HybridModel destroyed.", model_name_.c_str()); } + Status HybridModel::Init() { GELOGD("Start to init hybrid model."); GE_CHK_STATUS_RET(HybridModelBuilder(*this).Build(), "Failed to build hybrid model."); @@ -36,22 +38,6 @@ Status HybridModel::Init() { return SUCCESS; } -void HybridModel::Print() const { - for (const auto &node : node_items_) { - GELOGD("%s", node->DebugString().c_str()); - } -} - -TensorValue *HybridModel::GetWeight(const NodeItem *const_node) const { - auto it = weights_.find(const_node->node_id); - if (it == weights_.end() || it->second == nullptr) { - GELOGE(INTERNAL_ERROR, "[%s] Failed to get weight", const_node->NodeName().c_str()); - return nullptr; - } - - return it->second.get(); -} - TensorValue *HybridModel::GetVariable(const string &name) const { auto it = variable_tensors_.find(name); if (it == variable_tensors_.end()) { @@ -83,26 +69,26 @@ const std::vector *HybridModel::GetTaskDefs(const NodePtr &node) } NodeItem *HybridModel::MutableNodeItem(const NodePtr &node) { - auto node_id = node->GetOpDesc()->GetId(); - if (node_id < 0 || static_cast(node_id) > node_items_.size()) { - GELOGE(INTERNAL_ERROR, "index out of range. node_id = %ld, num_nodes = %zu", node_id, node_items_.size()); + auto it = node_items_.find(node); + if (it == node_items_.end()) { return nullptr; } - return node_items_[node_id].get(); + + return it->second.get(); } const NodeItem *HybridModel::GetNodeItem(const NodePtr &node) const { - auto node_id = node->GetOpDesc()->GetId(); - if (node_id < 0 || static_cast(node_id) > node_items_.size()) { - GELOGE(INTERNAL_ERROR, "Index out of range. node_id = %ld, num_nodes = %zu.", node_id, node_items_.size()); + auto it = node_items_.find(node); + if (it == node_items_.end()) { return nullptr; } - return node_items_[node_id].get(); + + return it->second.get(); } GeModelPtr HybridModel::GetGeModel(const NodePtr &node) const { - auto it = known_shape_sub_graphs_.find(node); - if (it == known_shape_sub_graphs_.end()) { + auto it = known_shape_sub_models_.find(node); + if (it == known_shape_sub_models_.end()) { GELOGE(INTERNAL_ERROR, "[%s] Failed to get GeModel for subgraph node.", node->GetName().c_str()); return nullptr; } @@ -110,8 +96,27 @@ GeModelPtr HybridModel::GetGeModel(const NodePtr &node) const { return it->second; } -const vector &HybridModel::GetNetOutputInputOffsets() const { return net_output_input_offsets_; } +const GraphItem *HybridModel::GetRootGraphItem() const { return root_graph_item_.get(); } + +const GraphItem *HybridModel::GetSubgraphItem(const std::string &graph_name) const { + GELOGD("To find subgraph item by name = %s", graph_name.c_str()); + auto it = subgraph_items_.find(graph_name); + if (it == subgraph_items_.end()) { + GELOGD("Subgraph item not found by node = %s", graph_name.c_str()); + return nullptr; + } + + return it->second.get(); +} + +const GraphItem *HybridModel::GetSubgraphItem(const ComputeGraphPtr &subgraph) const { + if (subgraph == nullptr) { + GELOGE(PARAM_INVALID, "subgraph is nullptr"); + return nullptr; + } -void HybridModel::SetDeviceId(uint32_t device_id) { device_id_ = device_id; } + auto subgraph_name = subgraph->GetName(); + return GetSubgraphItem(subgraph_name); +} } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/model/hybrid_model.h b/src/ge/hybrid/model/hybrid_model.h index 007f76c6..f554752e 100644 --- a/src/ge/hybrid/model/hybrid_model.h +++ b/src/ge/hybrid/model/hybrid_model.h @@ -26,39 +26,23 @@ #include "graph/node.h" #include "hybrid/common/tensor_value.h" #include "hybrid/model/node_item.h" +#include "hybrid/model/graph_item.h" #include "model/ge_root_model.h" namespace ge { namespace hybrid { -class HybridModelAsyncExecutor; class HybridModel { public: explicit HybridModel(GeRootModelPtr ge_model); - ~HybridModel() = default; + ~HybridModel(); Status Init(); - const std::vector &RootNodes() const { return root_nodes_; } - const NodeItem *GetNodeItem(const NodePtr &node) const; - size_t NumNodes() const { return node_items_.size(); } - uint64_t GetSessionId() const { return root_runtime_param_.session_id; } - int TotalInputs() const { return total_inputs_; } - - const map &GetInputNodes() const { return input_nodes_; } - - const std::map> &GetInputOffsets() const { return input_offsets_; } - - const vector &GetNetOutputInputOffsets() const; - - const std::vector &GetOutputOffsets() const { return output_offsets_; } - - const std::vector &GetConstNodes() const { return const_nodes_; } - GeModelPtr GetGeModel(const NodePtr &node) const; NodeItem *MutableNodeItem(const NodePtr &node); @@ -67,46 +51,40 @@ class HybridModel { const uint8_t *GetVarMemBase() const { return var_mem_base_; } - void SetDeviceId(uint32_t device_id); + void SetDeviceId(uint32_t device_id) { device_id_ = device_id; } void SetModelId(uint32_t model_id) { model_id_ = model_id; } uint32_t GetModelId() const { return model_id_; } - TensorValue *GetWeight(const NodeItem *const_node) const; - TensorValue *GetVariable(const string &name) const; NodePtr GetVariableNode(const string &name) const; const std::vector *GetTaskDefs(const NodePtr &node) const; - int TotalOutputs() const { return total_outputs_; } + const GraphItem *GetRootGraphItem() const; - GeRootModelPtr GetGeRootModel() const { return ge_root_model_; } - void Print() const; + const GraphItem *GetSubgraphItem(const std::string &graph_name) const; + + const GraphItem *GetSubgraphItem(const ComputeGraphPtr &subgraph) const; private: friend class HybridModelBuilder; friend class HybridModelAsyncExecutor; + std::string model_name_; GeRootModelPtr ge_root_model_; - std::vector root_nodes_; std::map input_nodes_; - std::map> input_offsets_; - std::vector output_offsets_; - std::vector net_output_input_offsets_; - NodeItem *net_output_node_ = nullptr; - std::vector> node_items_; - std::vector const_nodes_; std::map constant_op_nodes_; std::map variable_nodes_; std::map> variable_tensors_; - std::map> weights_; std::map> task_defs_; - std::map known_shape_sub_graphs_; - int total_inputs_ = 0; - int total_outputs_ = 0; + std::map known_shape_sub_models_; + + std::unique_ptr root_graph_item_; + std::map> subgraph_items_; + std::map> node_items_; // runtime fields uint32_t device_id_ = 0; diff --git a/src/ge/hybrid/model/hybrid_model_builder.cc b/src/ge/hybrid/model/hybrid_model_builder.cc index 190890b7..841f1f15 100644 --- a/src/ge/hybrid/model/hybrid_model_builder.cc +++ b/src/ge/hybrid/model/hybrid_model_builder.cc @@ -23,7 +23,6 @@ #include "graph/manager/trans_var_data_utils.h" #include "graph/utils/graph_utils.h" #include "graph/utils/type_utils.h" -#include "framework/common/debug/log.h" #include "hybrid/common/npu_memory_allocator.h" #include "hybrid/node_executor/node_executor.h" @@ -32,6 +31,7 @@ namespace hybrid { namespace { const uint32_t kSubgraphIndex = 0U; const uint32_t kVarOutputIndex = 0U; +const uint32_t kAlignment = 32; const int kBytes = 8; int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) { @@ -46,6 +46,9 @@ int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) { for (size_t dim_index = 0; dim_index < dim_num; ++dim_index) { var_size *= shape.GetDim(dim_index); } + + // padding up to multiple of kAlignment, and add extra kAlignment + var_size = (var_size + kAlignment * 2 - 1) / kAlignment * kAlignment; return var_size; } } // namespace @@ -56,20 +59,19 @@ HybridModelBuilder::HybridModelBuilder(HybridModel &hybrid_model) Status HybridModelBuilder::Build() { GE_CHK_STATUS_RET(ValidateParams(), "Failed to validate GeRootModel"); - graph_name_ = ge_root_model_->GetRootGraph()->GetName(); + hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); GELOGI("[%s] Start to build hybrid model.", GetGraphName()); GE_CHK_STATUS_RET(InitRuntimeParams(), "[%s] Failed to InitRuntimeParams", GetGraphName()); - GE_CHK_STATUS_RET(NodeExecutorManager::GetInstance().EnsureInitialized(), "Failed to initialize executors"); GE_CHK_STATUS_RET(IndexSpecialNodes(), "[%s] Failed to index nodes", GetGraphName()); GE_CHK_STATUS_RET(IndexTaskDefs(), "[%s] Failed to index task defs", GetGraphName()); GE_CHK_STATUS_RET(LoadGraph(), "[%s] Failed to load graph", GetGraphName()); + GE_CHK_STATUS_RET(AssignUninitializedConstantOps(), "[%s] Failed to assign uninitialized constants", GetGraphName()); GE_CHK_STATUS_RET(TransAllVarData(), "[%s] Failed to trans all var data", GetGraphName()); GE_CHK_STATUS_RET(CopyVarData(), "[%s] Failed to copy var data", GetGraphName()); GE_CHK_STATUS_RET(InitModelMem(), "[%s] Failed to init memory", GetGraphName()); GE_CHK_STATUS_RET(InitWeights(), "[%s] Failed to init weights", GetGraphName()); GE_CHK_STATUS_RET(InitConstantOps(), "[%s] Failed to init constant op", GetGraphName()); GE_CHK_STATUS_RET(InitVariableTensors(), "[%s] Failed to init variables", GetGraphName()); - GE_CHK_STATUS_RET(ResolveRootNodes(), "[%s] Failed to resolve root nodes", GetGraphName()); GE_CHK_STATUS_RET(LoadTasks(), "[%s] Failed to load tasks", GetGraphName()); GELOGI("[%s] Done building hybrid model successfully.", GetGraphName()); return SUCCESS; @@ -81,45 +83,17 @@ Status HybridModelBuilder::ValidateParams() { return SUCCESS; } -Status HybridModelBuilder::ResolveRootNodes() { - for (auto &node : hybrid_model_.node_items_) { - if (node->node->GetInDataNodes().empty()) { - hybrid_model_.root_nodes_.emplace_back(node.get()); - GELOGI("[%s] Root node added. node name = %s", GetGraphName(), node->NodeName().c_str()); - } - } - - if (hybrid_model_.root_nodes_.empty()) { - GELOGE(PARAM_INVALID, "[%s] Root nodes is empty.", GetGraphName()); - return PARAM_INVALID; - } - - return SUCCESS; -} - -Status HybridModelBuilder::BuildNoteItem(const NodePtr &node, NodeItem &node_item) { - GE_CHK_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, node_item.is_dynamic), - "[%s] Failed to get shape status.", node->GetName().c_str()); - +Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) { auto op_desc = node->GetOpDesc(); vector dependencies = node->GetOpDesc()->GetOpInferDepends(); GE_CHK_STATUS_RET(ParseDependentInputNodes(node_item, dependencies), "[%s] Failed to parse node dependencies.", node_item.NodeName().c_str()); - auto it = node_ref_inputs_.find(node); - if (it != node_ref_inputs_.end()) { - for (auto &idx_and_node : it->second) { - // var and constant only have one output - node_item.const_input_shapes[idx_and_node.first] = - idx_and_node.second->GetOpDesc()->MutableOutputDesc(kVarOutputIndex); - } - } - node_item.outputs.resize(node_item.num_outputs); for (int i = 0; i < node_item.num_outputs; ++i) { auto out_data_anchor = node->GetOutDataAnchor(i); if (out_data_anchor == nullptr) { - GELOGE(INTERNAL_ERROR, "out anchor[%zu] of node %s is nullptr", i, node->GetName().c_str()); + GELOGE(INTERNAL_ERROR, "out anchor[%d] of node %s is nullptr", i, node->GetName().c_str()); return INTERNAL_ERROR; } @@ -137,27 +111,46 @@ Status HybridModelBuilder::BuildNoteItem(const NodePtr &node, NodeItem &node_ite } } + GE_CHK_STATUS_RET_NOLOG(ResolveRefIo(node_item)); return SUCCESS; } -Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item) { - auto &node_items = hybrid_model_.node_items_; - auto node_id = node->GetOpDesc()->GetId(); - if (node_id < 0 || static_cast(node_id) > node_items.size()) { - GELOGE(INTERNAL_ERROR, "[%s] Index out of range. node_id = %ld, num_nodes = %zu", node->GetName().c_str(), node_id, - node_items.size()); - return INTERNAL_ERROR; +Status HybridModelBuilder::ResolveRefIo(NodeItem &node_item) { + bool is_ref = false; + auto &op_desc = *node_item.op_desc; + (void)AttrUtils::GetBool(op_desc, ATTR_NAME_REFERENCE, is_ref); + if (!is_ref) { + return SUCCESS; + } + + auto inputs = op_desc.GetAllInputName(); + auto outputs = op_desc.GetAllOutputName(); + for (auto &output : outputs) { + for (auto &input : inputs) { + if (input.first == output.first) { + auto input_idx = static_cast(input.second); + auto output_idx = static_cast(output.second); + node_item.reuse_inputs[output_idx] = input_idx; + GELOGD("[%s] Output[%d] reuse input[%d]", node_item.NodeName().c_str(), output_idx, input_idx); + } + } } - auto &node_ptr = node_items[node_id]; - if (node_ptr != nullptr) { - *node_item = node_ptr.get(); + return SUCCESS; +} + +Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item) { + auto &node_items = hybrid_model_.node_items_; + auto it = node_items.find(node); + if (it != node_items.end()) { + *node_item = it->second.get(); return SUCCESS; } auto new_node = std::unique_ptr(new (std::nothrow) NodeItem(node)); GE_CHECK_NOTNULL(new_node); GE_CHECK_NOTNULL(new_node->op_desc); + GE_CHK_STATUS_RET(new_node->Init(), "Failed to init NodeItem [%s] .", node->GetName().c_str()); GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); // we do not need L2 Buffer @@ -169,18 +162,58 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n int32_t unknown_shape_type_val = 0; (void)AttrUtils::GetInt(new_node->op_desc, ::ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, unknown_shape_type_val); new_node->shape_inference_type = static_cast(unknown_shape_type_val); - if (new_node->shape_inference_type == DEPEND_SHAPE_RANGE || new_node->shape_inference_type == DEPEND_COMPUTE) { - new_node->has_observer = true; + + GE_CHK_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, new_node->is_dynamic), + "[%s] Failed to get shape status.", node->GetName().c_str()); + + if (new_node->is_dynamic && (new_node->IsControlOp() || new_node->NodeType() == PARTITIONEDCALL)) { + new_node->shape_inference_type = DEPEND_COMPUTE; } + new_node->node_id = node_index; + new_node->op_desc->SetId(node_index); + node_index += 1; + *node_item = new_node.get(); - node_items[node_id] = std::move(new_node); + node_items[node] = std::move(new_node); return SUCCESS; } Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies) { std::set dependent_input_nodes; auto &ge_node = node_item.node; + + // The input tensors become valid after computation is done for parent nodes of type DEPEND_COMPUTE. + // Wait for these parent nodes before execution. + for (const auto &in_anchor : ge_node->GetAllInDataAnchors()) { + const auto &peer_anchor = in_anchor->GetPeerOutAnchor(); + if (peer_anchor == nullptr) { + GELOGD("[%s] Input[%d] do not have peer anchor", node_item.NodeName().c_str(), in_anchor->GetIdx()); + continue; + } + + auto src_node = peer_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + + auto src_node_item = MutableNodeItem(src_node); + GE_CHECK_NOTNULL(src_node_item); + + if (src_node_item->shape_inference_type == DEPEND_COMPUTE) { + GELOGD("[%s] Add input data dependent node [%s] due to inference type = DEPEND_COMPUTE", + node_item.NodeName().c_str(), src_node_item->NodeName().c_str()); + + src_node_item->has_observer = true; + node_item.dependents_for_execution.emplace_back(src_node); + } + + if (src_node_item->shape_inference_type == DEPEND_SHAPE_RANGE) { + GELOGD("[%s] Add input shape dependent node [%s] due to inference type = DEPEND_SHAPE_RANGE", + node_item.NodeName().c_str(), src_node_item->NodeName().c_str()); + src_node_item->has_observer = true; + dependent_input_nodes.emplace(src_node); + } + } + for (const auto &input_name : dependencies) { int input_index = node_item.op_desc->GetInputIndexByName(input_name); if (input_index < 0) { @@ -205,7 +238,7 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s } for (const auto &dep_node : dependent_input_nodes) { - node_item.dependent_node_list.emplace_back(dep_node); + node_item.dependents_for_shape_inference.emplace_back(dep_node); } return SUCCESS; @@ -262,9 +295,14 @@ Status HybridModelBuilder::DoLinkDataAnchors(OutDataAnchorPtr &out_data_anchor, Status HybridModelBuilder::MergeInputNodes(ComputeGraph &graph) { const auto &wrapped_node = graph.GetParentNode(); + std::set root_nodes; for (const auto &node : graph.GetDirectNode()) { GE_CHECK_NOTNULL(node); if (node->GetType() != DATA_TYPE) { + if (node->GetInDataNodes().empty()) { + root_nodes.emplace(node); + } + continue; } @@ -291,12 +329,28 @@ Status HybridModelBuilder::MergeInputNodes(ComputeGraph &graph) { for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { GE_CHECK_NOTNULL(out_data_anchor); for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { + auto dst_node = peer_in_data_anchor->GetOwnerNode(); + root_nodes.emplace(dst_node); GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(out_data_anchor, peer_in_data_anchor)); GE_CHK_STATUS_RET_NOLOG(DoLinkDataAnchors(src_out_anchor, peer_in_data_anchor)); } } } + // transfer in control edges to all root nodes + for (auto &root_node : root_nodes) { + auto in_nodes = root_node->GetInAllNodes(); + std::set in_node_set(in_nodes.begin(), in_nodes.end()); + for (auto &in_control_node : wrapped_node->GetInControlNodes()) { + if (in_node_set.count(in_control_node) == 0) { + GELOGD("[%s] Restore control edge to [%s]", in_control_node->GetName().c_str(), root_node->GetName().c_str()); + GE_CHECK_NOTNULL(in_control_node->GetOutControlAnchor()); + (void)in_control_node->GetOutControlAnchor()->LinkTo(root_node->GetInControlAnchor()); + } + } + } + + wrapped_node->GetInControlAnchor()->UnlinkAll(); return SUCCESS; } @@ -307,6 +361,11 @@ Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { const auto &net_output_desc = net_output_node->GetOpDesc(); GE_CHECK_NOTNULL(net_output_desc); + auto all_in_nodes = net_output_node->GetInAllNodes(); + auto all_out_nodes = parent_node->GetOutAllNodes(); + net_output_node->GetInControlAnchor()->UnlinkAll(); + parent_node->GetOutControlAnchor()->UnlinkAll(); + for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { auto src_out_anchor = in_data_anchor->GetPeerOutAnchor(); GE_CHECK_NOTNULL(src_out_anchor); @@ -338,10 +397,25 @@ Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { } } + // transfer out control edges + std::set in_node_set(all_in_nodes.begin(), all_in_nodes.end()); + std::set out_node_set(all_out_nodes.begin(), all_out_nodes.end()); + for (auto &src_node : in_node_set) { + GELOGD("[%s] process in node.", src_node->GetName().c_str()); + auto out_nodes = src_node->GetOutAllNodes(); + std::set node_set(out_nodes.begin(), out_nodes.end()); + for (auto &dst_node : out_node_set) { + if (node_set.count(dst_node) == 0) { + src_node->GetOutControlAnchor()->LinkTo(dst_node->GetInControlAnchor()); + GELOGD("[%s] Restore control edge to [%s]", src_node->GetName().c_str(), dst_node->GetName().c_str()); + } + } + } + return SUCCESS; } -Status HybridModelBuilder::MergeSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph) { +Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph) { merged_graph = MakeShared("MergedGraph"); for (const auto &node : root_graph.GetDirectNode()) { GE_CHECK_NOTNULL(node); @@ -371,32 +445,74 @@ Status HybridModelBuilder::MergeSubgraphs(ComputeGraph &root_graph, ComputeGraph } auto subgraph = NodeUtils::GetSubgraph(*node, kSubgraphIndex); - GE_CHK_STATUS_RET(MergeInputNodes(*subgraph), "Failed to merge data nodes for subgraph: %s", - subgraph->GetName().c_str()); - GE_CHK_STATUS_RET(MergeNetOutputNode(*subgraph), "Failed to merge net output nodes for subgraph: %s", - subgraph->GetName().c_str()); - GELOGD("Merging subgraph %s successfully.", subgraph->GetName().c_str()); - for (auto &sub_node : subgraph->GetAllNodes()) { - auto sub_op_type = sub_node->GetType(); - if (sub_op_type == DATA_TYPE || sub_op_type == NETOUTPUT) { - continue; - } + GE_CHECK_NOTNULL(subgraph); + GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, *merged_graph, *subgraph), "[%s] Failed to merge subgraph.", + subgraph->GetName().c_str()); + } - if (sub_op_type == CONSTANT || sub_op_type == CONSTANTOP || sub_op_type == VARIABLE) { - GELOGE(INTERNAL_ERROR, "Unexpected node in unknown subgraph. type = %s, node = %s::%s", sub_op_type.c_str(), - subgraph->GetName().c_str(), sub_node->GetName().c_str()); - return INTERNAL_ERROR; - } + // invoke before adding subgraphs. in case modify node id in known-shaped subgraphs. + GE_CHK_GRAPH_STATUS_RET(merged_graph->TopologicalSorting(), "Failed to invoke TopologicalSorting on merged graph."); + + for (auto &remained_subgraph : root_graph.GetAllSubgraphs()) { + GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str()); + GE_CHK_GRAPH_STATUS_RET(merged_graph->AddSubgraph(remained_subgraph), "Failed to add subgraph [%s]", + remained_subgraph->GetName().c_str()); + } + + return SUCCESS; +} + +Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, ComputeGraph &parent_graph, + ComputeGraph &sub_graph) { + auto parent_node = sub_graph.GetParentNode(); + GE_CHECK_NOTNULL(parent_node); + + GE_CHK_STATUS_RET(MergeInputNodes(sub_graph), "[%s] Failed to merge data nodes for subgraph", + sub_graph.GetName().c_str()); + GE_CHK_STATUS_RET(MergeNetOutputNode(sub_graph), "[%s] Failed to merge net output nodes for subgraph", + sub_graph.GetName().c_str()); + GELOGD("[%s] Done merging subgraph inputs and outputs successfully.", sub_graph.GetName().c_str()); + + for (auto &sub_node : sub_graph.GetDirectNode()) { + auto sub_op_type = sub_node->GetType(); + if (sub_op_type == DATA_TYPE || sub_op_type == NETOUTPUT) { + continue; + } + + if (sub_op_type == CONSTANT || sub_op_type == VARIABLE) { + GELOGE(INTERNAL_ERROR, "Unexpected node in unknown subgraph. type = %s, node = %s::%s", sub_op_type.c_str(), + sub_graph.GetName().c_str(), sub_node->GetName().c_str()); + return INTERNAL_ERROR; + } - merged_graph->AddNode(sub_node); - GELOGD("%s::%s added to merged graph.", subgraph->GetName().c_str(), sub_node->GetName().c_str()); + if (sub_op_type == PARTITIONEDCALL) { + bool is_unknown_shape = false; + GE_CHK_GRAPH_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*sub_node, is_unknown_shape), + "[%s] Failed to invoke GetNodeUnknownShapeStatus.", sub_node->GetName().c_str()); + if (is_unknown_shape) { + auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, kSubgraphIndex); + GE_CHECK_NOTNULL(sub_sub_graph); + GE_CHK_STATUS_RET(UnfoldSubgraph(root_graph, parent_graph, *sub_sub_graph), "[%s] Failed to merge subgraph", + sub_sub_graph->GetName().c_str()); + continue; + } } + + parent_graph.AddNode(sub_node); + GELOGD("[%s::%s] added to parent graph: [%s].", sub_graph.GetName().c_str(), sub_node->GetName().c_str(), + parent_graph.GetName().c_str()); } + GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str()); + root_graph.RemoveSubgraph(sub_graph.GetName()); return SUCCESS; } -Status HybridModelBuilder::ParseNetOutput(const NodeItem &node_item) { +Status HybridModelBuilder::BuildOutputMapping(GraphItem &graph_item, const NodeItem &node_item, bool is_root_graph) { + auto output_size = node_item.op_desc->GetAllInputsSize(); + GE_CHECK_LE(output_size, UINT32_MAX); + graph_item.output_edges_.resize(output_size); + for (auto &in_data_anchor : node_item.node->GetAllInDataAnchors()) { auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); GE_CHECK_NOTNULL(peer_out_anchor); @@ -408,11 +524,20 @@ Status HybridModelBuilder::ParseNetOutput(const NodeItem &node_item) { auto output_offset = src_node_item->output_start + peer_out_anchor->GetIdx(); GELOGI("Output[%d], node = %s, output_index = %d, output_offset = %d ", in_data_anchor->GetIdx(), src_node_item->NodeName().c_str(), peer_out_anchor->GetIdx(), output_offset); - hybrid_model_.output_offsets_.emplace_back(output_offset); + + graph_item.output_edges_[in_data_anchor->GetIdx()] = {src_node_item, peer_out_anchor->GetIdx()}; } - for (int i = 0; i < node_item.num_inputs; ++i) { - hybrid_model_.net_output_input_offsets_.emplace_back(node_item.input_start + i); + if (!is_root_graph) { + for (uint32_t i = 0; i < static_cast(output_size); ++i) { + uint32_t p_index = i; + // Net output of Subgraph of while do not have parent index + if (AttrUtils::GetInt(node_item.op_desc->GetInputDesc(i), ATTR_NAME_PARENT_NODE_INDEX, p_index)) { + GELOGD("[%s] Parent index not set for input[%u].", node_item.NodeName().c_str(), i); + } + + graph_item.output_index_mapping_.emplace_back(p_index); + } } return SUCCESS; @@ -420,82 +545,37 @@ Status HybridModelBuilder::ParseNetOutput(const NodeItem &node_item) { Status HybridModelBuilder::LoadGraph() { auto root_graph = ge_root_model_->GetRootGraph(); - GELOGI("Before merge subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), - root_graph->GetAllNodesSize()); - ComputeGraphPtr merged_graph; - GE_CHK_STATUS_RET_NOLOG(MergeSubgraphs(*root_graph, merged_graph)); - GELOGI("After merge subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", merged_graph->GetDirectNodesSize(), - merged_graph->GetAllNodesSize()); - - merged_graph->SetGraphID(runtime_param_.graph_id); - GE_DUMP(merged_graph, "hybrid_merged_graph"); - int input_start = 0; - int output_start = 0; - uint32_t data_op_index = 0; - hybrid_model_.node_items_.resize(merged_graph->GetDirectNodesSize()); - - int64_t node_index = 0; - for (auto &node : merged_graph->GetDirectNode()) { - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - op_desc->SetId(node_index++); - } - - for (const auto &node : merged_graph->GetDirectNode()) { - GE_CHECK_NOTNULL(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - const auto &op_type = node->GetType(); - - NodeItem *node_item = nullptr; - GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item)); - GE_CHK_STATUS_RET_NOLOG(BuildNoteItem(node, *node_item)); - GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task - - node_item->input_start = input_start; - node_item->output_start = output_start; - input_start += node_item->num_inputs; - output_start += node_item->num_outputs; - - if (op_type == DATA_TYPE || op_type == AIPP_DATA_TYPE) { - auto data_index = data_op_index; - if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_INDEX, data_index)) { - GELOGI("ge_train: get new index %u, old %u", data_index, data_op_index); - } - hybrid_model_.input_nodes_.emplace(data_index, node_item); - data_op_index++; - } else if (op_type == NETOUTPUT) { - hybrid_model_.net_output_node_ = node_item; - GE_CHK_STATUS_RET_NOLOG(ParseNetOutput(*node_item)); - } else if (op_type == PARTITIONEDCALL) { // known graph - GE_CHK_STATUS_RET_NOLOG(ParsePartitionedCall(*node_item)); + GE_CHK_STATUS_RET(LoadDynamicSubgraph(*root_graph, true), "Failed to load root graph."); + GELOGD("Done loading root graph successfully."); + + for (auto &sub_graph : root_graph->GetAllSubgraphs()) { + GE_CHECK_NOTNULL(sub_graph); + GELOGD("Start to load subgraph [%s]", sub_graph->GetName().c_str()); + auto parent_node = sub_graph->GetParentNode(); + GE_CHECK_NOTNULL(parent_node); + auto parent_node_item = MutableNodeItem(parent_node); + // parent node is in another known subgraph + if (parent_node_item == nullptr) { + GELOGD("[%s] Subgraph is in another known shaped subgraph, skip it.", sub_graph->GetName().c_str()); + continue; } - GELOGI("NodeItem created: %s", node_item->DebugString().c_str()); - } - - for (auto &it : hybrid_model_.input_nodes_) { - auto input_index = it.first; - auto input_node = it.second; - - if (input_node->outputs.empty()) { - GELOGE(INTERNAL_ERROR, "data output anchor is empty"); - return INTERNAL_ERROR; - } + if (sub_graph->GetGraphUnknownFlag()) { + GE_CHK_STATUS_RET(LoadDynamicSubgraph(*sub_graph, false), "Failed to load subgraph: [%s]", + sub_graph->GetName().c_str()); + } else { + GE_CHK_STATUS_RET(IdentifyVariableOutputs(*parent_node_item), "[%s] Failed to identify ref outputs.", + parent_node_item->NodeName().c_str()); - for (auto &out : input_node->outputs) { - std::vector offsets; - for (auto &dst_anchor_and_node : out) { - auto dst_node_item = dst_anchor_and_node.second; - offsets.emplace_back(dst_node_item->input_start + dst_anchor_and_node.first); + // if parent is function control op. need add a virtual partitioned call + if (parent_node_item->IsControlOp()) { + GE_CHK_STATUS_RET(LoadKnownShapedSubgraph(*sub_graph, parent_node_item), + "Failed to load function control op subgraph [%s]", sub_graph->GetName().c_str()); } - - hybrid_model_.input_offsets_.emplace(input_index, std::move(offsets)); } } - hybrid_model_.total_inputs_ = input_start; - hybrid_model_.total_outputs_ = output_start; - GELOGI("HybridGraph::LoadGraph OUT"); + GELOGI("Done loading all subgraphs successfully."); return SUCCESS; } @@ -507,7 +587,6 @@ Status HybridModelBuilder::VarNodeToTensor(const NodePtr &var_node, std::unique_ string var_name = var_node->GetName(); auto tensor_desc = var_node->GetOpDesc()->MutableOutputDesc(0); uint8_t *var_logic = nullptr; - GE_CHK_STATUS_RET(var_manager_->GetVarAddr(var_name, *tensor_desc, &var_logic), "Failed to get var addr. var_name = %s, session_id = %ld", var_name.c_str(), hybrid_model_.GetSessionId()); @@ -559,10 +638,26 @@ Status HybridModelBuilder::HandleDtString(const GeTensor &tensor, void *var_addr return SUCCESS; } +Status HybridModelBuilder::AssignUninitializedConstantOps() { + for (auto &it : hybrid_model_.constant_op_nodes_) { + const string &var_name = it.first; + const NodePtr &var_node = it.second; + auto tensor_desc = var_node->GetOpDesc()->MutableOutputDesc(0); + if (!var_manager_->IsVarExist(var_name, *tensor_desc)) { + // allocate constant + GELOGD("[%s] Constant not allocated during graph building. now allocate it.", var_name.c_str()); + GE_CHK_STATUS_RET(var_manager_->AssignVarMem(var_name, *tensor_desc, RT_MEMORY_HBM)); + GE_CHK_STATUS_RET(var_manager_->SetAllocatedGraphId(var_name, runtime_param_.graph_id)); + } + } + + return SUCCESS; +} + Status HybridModelBuilder::InitConstantOps() { for (auto &it : hybrid_model_.constant_op_nodes_) { - string var_name = it.first; - NodePtr &var_node = it.second; + const string &var_name = it.first; + const NodePtr &var_node = it.second; std::unique_ptr var_tensor; GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, var_tensor)); @@ -578,7 +673,7 @@ Status HybridModelBuilder::InitConstantOps() { if (ge_tensor->GetData().size() > 0) { GE_CHK_STATUS_RET_NOLOG(HandleDtString(*ge_tensor, v_output_addr)); - GELOGI("[IMAS]InitConstant memcpy graph_%u type[V] name[%s] output[%d] memaddr[%p] mem_size[%u] datasize[%zu]", + GELOGI("[IMAS]InitConstant memcpy graph_%u type[V] name[%s] output[%d] memaddr[%p] mem_size[%zu] datasize[%zu]", runtime_param_.graph_id, op_desc->GetName().c_str(), 0, v_output_addr, v_output_size, ge_tensor->GetData().size()); GE_CHK_RT_RET(rtMemcpy(v_output_addr, v_output_size, ge_tensor->GetData().data(), ge_tensor->GetData().size(), @@ -614,7 +709,8 @@ Status HybridModelBuilder::InitWeights() { } Status HybridModelBuilder::LoadTasks() { - for (auto &node_item : hybrid_model_.node_items_) { + for (auto &it : hybrid_model_.node_items_) { + auto &node_item = it.second; auto &node_ptr = node_item->node; if (node_item->node_type == NETOUTPUT) { continue; @@ -622,7 +718,6 @@ Status HybridModelBuilder::LoadTasks() { GELOGD("[%s] Start to build kernel task", node_ptr->GetName().c_str()); auto load_ret = node_item->node_executor->LoadTask(hybrid_model_, node_ptr, node_item->kernel_task); - if (load_ret != UNSUPPORTED && load_ret != SUCCESS) { GELOGE(load_ret, "[%s] Failed to load task", node_ptr->GetName().c_str()); return load_ret; @@ -634,6 +729,23 @@ Status HybridModelBuilder::LoadTasks() { return SUCCESS; } +Status HybridModelBuilder::LoadGeModel(ComputeGraph &sub_graph, const GeModelPtr &ge_model) { + auto parent_node = sub_graph.GetParentNode(); + GE_CHECK_NOTNULL(parent_node); + auto op_type = parent_node->GetType(); + if (op_type == IF || op_type == CASE || op_type == WHILE) { + GELOGD("Set ge_model for control op subgraph: [%s], task_size = %d", sub_graph.GetName().c_str(), + ge_model->GetModelTaskDefPtr()->task_size()); + subgraph_models_.emplace(sub_graph.GetName(), ge_model); + } else { + GELOGD("Set ge_model for subgraph: [%s], task_size = %d", sub_graph.GetName().c_str(), + ge_model->GetModelTaskDefPtr()->task_size()); + hybrid_model_.known_shape_sub_models_.emplace(sub_graph.GetParentNode(), ge_model); + } + + return SUCCESS; +} + Status HybridModelBuilder::IndexTaskDefs() { const auto &root_graph = ge_root_model_->GetRootGraph(); for (auto &it : ge_root_model_->GetSubgraphInstanceNameToModel()) { @@ -646,12 +758,9 @@ Status HybridModelBuilder::IndexTaskDefs() { continue; } - bool is_unknown_shape = false; - GE_CHK_GRAPH_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*sub_graph->GetParentNode(), is_unknown_shape), - "Failed to invoke GetNodeUnknownShapeStatus."); + bool is_unknown_shape = sub_graph->GetGraphUnknownFlag(); if (!is_unknown_shape) { - GELOGD("Set ge_model for subgraph: %s", sub_graph->GetName().c_str()); - hybrid_model_.known_shape_sub_graphs_.emplace(sub_graph->GetParentNode(), ge_model); + GE_CHK_STATUS_RET_NOLOG(LoadGeModel(*sub_graph, ge_model)); continue; } @@ -676,6 +785,8 @@ Status HybridModelBuilder::IndexTaskDefs() { op_index = task_def.kernel().context().op_index(); } else if (task_type == RT_MODEL_TASK_KERNEL_EX) { op_index = task_def.kernel_ex().op_index(); + } else if (task_type == RT_MODEL_TASK_HCCL) { + op_index = task_def.kernel_hccl().op_index(); } else { GELOGD("Skip task type: %d", static_cast(task_type)); continue; @@ -790,12 +901,12 @@ Status HybridModelBuilder::GetPeerNodeAcrossSubGraphs(const NodePtr &data_node, for (uint32_t i = 0; i < static_cast(input_size); ++i) { uint32_t p_index = 0; if (!AttrUtils::GetInt(net_output_desc->GetInputDesc(i), ATTR_NAME_PARENT_NODE_INDEX, p_index)) { - GELOGW("SubGraph: %s input tensor %zu attr %s not found.", src_graph->GetName().c_str(), i, + GELOGW("SubGraph: %s input tensor %u attr %s not found.", src_graph->GetName().c_str(), i, ATTR_NAME_PARENT_NODE_INDEX.c_str()); continue; } - GELOGD("NetOutput's input[%zu], parent_node_index = %u", i, p_index); + GELOGD("NetOutput's input[%u], parent_node_index = %u", i, p_index); if (p_index == out_index) { auto in_anchor = src_net_output_node->GetInDataAnchor(i); GE_CHECK_NOTNULL(in_anchor); @@ -830,7 +941,7 @@ Status HybridModelBuilder::InitRuntimeParams() { ret = ge::AttrUtils::GetInt(first_model, ATTR_MODEL_VAR_SIZE, value); runtime_param_.var_size = ret ? (uint64_t)value : 0; runtime_param_.graph_id = ge_root_model_->GetRootGraph()->GetGraphID(); - GELOGI("InitRuntimeParams(), session_id:%u, var_size:%lu. graph_id = %u", runtime_param_.session_id, + GELOGI("InitRuntimeParams(), session_id:%lu, var_size:%lu. graph_id = %u", runtime_param_.session_id, runtime_param_.var_size, runtime_param_.graph_id); var_manager_ = VarManager::Instance(runtime_param_.session_id); @@ -838,7 +949,7 @@ Status HybridModelBuilder::InitRuntimeParams() { return SUCCESS; } -Status HybridModelBuilder::ParsePartitionedCall(NodeItem &node_item) { +Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { GELOGD("Start to parse outputs of node: %s", node_item.NodeName().c_str()); auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); GE_CHECK_NOTNULL(subgraph); @@ -847,6 +958,7 @@ Status HybridModelBuilder::ParsePartitionedCall(NodeItem &node_item) { auto net_output_desc = net_output_node->GetOpDesc(); GE_CHECK_NOTNULL(net_output_desc); + // constant/variable connected to net output for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { auto src_node = GetPeerNode(in_data_anchor); GE_CHECK_NOTNULL(src_node); @@ -864,6 +976,8 @@ Status HybridModelBuilder::ParsePartitionedCall(NodeItem &node_item) { node_item.ref_outputs.emplace(parent_index, src_node); } + // Data nodes marked with REF_VAR_SRC_VAR_NAME + // Using variable tensor as data's output for (auto &node : subgraph->GetDirectNode()) { if (node->GetType() != DATA) { continue; @@ -912,6 +1026,11 @@ Status HybridModelBuilder::GetParentNodeOutputIndex(const OpDesc &op_desc, int i Status HybridModelBuilder::InitModelMem() { hybrid_model_.var_mem_base_ = var_manager_->GetVarMemoryBase(RT_MEMORY_HBM); auto total_var_size = hybrid_model_.TotalVarMemSize(); + if (total_var_size == 0 && !hybrid_model_.constant_op_nodes_.empty()) { + total_var_size = var_manager_->GetVarMemSize(RT_MEMORY_HBM) > 0 ? var_manager_->GetVarMemMaxSize() : 0; + GELOGD("Model var size = 0. but got uninitialized constant. set var size to %zu.", total_var_size); + } + if (total_var_size > 0 && hybrid_model_.var_mem_base_ == nullptr) { GE_CHK_STATUS_RET(var_manager_->MallocVarMemory(total_var_size), "Malloc Var Memory Fail."); hybrid_model_.var_mem_base_ = var_manager_->GetVarMemoryBase(RT_MEMORY_HBM); @@ -951,5 +1070,154 @@ Status HybridModelBuilder::CopyVarData() { GELOGI("CopyVarData success."); return SUCCESS; } + +Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem *parent_node_item) { + GELOGD("Start to load known shaped subgraph [%s]", graph.GetName().c_str()); + auto graph_item = std::unique_ptr(new (std::nothrow) GraphItem()); + GE_CHECK_NOTNULL(graph_item); + graph_item->is_dynamic_ = false; + auto subgraph_name = graph.GetName(); + auto wrapper_op_desc = MakeShared(subgraph_name + "_partitioned_call", PARTITIONEDCALL); + GE_CHECK_NOTNULL(wrapper_op_desc); + + for (auto &node : graph.GetDirectNode()) { + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + const auto &op_type = node->GetType(); + + if (op_type == DATA) { + int32_t data_index = 0; + if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, data_index)) { + GELOGE(FAILED, "[%s] Failed to get attr [%s]", node->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); + return FAILED; + } + + (void)wrapper_op_desc->AddInputDesc(op_desc->GetInputDesc(0)); + graph_item->input_index_mapping_.emplace_back(data_index); + } else if (op_type == NETOUTPUT) { + int output_index = 0; + for (const auto &output_desc : op_desc->GetAllInputsDescPtr()) { + int32_t data_index = output_index++; + if (!AttrUtils::GetInt(output_desc, ATTR_NAME_PARENT_NODE_INDEX, data_index)) { + GELOGI("[%s] Failed to get attr [%s]", node->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); + } + + GE_CHK_GRAPH_STATUS_RET(wrapper_op_desc->AddOutputDesc(*output_desc), + "[%s] Failed to add output desc. output index = %d", graph.GetName().c_str(), + output_index); + + graph_item->output_index_mapping_.emplace_back(data_index); + } + } + } + + auto temp_graph = MakeShared("temp"); + GE_CHECK_NOTNULL(temp_graph); + auto wrapper_node = temp_graph->AddNode(wrapper_op_desc); + GeModelPtr ge_model = subgraph_models_[subgraph_name]; + GE_CHECK_NOTNULL(ge_model); + hybrid_model_.known_shape_sub_models_.emplace(wrapper_node, ge_model); + + NodeItem *node_item = nullptr; + GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(wrapper_node, &node_item)); + node_item->input_start = 0; + node_item->output_start = 0; + node_item->outputs.resize(node_item->num_outputs); + graph_item->node_items_.emplace_back(node_item); + graph_item->output_node_ = node_item; + graph_item->total_inputs_ = node_item->num_inputs; + graph_item->total_outputs_ = node_item->num_outputs; + + GELOGD("NodeItem create for known shape subgraph [%s], NodeItem = %s", graph.GetName().c_str(), + node_item->DebugString().c_str()); + + GELOGD("Done parse known shape subgraph successfully. graph = [%s]", graph.GetName().c_str()); + graph_item->SetName(graph.GetName()); + GELOGD("Done loading known shape subgraph: [%s]", graph_item->GetName().c_str()); + hybrid_model_.subgraph_items_.emplace(graph.GetName(), std::move(graph_item)); + return SUCCESS; +} + +Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root_graph) { + GELOGD("Start to load subgraph [%s]", graph.GetName().c_str()); + // for known partitioned call, load all nodes + auto graph_item = std::unique_ptr(new (std::nothrow) GraphItem()); + GE_CHECK_NOTNULL(graph_item); + + graph_item->is_dynamic_ = true; + graph_item->node_items_.reserve(graph.GetDirectNodesSize()); + int input_start = 0; + int output_start = 0; + std::vector data_nodes; + for (auto &node : graph.GetDirectNode()) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + const auto &op_type = node->GetType(); + + NodeItem *node_item = nullptr; + GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item)); + GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item)); + GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task + + node_item->input_start = input_start; + node_item->output_start = output_start; + input_start += node_item->num_inputs; + output_start += node_item->num_outputs; + + if (op_type == DATA_TYPE || op_type == AIPP_DATA_TYPE) { + data_nodes.emplace_back(node_item); + } else if (op_type == NETOUTPUT) { + graph_item->output_node_ = node_item; + GE_CHK_STATUS_RET_NOLOG(BuildOutputMapping(*graph_item, *node_item, is_root_graph)); + } + + graph_item->node_items_.emplace_back(node_item); + GELOGD("NodeItem created: %s", node_item->DebugString().c_str()); + } + + graph_item->total_inputs_ = input_start; + graph_item->total_outputs_ = output_start; + GE_CHK_STATUS_RET_NOLOG(BuildInputMapping(*graph_item, data_nodes, is_root_graph)); + if (is_root_graph) { + graph_item->SetName("Root-Graph"); + GELOGD("Done loading dynamic subgraph: [%s]", graph_item->GetName().c_str()); + hybrid_model_.root_graph_item_ = std::move(graph_item); + } else { + graph_item->SetName(graph.GetName()); + GELOGD("Done loading dynamic subgraph: [%s]", graph_item->GetName().c_str()); + hybrid_model_.subgraph_items_.emplace(graph.GetName(), std::move(graph_item)); + } + + return SUCCESS; +} + +Status HybridModelBuilder::BuildInputMapping(GraphItem &graph_item, vector &data_nodes, + bool is_root_graph) { + uint32_t data_op_index = 0; + for (auto &node_item : data_nodes) { + auto node = node_item->node; + int data_index = data_op_index; + if (is_root_graph) { + if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_INDEX, data_index)) { + GELOGI("ge_train: get new index %u, old %u", data_index, data_op_index); + } + data_op_index++; + } else { + if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, data_index)) { + GELOGE(FAILED, "[%s] Failed to get attr [%s]", node->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); + return FAILED; + } + } + + if (graph_item.input_nodes_.size() <= static_cast(data_index)) { + graph_item.input_nodes_.resize(data_index + 1); + } + + graph_item.input_nodes_[data_index] = node_item; + } + + return SUCCESS; +} } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/model/hybrid_model_builder.h b/src/ge/hybrid/model/hybrid_model_builder.h index 33cd1f03..1103aa1c 100644 --- a/src/ge/hybrid/model/hybrid_model_builder.h +++ b/src/ge/hybrid/model/hybrid_model_builder.h @@ -46,18 +46,20 @@ class HybridModelBuilder { static Status HandleDtString(const GeTensor &tensor, void *var_addr); static Status MergeInputNodes(ComputeGraph &compute_graph); static Status MergeNetOutputNode(ComputeGraph &compute_graph); - static Status MergeSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph); + static Status UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph); + static Status UnfoldSubgraph(ComputeGraph &root_graph, ComputeGraph &parent_graph, ComputeGraph &sub_graph); static Status InitWeights(); - + static Status BuildInputMapping(GraphItem &graph_item, std::vector &data_nodes, bool is_root_graph); + static Status ResolveRefIo(NodeItem &node_item); + Status BuildOutputMapping(GraphItem &partitioned_call, const NodeItem &node_item, bool is_root_graph); Status ValidateParams(); Status LoadGraph(); + Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); Status LoadTasks(); - Status ParsePartitionedCall(NodeItem &node_item); - Status ParseNetOutput(const NodeItem &node_item); - Status BuildNoteItem(const NodePtr &node, NodeItem &node_item); + Status IdentifyVariableOutputs(NodeItem &node_item); + Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); Status ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies); - Status ResolveRootNodes(); Status IndexTaskDefs(); Status IndexSpecialNodes(); Status InitRuntimeParams(); @@ -65,19 +67,23 @@ class HybridModelBuilder { Status TransAllVarData(); Status CopyVarData(); Status VarNodeToTensor(const NodePtr &var_node, std::unique_ptr &tensor); + Status AssignUninitializedConstantOps(); Status InitConstantOps(); Status InitVariableTensors(); + Status LoadDynamicSubgraph(ComputeGraph &graph, bool is_root_graph); + Status LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem *parent_node_item); - const char *GetGraphName() const { return graph_name_.c_str(); } + const char *GetGraphName() const { return hybrid_model_.model_name_.c_str(); } const NodeItem *GetNodeItem(const NodePtr &node) const; NodeItem *MutableNodeItem(const NodePtr &node); GeRootModelPtr ge_root_model_; - std::string graph_name_; std::map> weights_; + std::map subgraph_models_; HybridModel &hybrid_model_; std::map>> node_ref_inputs_; + int node_index = 0; RuntimeParam &runtime_param_; VarManager *var_manager_ = nullptr; diff --git a/src/ge/hybrid/model/node_item.cc b/src/ge/hybrid/model/node_item.cc index b5d4fbda..e1cd7f64 100644 --- a/src/ge/hybrid/model/node_item.cc +++ b/src/ge/hybrid/model/node_item.cc @@ -16,6 +16,8 @@ #include "node_item.h" #include +#include "common/debug/log.h" +#include "hybrid/node_executor/node_executor.h" namespace ge { namespace hybrid { @@ -28,12 +30,34 @@ NodeItem::NodeItem(NodePtr node) : node(std::move(node)) { this->node_type = this->node->GetType(); } +Status NodeItem::Init() { + for (int i = 0; i < num_inputs; ++i) { + const auto &input_desc = op_desc->MutableInputDesc(i); + GE_CHECK_NOTNULL(input_desc); + if (input_desc->MutableShape().IsUnknownShape()) { + is_input_shape_static.push_back(false); + } else { + num_static_input_shapes++; + is_input_shape_static.push_back(true); + GELOGD("[%s] The shape of input[%d] is static. shape = [%s]", NodeName().c_str(), i, + input_desc->MutableShape().ToString().c_str()); + } + } + + return SUCCESS; +} + +bool NodeItem::IsControlOp() const { + auto op_type = op_desc->GetType(); + return op_type == IF || op_type == CASE || op_type == WHILE || op_type == FOR; +} + std::string NodeItem::DebugString() const { std::stringstream ss; ss << "Node: "; ss << "id = " << node_id; - ss << ", name = " << node->GetName(); - ss << ", type = " << node->GetType(); + ss << ", name = [" << node->GetName(); + ss << "], type = " << node->GetType(); ss << ", is_dynamic = " << (is_dynamic ? "True" : "False"); ss << ", unknown_shape_op_type = " << shape_inference_type; ss << ", input_start = " << input_start; @@ -41,7 +65,7 @@ std::string NodeItem::DebugString() const { ss << ", output_start = " << output_start; ss << ", num_outputs = " << num_outputs; ss << ", dependent_nodes = ["; - for (const auto &dep_node : dependent_node_list) { + for (const auto &dep_node : dependents_for_shape_inference) { ss << dep_node->GetName() << ", "; } ss << "]"; @@ -55,5 +79,18 @@ std::string NodeItem::DebugString() const { return ss.str(); } + +void NodeItem::SetToDynamic() { + num_static_input_shapes = 0; + is_dynamic = true; + for (size_t i = 0; i < is_input_shape_static.size(); ++i) { + is_input_shape_static[i] = false; + } + if (kernel_task != nullptr && !kernel_task->IsSupportDynamicShape()) { + GELOGD("[%s] Dynamic shape is not supported, clear node task.", node_name.c_str()); + kernel_task = nullptr; + } +} + } // namespace hybrid } // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/model/node_item.h b/src/ge/hybrid/model/node_item.h index b12d100b..4e6d770b 100644 --- a/src/ge/hybrid/model/node_item.h +++ b/src/ge/hybrid/model/node_item.h @@ -18,6 +18,7 @@ #define GE_HYBRID_MODEL_NODE_ITEM_H_ #include +#include "external/ge/ge_api_error_codes.h" #include "graph/node.h" #include "graph/op_desc.h" #include "framework/common/types.h" @@ -33,10 +34,16 @@ struct NodeItem { explicit NodeItem(NodePtr node); ~NodeItem() = default; + Status Init(); + const std::string &NodeName() const { return node_name; } const std::string &NodeType() const { return node_type; } + bool IsControlOp() const; + + void SetToDynamic(); + std::string DebugString() const; NodePtr node; @@ -52,17 +59,21 @@ struct NodeItem { UnknowShapeOpType shape_inference_type = DEPEND_IN_SHAPE; std::string node_name; std::string node_type; - std::vector dependent_node_list; + std::vector dependents_for_shape_inference; + std::vector dependents_for_execution; std::set to_const_output_id_list; - // src_output_id, dst_anchor_id, dst_node vector inputs; + // src_output_id, dst_anchor_id, dst_node vector>> outputs; std::shared_ptr kernel_task; const NodeExecutor *node_executor = nullptr; - std::map const_input_shapes; std::map ref_outputs; + std::map reuse_inputs; + + std::vector is_input_shape_static; + int num_static_input_shapes = 0; }; } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/node_executor/aicore/aicore_node_executor.cc b/src/ge/hybrid/node_executor/aicore/aicore_node_executor.cc index 3f198bba..50c8e899 100644 --- a/src/ge/hybrid/node_executor/aicore/aicore_node_executor.cc +++ b/src/ge/hybrid/node_executor/aicore/aicore_node_executor.cc @@ -16,10 +16,8 @@ #include "aicore_node_executor.h" #include "cce/taskdown_common.hpp" -#include "graph/debug/ge_attr_define.h" -#include "hybrid/model/hybrid_model.h" +#include "hybrid/executor/hybrid_execution_context.h" #include "init/gelib.h" -#include "framework/common/debug/log.h" namespace ge { namespace hybrid { @@ -27,16 +25,47 @@ REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICORE, AiCore AiCoreNodeTask::AiCoreNodeTask(std::vector> &&tasks) : tasks_(std::move(tasks)) {} +Status AiCoreNodeExecutor::Initialize() { + auto ge_lib = GELib::GetInstance(); + GE_CHECK_NOTNULL(ge_lib); + if (!ge_lib->InitFlag()) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge_lib is uninitialized, failed."); + return GE_CLI_GE_NOT_INITIALIZED; + } + + auto &kernel_manager = ge_lib->OpsKernelManagerObj(); + auto aic_ops_store = kernel_manager.GetOpsKernelInfoStore("AIcoreEngine"); + GE_CHECK_NOTNULL(aic_ops_store); + + compiler_.reset(new (std::nothrow) AiCoreTaskCompiler(aic_ops_store)); + GE_CHECK_NOTNULL(compiler_); + return SUCCESS; +} + Status AiCoreNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const { GE_CHECK_NOTNULL(node); - GELOGI("AiCoreNodeExecutor[%s] LoadTask Start.", node->GetName().c_str()); + GELOGI("AiCoreNodeExecutor(%s) LoadTask Start.", node->GetName().c_str()); auto *task_defs = model.GetTaskDefs(node); - Status ret = SUCCESS; - GE_IF_BOOL_EXEC(task_defs != nullptr && !task_defs->empty(), ret = CreateTask(model, *task_defs, node, task)); + if (task_defs == nullptr || task_defs->empty()) { + bool dynamic_flag = false; + if (!AttrUtils::GetBool(node->GetOpDesc(), "support_dynamicshape", dynamic_flag) || !dynamic_flag) { + GELOGD("Skip create task of node (%s) as 'support_dynamicshape' is false and cann't get task_defs.", + node->GetName().c_str()); + return SUCCESS; + } else { + GELOGE(FAILED, "Task_defs is empty for node (%s) which 'support_dynamicshape' is true, failed.", + node->GetName().c_str()); + return FAILED; + } + } - GELOGI("AiCoreNodeExecutor[%s] LoadTask End, ret[%u].", node->GetName().c_str(), ret); - return ret; + AiCoreTaskBuilder builder(node->GetOpDesc(), *task_defs); + std::unique_ptr node_task; + GE_CHK_STATUS_RET(builder.BuildTask(node_task, true), "[%s] Failed to build op tasks.", node->GetName().c_str()); + task = std::move(node_task); + GELOGI("AiCoreNodeExecutor(%s) LoadTask End.", node->GetName().c_str()); + return SUCCESS; } Status AiCoreNodeExecutor::GenNodeKey(const NodePtr &node, std::string &node_key) { @@ -47,16 +76,19 @@ Status AiCoreNodeExecutor::GenNodeKey(const NodePtr &node, std::string &node_key // make sure unique, (op_id + input_shape) is unique node_key = std::to_string(op_desc->GetId()) + "/"; node_key.append(std::to_string(op_desc->GetInputsSize())); - auto input_descs = op_desc->GetAllInputsDesc(); - for (auto input_desc : input_descs) { + auto input_descs = op_desc->GetAllInputsDescPtr(); + for (auto &input_desc : input_descs) { node_key.push_back('/'); - std::vector dims = input_desc.GetShape().GetDims(); - GE_IF_BOOL_EXEC(dims.size() == 0, continue); // scalar - for (std::size_t i = 0; i < dims.size() - 1; i++) { - node_key.append(std::to_string(dims[i])); + auto &shape = input_desc->MutableShape(); + auto num_dims = shape.GetDimNum(); + if (num_dims == 0) { + continue; + } // scalar + for (std::size_t i = 0; i < num_dims - 1; i++) { + node_key.append(std::to_string(shape.GetDim(i))); node_key.push_back(','); } - node_key.append(std::to_string(dims[dims.size() - 1])); + node_key.append(std::to_string(shape.GetDim(num_dims - 1))); } return SUCCESS; } @@ -65,8 +97,10 @@ bool AiCoreNodeTaskRegistry::AddTask(const std::string &node_key, const std::sha GE_CHECK_NOTNULL(task); std::lock_guard lock(mutex_); auto iter = reg_node_tasks_.find(node_key); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(iter != reg_node_tasks_.end(), return false, - "AiCoreNodeTaskRegistry[%s] AddTask failed, key already exist.", node_key.c_str()); + if (iter != reg_node_tasks_.end()) { + GELOGE(FAILED, "AiCoreNodeTaskRegistry(%s) AddTask failed, key already exist.", node_key.c_str()); + return false; + } auto ret = reg_node_tasks_.emplace(node_key, task); return ret.second; } @@ -80,231 +114,84 @@ std::shared_ptr AiCoreNodeTaskRegistry::GetTask(const std::string &nod Status AiCoreNodeExecutor::CompileTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const { GE_CHECK_NOTNULL(node); - GELOGI("AiCoreNodeExecutor[%s] CompileTask Start.", node->GetName().c_str()); + GELOGI("AiCoreNodeExecutor(%s) CompileTask Start.", node->GetName().c_str()); AiCoreNodeTaskRegistry ®istry = AiCoreNodeTaskRegistry::GetInstance(); std::string node_key; - GE_CHK_STATUS_RET(GenNodeKey(node, node_key), "GenNodeKey failed. op name = %s", node->GetName().c_str()); + GE_CHK_STATUS_RET(GenNodeKey(node, node_key), "GenNodeKey failed, op name = %s.", node->GetName().c_str()); + node_key = std::to_string(model.GetModelId()) + "/" + node_key; GELOGD("NodeKey for %s = %s", node->GetName().c_str(), node_key.c_str()); task = registry.GetTask(node_key); - GE_CHK_TRUE_EXEC_INFO(task != nullptr, return SUCCESS, "AiCoreNodeExecutor[%s] CompileTask Skip.", - node->GetName().c_str()); + if (task != nullptr) { + GELOGI("AiCoreNodeExecutor(%s) CompileTask Skip.", node->GetName().c_str()); + return SUCCESS; + } std::vector task_defs; - GE_CHK_STATUS_RET_NOLOG(compiler_->CompileOp(node, task_defs)); + GE_CHK_STATUS_RET(compiler_->CompileOp(node, task_defs), "Compile op(%s) failed.", node->GetName().c_str()); GELOGD("successfully generated task_defs: %s", node->GetName().c_str()); - GE_CHK_STATUS_RET_NOLOG(CreateTask(model, task_defs, node, task)); + AiCoreTaskBuilder builder(node->GetOpDesc(), task_defs); + std::unique_ptr node_task; + GE_CHK_STATUS_RET(builder.BuildTask(node_task, false), "[%s] Failed to build op tasks.", node->GetName().c_str()); + task = std::move(node_task); GELOGD("successfully created node task: %s", node->GetName().c_str()); - GE_CHK_BOOL_EXEC(registry.AddTask(node_key, task), return INTERNAL_ERROR, "Add NodeTask failed. op name = %s", - node->GetName().c_str()); // should not happen. - GELOGI("AiCoreNodeExecutor[%s] CompileTask End.", node->GetName().c_str()); - return SUCCESS; -} - -Status AiCoreNodeExecutor::BuildAiCoreTask(const domi::KernelDef &kernel_def, const OpDescPtr &op_desc, - AiCoreOpTask **task) { - GE_CHECK_NOTNULL(op_desc); - GE_CHECK_NOTNULL(task); - - const auto &context = kernel_def.context(); - auto kernel_type = static_cast(context.kernel_type()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(kernel_type != cce::ccKernelType::TE, return UNSUPPORTED, - "Only TBE kernel is supported, but [%s] got %u", op_desc->GetName().c_str(), - context.kernel_type()); - - auto *aicore_task = new (std::nothrow) AiCoreOpTask(); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(aicore_task == nullptr, return MEMALLOC_FAILED, "Create AiCore op task failed."); - - auto builder = AiCoreTaskBuilder(op_desc, kernel_def); - auto ret = builder.BuildTask(*aicore_task); - GE_IF_BOOL_EXEC(ret != SUCCESS, delete aicore_task; aicore_task = nullptr; return ret); - - *task = aicore_task; - return SUCCESS; -} - -Status AiCoreNodeExecutor::CreateTask(const HybridModel &model, const std::vector &task_defs, - const NodePtr &node, std::shared_ptr &task) { - GE_CHECK_NOTNULL(node); - GELOGD("To CreateTask, task def size = %zu", task_defs.size()); - std::vector> aicore_op_tasks; - aicore_op_tasks.reserve(task_defs.size()); - for (size_t i = 0; i < task_defs.size(); ++i) { - const domi::TaskDef &task_def = task_defs[i]; - GELOGD("Op[%s] Task[%d], type = %u, DebugString = %s", node->GetName().c_str(), i, task_def.type(), - task_def.DebugString().c_str()); - auto task_type = static_cast(task_def.type()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(task_type == RT_MODEL_TASK_KERNEL_EX, return UNSUPPORTED, - "BuildKernelExTask is not supported"); - GE_CHK_BOOL_TRUE_EXEC_INFO(task_type != RT_MODEL_TASK_KERNEL, continue, "Skip task type %d", - static_cast(task_type)); - - const domi::KernelDef &kernel_def = task_def.kernel(); - AiCoreOpTask *aicore_op_task = nullptr; - // not use hybrid model now - GE_CHK_STATUS_RET_NOLOG(BuildAiCoreTask(kernel_def, node->GetOpDesc(), &aicore_op_task)); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(aicore_op_task == nullptr, return FAILED, "BuildAiCoreTask[%s] failed.", - node->GetName().c_str()); - - aicore_op_tasks.emplace_back(std::unique_ptr(aicore_op_task)); + if (!registry.AddTask(node_key, task)) { + GELOGE(INTERNAL_ERROR, "Add NodeTask failed, op name = %s.", node->GetName().c_str()); + return INTERNAL_ERROR; } - if (!aicore_op_tasks.empty()) { - auto aic_task = std::shared_ptr(new AiCoreNodeTask(std::move(aicore_op_tasks))); - task = std::move(aic_task); - GELOGD("Generate AiCoreOpTask success"); - return SUCCESS; - } - - GELOGE(INTERNAL_ERROR, "Failed to build task. node = %s", node->GetName().c_str()); - return INTERNAL_ERROR; -} - -Status AiCoreNodeExecutor::Initialize() { - std::shared_ptr ge_lib = GELib::GetInstance(); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((ge_lib == nullptr) || !ge_lib->InitFlag(), return GE_CLI_GE_NOT_INITIALIZED, - "Get ge_lib failed."); - - auto &kernel_manager = ge_lib->OpsKernelManagerObj(); - auto aic_ops_store = kernel_manager.GetOpsKernelInfoStore("AIcoreEngine"); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(aic_ops_store == nullptr, return GE_CLI_GE_NOT_INITIALIZED, - "Failed to get kernel info store for AIcoreEngine."); - - compiler_.reset(new (std::nothrow) AiCoreTaskCompiler(aic_ops_store)); - GE_CHECK_NOTNULL(compiler_); + GELOGI("AiCoreNodeExecutor(%s) CompileTask End.", node->GetName().c_str()); return SUCCESS; } -Status AiCoreNodeExecutor::Finalize() { return NodeExecutor::Finalize(); } - Status AiCoreNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { auto op_desc = context.GetNodeItem().op_desc; GE_CHECK_NOTNULL(op_desc); - GELOGI("AiCoreNodeTask[%s] ExecuteAsync Start.", op_desc->GetName().c_str()); - for (size_t i = 0; i < tasks_.size(); i++) { - GE_CHECK_NOTNULL(tasks_[i]); - GE_CHK_STATUS_RET_NOLOG(tasks_[i]->LaunchKernel(context.GetStream())); + GELOGI("[%s] ExecuteAsync Start.", op_desc->GetName().c_str()); + for (auto &task : tasks_) { + GE_CHK_STATUS_RET_NOLOG(task->LaunchKernel(context.GetStream())); } if (done_callback != nullptr) { GE_CHK_STATUS_RET_NOLOG(context.RegisterCallback(done_callback)); } - GELOGI("AiCoreNodeTask[%s] ExecuteAsync End.", op_desc->GetName().c_str()); + GELOGD("[%s] ExecuteAsync End.", op_desc->GetName().c_str()); return SUCCESS; } -Status AiCoreNodeTask::UpdateAtomicArgs(TaskContext &context, std::unique_ptr &task) { - GE_CHECK_NOTNULL(task); +Status AiCoreNodeTask::UpdateArgs(TaskContext &context) { auto op_desc = context.GetNodeItem().op_desc; GE_CHECK_NOTNULL(op_desc); - - // refresh atomic output addr - std::vector atomic_output_indexes; // here atomic just clean output - (void)ge::AttrUtils::GetListInt(op_desc, ge::ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indexes); - GE_RETURN_WITH_LOG_IF_TRUE(atomic_output_indexes.size() > static_cast(context.NumOutputs()), - "AtomicAddrClean op's arg_size error."); - auto *arg_off = reinterpret_cast(task->args_.get()) + task->offset_; - auto *arg_base = reinterpret_cast(arg_off); - int index = 0; - for (size_t i = 0; i < atomic_output_indexes.size(); ++i) { - const auto output = context.GetOutput(atomic_output_indexes[i]); - GE_CHECK_NOTNULL(output); - arg_base[index++] = reinterpret_cast(output->GetData()); - } - - // refresh atomic workspace addr - auto workspace_sizes = op_desc->GetWorkspaceBytes(); - uint64_t ops_workspace_num = static_cast(workspace_sizes.size()); - uint64_t workspace_num = static_cast(context.NumWorkspaces()); - GE_CHK_BOOL_EXEC(ops_workspace_num == workspace_num, return PARAM_INVALID, - "The workspace_num in op_desc %lu is not equal to it %lu in context.", ops_workspace_num, - workspace_num); - GE_IF_BOOL_EXEC(workspace_num == 0, return SUCCESS); - - map> workspace_info; - workspace_info = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, workspace_info); - if (!workspace_info.empty()) { - bool is_fusion_node = false; - (void)ge::AttrUtils::GetBool(op_desc, ATOMIC_ATTR_IS_FUSION_NODE, is_fusion_node); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(is_fusion_node, return PARAM_INVALID, - "Atomic desc[%s] shouldn't be fusion_node in AiCoreNodeTask", - op_desc->GetName().c_str()); - - for (auto iter = workspace_info.begin(); iter != workspace_info.end(); ++iter) { - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(op_desc->GetName() != iter->first, return PARAM_INVALID, - "The node name %s and the node name %s in workspace info are inconsistent.", - op_desc->GetName().c_str(), iter->first.c_str()); - GE_IF_BOOL_EXEC(iter->second.empty(), continue); - - for (auto &info_iter : iter->second) { - auto workspace_index = static_cast(info_iter.first); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(workspace_index >= workspace_num, return PARAM_INVALID, - "The workspace index %lu is more than the size %lu of workspace vector.", - workspace_index, workspace_num); - - const auto workspace = context.MutableWorkspace(workspace_index); - arg_base[index++] = reinterpret_cast(workspace); - } - } + GELOGI("[%s] AiCoreNodeTask UpdateArgs Start.", op_desc->GetName().c_str()); + for (auto &task : tasks_) { + GE_CHK_STATUS_RET_NOLOG(task->UpdateArgs(context)); } - + GELOGI("[%s] AiCoreNodeTask UpdateArgs End.", op_desc->GetName().c_str()); return SUCCESS; } -Status AiCoreNodeTask::UpdateAllArgs(TaskContext &context, std::unique_ptr &task) { - GE_CHECK_NOTNULL(task); - auto *arg_off = reinterpret_cast(task->args_.get()) + task->offset_; - auto *arg_base = reinterpret_cast(arg_off); - int index = 0; - for (int i = 0; i < context.NumInputs(); ++i) { - const auto input = context.GetInput(i); - GE_CHECK_NOTNULL(input); - arg_base[index++] = reinterpret_cast(input->GetData()); - } - - for (int i = 0; i < context.NumOutputs(); ++i) { - const auto output = context.GetOutput(i); - GE_CHECK_NOTNULL(output); - arg_base[index++] = reinterpret_cast(output->GetData()); - } - - auto op_desc = context.GetNodeItem().op_desc; - GE_CHECK_NOTNULL(op_desc); - auto workspace_sizes = op_desc->GetWorkspaceBytes(); - int ops_workspace_num = static_cast(workspace_sizes.size()); - int workspace_num = static_cast(context.NumWorkspaces()); - GE_CHK_BOOL_EXEC(ops_workspace_num == workspace_num, return PARAM_INVALID, - "The workspace_num in op_desc %lu is not equal to it %lu in context.", ops_workspace_num, - workspace_num); - for (int i = 0; i < workspace_num; ++i) { - const auto workspace = context.MutableWorkspace(i); - arg_base[index++] = reinterpret_cast(workspace); +Status AiCoreNodeTask::UpdateTilingData(TaskContext &context) { + GELOGD("[%s] PrepareWithShape started", context.GetNodeName()); + for (auto &task : tasks_) { + GE_CHK_STATUS_RET_NOLOG(task->PrepareWithShape(context)); } - + GELOGD("[%s] Done PrepareWithShape successfully.", context.GetNodeName()); return SUCCESS; } -Status AiCoreNodeTask::UpdateArgs(TaskContext &context) { - auto op_desc = context.GetNodeItem().op_desc; - GE_CHECK_NOTNULL(op_desc); - GELOGI("AiCoreNodeTask[%s] UpdateArgs Start.", op_desc->GetName().c_str()); - GE_IF_BOOL_EXEC(tasks_.size() == 1, return UpdateAllArgs(context, tasks_[0])); - - std::vector atomic_output_indexes; // here atomic just clean output - (void)ge::AttrUtils::GetListInt(op_desc, ge::ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indexes); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(atomic_output_indexes.empty(), return FAILED, "ATOMIC_ATTR_OUTPUT_INDEX is empty."); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(tasks_.size() != 2, return FAILED, "AtomicAddrClean op task num != 2."); - - GE_CHK_STATUS_RET_NOLOG(UpdateAtomicArgs(context, tasks_[0])); - GE_CHK_STATUS_RET_NOLOG(UpdateAllArgs(context, tasks_[1])); +bool AiCoreNodeTask::IsSupportDynamicShape() { + for (size_t i = 0; i < tasks_.size(); ++i) { + if (!tasks_[i]->IsDynamicShapeSupported()) { + GELOGD("[%s] Task does not support dynamic shape.", tasks_[i]->GetName().c_str()); + return false; + } + } - GELOGI("AiCoreNodeTask[%s] UpdateArgs End.", op_desc->GetName().c_str()); - return SUCCESS; + return true; } } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/node_executor/aicore/aicore_node_executor.h b/src/ge/hybrid/node_executor/aicore/aicore_node_executor.h index a8b24e68..506202fa 100644 --- a/src/ge/hybrid/node_executor/aicore/aicore_node_executor.h +++ b/src/ge/hybrid/node_executor/aicore/aicore_node_executor.h @@ -25,7 +25,6 @@ namespace ge { namespace hybrid { - class AiCoreNodeTaskRegistry { public: ~AiCoreNodeTaskRegistry() = default; @@ -47,32 +46,27 @@ class AiCoreNodeTaskRegistry { class AiCoreNodeTask : public NodeTask { public: explicit AiCoreNodeTask(std::vector> &&tasks); - ~AiCoreNodeTask() = default; - Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + ~AiCoreNodeTask() override = default; + bool IsSupportDynamicShape() override; + Status UpdateTilingData(TaskContext &context) override; + Status UpdateArgs(TaskContext &context) override; + Status ExecuteAsync(TaskContext &context, std::function done_callback) override; private: - static Status UpdateAllArgs(TaskContext &context, std::unique_ptr &task); - static Status UpdateAtomicArgs(TaskContext &context, std::unique_ptr &task); std::vector> tasks_; }; class AiCoreNodeExecutor : public NodeExecutor { public: Status Initialize() override; - Status Finalize() override; - Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const override; Status CompileTask(const HybridModel &model, const NodePtr &node, std::shared_ptr &task) const override; private: - static Status CreateTask(const HybridModel &model, const std::vector &task_defs, const NodePtr &node, - std::shared_ptr &task); - static Status BuildAiCoreTask(const domi::KernelDef &kernel_def, const OpDescPtr &op_desc, AiCoreOpTask **task); static Status GenNodeKey(const NodePtr &node, std::string &node_key); std::unique_ptr compiler_; }; - } // namespace hybrid } // namespace ge #endif // GE_HYBRID_KERNEL_AICORE_NODE_EXECUTOR_H_ diff --git a/src/ge/hybrid/node_executor/aicore/aicore_op_task.cc b/src/ge/hybrid/node_executor/aicore/aicore_op_task.cc index 27256e9a..f5a4af83 100644 --- a/src/ge/hybrid/node_executor/aicore/aicore_op_task.cc +++ b/src/ge/hybrid/node_executor/aicore/aicore_op_task.cc @@ -14,19 +14,305 @@ * limitations under the License. */ -#include "aicore_op_task.h" +#include "hybrid/node_executor/aicore/aicore_op_task.h" +#include "cce/taskdown_common.hpp" #include "framework/common/debug/log.h" +#include "hybrid/executor/hybrid_execution_context.h" +#include "hybrid/node_executor/aicore/aicore_task_builder.h" + +using optiling::OpRunInfo; namespace ge { namespace hybrid { +namespace { +constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; +constexpr char const *kAttrOpParamSize = "op_para_size"; +constexpr char const *kAttrAtomicOpParamSize = "atomic_op_para_size"; +} // namespace -Status AiCoreOpTask::LaunchKernel(rtStream_t stream) { - GELOGI("AiCoreOpTask LaunchKernel Start (task = %s, block_dim = %u).", stub_name_.c_str(), block_dim_); +Status AiCoreOpTask::Init(const OpDesc &op_desc, const domi::TaskDef &task_def) { + GE_CHK_STATUS_RET_NOLOG(InitWithTaskDef(op_desc, task_def)); + GE_CHK_STATUS_RET_NOLOG(InitTilingInfo(op_desc)); + return SUCCESS; +} + +Status AiCoreOpTask::InitWithTaskDef(const OpDesc &op_desc, const domi::TaskDef &task_def) { + GE_CHK_STATUS_RET(ValidateTaskDef(task_def), "[%s] Failed to validate task def: [%s]", op_desc.GetName().c_str(), + task_def.DebugString().c_str()); + + const domi::KernelDef &kernel_def = task_def.kernel(); + const domi::KernelContext &context = kernel_def.context(); + stub_name_ = kernel_def.stub_func(); + GE_CHK_RT_RET(rtGetFunctionByName(stub_name_.c_str(), &stub_func_)); + args_size_ = kernel_def.args_size(); + block_dim_ = kernel_def.block_dim(); + + // malloc args memory + args_.reset(new (std::nothrow) uint8_t[args_size_]); + GE_CHECK_NOTNULL(args_); + errno_t err = memcpy_s(args_.get(), args_size_, kernel_def.args().data(), args_size_); + if (err != EOK) { + GELOGE(INTERNAL_ERROR, "AiCoreTask memcpy args failed."); + return INTERNAL_ERROR; + } + + if (context.args_offset().size() < sizeof(uint16_t)) { + GELOGE(INTERNAL_ERROR, "Invalid args_offset, size = %zu.", context.args_offset().size()); + return INTERNAL_ERROR; + } + + const auto *args_offset_buffer = reinterpret_cast(context.args_offset().data()); + uint32_t offset = *args_offset_buffer; + if (offset > args_size_) { + GELOGE(INTERNAL_ERROR, "[%s] Arg offset out of range. offset = %u, arg size = %u", GetName().c_str(), offset, + args_size_); + return INTERNAL_ERROR; + } + + arg_base_ = reinterpret_cast(args_.get() + offset); + max_arg_count_ = (args_size_ - offset) / sizeof(void *); + GELOGD("[%s] Done setting kernel args successfully. stub_func = %s, block_dim = %d, arg base = %p, arg size = %u", + op_desc.GetName().c_str(), stub_name_.c_str(), block_dim_, arg_base_, args_size_); + + return SUCCESS; +} + +Status AiCoreOpTask::ValidateTaskDef(const domi::TaskDef &task_def) { + auto task_type = static_cast(task_def.type()); + if (task_type != RT_MODEL_TASK_KERNEL) { + GELOGE(INTERNAL_ERROR, "Invalid task type (%d) in AiCore CreateTask.", static_cast(task_type)); + return INTERNAL_ERROR; + } + + const domi::KernelDef &kernel_def = task_def.kernel(); + const domi::KernelContext &context = kernel_def.context(); + auto kernel_type = static_cast(context.kernel_type()); + if (kernel_type != cce::ccKernelType::TE) { + GELOGE(INTERNAL_ERROR, "Invalid kernel type(%d) in AiCore TaskDef.", static_cast(kernel_type)); + return INTERNAL_ERROR; + } + + return SUCCESS; +} + +Status AiCoreOpTask::PrepareWithShape(TaskContext &context) { + if (tiling_buffer_ != nullptr) { + return UpdateTilingInfo(context); + } + + return SUCCESS; +} + +Status AiCoreOpTask::UpdateTilingInfo(TaskContext &context) { + auto node = context.GetNodeItem().node; + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + GELOGD("[%s] Start to update tiling info for task: [%s]", node->GetName().c_str(), stub_name_.c_str()); + OpRunInfo tiling_info; + tiling_info.block_dim = -1; // codex: Using uninitialized value + + auto execution_context = context.GetExecutionContext(); + RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CalcTilingInfo] Start"); + GE_CHK_STATUS_RET(CalcTilingInfo(node, tiling_info)); + RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CalcTilingInfo] End"); + + // update op args by tiling info + block_dim_ = static_cast(tiling_info.block_dim); + op_desc->SetWorkspaceBytes(tiling_info.workspaces); + + tiling_data_ = tiling_info.tiling_data.str(); + if (tiling_data_.empty()) { + GELOGE(INTERNAL_ERROR, "[%s] Tiling data is empty.", stub_name_.c_str()); + return INTERNAL_ERROR; + } + + if (tiling_data_.size() > tiling_buffer_->GetSize()) { + GELOGE(INTERNAL_ERROR, "[%s] Tiling data size now (%zu) shouldn't larger than we alloc before (%zu).", + stub_name_.c_str(), tiling_data_.size(), tiling_buffer_->GetSize()); + return INTERNAL_ERROR; + } + + RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CopyTilingInfo] Start"); + GE_CHK_RT_RET(rtMemcpy(tiling_buffer_->GetData(), tiling_buffer_->GetSize(), tiling_data_.c_str(), + tiling_data_.size(), RT_MEMCPY_HOST_TO_DEVICE)); + RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CopyTilingInfo] End"); + + GELOGD("[%s] Done updating tiling info for task: [%s]", node->GetName().c_str(), stub_name_.c_str()); + return SUCCESS; +} + +Status AiCoreOpTask::CalcTilingInfo(const NodePtr &node, OpRunInfo &tiling_info) { + GELOGD("[%s] Start to invoke OpParaCalculate.", node->GetName().c_str()); + GE_CHK_STATUS_RET(OpParaCalculate(*node, tiling_info), "Failed calc tiling data of node %s.", + node->GetName().c_str()); + GELOGD("[%s] Done invoking OpParaCalculate successfully.", node->GetName().c_str()); + return SUCCESS; +} + +Status AiCoreOpTask::UpdateArgs(TaskContext &task_context) { + size_t expected_arg_count = task_context.NumInputs() + task_context.NumOutputs() + task_context.NumWorkspaces(); + if (tiling_buffer_ != nullptr) { + ++expected_arg_count; + } + if (expected_arg_count > max_arg_count_) { + GELOGE(INTERNAL_ERROR, "[%s] Invalid arg memory, max arg count = %u, but expect = %zu", GetName().c_str(), + max_arg_count_, expected_arg_count); + return INTERNAL_ERROR; + } + int index = 0; + for (int i = 0; i < task_context.NumInputs(); ++i) { + const auto input = task_context.GetInput(i); + GE_CHECK_NOTNULL(input); + arg_base_[index++] = reinterpret_cast(input->GetData()); + } + + for (int i = 0; i < task_context.NumOutputs(); ++i) { + const auto output = task_context.GetOutput(i); + GE_CHECK_NOTNULL(output); + arg_base_[index++] = reinterpret_cast(output->GetData()); + } + + int workspace_num = static_cast(task_context.NumWorkspaces()); + for (int i = 0; i < workspace_num; ++i) { + const auto workspace = task_context.MutableWorkspace(i); + GE_CHECK_NOTNULL(workspace); + arg_base_[index++] = reinterpret_cast(workspace); + } + + if (tiling_buffer_ != nullptr) { + arg_base_[index++] = reinterpret_cast(tiling_buffer_->GetData()); + } + + if (task_context.IsTraceEnabled()) { + for (int i = 0; i < index; ++i) { + GELOGD("[%s] Arg[%d] = %lu", stub_name_.c_str(), i, arg_base_[i]); + } + } + + return SUCCESS; +} + +Status AiCoreOpTask::LaunchKernel(rtStream_t stream) { + GELOGD("AiCoreOpTask LaunchKernel Start (task = %s, block_dim = %u).", stub_name_.c_str(), block_dim_); GE_CHK_RT_RET(rtKernelLaunch(stub_func_, block_dim_, args_.get(), args_size_, nullptr, stream)); - GELOGI("AiCoreOpTask LaunchKernel End (task = %s, block_dim = %u).", stub_name_.c_str(), block_dim_); + GELOGD("AiCoreOpTask LaunchKernel End (task = %s, block_dim = %u).", stub_name_.c_str(), block_dim_); return SUCCESS; } +Status AiCoreOpTask::InitTilingInfo(const OpDesc &op_desc) { + bool dynamic_supported = false; + (void)AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, dynamic_supported); + if (!dynamic_supported) { + GELOGD("[%s] Dynamic shape is not supported.", op_desc.GetName().c_str()); + return SUCCESS; + } + + GELOGD("Start alloc tiling data of node %s.", op_desc.GetName().c_str()); + int64_t max_size = -1; + (void)AttrUtils::GetInt(op_desc, GetKeyForOpParamSize(), max_size); + GELOGD("Got op param size by key: %s, ret = %ld", GetKeyForOpParamSize().c_str(), max_size); + if (max_size <= 0) { + GELOGE(PARAM_INVALID, "[%s] Invalid op_param_size: %ld.", op_desc.GetName().c_str(), max_size); + return PARAM_INVALID; + } + + auto allocator = NpuMemoryAllocator::GetAllocator(); + GE_CHECK_NOTNULL(allocator); + tiling_buffer_ = TensorBuffer::Create(allocator, static_cast(max_size)); + GE_CHECK_NOTNULL(tiling_buffer_); + + GELOGD("[%s] Done allocating tiling buffer, size=%ld.", op_desc.GetName().c_str(), max_size); + return SUCCESS; +} + +bool AiCoreOpTask::IsDynamicShapeSupported() { return tiling_buffer_ != nullptr; } + +const std::string &AiCoreOpTask::GetName() const { return stub_name_; } + +std::string AiCoreOpTask::GetKeyForOpParamSize() const { return kAttrOpParamSize; } + +Status AtomicAddrCleanOpTask::Init(const OpDesc &op_desc, const domi::TaskDef &task_def) { + GE_CHK_STATUS_RET_NOLOG(AiCoreOpTask::Init(op_desc, task_def)); + return InitAtomicAddrCleanIndices(op_desc); +} + +Status AtomicAddrCleanOpTask::InitAtomicAddrCleanIndices(const OpDesc &op_desc) { + GELOGD("[%s] Start to setup AtomicAddrClean task.", op_desc.GetName().c_str()); + std::vector atomic_output_indices; + (void)ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indices); + map> workspace_info; // op_name, ws_index, ws_offset + workspace_info = op_desc.TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, workspace_info); + if (atomic_output_indices.empty() && workspace_info.empty()) { + GELOGE(INTERNAL_ERROR, "[%s] Neither ATOMIC_ATTR_OUTPUT_INDEX nor EXT_ATTR_ATOMIC_WORKSPACE_INFO is empty.", + op_desc.GetName().c_str()); + return INTERNAL_ERROR; + } + + for (auto output_index : atomic_output_indices) { + GELOGD("[%s] Adding output index [%ld]", op_desc.GetName().c_str(), output_index); + GE_CHECK_GE(output_index, 0); + GE_CHECK_LE(output_index, INT32_MAX); + atomic_output_indices_.emplace_back(static_cast(output_index)); + } + + for (auto &iter : workspace_info) { + for (auto &info_iter : iter.second) { + auto workspace_index = info_iter.first; + GELOGD("[%s] Adding workspace index [%ld]", op_desc.GetName().c_str(), workspace_index); + GE_CHECK_GE(workspace_index, 0); + GE_CHECK_LE(workspace_index, INT32_MAX); + atomic_workspace_indices_.emplace_back(static_cast(workspace_index)); + } + } + + size_t arg_count = atomic_workspace_indices_.size() + atomic_output_indices_.size(); + if (tiling_buffer_ != nullptr) { + arg_count += 1; + } + + if (arg_count > max_arg_count_) { + GELOGE(INTERNAL_ERROR, "[%s] Invalid arg memory, max arg count = %u, but expect = %zu", GetName().c_str(), + max_arg_count_, arg_count); + return INTERNAL_ERROR; + } + + return SUCCESS; +} + +std::string AtomicAddrCleanOpTask::GetKeyForOpParamSize() const { return kAttrAtomicOpParamSize; } + +Status AtomicAddrCleanOpTask::UpdateArgs(TaskContext &task_context) { + // refresh atomic output addr + int index = 0; + for (auto atomic_output_index : atomic_output_indices_) { + const auto output_tensor = task_context.GetOutput(atomic_output_index); + GE_CHECK_NOTNULL(output_tensor); + arg_base_[index++] = reinterpret_cast(output_tensor->GetData()); + } + + // refresh atomic workspace addr + for (auto atomic_ws_index : atomic_workspace_indices_) { + const auto workspace_tensor = task_context.GetOutput(atomic_ws_index); + GE_CHECK_NOTNULL(workspace_tensor); + arg_base_[index++] = reinterpret_cast(workspace_tensor->GetData()); + } + + if (tiling_buffer_ != nullptr) { + arg_base_[index++] = reinterpret_cast(tiling_buffer_->GetData()); + } else { + GELOGD("[%s] Not a dynamic op", GetName().c_str()); + } + + if (task_context.IsTraceEnabled()) { + for (int i = 0; i < index; ++i) { + GELOGD("[%s] Arg[%d] = %lu", GetName().c_str(), i, arg_base_[i]); + } + } + + return SUCCESS; +} } // namespace hybrid -} // namespace ge \ No newline at end of file +} // namespace ge diff --git a/src/ge/hybrid/node_executor/aicore/aicore_op_task.h b/src/ge/hybrid/node_executor/aicore/aicore_op_task.h index d23688a5..74876588 100644 --- a/src/ge/hybrid/node_executor/aicore/aicore_op_task.h +++ b/src/ge/hybrid/node_executor/aicore/aicore_op_task.h @@ -18,27 +18,69 @@ #define GE_HYBRID_KERNEL_AICORE_OP_TASK_H_ #include +#include #include "common/ge_inner_error_codes.h" #include "runtime/stream.h" +#include "hybrid/common/tensor_value.h" +#include "hybrid/node_executor/task_context.h" +#include "proto/task.pb.h" +#include "register/op_tiling.h" + namespace ge { namespace hybrid { class AiCoreOpTask { public: AiCoreOpTask() = default; - ~AiCoreOpTask() = default; + virtual ~AiCoreOpTask() = default; + + virtual Status Init(const OpDesc &op_desc, const domi::TaskDef &task_def); + + bool IsDynamicShapeSupported(); + + // do preparation with shape(without actual io memory) + Status PrepareWithShape(TaskContext &context); + + virtual Status UpdateArgs(TaskContext &task_context); + Status LaunchKernel(rtStream_t stream); + const std::string &GetName() const; + + protected: + Status UpdateTilingInfo(TaskContext &context); + virtual std::string GetKeyForOpParamSize() const; + virtual Status CalcTilingInfo(const NodePtr &node, optiling::OpRunInfo &tiling_info); + + std::unique_ptr tiling_buffer_ = nullptr; + std::string tiling_data_; + uintptr_t *arg_base_ = nullptr; + uint32_t max_arg_count_ = 0; + private: - friend class AiCoreTaskBuilder; - friend class AiCoreNodeTask; + static Status ValidateTaskDef(const domi::TaskDef &task_def); + Status InitWithTaskDef(const OpDesc &node, const domi::TaskDef &task_def); + Status InitTilingInfo(const OpDesc &op_desc); + std::string stub_name_; void *stub_func_ = nullptr; std::unique_ptr args_ = nullptr; uint32_t args_size_ = 0; uint32_t block_dim_ = 1; - uint16_t offset_ = 0; }; +class AtomicAddrCleanOpTask : public AiCoreOpTask { + public: + Status Init(const OpDesc &op_desc, const domi::TaskDef &task_def) override; + Status UpdateArgs(TaskContext &task_context) override; + + protected: + std::string GetKeyForOpParamSize() const override; + + private: + Status InitAtomicAddrCleanIndices(const OpDesc &op_desc); + std::vector atomic_output_indices_; + std::vector atomic_workspace_indices_; +}; } // namespace hybrid } // namespace ge #endif // GE_HYBRID_KERNEL_AICORE_OP_TASK_H_ diff --git a/src/ge/hybrid/node_executor/aicore/aicore_task_builder.cc b/src/ge/hybrid/node_executor/aicore/aicore_task_builder.cc index 5b263007..bad91806 100644 --- a/src/ge/hybrid/node_executor/aicore/aicore_task_builder.cc +++ b/src/ge/hybrid/node_executor/aicore/aicore_task_builder.cc @@ -15,76 +15,78 @@ */ #include "aicore_task_builder.h" -#include -#include "graph/op_desc.h" -#include "cce/taskdown_common.hpp" -#include "framework/common/debug/log.h" -#include "graph/debug/ge_attr_define.h" +#include "common/debug/log.h" +#include "aicore_node_executor.h" namespace ge { namespace hybrid { -std::mutex g_reg_mutex; - -AiCoreTaskBuilder::AiCoreTaskBuilder(const OpDescPtr &op_desc, const domi::KernelDef &kernel_def) - : op_desc_(op_desc), kernel_def_(kernel_def) { - std::string session_graph_id; - GE_IF_BOOL_EXEC(AttrUtils::GetStr(*op_desc_, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id), - GELOGD("Get original type of session_graph_id.")); - // get bin_file_key - stub_name_ = (session_graph_id.empty()) ? op_desc_->GetName() : session_graph_id + "_" + op_desc_->GetName(); -} - -Status AiCoreTaskBuilder::SetKernelArgs(AiCoreOpTask &task) { - const domi::KernelContext &context = kernel_def_.context(); - // get kernel_type - auto kernel_type = static_cast(context.kernel_type()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(kernel_type != cce::ccKernelType::TE, return UNSUPPORTED, - "Invalid kernel type[%d] in AiCore TaskDef.", static_cast(kernel_type)); - - task.args_size_ = kernel_def_.args_size(); - task.block_dim_ = kernel_def_.block_dim(); - - // malloc args memory - task.args_.reset(new (std::nothrow) uint8_t[task.args_size_]); - // task.args_ = std::make_unique(task.args_size_); - GE_CHECK_NOTNULL(task.args_); - errno_t err = memcpy_s(task.args_.get(), task.args_size_, kernel_def_.args().data(), task.args_size_); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(err != EOK, return INTERNAL_ERROR, "AiCoreTask memcpy failed."); - - const auto *args_offset_tmp = reinterpret_cast(const_cast(context.args_offset().data())); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(context.args_offset().size() / sizeof(uint16_t) < 1, return FAILED, - "context.args_offset().size() / sizeof(uint16_t) less than 1"); - task.offset_ = *args_offset_tmp; - return SUCCESS; +namespace { +const size_t kNumTaskWithAtomicAddrCleanTask = 2; } - const char *AiCoreKernelRegistry::GetUnique(const string &stub_key) { std::lock_guard lock(mutex_); auto it = unique_stubs_.find(stub_key); - GE_IF_BOOL_EXEC(it != unique_stubs_.end(), return it->c_str()); + if (it != unique_stubs_.end()) { + return it->c_str(); + } it = unique_stubs_.insert(unique_stubs_.end(), stub_key); return it->c_str(); } -Status AiCoreTaskBuilder::SetStub(AiCoreOpTask &task) { - AiCoreKernelRegistry ®istry = AiCoreKernelRegistry::GetInstance(); - std::lock_guard lock(g_reg_mutex); - const char *unique_key = registry.GetUnique(stub_name_); +AiCoreTaskBuilder::AiCoreTaskBuilder(const OpDescPtr &op_desc, const std::vector &task_defs) + : op_desc_(op_desc), task_defs_(task_defs) {} - GE_CHK_RT_RET(rtGetFunctionByName(unique_key, &(task.stub_func_))); - task.stub_name_ = stub_name_; +Status AiCoreTaskBuilder::BuildTask(std::unique_ptr &node_task, bool ignore_failure_on_atomic) { + GE_CHECK_NOTNULL(op_desc_); + if (task_defs_.size() > kNumTaskWithAtomicAddrCleanTask) { + GELOGE(INTERNAL_ERROR, "[%s] At most 2 task was supported, but got %zu", op_desc_->GetName().c_str(), + task_defs_.size()); + return INTERNAL_ERROR; + } - return SUCCESS; -} + std::vector> op_tasks; + if (ExpectAtomicAddrCleanTask()) { + if (task_defs_.size() != kNumTaskWithAtomicAddrCleanTask) { + if (ignore_failure_on_atomic) { + GELOGI("[%s] AtomicAddrClean task was expected, but got %zu task_defs", op_desc_->GetName().c_str(), + task_defs_.size()); + return SUCCESS; + } else { + GELOGE(INTERNAL_ERROR, "[%s] AtomicAddrClean task was expected, but got %zu task_defs", + op_desc_->GetName().c_str(), task_defs_.size()); + return INTERNAL_ERROR; + } + } -Status AiCoreTaskBuilder::BuildTask(AiCoreOpTask &task) { - GE_CHECK_NOTNULL(op_desc_); - GELOGI("AiCoreTaskBuilder[%s] BuildTask Start.", op_desc_->GetName().c_str()); - GE_CHK_STATUS_RET_NOLOG(SetKernelArgs(task)); - GE_CHK_STATUS_RET_NOLOG(SetStub(task)); - GELOGI("AiCoreTaskBuilder[%s] BuildTask End.", op_desc_->GetName().c_str()); + GELOGD("[%s] Build AtomicAddrClean task.", op_desc_->GetName().c_str()); + auto atomic_task = std::unique_ptr(new (std::nothrow) AtomicAddrCleanOpTask()); + GE_CHECK_NOTNULL(atomic_task); + GE_CHK_STATUS_RET(atomic_task->Init(*op_desc_, task_defs_.front()), "[%s] Failed to init task for AtomicAddrClean", + op_desc_->GetName().c_str()); + op_tasks.emplace_back(std::move(atomic_task)); + } + + // build aicore task + auto aicore_task = std::unique_ptr(new (std::nothrow) AiCoreOpTask()); + GE_CHECK_NOTNULL(aicore_task); + GE_CHK_STATUS_RET(aicore_task->Init(*op_desc_, task_defs_.back()), "[%s] Failed to init task for AtomicAddrClean", + op_desc_->GetName().c_str()); + op_tasks.emplace_back(std::move(aicore_task)); + + node_task.reset(new (std::nothrow) AiCoreNodeTask(std::move(op_tasks))); + GE_CHECK_NOTNULL(node_task); return SUCCESS; } +bool AiCoreTaskBuilder::ExpectAtomicAddrCleanTask() { + if (op_desc_->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX)) { + GELOGD("[%s] Node has ATOMIC_ATTR_OUTPUT_INDEX", op_desc_->GetName().c_str()); + return true; + } + map> workspace_info; + workspace_info = op_desc_->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, workspace_info); + + return !workspace_info.empty(); +} } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/node_executor/aicore/aicore_task_builder.h b/src/ge/hybrid/node_executor/aicore/aicore_task_builder.h index 18cb309c..4610e57a 100644 --- a/src/ge/hybrid/node_executor/aicore/aicore_task_builder.h +++ b/src/ge/hybrid/node_executor/aicore/aicore_task_builder.h @@ -17,14 +17,13 @@ #ifndef GE_HYBRID_KERNEL_AICORE_TASK_BUILDER_H_ #define GE_HYBRID_KERNEL_AICORE_TASK_BUILDER_H_ -#include +#include #include -#include -#include #include "aicore_op_task.h" -#include "proto/task.pb.h" +#include "framework/common/debug/ge_log.h" #include "graph/utils/attr_utils.h" #include "graph/op_kernel_bin.h" +#include "proto/task.pb.h" namespace ge { namespace hybrid { @@ -45,16 +44,16 @@ class AiCoreKernelRegistry { class AiCoreTaskBuilder { public: - AiCoreTaskBuilder(const OpDescPtr &op_desc, const domi::KernelDef &kernel_def); + AiCoreTaskBuilder(const OpDescPtr &op_desc, const std::vector &task_defs); ~AiCoreTaskBuilder() = default; - Status BuildTask(AiCoreOpTask &task); + + Status BuildTask(std::unique_ptr &node_task, bool ignore_failure_on_atomic); private: - Status SetKernelArgs(AiCoreOpTask &task); - Status SetStub(AiCoreOpTask &task); - const OpDescPtr &op_desc_; - const domi::KernelDef &kernel_def_; - std::string stub_name_; + bool ExpectAtomicAddrCleanTask(); + + OpDescPtr op_desc_; + const std::vector &task_defs_; }; } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc b/src/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc index ac89afbd..9119bebb 100644 --- a/src/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc +++ b/src/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc @@ -34,7 +34,6 @@ Status AiCoreTaskCompiler::DoCompileOp(OpsKernelInfoStore &ops_store, const Node GE_CHECK_NOTNULL(node); vector node_vec; node_vec.emplace_back(node); - std::lock_guard lk(mu_); GE_CHK_STATUS_RET(ops_store.CompileOpRun(node_vec), "Failed to execute CompileOp, node = %s", node->GetName().c_str()); GE_CHK_STATUS_RET(ops_store.CalcOpRunningParam(*node), "Failed to execute CalcOpRunningParam, node = %s", @@ -44,9 +43,8 @@ Status AiCoreTaskCompiler::DoCompileOp(OpsKernelInfoStore &ops_store, const Node Status AiCoreTaskCompiler::CompileOp(const NodePtr &node, std::vector &tasks) const { GE_CHECK_NOTNULL(node); - GELOGI("AiCoreTaskCompiler[%s] CompileOp Start.", node->GetName().c_str()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(aic_kernel_store_ == nullptr, return FAILED, - "Failed to get AiCore kernel store, node = %s", node->GetName().c_str()); + GELOGI("AiCoreTaskCompiler(%s) CompileOp Start.", node->GetName().c_str()); + GE_CHECK_NOTNULL(aic_kernel_store_); GE_CHK_STATUS_RET_NOLOG(DoCompileOp(*aic_kernel_store_, node)); GELOGD("successfully compiled op: %s", node->GetName().c_str()); @@ -58,7 +56,7 @@ Status AiCoreTaskCompiler::CompileOp(const NodePtr &node, std::vectorSetOutputOffset(output_offsets); GE_CHK_STATUS_RET_NOLOG(DoGenerateTask(*aic_kernel_store_, *node, tasks)); GELOGD("successfully generated task: %s", node->GetName().c_str()); - GELOGI("AiCoreTaskCompiler[%s] CompileOp End.", node->GetName().c_str()); + GELOGI("AiCoreTaskCompiler(%s) CompileOp End.", node->GetName().c_str()); return SUCCESS; } @@ -91,6 +89,5 @@ Status AiCoreTaskCompiler::DoGenerateTask(OpsKernelInfoStore &store, const Node GE_CHK_RT(rtModelDestroy(rt_model_)); return ret; } - } // namespace hybrid -} // namespace ge \ No newline at end of file +} // namespace ge diff --git a/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc b/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc index d5c3c03c..332675bf 100644 --- a/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc +++ b/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc @@ -199,6 +199,5 @@ void AicpuExtInfoHandler::GetShapeAndType(const AicpuShapeAndType *shape_and_typ data_type = static_cast(shape_and_type->type); shape = std::move(GeShape(dims)); } - } // namespace hybrid } // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h b/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h index e96d794c..a42678b1 100644 --- a/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h +++ b/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h @@ -24,7 +24,6 @@ namespace ge { namespace hybrid { - using AicpuShapeAndType = aicpu::FWKAdapter::ShapeAndType; using AicpuExtInfo = aicpu::FWKAdapter::ExtInfo; diff --git a/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc b/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc index 372f35f5..46d9a0aa 100644 --- a/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc +++ b/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc @@ -40,19 +40,28 @@ Status AicpuNodeTaskBase::AllocTensorBuffer(size_t size, std::unique_ptris_dynamic) { + // dynamic node must have ext info + GE_CHK_STATUS_RET(aicpu_ext_handle_.Parse(kernel_ext_info), + "Node[%s] parse kernel ext info failed, kernel_ext_info_size=%zu.", node_name_.c_str(), + kernel_ext_info.size()); + } + + // if no ext info no need copy to device. + if (kernel_ext_info.empty()) { + GELOGI("Node[%s] kernel_ext_info is empty, no need copy to device, is_dynamic=%s.", node_name_.c_str(), + node_item_->is_dynamic ? "true" : "false"); + return SUCCESS; + } // copy task args buf GE_CHK_STATUS_RET(AllocTensorBuffer(kernel_ext_info.size(), ext_info_addr_dev_), "Node[%s] alloc kernel_ext_info buf failed, size=%zu", node_name_.c_str(), kernel_ext_info.size()); - // if no input and no output(DEPEND_COMPUTE equal no output), copy once, or else copy when update args. - if (node_item_->num_inputs == 0 && ((unknown_type_ == DEPEND_COMPUTE) || (node_item_->num_outputs == 0))) { - GE_CHK_RT_RET(rtMemcpy(ext_info_addr_dev_->GetData(), ext_info_addr_dev_->GetSize(), kernel_ext_info.data(), - kernel_ext_info.size(), RT_MEMCPY_HOST_TO_DEVICE)); - } + // copy default ext info to device + GE_CHK_RT_RET(rtMemcpy(ext_info_addr_dev_->GetData(), ext_info_addr_dev_->GetSize(), kernel_ext_info.data(), + kernel_ext_info.size(), RT_MEMCPY_HOST_TO_DEVICE)); + return SUCCESS; } @@ -139,16 +148,18 @@ Status AicpuNodeTaskBase::UpdateExtInfo() { } Status AicpuNodeTaskBase::UpdateArgs(TaskContext &context) { - GELOGI("Node[%s] update args begin. unknown_type=%d", node_name_.c_str(), unknown_type_); + GELOGI("Node[%s] update args begin. is_dynamic=%s, unknown_type=%d", node_name_.c_str(), + node_item_->is_dynamic ? "true" : "false", unknown_type_); if (node_item_->num_inputs == 0 && node_item_->num_outputs == 0) { GELOGI("Node[%s] has no input and output, no need update args.", node_name_.c_str()); return SUCCESS; } GE_CHK_STATUS_RET(UpdateIoAddr(context), "Node[%s] update io addr failed.", node_name_.c_str()); - - GE_CHK_STATUS_RET(UpdateExtInfo(), "Node[%s] update ext info failed.", node_name_.c_str()); - + if (node_item_->is_dynamic) { + // dynamic node need update ext info. + GE_CHK_STATUS_RET(UpdateExtInfo(), "Node[%s] update ext info failed.", node_name_.c_str()); + } GELOGI("Node[%s] update args end.", node_name_.c_str()); return SUCCESS; } @@ -275,9 +286,12 @@ Status AicpuTfNodeTask::Init(const HybridModel &model) { fwk_op_kernel.fwkKernelBase.fwk_kernel.workspaceBaseAddr = reinterpret_cast(kernel_workspace_->GetData()); fwk_op_kernel.fwkKernelBase.fwk_kernel.inputOutputAddr = reinterpret_cast(input_output_addr_->GetData()); - // set ext info addr and ext info num - fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoAddr = reinterpret_cast(ext_info_addr_dev_->GetData()); - fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoLen = ext_info_addr_dev_->GetSize(); + + if (ext_info_addr_dev_ != nullptr) { + // set ext info addr and ext info num + fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoAddr = reinterpret_cast(ext_info_addr_dev_->GetData()); + fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoLen = ext_info_addr_dev_->GetSize(); + } fwk_op_kernel.fwkKernelBase.fwk_kernel.stepIDAddr = GetStepIdAddr(model); @@ -506,7 +520,8 @@ Status AicpuTfNodeTask::UpdateIoAddr(TaskContext &context) { io_addrs.emplace_back(reinterpret_cast(inputData->GetData())); } - if (unknown_type_ != DEPEND_COMPUTE) { + // known shape or not depend compute + if (!node_item_->is_dynamic || unknown_type_ != DEPEND_COMPUTE) { // unknown type 4 do this in call back. GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); for (auto j = 0; j < node_item_->num_outputs; ++j) { @@ -548,14 +563,17 @@ Status AicpuTfNodeTask::LaunchTask(TaskContext &context) { } Status AicpuTfNodeTask::TaskCallback(TaskContext &context) { - GELOGI("Node[%s] task callback start. unknown_type=%d.", node_name_.c_str(), unknown_type_); + GELOGI("Node[%s] task callback start. is_dynamic=%s, unknown_type=%d.", node_name_.c_str(), + node_item_->is_dynamic ? "true" : "false", unknown_type_); Status callback_ret = SUCCESS; - // check need update shape, call update shape. - if (unknown_type_ == DEPEND_SHAPE_RANGE) { - // check result - callback_ret = UpdateOutputShapeFromExtInfo(); - } else if (unknown_type_ == DEPEND_COMPUTE) { - callback_ret = UpdateShapeAndDataByResultSummary(context); + if (node_item_->is_dynamic) { + // check need update shape, call update shape. + if (unknown_type_ == DEPEND_SHAPE_RANGE) { + // check result + callback_ret = UpdateOutputShapeFromExtInfo(); + } else if (unknown_type_ == DEPEND_COMPUTE) { + callback_ret = UpdateShapeAndDataByResultSummary(context); + } } GELOGI("Node[%s] task callback end.", node_name_.c_str()); return callback_ret; @@ -612,8 +630,13 @@ Status AicpuNodeTask::Init(const HybridModel &model) { GE_CHK_STATUS_RET(InitExtInfo(kernel_ext_info), "Node[%s] init ext info failed.", node_name.c_str()); - aicpu_param_head->extInfoLength = ext_info_addr_dev_->GetSize(); - aicpu_param_head->extInfoAddr = reinterpret_cast(ext_info_addr_dev_->GetData()); + if (ext_info_addr_dev_ == nullptr) { + aicpu_param_head->extInfoLength = 0; + aicpu_param_head->extInfoAddr = 0; + } else { + aicpu_param_head->extInfoLength = ext_info_addr_dev_->GetSize(); + aicpu_param_head->extInfoAddr = reinterpret_cast(ext_info_addr_dev_->GetData()); + } GELOGI("Node[%s] init end.", node_name.c_str()); return SUCCESS; @@ -664,10 +687,12 @@ Status AicpuNodeTask::LaunchTask(TaskContext &context) { } Status AicpuNodeTask::TaskCallback(TaskContext &context) { - GELOGI("Node[%s] task callback start, unknown_type=%d.", node_name_.c_str(), unknown_type_); + GELOGI("Node[%s] task callback start, is_dynamic = %s, unknown_type=%d.", node_name_.c_str(), + node_item_->is_dynamic ? "true" : "false", unknown_type_); Status callback_ret = SUCCESS; + // check need update shape, call update shape. - if (unknown_type_ == DEPEND_SHAPE_RANGE) { + if (node_item_->is_dynamic && unknown_type_ == DEPEND_SHAPE_RANGE) { // check result callback_ret = UpdateOutputShapeFromExtInfo(); } else { diff --git a/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h b/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h index ce3f9707..8aca6ff7 100644 --- a/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h +++ b/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h @@ -24,7 +24,6 @@ namespace ge { namespace hybrid { - class AicpuNodeTaskBase : public NodeTask { public: AicpuNodeTaskBase(const NodeItem *node_item, const domi::TaskDef &task_def) @@ -70,8 +69,10 @@ class AicpuNodeTaskBase : public NodeTask { const std::string node_type_; + // valid when node_item_->is_dynamic is true UnknowShapeOpType unknown_type_ = DEPEND_IN_SHAPE; + // valid when node_item_->is_dynamic is true AicpuExtInfoHandler aicpu_ext_handle_; // ext info addr, device mem @@ -169,7 +170,6 @@ class AiCpuNodeExecutor : public NodeExecutor { Status PrepareTask(NodeTask &task, TaskContext &context) const override; }; - } // namespace hybrid } // namespace ge #endif // GE_HYBRID_KERNEL_AICPU_NODE_EXECUTOR_H_ diff --git a/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc b/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc index 81960c48..afa53724 100644 --- a/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc +++ b/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc @@ -26,7 +26,6 @@ namespace ge { namespace hybrid { - REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::COMPILED_SUBGRAPH, KnownNodeExecutor); Status KnownNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { @@ -98,8 +97,11 @@ Status KnownNodeTask::Init(TaskContext &context) { GE_CHK_STATUS_RET(context.AllocateOutputs(), "known node task allocate output failed."); // init davinicmodel - davinci_model_->InitRuntimeParams(); - GE_CHK_STATUS_RET(davinci_model_->InitVariableMem(), "init variable mem failed."); + if (!load_flag_) { + davinci_model_->InitRuntimeParams(); + GE_CHK_STATUS_RET(davinci_model_->InitVariableMem(), "init variable mem failed."); + } + // allocate mem base void *buffer = nullptr; if (davinci_model_->TotalMemSize() != 0) { @@ -161,6 +163,5 @@ Status KnownNodeExecutor::ExecuteTask(NodeTask &task, TaskContext &context, context.GetNodeItem().NodeName().c_str()); return SUCCESS; } - } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/node_executor/controlop/control_op_executor.cc b/src/ge/hybrid/node_executor/controlop/control_op_executor.cc new file mode 100644 index 00000000..1f18db3d --- /dev/null +++ b/src/ge/hybrid/node_executor/controlop/control_op_executor.cc @@ -0,0 +1,318 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "control_op_executor.h" +#include "graph/utils/node_utils.h" +#include "hybrid/executor/hybrid_execution_context.h" +#include "hybrid/executor/subgraph_executor.h" + +namespace ge { +namespace hybrid { +REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::CONTROL_OP, ControlOpNodeExecutor); + +Status ControlOpNodeTask::ExecuteSubgraph(const GraphItem *subgraph, TaskContext &task_context, + const std::function &done_callback) { + GELOGD("[%s] Start to execute subgraph.", subgraph->GetName().c_str()); + auto execution_context = const_cast(task_context.GetExecutionContext()); + auto executor = MakeShared(subgraph, execution_context, task_context.IsForceInferShape()); + GE_CHECK_NOTNULL(executor); + GE_CHK_STATUS_RET(executor->ExecuteAsync(task_context), "[%s] Failed to execute partitioned call.", + subgraph->GetName().c_str()); + + auto callback = [executor, done_callback]() mutable { + if (done_callback != nullptr) { + done_callback(); + } + // executor must outlive task context + executor.reset(); + }; + + GE_CHK_STATUS_RET_NOLOG(task_context.RegisterCallback(callback)); + GELOGD("[%s] Done executing subgraph successfully.", subgraph->GetName().c_str()); + return SUCCESS; +} + +Status ControlOpNodeTask::CopyTensorValueToHost(const TensorValue &tensor, int32_t &value) { + GE_CHECK_NOTNULL(tensor.GetData()); + GE_CHECK_GE(tensor.GetSize(), sizeof(value)); + GE_CHK_RT_RET(rtMemcpy(&value, sizeof(value), tensor.GetData(), sizeof(value), RT_MEMCPY_DEVICE_TO_HOST)); + return SUCCESS; +} + +Status ControlOpNodeTask::UpdateArgs(TaskContext &context) { + // do nothing + return SUCCESS; +} + +Status ControlOpNodeTask::ExecuteAsync(TaskContext &task_context, std::function done_callback) { + auto ret = DoExecuteAsync(task_context, done_callback); + task_context.SetStatus(ret); + + if (done_callback) { + done_callback(); + } + + return ret; +} + +Status IfOpNodeTask::Init(const NodePtr &node, const HybridModel &model) { + GELOGD("[%s] Start to init IfOpNodeTask.", node->GetName().c_str()); + auto then_subgraph = NodeUtils::GetSubgraph(*node, kThenBranchIndex); + GE_CHECK_NOTNULL(then_subgraph); + GELOGD("[%s] Adding subgraph [%s] to then-subgraph.", node->GetName().c_str(), then_subgraph->GetName().c_str()); + then_ = model.GetSubgraphItem(then_subgraph); + GE_CHECK_NOTNULL(then_); + + auto else_subgraph = NodeUtils::GetSubgraph(*node, kElseBranchIndex); + GE_CHECK_NOTNULL(else_subgraph); + GELOGD("[%s] Adding subgraph [%s] to else-subgraph.", node->GetName().c_str(), else_subgraph->GetName().c_str()); + else_ = model.GetSubgraphItem(else_subgraph); + GE_CHECK_NOTNULL(else_); + + GELOGD("[%s] Done initialization successfully.", node->GetName().c_str()); + return SUCCESS; +} + +const GraphItem *IfOpNodeTask::SelectBranch(int32_t cond) const { return cond != 0 ? then_ : else_; } + +Status IfOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::function &done_callback) const { + auto cond_tensor = task_context.GetInput(kIfCondIndex); + GE_CHECK_NOTNULL(cond_tensor); + int32_t cond_val = 0; + GE_CHK_STATUS_RET(CopyTensorValueToHost(*cond_tensor, cond_val), "[%s] Failed to get cond value.", + task_context.GetNodeName()); + + auto subgraph = SelectBranch(cond_val); + GELOGD("[%s] Taking subgraph [%s] by cond = [%d]", task_context.GetNodeName(), subgraph->GetName().c_str(), cond_val); + GE_CHK_STATUS_RET(ExecuteSubgraph(subgraph, task_context, done_callback), + "[%s] Failed to execute subgraph. cond = %d", task_context.GetNodeName(), cond_val); + + GELOGD("[%s] Done executing with cond = %d successfully.", task_context.GetNodeName(), cond_val); + return SUCCESS; +} + +Status CaseOpNodeTask::Init(const NodePtr &node, const HybridModel &model) { + size_t num_subgraphs = node->GetOpDesc()->GetSubgraphInstanceNames().size(); + GE_CHECK_LE(num_subgraphs, kMaxBranchNum); + GE_CHECK_GE(num_subgraphs, kMinBranchNum); + auto num_branches = static_cast(num_subgraphs); + GELOGD("[%s] Start to init CaseOpNodeTask with %u branches.", node->GetName().c_str(), num_branches); + + for (uint32_t i = 0; i < num_branches; ++i) { + auto sub_graph = NodeUtils::GetSubgraph(*node, i); + GE_CHECK_NOTNULL(sub_graph); + auto graph_item = model.GetSubgraphItem(sub_graph); + GE_CHECK_NOTNULL(graph_item); + GELOGD("[%s] Adding subgraph [%s] to branch %u.", node->GetName().c_str(), sub_graph->GetName().c_str(), i); + subgraphs_.emplace_back(graph_item); + } + + GELOGD("[%s] Done initialization successfully.", node->GetName().c_str()); + return SUCCESS; +} + +const GraphItem *CaseOpNodeTask::SelectBranch(int32_t branch_index) const { + // subgraphs_ is non-empty. checked int Init + if (branch_index < 0 || static_cast(branch_index) >= subgraphs_.size()) { + GELOGI("Branch index out of range. index = %d, num_subgraphs = %zu, will taking last branch.", branch_index, + subgraphs_.size()); + branch_index = subgraphs_.size() - 1; + } + + return subgraphs_[branch_index]; +} + +Status CaseOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::function &done_callback) const { + auto branch_tensor = task_context.GetInput(kCaseBranchIndex); + GE_CHECK_NOTNULL(branch_tensor); + int32_t branch_index = 0; + GE_CHK_STATUS_RET(CopyTensorValueToHost(*branch_tensor, branch_index), "[%s] Failed to get branch index.", + task_context.GetNodeName()); + + const GraphItem *subgraph = SelectBranch(branch_index); + GELOGI("[%s] Taking subgraph [%s] by branch = [%d]", task_context.GetNodeName(), subgraph->GetName().c_str(), + branch_index); + + std::vector inputs; + std::vector outputs; + for (int i = 0; i < task_context.NumInputs(); ++i) { + auto input_tensor = task_context.GetInput(i); + GE_CHECK_NOTNULL(input_tensor); + inputs.emplace_back(*input_tensor); + } + + GE_CHK_STATUS_RET(ExecuteSubgraph(subgraph, task_context, done_callback), "[%s] Failed to execute else-subgraph.", + task_context.GetNodeName()); + + GELOGD("[%s] Done executing subgraph[%d] successfully.", task_context.GetNodeName(), branch_index); + return SUCCESS; +} + +Status WhileOpNodeTask::Init(const NodePtr &node, const HybridModel &model) { + GELOGD("[%s] Start to init WhileOpNodeTask.", node->GetName().c_str()); + auto cond_subgraph = NodeUtils::GetSubgraph(*node, kCondBranchIndex); + GE_CHECK_NOTNULL(cond_subgraph); + GELOGD("[%s] Adding subgraph [%s] to cond-subgraph.", node->GetName().c_str(), cond_subgraph->GetName().c_str()); + cond_ = model.GetSubgraphItem(cond_subgraph); + GE_CHECK_NOTNULL(cond_); + + auto body_subgraph = NodeUtils::GetSubgraph(*node, kBodyBranchIndex); + GE_CHECK_NOTNULL(body_subgraph); + GELOGD("[%s] Adding subgraph [%s] to body-subgraph.", node->GetName().c_str(), body_subgraph->GetName().c_str()); + body_ = model.GetSubgraphItem(body_subgraph); + GE_CHECK_NOTNULL(body_); + + GELOGD("[%s] Done initialization successfully.", node->GetName().c_str()); + return SUCCESS; +} + +Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::function &done_callback) const { + if (task_context.NumInputs() != task_context.NumOutputs()) { + GELOGE(INTERNAL_ERROR, "[%s] Invalid while args. num_inputs = %d, num_outputs = %d", task_context.GetNodeName(), + task_context.NumInputs(), task_context.NumOutputs()); + return INTERNAL_ERROR; + } + + // graph build can not set accurate flag unknown_shape_status by now. + // Treating all nodes in while scope as unknown shape. + task_context.SetForceInferShape(true); + + int iteration = 0; + while (true) { + bool is_continue = false; + GELOGD("[%s] Start to execute, iteration = %d", task_context.GetNodeName(), iteration); + GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue), "[%s] Failed to execute iteration %d.", + task_context.GetNodeName(), iteration); + + if (!is_continue) { + GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration); + break; + } + + ++iteration; + } + + return SUCCESS; +} + +Status WhileOpNodeTask::ExecuteCond(TaskContext &task_context, bool &is_continue) const { + std::vector inputs; + std::vector input_desc; + std::vector output_desc; + for (int i = 0; i < task_context.NumInputs(); ++i) { + auto input_tensor = task_context.GetInput(i); + GE_CHECK_NOTNULL(input_tensor); + inputs.emplace_back(*input_tensor); + input_desc.emplace_back(task_context.GetInputDesc(i)); + } + + auto execution_context = const_cast(task_context.GetExecutionContext()); + auto executor = MakeShared(cond_, execution_context, task_context.IsForceInferShape()); + GE_CHECK_NOTNULL(executor); + GELOGD("[%s] Start to execute cond-subgraph.", task_context.GetNodeName()); + GE_CHK_STATUS_RET(executor->ExecuteAsync(inputs, input_desc), "Failed to execute partitioned call."); + GELOGD("[%s] Done executing cond-subgraph successfully.", cond_->GetName().c_str()); + GE_CHK_STATUS_RET_NOLOG(task_context.RegisterCallback([executor]() mutable { executor.reset(); })); + + // get cond output + GE_CHK_STATUS_RET(executor->Synchronize(), "[%s] Failed to sync cond-subgraph result.", cond_->GetName().c_str()); + std::vector cond_outputs; + GE_CHK_STATUS_RET(executor->GetOutputs(cond_outputs), "[%s] Failed to get cond-output.", cond_->GetName().c_str()); + if (cond_outputs.empty()) { + GELOGE(INTERNAL_ERROR, "[%s] Cond output is empty.", task_context.GetNodeName()); + return INTERNAL_ERROR; + } + + int cond_val = 0; + GE_CHK_STATUS_RET(CopyTensorValueToHost(cond_outputs[0], cond_val), "[%s] Failed to get cond result.", + task_context.GetNodeName()); + is_continue = cond_val != 0; + return SUCCESS; +} + +Status WhileOpNodeTask::MoveOutputs2Inputs(TaskContext &task_context) { + // set outputs to inputs for next iteration + for (int i = 0; i < task_context.NumInputs(); ++i) { + auto input_tensor = task_context.MutableInput(i); + auto output_tensor = task_context.MutableOutput(i); + GE_CHECK_NOTNULL(input_tensor); + GE_CHECK_NOTNULL(output_tensor); + *input_tensor = *output_tensor; + output_tensor->Destroy(); + + auto output_tensor_desc = task_context.MutableOutputDesc(i); + GE_CHECK_NOTNULL(output_tensor_desc); + GELOGD("[%s] To update input shape[%d] by output shape. from [%s] to [%s]", task_context.GetNodeName(), i, + task_context.MutableInputDesc(i)->GetShape().ToString().c_str(), + output_tensor_desc->GetShape().ToString().c_str()); + *task_context.MutableInputDesc(i) = *output_tensor_desc; + } + + return SUCCESS; +} + +Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const { + GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), "[%s] Failed to execute cond-subgraph", + task_context.GetNodeName()); + if (!is_continue) { + for (int i = 0; i < task_context.NumInputs(); ++i) { + auto input_tensor = task_context.GetInput(i); + GE_CHECK_NOTNULL(input_tensor); + task_context.SetOutput(i, *input_tensor); + } + return SUCCESS; + } + + GELOGD("[%s] Start to execute body-subgraph.", task_context.GetNodeName()); + GE_CHK_STATUS_RET(ExecuteSubgraph(body_, task_context, nullptr), "[%s] Failed to execute cond-subgraph", + task_context.GetNodeName()); + GELOGD("[%s] Done executing body-subgraph successfully.", task_context.GetNodeName()); + + // set outputs to inputs for next iteration + GE_CHK_STATUS_RET(MoveOutputs2Inputs(task_context), "[%s] Failed to move outputs to inputs", + task_context.GetNodeName()); + + return SUCCESS; +} + +Status ControlOpNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, + shared_ptr &task) const { + auto node_item = model.GetNodeItem(node); + GE_CHECK_NOTNULL(node_item); + + unique_ptr node_task; + auto node_type = node->GetType(); + if (node_type == IF) { + node_task.reset(new (std::nothrow) IfOpNodeTask()); + } else if (node_type == CASE) { + node_task.reset(new (std::nothrow) CaseOpNodeTask()); + } else if (node_type == WHILE) { + node_task.reset(new (std::nothrow) WhileOpNodeTask()); + } else { + GELOGE(PARAM_INVALID, "[%s] Unsupported type: %s", node->GetName().c_str(), node_type.c_str()); + return PARAM_INVALID; + } + + GE_CHECK_NOTNULL(node_task); + GE_CHK_STATUS_RET(node_task->Init(node, model), "[%s] Failed to init ControlOpNodeTask.", node->GetName().c_str()); + + task = std::move(node_task); + return SUCCESS; +} + +Status ControlOpNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { return SUCCESS; } +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/node_executor/controlop/control_op_executor.h b/src/ge/hybrid/node_executor/controlop/control_op_executor.h new file mode 100644 index 00000000..0619c6a0 --- /dev/null +++ b/src/ge/hybrid/node_executor/controlop/control_op_executor.h @@ -0,0 +1,100 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_CONTROLOP_CONTROL_OP_EXECUTOR_H_ +#define GE_HYBRID_CONTROLOP_CONTROL_OP_EXECUTOR_H_ + +#include +#include "hybrid/node_executor/node_executor.h" +#include "hybrid/model/graph_item.h" + +namespace ge { +namespace hybrid { +class ControlOpNodeTask : public NodeTask { + public: + virtual Status Init(const NodePtr &node, const HybridModel &model) = 0; + Status UpdateArgs(TaskContext &context) override; + + Status ExecuteAsync(TaskContext &task_context, std::function done_callback) override; + + protected: + virtual Status DoExecuteAsync(TaskContext &task_context, const std::function &done_callback) const = 0; + static Status CopyTensorValueToHost(const TensorValue &tensor_value, int32_t &value); + static Status ExecuteSubgraph(const GraphItem *subgraph, TaskContext &task_context, + const std::function &done_callback); +}; + +class IfOpNodeTask : public ControlOpNodeTask { + public: + Status Init(const NodePtr &node, const HybridModel &model) override; + + protected: + const GraphItem *SelectBranch(int32_t cond) const; + Status DoExecuteAsync(TaskContext &task_context, const std::function &done_callback) const override; + + private: + static constexpr int kIfCondIndex = 0; + static constexpr int kThenBranchIndex = 0; + static constexpr int kElseBranchIndex = 1; + + const GraphItem *then_ = nullptr; + const GraphItem *else_ = nullptr; +}; + +class CaseOpNodeTask : public ControlOpNodeTask { + public: + Status Init(const NodePtr &node, const HybridModel &model) override; + + protected: + const GraphItem *SelectBranch(int32_t branch_index) const; + Status DoExecuteAsync(TaskContext &task_context, const std::function &done_callback) const override; + + private: + static constexpr int kCaseBranchIndex = 0; + static constexpr size_t kMaxBranchNum = INT32_MAX; + static constexpr size_t kMinBranchNum = 1; + + std::vector subgraphs_; +}; + +class WhileOpNodeTask : public ControlOpNodeTask { + public: + Status Init(const NodePtr &node, const HybridModel &model) override; + + protected: + Status DoExecuteAsync(TaskContext &task_context, const std::function &done_callback) const override; + Status ExecuteCond(TaskContext &task_context, bool &is_continue) const; + + static Status MoveOutputs2Inputs(TaskContext &task_context); + + Status ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const; + + private: + static constexpr int kCondBranchIndex = 0; + static constexpr int kBodyBranchIndex = 1; + + const GraphItem *cond_ = nullptr; + const GraphItem *body_ = nullptr; +}; + +class ControlOpNodeExecutor : public NodeExecutor { + public: + Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const override; + Status PrepareTask(NodeTask &task, TaskContext &context) const override; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_CONTROLOP_CONTROL_OP_EXECUTOR_H_ diff --git a/src/ge/hybrid/node_executor/hccl/hccl_node_executor.cc b/src/ge/hybrid/node_executor/hccl/hccl_node_executor.cc new file mode 100644 index 00000000..e86c0cb0 --- /dev/null +++ b/src/ge/hybrid/node_executor/hccl/hccl_node_executor.cc @@ -0,0 +1,207 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/node_executor/hccl/hccl_node_executor.h" +#include "graph/manager/util/hcom_util.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/fmk_error_codes.h" +#include "common/ge/ge_util.h" +#include "common/ge/plugin_manager.h" +#include "graph/attr_value.h" +#include "graph/debug/ge_attr_define.h" +#include "hccl/hcom.h" + +namespace ge { +namespace hybrid { + +REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::HCCL, HcclNodeExecutor); + +Status HcclNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + GELOGI("[%s] HcclNodeTask::ExecuteAsync in.", context.GetNodeName()); + if (context.handle_ == nullptr) { + GELOGE(FAILED, "hccl handle is nullptr! "); + return FAILED; + } + auto EnqueueHcomOpertion = (hcclResult_t(*)(HcomOpertion, std::function))dlsym( + context.handle_, "EnqueueHcomOpertion"); + if (EnqueueHcomOpertion == nullptr) { + GELOGE(FAILED, "Failed to invoke EnqueueHcomOpertion hcom unknown node function."); + if (dlclose(context.handle_) != 0) { + GELOGW("Failed to close handle %s", dlerror()); + } + return FAILED; + } + + vector inputs; + for (int i = 0; i < context.NumInputs(); ++i) { + TensorValue *tv = context.MutableInput(i); + GE_CHECK_NOTNULL(tv); + inputs.emplace_back(tv->MutableData()); + } + + vector outputs; + for (int i = 0; i < context.NumOutputs(); ++i) { + TensorValue *tv = context.MutableOutput(i); + GE_CHECK_NOTNULL(tv); + outputs.emplace_back(tv->MutableData()); + } + + const NodeItem &node_item = context.GetNodeItem(); + const OpDescPtr op_desc = MakeShared(*(node_item.op_desc)); + GE_CHECK_NOTNULL(op_desc); + + HcomOpertion op_info; + op_info.hcclType = op_desc->GetType(); + op_info.inputPtr = inputs.empty() ? nullptr : inputs[0]; + op_info.outputPtr = outputs.empty() ? nullptr : outputs[0]; + ge::DataType src_data_type = op_desc->GetInputDescPtr(0)->GetDataType(); + auto iter = kConstOpHcclDataType.find(static_cast(src_data_type)); + if (iter == kConstOpHcclDataType.end()) { + GELOGE(PARAM_INVALID, "kConstOpHcclDataType find failed."); + return PARAM_INVALID; + } + op_info.dataType = iter->second; + hcclRedOp_t op_type = HCCL_REP_OP_SUM; + if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HCOMREDUCESCATTER || + op_desc->GetType() == HVDCALLBACKALLREDUCE) { + GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), "GetHcclOperationType failed"); + op_info.opType = op_type; + } + int64_t root_id = 0; + if (op_desc->GetType() == HCOMBROADCAST) { + GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclRootId(op_desc, root_id), "GetHcclRootId failed"); + } + op_info.root = root_id; + auto callback = [this](hcclResult_t status) { + if (status != HCCL_SUCCESS) { + GELOGE(HCCL_E_INTERNAL, "Call HcomExcutorInitialize failed, ret: 0x%X", status); + } + std::lock_guard lock(this->hccl_mutex_); + this->cond_.notify_all(); + GELOGI("hccl callback success."); + }; + int32_t count = 0; + GE_CHK_STATUS_RET(HcomOmeUtil::GetHcomCount(op_desc, static_cast(op_info.dataType), false, count), + "GetHcomCount failed"); + GELOGI("[%s] HcclNodeTask::ExecuteAsync hccl_type %s, count %d, data_type %d, op_type %d, root %d.", + context.GetNodeName(), op_info.hcclType.c_str(), count, op_info.dataType, op_info.opType, op_info.root); + op_info.count = count; + + hcclResult_t hccl_ret = EnqueueHcomOpertion(op_info, callback); + if (hccl_ret != HCCL_SUCCESS) { + GELOGE(HCCL_E_INTERNAL, "Call HcomExcutorInitialize failed, ret: 0x%X", hccl_ret); + return HCCL_E_INTERNAL; + } + + // pending until hccl finished + std::unique_lock ulock(hccl_mutex_); + cond_.wait(ulock); + + context.RegisterCallback(done_callback); + GELOGI("[%s] HcclNodeTask::ExecuteAsync success.", context.GetNodeName()); + return SUCCESS; +} + +Status HcclNodeTask::UpdateArgs(TaskContext &context) { return SUCCESS; } + +Status HcclNodeTask::Init(TaskContext &context) { + GELOGI("[%s] HcclNodeExecutor::Init success.", context.GetNodeName()); + return SUCCESS; +} + +Status HcclNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { + GELOGI("[%s] HcclNodeExecutor::PrepareTask in.", context.GetNodeName()); + + GE_CHK_STATUS_RET(task.Init(context), "hccl node load hccl so failed."); + // allocate output mem + GE_CHK_STATUS_RET(context.AllocateOutputs(), "hccl node task allocate output failed."); + + GE_CHK_STATUS_RET(task.UpdateArgs(context), "hccl node task update args failed."); + GELOGI("[%s] HcclNodeExecutor::PrepareTask success.", context.GetNodeName()); + return SUCCESS; +} + +Status HcclNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const { + GELOGI("[%s] HcclNodeExecutor::LoadTask in.", node->GetName().c_str()); + GE_CHECK_NOTNULL(node); + + task = MakeShared(); + GE_CHECK_NOTNULL(task); + GELOGI("[%s] HcclNodeExecutor::LoadTask success.", node->GetName().c_str()); + return SUCCESS; +} + +Status HcclNodeExecutor::ExecuteTask(NodeTask &task, TaskContext &context, + const std::function &callback) const { + context.handle_ = handle_; + GE_CHK_STATUS_RET(task.ExecuteAsync(context, callback), "Failed to execute task. node = %s", + context.GetNodeItem().NodeName().c_str()); + return SUCCESS; +} + +Status HcclNodeExecutor::Initialize() { + std::string file_name = "libhccl.so"; + std::string path = PluginManager::GetPath(); + path.append(file_name); + string canonical_path = RealPath(path.c_str()); + if (canonical_path.empty()) { + GELOGW("failed to get realpath of %s", path.c_str()); + return FAILED; + } + + GELOGI("FileName:%s, Path:%s.", file_name.c_str(), canonical_path.c_str()); + handle_ = dlopen(canonical_path.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (handle_ == nullptr) { + GELOGE(GE_PLGMGR_SO_NOT_EXIST, "Failed in dlopen %s! ", dlerror()); + return FAILED; + } + auto HcomExcutorInitialize = (hcclResult_t(*)())dlsym(handle_, "HcomExcutorInitialize"); + if (HcomExcutorInitialize == nullptr) { + GELOGE(FAILED, "Failed to invoke HcomExcutorInitialize hcom unknown node function."); + return FAILED; + } + hcclResult_t hccl_ret = HcomExcutorInitialize(); + if (hccl_ret == HCCL_E_PTR) { + GELOGI("Hccl comm is null, hcom executor initialize is not required."); + } else if (hccl_ret == HCCL_SUCCESS) { + GELOGI("Hcom executor initialize success."); + } else { + GELOGE(FAILED, "Call HcomExcutorInitialize failed, ret: 0x%X", hccl_ret); + return FAILED; + } + return SUCCESS; +} + +Status HcclNodeExecutor::Finalize() { + auto HcomExcutorFinalize = (hcclResult_t(*)())dlsym(handle_, "HcomExcutorFinalize"); + if (HcomExcutorFinalize == nullptr) { + GELOGE(FAILED, "Failed to invoke HcomExcutorFinalize hcom unknown node function."); + return FAILED; + } + hcclResult_t hccl_ret = HcomExcutorFinalize(); + if (hccl_ret != HCCL_SUCCESS) { + GELOGE(FAILED, "Call HcomExcutorFinalize failed, ret: 0x%X", hccl_ret); + return FAILED; + } + // dlclose file handle + if (dlclose(handle_) != 0) { + GELOGW("Failed to close handle %s", dlerror()); + } + GELOGI("Hcom executor finalize success."); + return SUCCESS; +} +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/node_executor/hccl/hccl_node_executor.h b/src/ge/hybrid/node_executor/hccl/hccl_node_executor.h new file mode 100644 index 00000000..8791c4e3 --- /dev/null +++ b/src/ge/hybrid/node_executor/hccl/hccl_node_executor.h @@ -0,0 +1,59 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef HYBRID_HCCL_NODE_EXECUTOR_H_ +#define HYBRID_HCCL_NODE_EXECUTOR_H_ +#include "hybrid/node_executor/node_executor.h" +#include "hybrid/model/hybrid_model.h" +#include "graph/op_desc.h" + +namespace ge { +namespace hybrid { +class HybridModel; + +class HcclNodeTask : public NodeTask { + public: + HcclNodeTask() {} + + ~HcclNodeTask() {} + + Status UpdateArgs(TaskContext &context) override; + Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + Status Init(TaskContext &context) override; + + private: + std::shared_ptr davinci_model_ = nullptr; + bool load_flag_ = false; + std::mutex hccl_mutex_; + std::condition_variable cond_; +}; + +class HcclNodeExecutor : public NodeExecutor { + public: + Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const; + Status PrepareTask(NodeTask &task, TaskContext &context) const; + Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function &callback) const; + Status Initialize() override; + Status Finalize() override; + ~HcclNodeExecutor() {} + + private: + void *handle_; +}; +} // namespace hybrid +} // namespace ge + +#endif // HYBRID_HCCL_NODE_EXECUTOR_H_ diff --git a/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.cc b/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.cc index c3bc9a41..d353dff1 100644 --- a/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.cc +++ b/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.cc @@ -17,14 +17,12 @@ #include "hybrid/node_executor/hostcpu/ge_local_node_executor.h" #include "graph/debug/ge_attr_define.h" #include "framework/common/util.h" -#include "framework/common/types.h" +#include "hybrid/model/hybrid_model.h" #include "inc/kernel.h" #include "inc/kernel_factory.h" -#include "common/ge/ge_util.h" namespace ge { namespace hybrid { - REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::GE_LOCAL, GeLocalNodeExecutor); const std::unordered_map> RefInputTask::out_ref_input_index_ = { @@ -132,7 +130,7 @@ Status DependInputShapeTask::Execute(TaskContext &context) { } // alloc output - GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); + GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs(NpuMemoryAllocator::AttrWithDefaultPadding())); // copy data to output for (auto i = 0; i < output_num; ++i) { @@ -194,6 +192,16 @@ Status GeLocalNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &no node_type.c_str()); return MEMALLOC_FAILED; } + } else if (node_type == CONSTANTOP || node_type == VARIABLE) { + GELOGI("node %s type %s, use ConstantNodeTask.", node->GetName().c_str(), node_type.c_str()); + auto tensor = model.GetVariable(node->GetName()); + if (tensor == nullptr) { + GELOGE(INTERNAL_ERROR, "Failed to get tensor by name: %s", node->GetName().c_str()); + return INTERNAL_ERROR; + } + + task = MakeShared(tensor); + GE_CHECK_NOTNULL(task); } else { GELOGE(UNSUPPORTED, "node %s type %s is not support in GeLocalNodeExecutor now.", node->GetName().c_str(), node_type.c_str()); @@ -202,5 +210,20 @@ Status GeLocalNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &no return SUCCESS; } +ConstantNodeTask::ConstantNodeTask(const TensorValue *tensor) : tensor_(tensor) {} + +Status ConstantNodeTask::UpdateArgs(TaskContext &context) { return SUCCESS; } + +Status ConstantNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + GELOGD("[%s] Start execute.", context.GetNodeName()); + GE_CHK_STATUS_RET(context.SetOutput(0, *tensor_), "[%s] Failed to set output.", context.GetNodeName()); + if (done_callback) { + GELOGD("[%s] Start invoke callback.", context.GetNodeName()); + done_callback(); + } + + GELOGD("[%s] Done execute successfully.", context.GetNodeName()); + return SUCCESS; +} } // namespace hybrid } // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.h b/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.h index beb1f50d..0195e76c 100644 --- a/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.h +++ b/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.h @@ -23,7 +23,6 @@ namespace ge { namespace hybrid { - class RefInputTask : public NodeTask { public: explicit RefInputTask(const NodePtr &node) : node_name_(node->GetName()), node_type_(node->GetType()) {} @@ -68,6 +67,18 @@ class DependInputShapeTask : public NodeTask { static const std::unordered_set depend_input_shape_ops_; }; +class ConstantNodeTask : public NodeTask { + public: + explicit ConstantNodeTask(const TensorValue *tensor); + ~ConstantNodeTask() = default; + Status UpdateArgs(TaskContext &context) override; + + Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + + private: + const TensorValue *tensor_; +}; + class GeLocalNodeExecutor : public NodeExecutor { public: Status PrepareTask(NodeTask &task, TaskContext &context) const override; diff --git a/src/ge/hybrid/node_executor/node_executor.cc b/src/ge/hybrid/node_executor/node_executor.cc index f3b86948..016ec6ef 100644 --- a/src/ge/hybrid/node_executor/node_executor.cc +++ b/src/ge/hybrid/node_executor/node_executor.cc @@ -16,6 +16,7 @@ #include "hybrid/node_executor/node_executor.h" #include "framework/common/debug/log.h" +#include "graph/utils/node_utils.h" #include "init/gelib.h" #include "hybrid/model/hybrid_model.h" @@ -25,9 +26,11 @@ namespace { const char *const kEngineNameAiCore = "AIcoreEngine"; const char *const kEngineNameGeLocal = "DNN_VM_GE_LOCAL_OP_STORE"; const char *const kEngineNameAiCpu = "aicpu_kernel"; +const char *const kEngineNameHccl = "ops_kernel_info_hccl"; } // namespace Status NodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); + GE_CHK_STATUS_RET_NOLOG(task.UpdateTilingData(context)); // update op_desc before alloc ws GE_CHK_STATUS_RET_NOLOG(context.AllocateWorkspaces()); GE_CHK_STATUS_RET_NOLOG(task.UpdateArgs(context)); return SUCCESS; @@ -48,6 +51,7 @@ Status NodeExecutor::CompileTask(const HybridModel &model, const NodePtr &node, } Status NodeExecutorManager::EnsureInitialized() { + GE_CHK_STATUS_RET(InitializeExecutors()); std::lock_guard lk(mu_); if (initialized_) { return SUCCESS; @@ -56,6 +60,7 @@ Status NodeExecutorManager::EnsureInitialized() { engine_mapping_.emplace(kEngineNameAiCore, NodeExecutorManager::ExecutorType::AICORE); engine_mapping_.emplace(kEngineNameGeLocal, NodeExecutorManager::ExecutorType::GE_LOCAL); engine_mapping_.emplace(kEngineNameAiCpu, NodeExecutorManager::ExecutorType::AICPU_TF); + engine_mapping_.emplace(kEngineNameHccl, NodeExecutorManager::ExecutorType::HCCL); std::shared_ptr instance_ptr = GELib::GetInstance(); if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { @@ -69,22 +74,6 @@ Status NodeExecutorManager::EnsureInitialized() { kernel_stores_.emplace(it.first, it.second); } - GELOGI("Start to Initialize NodeExecutors"); - for (auto &it : builders_) { - auto engine_type = it.first; - auto build_fn = it.second; - GE_CHECK_NOTNULL(build_fn); - auto executor = std::unique_ptr(build_fn()); - if (executor == nullptr) { - GELOGE(INTERNAL_ERROR, "Failed to create executor for engine type = %d", engine_type); - return INTERNAL_ERROR; - } - - GELOGD("Executor of engine type = %d was created successfully", engine_type); - GE_CHK_STATUS_RET(executor->Initialize(), "Failed to initialize NodeExecutor of type = %d", engine_type); - executors_.emplace(engine_type, std::move(executor)); - } - initialized_ = true; GELOGI("Initializing NodeExecutors successfully"); return SUCCESS; @@ -93,6 +82,11 @@ Status NodeExecutorManager::EnsureInitialized() { NodeExecutorManager::ExecutorType NodeExecutorManager::ResolveExecutorType(Node &node) const { auto op_type = node.GetType(); if (op_type == PARTITIONEDCALL) { + bool is_dynamic = false; + (void)NodeUtils::GetNodeUnknownShapeStatus(node, is_dynamic); + if (is_dynamic) { + return ExecutorType::DYNAMIC_SUBGRAPH; + } return ExecutorType::COMPILED_SUBGRAPH; } @@ -101,6 +95,10 @@ NodeExecutorManager::ExecutorType NodeExecutorManager::ResolveExecutorType(Node return ExecutorType::GE_LOCAL; } + if (op_type == IF || op_type == CASE || op_type == WHILE) { + return ExecutorType::CONTROL_OP; + } + auto op_desc = node.GetOpDesc(); // checked before const auto &lib_name = op_desc->GetOpKernelLibName(); auto it = engine_mapping_.find(lib_name); @@ -116,10 +114,11 @@ Status NodeExecutorManager::GetExecutor(Node &node, const NodeExecutor **executo auto executor_type = ResolveExecutorType(node); const auto it = executors_.find(executor_type); if (it == executors_.end()) { - GELOGE(INTERNAL_ERROR, "Failed to get executor by type: %d", executor_type); + GELOGE(INTERNAL_ERROR, "Failed to get executor by type: %d.", executor_type); return INTERNAL_ERROR; } + GELOGD("[%s] Set node executor by type: %d.", node.GetName().c_str(), executor_type); *executor = it->second.get(); return SUCCESS; } @@ -132,6 +131,11 @@ void NodeExecutorManager::RegisterExecutorBuilder(NodeExecutorManager::ExecutorT Status NodeExecutorManager::CalcOpRunningParam(Node &node) const { auto op_desc = node.GetOpDesc(); GE_CHECK_NOTNULL(op_desc); + if (op_desc->GetType() == PARTITIONEDCALL) { + GELOGD("[%s] Skipping CalcOpRunningParam for PartitionedCall.", node.GetName().c_str()); + return SUCCESS; + } + auto it = kernel_stores_.find(op_desc->GetOpKernelLibName()); if (it == kernel_stores_.end()) { GELOGE(INTERNAL_ERROR, "Failed to get OpKernelStore. libName = %s, node = %s", @@ -139,9 +143,91 @@ Status NodeExecutorManager::CalcOpRunningParam(Node &node) const { return INTERNAL_ERROR; } + // calc hccl output size independent, hccl ops kernel manager should GetSize for + // input which is the output size of input-op, but sometimes return error + // when multi-thread + if (op_desc->GetOpKernelLibName() == kEngineNameHccl) { + for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) { + GeTensorDesc output_tensor = op_desc->GetOutputDesc(static_cast(i)); + Format format = output_tensor.GetFormat(); + DataType data_type = output_tensor.GetDataType(); + GeShape output_shape = output_tensor.GetShape(); + int64_t output_mem_size = 0; + GE_CHK_STATUS_RET(TensorUtils::CalcTensorMemSize(output_shape, format, data_type, output_mem_size), + "hccl calc tensor mem size failed."); + output_mem_size = + ((output_mem_size + MEMORY_ALIGN_RATIO * MEMORY_ALIGN_SIZE - 1) / MEMORY_ALIGN_SIZE) * MEMORY_ALIGN_SIZE; + TensorUtils::SetSize(output_tensor, output_mem_size); + GE_CHK_STATUS_RET(op_desc->UpdateOutputDesc(static_cast(i), output_tensor), + "hccl update output size failed."); + GELOGD("%s output desc[%u], dim_size: %zu, mem_size: %ld.", node.GetName().c_str(), i, + output_tensor.GetShape().GetDimNum(), output_mem_size); + } + return SUCCESS; + } return it->second->CalcOpRunningParam(node); } +Status NodeExecutorManager::InitializeExecutors() { + std::lock_guard lk(mu_); + if (executor_initialized_) { + ++ref_count_; + GELOGI("Executor is already initialized. add ref count to [%d]", ref_count_); + return SUCCESS; + } + + GELOGI("Start to Initialize NodeExecutors"); + for (auto &it : builders_) { + auto engine_type = it.first; + auto build_fn = it.second; + GE_CHECK_NOTNULL(build_fn); + auto executor = std::unique_ptr(build_fn()); + if (executor == nullptr) { + GELOGE(INTERNAL_ERROR, "Failed to create executor for engine type = %d", engine_type); + return INTERNAL_ERROR; + } + + GELOGD("Executor of engine type = %d was created successfully", engine_type); + auto ret = executor->Initialize(); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to initialize NodeExecutor of type = %d, clear executors", engine_type); + for (auto &executor_it : executors_) { + executor_it.second->Finalize(); + } + executors_.clear(); + return ret; + } + + executors_.emplace(engine_type, std::move(executor)); + } + + ++ref_count_; + executor_initialized_ = true; + GELOGI("Initializing NodeExecutors successfully."); + return SUCCESS; +} + +void NodeExecutorManager::FinalizeExecutors() { + std::lock_guard lk(mu_); + if (!executor_initialized_) { + GELOGD("No need for finalizing for not initialized."); + return; + } + + if (--ref_count_ > 0) { + GELOGD("Ref count = %d, do not finalize executors.", ref_count_); + return; + } + + GELOGD("Start to invoke Finalize on executors."); + for (auto &it : executors_) { + it.second->Finalize(); + } + executors_.clear(); + executor_initialized_ = false; + GELOGD("Done invoking Finalize successfully."); +} + NodeExecutorRegistrar::NodeExecutorRegistrar(NodeExecutorManager::ExecutorType executor_type, NodeExecutor *(*builder)()) { NodeExecutorManager::GetInstance().RegisterExecutorBuilder(executor_type, builder); diff --git a/src/ge/hybrid/node_executor/node_executor.h b/src/ge/hybrid/node_executor/node_executor.h index 613c0bb1..cc456fa3 100644 --- a/src/ge/hybrid/node_executor/node_executor.h +++ b/src/ge/hybrid/node_executor/node_executor.h @@ -14,70 +14,182 @@ * limitations under the License. */ -#ifndef GE_HYBRID_KERNEL_NODE_EXECUTOR_H_ -#define GE_HYBRID_KERNEL_NODE_EXECUTOR_H_ +#ifndef GE_HYBRID_NODE_EXECUTOR_NODE_EXECUTOR_H_ +#define GE_HYBRID_NODE_EXECUTOR_NODE_EXECUTOR_H_ #include "external/ge/ge_api_error_codes.h" #include "common/opskernel/ops_kernel_info_store.h" #include "graph/node.h" -#include "proto/task.pb.h" #include "task_context.h" namespace ge { +const uint32_t MEMORY_ALIGN_RATIO = 2; +const uint32_t MEMORY_ALIGN_SIZE = 32; namespace hybrid { class HybridModel; - +// Base class of Node Task class NodeTask { public: NodeTask() = default; virtual ~NodeTask() = default; + + /** + * Update tiling data + * @param context instance of TaskContext + * @return SUCCESS on success, error code otherwise + */ + virtual Status UpdateTilingData(TaskContext &context) { return SUCCESS; } + + /** + * Init + * @param context instance of TaskContext + * @return SUCCESS on success, error code otherwise + */ + virtual Status Init(TaskContext &context) { return SUCCESS; } + + /** + * Whether this task supports dynamic shape + * @return true if this task supports dynamic shape, false otherwise + */ + virtual bool IsSupportDynamicShape() { return true; } + + /** + * Update args for execution + * @param context instance of TaskContext + * @return SUCCESS on success, error code otherwise + */ virtual Status UpdateArgs(TaskContext &context) = 0; + + /** + * Execute task async + * @param context instance of TaskContext + * @param done_callback callback function, will be invoked after task is done + * @return SUCCESS on success, error code otherwise + */ virtual Status ExecuteAsync(TaskContext &context, std::function done_callback) = 0; - virtual Status Init(TaskContext &context) { return SUCCESS; } }; +// Node executor class NodeExecutor { public: NodeExecutor() = default; virtual ~NodeExecutor() = default; + /** + * Initialize node executor + * @return SUCCESS on success, error code otherwise + */ virtual Status Initialize() { return SUCCESS; } + /** + * Finalize node executor + * @return SUCCESS on success, error code otherwise + */ virtual Status Finalize() { return SUCCESS; } + /** + * Load task in load stage + * @param model instance of HybridModel + * @param node node + * @param task generated node task + * @return SUCCESS on success, error code otherwise + */ virtual Status LoadTask(const HybridModel &model, const NodePtr &node, std::shared_ptr &task) const; + /** + * Compile task in run stage + * @param model instance of HybridModel + * @param node node + * @param task generated node task + * @return SUCCESS on success, error code otherwise + */ virtual Status CompileTask(const HybridModel &model, const NodePtr &node, std::shared_ptr &task) const; + /** + * Preparation actions before execution + * @param task instance of NodeTask + * @param context instance of TaskContext + * @return SUCCESS on success, error code otherwise + */ virtual Status PrepareTask(NodeTask &task, TaskContext &context) const; + + /** + * Execute task + * @param task instance of NodeTask + * @param context instance of TaskContext + * @param callback callback function which will be invoked after computation is done + * @return SUCCESS on success, error code otherwise + */ virtual Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function &callback) const; }; class NodeExecutorManager { public: - enum class ExecutorType { AICORE, GE_LOCAL, AICPU_TF, AICPU_CUSTOM, COMPILED_SUBGRAPH, HCCL, RESERVED }; + enum class ExecutorType { + AICORE, + AICPU_TF, + AICPU_CUSTOM, + COMPILED_SUBGRAPH, + DYNAMIC_SUBGRAPH, + GE_LOCAL, + CONTROL_OP, + HCCL, + RESERVED + }; static NodeExecutorManager &GetInstance() { static NodeExecutorManager instance; return instance; } - Status CalcOpRunningParam(Node &node) const; - + /** + * Register build of executor + * @param executor_type type of executor + * @param builder build function + */ void RegisterExecutorBuilder(ExecutorType executor_type, const std::function &builder); + /** + * Initialize executor if needed + * @return SUCCESS on success, error code otherwise + */ Status EnsureInitialized(); + Status InitializeExecutors(); + + void FinalizeExecutors(); + + /** + * CalcOpRunningParam + * @param node node + * @return SUCCESS on success, error code otherwise + */ + Status CalcOpRunningParam(Node &node) const; + + /** + * Get executor by node + * @param node node + * @param executor executor + * @return SUCCESS on success, error code otherwise + */ Status GetExecutor(Node &node, const NodeExecutor **executor) const; + /** + * Resolve executor type by node + * @param node node + * @return executor type + */ ExecutorType ResolveExecutorType(Node &node) const; + private: std::map> executors_; std::map> builders_; std::map> kernel_stores_; std::map engine_mapping_; std::mutex mu_; bool initialized_ = false; + bool executor_initialized_ = false; + int ref_count_ = 0; }; class NodeExecutorRegistrar { @@ -99,4 +211,4 @@ class NodeExecutorRegistrar { ::ge::hybrid::NodeExecutorRegistrar( \ engine_type, []() -> ::ge::hybrid::NodeExecutor * { return new (std::nothrow) executor(); }) -#endif // GE_HYBRID_KERNEL_NODE_EXECUTOR_H_ +#endif // GE_HYBRID_NODE_EXECUTOR_NODE_EXECUTOR_H_ diff --git a/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc b/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc new file mode 100644 index 00000000..cda9a275 --- /dev/null +++ b/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "partitioned_call_node_executor.h" +#include "graph/utils/node_utils.h" + +namespace ge { +namespace hybrid { +REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::DYNAMIC_SUBGRAPH, PartitionedCallNodeExecutor); + +PartitionedCallNodeTask::PartitionedCallNodeTask(const GraphItem *graph_item) : graph_item_(graph_item) {} + +PartitionedCallNodeTask::~PartitionedCallNodeTask() { + GELOGD("[%s] PartitionedCallNodeTask destroyed.", graph_item_->GetName().c_str()); +} + +Status PartitionedCallNodeTask::Init(TaskContext &context) { + auto execution_context = const_cast(context.GetExecutionContext()); + subgraph_executor_.reset(new (std::nothrow) SubgraphExecutor(graph_item_, execution_context)); + GE_CHECK_NOTNULL(subgraph_executor_); + return SUCCESS; +} + +Status PartitionedCallNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + GE_CHK_STATUS_RET(subgraph_executor_->ExecuteAsync(context), "[%s] Failed to set inputs", + graph_item_->GetName().c_str()); + + auto callback = [=]() { Callback(done_callback); }; + + GE_CHK_STATUS_RET(context.RegisterCallback(callback), "[%s] Failed to register callback", + graph_item_->GetName().c_str()); + GELOGD("[%s] Done executing subgraph successfully.", graph_item_->GetName().c_str()); + return SUCCESS; +} + +Status PartitionedCallNodeTask::Callback(const std::function &done_callback) { + GELOGD("[%s] On subgraph callback", graph_item_->GetName().c_str()); + if (done_callback != nullptr) { + done_callback(); + } + + GELOGD("[%s] To release sub graph tensors.", graph_item_->GetName().c_str()); + subgraph_executor_.reset(); + GELOGD("[%s] Done releasing sub graph tensors.", graph_item_->GetName().c_str()); + return SUCCESS; +} + +Status PartitionedCallNodeTask::UpdateArgs(TaskContext &context) { return SUCCESS; } + +Status PartitionedCallNodeExecutor::LoadTask(const ge::hybrid::HybridModel &model, const ge::NodePtr &node, + std::shared_ptr &task) const { + GELOGD("Load dynamic partitioned call: [%s]", node->GetName().c_str()); + auto subgraph = NodeUtils::GetSubgraph(*node, 0); + GE_CHECK_NOTNULL(subgraph); + auto partitioned_call = model.GetSubgraphItem(subgraph); + GE_CHECK_NOTNULL(partitioned_call); + task.reset(new (std::nothrow) PartitionedCallNodeTask(partitioned_call)); + GE_CHECK_NOTNULL(task); + GELOGD("Done loading dynamic partitioned call: [%s]", node->GetName().c_str()); + return SUCCESS; +} + +Status PartitionedCallNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { + GE_CHK_STATUS_RET(task.Init(context), "[%s] Failed to init task.", context.GetNodeName()); + return SUCCESS; +} +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h b/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h new file mode 100644 index 00000000..fd87d6c1 --- /dev/null +++ b/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h @@ -0,0 +1,54 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_NODE_EXECUTOR_SUBGRAPH_SUBGRAPH_EXECUTOR_H_ +#define GE_HYBRID_NODE_EXECUTOR_SUBGRAPH_SUBGRAPH_EXECUTOR_H_ + +#include "hybrid/node_executor/node_executor.h" +#include "hybrid/model/hybrid_model.h" +#include "hybrid/executor/node_state.h" +#include "hybrid/executor/subgraph_executor.h" +#include "common/thread_pool.h" + +namespace ge { +namespace hybrid { +class PartitionedCallNodeTask : public NodeTask { + public: + explicit PartitionedCallNodeTask(const GraphItem *graph_item); + ~PartitionedCallNodeTask() override; + + Status Init(TaskContext &context) override; + + Status UpdateArgs(TaskContext &context) override; + + Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + + private: + Status Callback(const std::function &done_callback); + + const GraphItem *graph_item_; + std::unique_ptr subgraph_executor_; + GraphExecutionContext *context_ = nullptr; +}; + +class PartitionedCallNodeExecutor : public NodeExecutor { + public: + Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const override; + Status PrepareTask(NodeTask &task, TaskContext &context) const override; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_NODE_EXECUTOR_SUBGRAPH_SUBGRAPH_EXECUTOR_H_ diff --git a/src/ge/hybrid/node_executor/task_context.cc b/src/ge/hybrid/node_executor/task_context.cc index 42c653be..ee35bffa 100644 --- a/src/ge/hybrid/node_executor/task_context.cc +++ b/src/ge/hybrid/node_executor/task_context.cc @@ -19,12 +19,16 @@ #include "framework/common/debug/log.h" #include "graph/utils/tensor_utils.h" #include "hybrid/executor/hybrid_execution_context.h" +#include "hybrid/executor/subgraph_executor.h" namespace ge { namespace hybrid { -TaskContext::TaskContext(GraphExecutionContext *execution_context) : execution_context_(execution_context) {} +TaskContext::TaskContext(GraphExecutionContext *execution_context, const NodeItem *node_item, + SubgraphContext *subgraph_context) + : node_item_(node_item), execution_context_(execution_context), subgraph_context_(subgraph_context) {} + TaskContext::~TaskContext() { - GELOGD("To execute ~TaskContext(). node = %s", node_item_->NodeName().c_str()); + GELOGD("[%s] TaskContext destroyed.", node_item_->NodeName().c_str()); for (auto ws_addr : workspaces_) { execution_context_->allocator->Deallocate(ws_addr); } @@ -38,19 +42,28 @@ TaskContext::~TaskContext() { } } -std::unique_ptr TaskContext::Create(const NodeItem &node_item, GraphExecutionContext *graph_context) { - GELOGI("To create task context for node %s, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d", +std::unique_ptr TaskContext::Create(const NodeItem &node_item, GraphExecutionContext *execution_context, + SubgraphContext *subgraph_context) { + GELOGI("[%s] To create task context, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d.", node_item.NodeName().c_str(), node_item.input_start, node_item.num_inputs, node_item.output_start, node_item.num_outputs); - auto task_context = std::unique_ptr(new (std::nothrow) TaskContext(graph_context)); + if (node_item.input_start < 0 || node_item.output_start < 0) { + GELOGE(INTERNAL_ERROR, "NodeItem not property initialized. input_start = %d, output_start = %d", + node_item.input_start, node_item.output_start); + return nullptr; + } + + auto task_context = + std::unique_ptr(new (std::nothrow) TaskContext(execution_context, &node_item, subgraph_context)); if (task_context == nullptr) { - GELOGE(MEMALLOC_FAILED, "Failed to create instance of TaskContext. node = %s", node_item.NodeName().c_str()); + GELOGE(MEMALLOC_FAILED, "[%s] Failed to create instance of TaskContext.", node_item.NodeName().c_str()); return nullptr; } task_context->node_item_ = &node_item; - task_context->inputs_start_ = graph_context->all_inputs.data() + node_item.input_start; - task_context->outputs_start_ = graph_context->all_outputs.data() + node_item.output_start; + task_context->inputs_start_ = subgraph_context->all_inputs_.data() + node_item.input_start; + task_context->outputs_start_ = subgraph_context->all_outputs_.data() + node_item.output_start; + task_context->iteration_ = execution_context->iteration; return task_context; } @@ -59,7 +72,7 @@ int TaskContext::NumInputs() const { return node_item_->num_inputs; } int TaskContext::NumOutputs() const { return node_item_->num_outputs; } TensorValue *TaskContext::MutableInput(int index) { - if (index < 0 || index > node_item_->num_inputs) { + if (index < 0 || index >= node_item_->num_inputs) { GELOGE(PARAM_INVALID, "Index out of range. index = %d, num_inputs = %d", index, node_item_->num_inputs); return nullptr; } @@ -68,7 +81,7 @@ TensorValue *TaskContext::MutableInput(int index) { } const TensorValue *TaskContext::GetOutput(int index) const { - if (index < 0 || index > node_item_->num_outputs) { + if (index < 0 || index >= node_item_->num_outputs) { GELOGE(PARAM_INVALID, "Index out of range. index = %d, num_outputs = %d", index, node_item_->num_outputs); return nullptr; } @@ -77,7 +90,7 @@ const TensorValue *TaskContext::GetOutput(int index) const { } TensorValue *TaskContext::MutableOutput(int index) { - if (index < 0 || index > node_item_->num_outputs) { + if (index < 0 || index >= node_item_->num_outputs) { GELOGE(PARAM_INVALID, "Index out of range. index = %d, num_outputs = %d", index, node_item_->num_outputs); return nullptr; } @@ -97,7 +110,7 @@ void *TaskContext::MutableWorkspace(int index) { } const TensorValue *TaskContext::GetInput(int index) const { - if (index < 0 || index > node_item_->num_inputs) { + if (index < 0 || index >= node_item_->num_inputs) { GELOGE(PARAM_INVALID, "Index out of range. index = %d, num_inputs = %d", index, node_item_->num_inputs); return nullptr; } @@ -120,7 +133,14 @@ Status TaskContext::AllocateWorkspaces() { } Status TaskContext::RegisterCallback(const std::function &callback_fun) const { - return execution_context_->callback_manager->RegisterCallback(callback_fun); + auto ret = execution_context_->callback_manager->RegisterCallback(callback_fun); + if (ret != SUCCESS) { + GELOGE(ret, "[%s] Failed to register callback", GetNodeName()); + execution_context_->callback_manager->Destroy(); + return ret; + } + + return SUCCESS; } string TaskContext::TensorDesc2String(const GeTensorDesc &desc) { @@ -137,7 +157,7 @@ string TaskContext::TensorDesc2String(const GeTensorDesc &desc) { return ss.str(); } -Status TaskContext::AllocateTensor(const GeTensorDesc &tensor_desc, TensorValue &tensor) { +Status TaskContext::AllocateTensor(const GeTensorDesc &tensor_desc, TensorValue &tensor, AllocationAttr *attr) { int64_t size = 0; if (ge::TensorUtils::GetSize(tensor_desc, size) != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "Failed to get tensor size"); @@ -148,13 +168,14 @@ Status TaskContext::AllocateTensor(const GeTensorDesc &tensor_desc, TensorValue GELOGW("size from tensor_desc == 0"); } - auto buffer = TensorBuffer::Create(execution_context_->allocator, size); + auto buffer = TensorBuffer::Create(execution_context_->allocator, size, attr); GE_CHECK_NOTNULL(buffer); tensor = TensorValue(shared_ptr(buffer.release())); return SUCCESS; } -Status TaskContext::AllocateOutput(int index, const GeTensorDesc &tensor_desc, TensorValue **tensor) { +Status TaskContext::AllocateOutput(int index, const GeTensorDesc &tensor_desc, TensorValue **tensor, + AllocationAttr *attr) { GELOGI("To allocate output for node: %s. index = %d, tensor desc = %s", node_item_->NodeName().c_str(), index, TensorDesc2String(tensor_desc).c_str()); @@ -178,9 +199,29 @@ Status TaskContext::AllocateOutput(int index, const GeTensorDesc &tensor_desc, T GE_CHECK_NOTNULL(ref_tensor); outputs_start_[index] = *ref_tensor; } else { - GE_CHK_STATUS_RET_NOLOG(AllocateTensor(tensor_desc, outputs_start_[index])); - GELOGD("Allocating output successfully. node: %s. index = %d, size = %zu", node_item_->NodeName().c_str(), index, - outputs_start_[index].GetSize()); + auto reuse_input = node_item_->reuse_inputs.find(index); + if (reuse_input != node_item_->reuse_inputs.end()) { + GELOGD("[%s] Output[%d] is referenced to input[%d]", GetNodeName(), index, reuse_input->second); + outputs_start_[index] = inputs_start_[reuse_input->second]; + } else { + GE_CHK_STATUS_RET_NOLOG(AllocateTensor(tensor_desc, outputs_start_[index], attr)); + GELOGD("Allocating output successfully. node: %s. index = %d, size = %zu", node_item_->NodeName().c_str(), index, + outputs_start_[index].GetSize()); + } + } + + // Temp modification + if (node_item_->node_type == "UnsortedSegmentSum" || node_item_->node_type == "UnsortedSegmentSumD" || + node_item_->node_type == "ScatterNd") { + auto &out_tensor = outputs_start_[index]; + GELOGD("[%s] clear output tensor: %s", GetNodeName(), out_tensor.DebugString().c_str()); + auto *ctx = GetExecutionContext(); + string name = "rtMemsetAsync" + node_item_->node_name; + RegisterCallback([ctx, name]() { RECORD_CALLBACK_EVENT(ctx, name.c_str(), "[Compute] Start"); }); + RECORD_EXECUTION_EVENT(GetExecutionContext(), node_item_->node_name.c_str(), "[rtMemsetAsync] Start"); + GE_CHK_RT_RET(rtMemsetAsync(out_tensor.MutableData(), out_tensor.GetSize(), 0, out_tensor.GetSize(), GetStream())); + RECORD_EXECUTION_EVENT(GetExecutionContext(), node_item_->node_name.c_str(), "[rtMemsetAsync] End"); + RegisterCallback([ctx, name]() { RECORD_CALLBACK_EVENT(ctx, name.c_str(), "[Compute] End"); }); } if (execution_context_->trace_enabled) { @@ -194,11 +235,11 @@ Status TaskContext::AllocateOutput(int index, const GeTensorDesc &tensor_desc, T return SUCCESS; } -Status TaskContext::AllocateOutputs() { +Status TaskContext::AllocateOutputs(AllocationAttr *attr) { for (int i = 0; i < node_item_->num_outputs; ++i) { const auto &output_desc = node_item_->op_desc->MutableOutputDesc(i); GE_CHECK_NOTNULL(output_desc); - GE_CHK_STATUS_RET_NOLOG(AllocateOutput(i, *output_desc, nullptr)); + GE_CHK_STATUS_RET_NOLOG(AllocateOutput(i, *output_desc, nullptr, attr)); } return SUCCESS; @@ -230,7 +271,7 @@ Status TaskContext::SetOutput(int index, const TensorValue &tensor) { rtStream_t TaskContext::GetStream() { return execution_context_->stream; } -int64_t TaskContext::GetSessionId() { return execution_context_->session_id; } +int64_t TaskContext::GetSessionId() const { return execution_context_->session_id; } Status TaskContext::GetStatus() const { return status_; } @@ -238,7 +279,13 @@ void TaskContext::SetStatus(Status status) { status_ = status; } Status TaskContext::AllocateWorkspace(size_t size, void **buffer, void *ori_addr) { GE_CHECK_NOTNULL(buffer); - *buffer = execution_context_->allocator->Allocate(size, ori_addr); + if (ori_addr == nullptr) { + *buffer = execution_context_->allocator->Allocate(size, nullptr); + } else { + AllocationAttr attr(ori_addr); + *buffer = execution_context_->allocator->Allocate(size, &attr); + } + if (*buffer == nullptr) { GELOGE(MEMALLOC_FAILED, "Failed to allocate workspace of size = %zu", size); return MEMALLOC_FAILED; @@ -261,16 +308,21 @@ Status TaskContext::PropagateOutputs() { for (auto &dst_input_index_and_node : output_nodes) { auto dst_input_idx = dst_input_index_and_node.first; auto dst_node_item = dst_input_index_and_node.second; + auto input_offset = dst_node_item->input_start + dst_input_idx; GELOGI( "Propagate output of node %s, output index = %d, dst node = %s, " - "dst_input_index = %d, dst_input_offset = %d, addr = %p", - node_item_->NodeName().c_str(), i, dst_node_item->NodeName().c_str(), dst_input_idx, - dst_node_item->input_start + dst_input_idx, - execution_context_->all_inputs.data() + dst_node_item->input_start + dst_input_idx); - execution_context_->all_inputs[dst_node_item->input_start + dst_input_idx] = *tensor; + "dst_input_index = %d, dst_input_offset = %d.", + node_item_->NodeName().c_str(), i, dst_node_item->NodeName().c_str(), dst_input_idx, input_offset); + + if (subgraph_context_->all_inputs_.size() <= static_cast(input_offset)) { + GELOGE(INTERNAL_ERROR, "[%s] input index out of range. index = %d, total input num = %zu", GetNodeName(), + input_offset, subgraph_context_->all_inputs_.size()); + return INTERNAL_ERROR; + } + + subgraph_context_->all_inputs_[input_offset] = *tensor; if (execution_context_->trace_enabled) { - execution_context_->all_inputs[dst_node_item->input_start + dst_input_idx].SetName(node_item_->NodeName() + - "_in_" + std::to_string(i)); + subgraph_context_->all_inputs_[input_offset].SetName(node_item_->NodeName() + "_in_" + std::to_string(i)); } } } @@ -289,5 +341,37 @@ void TaskContext::ReleaseInput(int index) { GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), index); } } + +ConstGeTensorDescPtr TaskContext::GetOutputDesc(int index) { + return node_item_->op_desc->MutableOutputDesc(static_cast(index)); +} + +ConstGeTensorDescPtr TaskContext::GetInputDesc(int index) { + return node_item_->op_desc->MutableInputDesc(static_cast(index)); +} + +GeTensorDescPtr TaskContext::MutableInputDesc(int index) { + return node_item_->op_desc->MutableInputDesc(static_cast(index)); +} + +GeTensorDescPtr TaskContext::MutableOutputDesc(int index) { + return node_item_->op_desc->MutableOutputDesc(static_cast(index)); +} + +bool TaskContext::IsForceInferShape() const { return force_infer_shape_; } + +void TaskContext::SetForceInferShape(bool force_infer_shape) { force_infer_shape_ = force_infer_shape; } + +void TaskContext::NodeDone() { subgraph_context_->NodeDone(node_item_->node); } + +void TaskContext::OnError(Status error) { subgraph_context_->OnError(error); } + +bool TaskContext::IsTraceEnabled() const { return execution_context_->trace_enabled; } + +TensorValue *TaskContext::GetVariable(const std::string &name) { return execution_context_->model->GetVariable(name); } + +uint64_t TaskContext::GetIterationNumber() const { return iteration_; } + +bool TaskContext::IsDumpEnabled() const { return execution_context_->dump_enabled; } } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/node_executor/task_context.h b/src/ge/hybrid/node_executor/task_context.h index 841dcb17..5c42a347 100644 --- a/src/ge/hybrid/node_executor/task_context.h +++ b/src/ge/hybrid/node_executor/task_context.h @@ -22,16 +22,19 @@ #include #include "external/ge/ge_api_error_codes.h" #include "hybrid/common/tensor_value.h" +#include "hybrid/common/npu_memory_allocator.h" #include "hybrid/executor/rt_callback_manager.h" #include "hybrid/model/node_item.h" namespace ge { namespace hybrid { class GraphExecutionContext; +class SubgraphContext; class TaskContext { public: - static std::unique_ptr Create(const NodeItem &node_item, GraphExecutionContext *graph_context); + static std::unique_ptr Create(const NodeItem &node_item, GraphExecutionContext *execution_context, + SubgraphContext *subgraph_context); ~TaskContext(); @@ -41,19 +44,33 @@ class TaskContext { const NodeItem &GetNodeItem() const; const char *GetNodeName() const; TensorValue *MutableInput(int index); + ConstGeTensorDescPtr GetInputDesc(int index); + ConstGeTensorDescPtr GetOutputDesc(int index); + GeTensorDescPtr MutableInputDesc(int index); + GeTensorDescPtr MutableOutputDesc(int index); void ReleaseInput(int index); const TensorValue *GetInput(int index) const; const TensorValue *GetOutput(int index) const; TensorValue *MutableOutput(int index); + TensorValue *GetVariable(const std::string &name); rtStream_t GetStream(); - int64_t GetSessionId(); + int64_t GetSessionId() const; + uint64_t GetIterationNumber() const; + + void NodeDone(); + void OnError(Status error); Status SetOutput(int index, const TensorValue &tensor); - Status AllocateOutput(int index, const GeTensorDesc &tensor_desc, TensorValue **tensor); - Status AllocateOutputs(); + Status AllocateOutput(int index, const GeTensorDesc &tensor_desc, TensorValue **tensor, + AllocationAttr *attr = nullptr); + Status AllocateOutputs(AllocationAttr *attr = nullptr); Status AllocateWorkspaces(); Status AllocateWorkspace(size_t size, void **buffer, void *ori_addr = nullptr); + bool IsTraceEnabled() const; + + bool IsDumpEnabled() const; + const GraphExecutionContext *GetExecutionContext() { return execution_context_; } Status AllocateTemp(size_t size, TensorValue &tensor); @@ -68,17 +85,25 @@ class TaskContext { void SetStatus(Status status); + bool IsForceInferShape() const; + void SetForceInferShape(bool force_infer_shape); + void *handle_ = nullptr; + private: - explicit TaskContext(GraphExecutionContext *execution_context); - TensorValue *inputs_start_ = nullptr; - TensorValue *outputs_start_ = nullptr; + TaskContext(GraphExecutionContext *execution_context, const NodeItem *node_item, SubgraphContext *subgraph_context); + static string TensorDesc2String(const GeTensorDesc &desc); - Status AllocateTensor(const GeTensorDesc &tensor_desc, TensorValue &tensor); + Status AllocateTensor(const GeTensorDesc &tensor_desc, TensorValue &tensor, AllocationAttr *attr); - GraphExecutionContext *execution_context_; const NodeItem *node_item_ = nullptr; + bool force_infer_shape_ = false; + GraphExecutionContext *execution_context_; + SubgraphContext *subgraph_context_; + TensorValue *inputs_start_ = nullptr; + TensorValue *outputs_start_ = nullptr; Status status_ = SUCCESS; std::vector workspaces_; + uint64_t iteration_ = 0; }; } // namespace hybrid } // namespace ge diff --git a/src/ge/inc/kernel_factory.h b/src/ge/inc/kernel_factory.h index c0624e14..61455836 100644 --- a/src/ge/inc/kernel_factory.h +++ b/src/ge/inc/kernel_factory.h @@ -103,5 +103,5 @@ class KernelFactory { return ptr; \ } \ KernelFactory::Registerar g_##type##_Kernel_Creator(type, Creator_##type##_Kernel) -}; // end namespace ge +} // namespace ge #endif // GE_INC_KERNEL_FACTORY_H_ diff --git a/src/ge/init/gelib.cc b/src/ge/init/gelib.cc index 5fcb0cd7..f7740a3c 100644 --- a/src/ge/init/gelib.cc +++ b/src/ge/init/gelib.cc @@ -37,6 +37,7 @@ #include "graph/load/new_model_manager/model_manager.h" #include "graph/manager/graph_mem_allocator.h" #include "graph/manager/graph_var_manager.h" +#include "graph/common/ge_call_wrapper.h" #include "omm/csa_interact.h" #include "runtime/kernel.h" @@ -46,6 +47,9 @@ namespace ge { namespace { const int kDecimal = 10; const int kSocVersionLen = 50; +const uint32_t kAicoreOverflow = (0x1 << 0); +const uint32_t kAtomicOverflow = (0x1 << 1); +const uint32_t kAllOverflow = (kAicoreOverflow | kAtomicOverflow); } // namespace static std::shared_ptr instancePtr_ = nullptr; @@ -75,7 +79,7 @@ Status GELib::Initialize(const map &options) { instancePtr_ = nullptr; return ret; } - GE_TIMESTAMP_END(Init, "GELib::Initialize"); + GE_TIMESTAMP_EVENT_END(Init, "GELib::Initialize"); return SUCCESS; } @@ -126,16 +130,6 @@ Status GELib::InnerInitialize(const map &options) { return initSmStatus; } - GELOGI("memoryMallocSize initial."); - GE_TIMESTAMP_START(SetMemoryMallocSize); - Status initMemStatus = VarManager::Instance(0)->SetMemoryMallocSize(options); - GE_TIMESTAMP_END(SetMemoryMallocSize, "InnerInitialize::SetMemoryMallocSize"); - if (initMemStatus != SUCCESS) { - GELOGE(initMemStatus, "failed to set malloc size"); - RollbackInit(); - return initMemStatus; - } - GELOGI("Start to initialize HostCpuEngine"); GE_TIMESTAMP_START(HostCpuEngineInitialize); Status initHostCpuEngineStatus = HostCpuEngine::GetInstance().Initialize(); @@ -160,37 +154,6 @@ Status GELib::SystemInitialize(const map &options) { } } - iter = options.find(HEAD_STREAM); - head_stream_ = (iter != options.end()) ? std::strtol(iter->second.c_str(), nullptr, kDecimal) : false; - - iter = options.find(OPTION_EXEC_ENABLE_DUMP); - if (iter != options.end()) { - int32_t enable_dump_flag = 1; - auto path_iter = options.find(OPTION_EXEC_DUMP_PATH); - if (iter->second == std::to_string(enable_dump_flag) && path_iter != options.end()) { - std::string dump_path = path_iter->second; - if (!dump_path.empty() && dump_path[dump_path.size() - 1] != '/') { - dump_path = dump_path + "/" + CurrentTimeInStr() + "/"; - } - - PropertiesManager::Instance().AddDumpPropertyValue(DUMP_ALL_MODEL, {}); - GELOGD("Get dump path %s successfully", dump_path.c_str()); - PropertiesManager::Instance().SetDumpOutputPath(dump_path); - } - auto step_iter = options.find(OPTION_EXEC_DUMP_STEP); - if (step_iter != options.end()) { - std::string dump_step = step_iter->second; - GELOGD("Get dump step %s successfully", dump_step.c_str()); - PropertiesManager::Instance().SetDumpStep(dump_step); - } - auto mode_iter = options.find(OPTION_EXEC_DUMP_MODE); - if (mode_iter != options.end()) { - std::string dump_mode = mode_iter->second; - GELOGD("Get dump mode %s successfully", dump_mode.c_str()); - PropertiesManager::Instance().SetDumpMode(dump_mode); - } - } - // In train and infer, profiling is always needed. InitOptions(options); InitProfiling(this->options_); diff --git a/src/ge/init/gelib.h b/src/ge/init/gelib.h index 0dfec391..b5621dfd 100644 --- a/src/ge/init/gelib.h +++ b/src/ge/init/gelib.h @@ -62,9 +62,6 @@ class GELib { // get TrainMode flag bool isTrainMode() { return is_train_mode_; } - // add head stream to model - bool HeadStream() const { return head_stream_; } - // get incre build flag bool IsIncreBuild() const { return is_incre_build_; } @@ -86,6 +83,8 @@ class GELib { Status SetRTSocVersion(const map &options, map &new_options); void RollbackInit(); void InitOptions(const map &options); + void SetDumpModelOptions(const map &options); + void SetOpDebugOptions(const map &options); DNNEngineManager engineManager_; OpsKernelManager opsManager_; @@ -98,7 +97,6 @@ class GELib { bool is_shutdown = false; bool is_use_hcom = false; bool is_incre_build_ = false; - bool head_stream_ = false; std::string incre_build_cache_path_; }; } // namespace ge diff --git a/src/ge/ir_build/atc_ir_common.cc b/src/ge/ir_build/atc_ir_common.cc index 12c85bc0..352e5dc2 100644 --- a/src/ge/ir_build/atc_ir_common.cc +++ b/src/ge/ir_build/atc_ir_common.cc @@ -32,8 +32,29 @@ const int64_t kDynamicImageSizeNum = 2; // datatype/formats from user to GE, Unified to util interface file later const std::map kOutputTypeSupportDatatype = { {"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}}; -const std::set kBufferOptimizeSupportOption = {"l2_optimize", "off_optimize"}; +const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8"; +const std::set kBufferOptimizeSupportOption = {"l1_optimize", "l2_optimize", "off_optimize", + "l1_and_l2_optimize"}; +// The function is incomplete. Currently, only l2_optimize, off_optimize is supported. +const char *const kBufferOptimizeSupport = "only support l2_optimize, off_optimize"; const std::string IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT = "high_performance"; +const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\""; +const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\""; +const char *const kSplitError1 = "size not equal to 2 split by \":\""; +const char *const kEmptyError = "can not be empty"; +const char *const kFloatNumError = "exist float number"; +const char *const kDigitError = "is not digit"; +const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]"; + +vector SplitInputShape(const std::string &input_shape) { + vector shape_pair_vec; + size_t pos = input_shape.rfind(":"); + if (pos != std::string::npos) { + shape_pair_vec.emplace_back(input_shape.substr(0, pos)); + shape_pair_vec.emplace_back(input_shape.substr(pos + 1, input_shape.size() - pos)); + } + return shape_pair_vec; +} } // namespace bool CheckDynamicBatchSizeInputShapeValid(unordered_map> shape_map, @@ -42,7 +63,7 @@ bool CheckDynamicBatchSizeInputShapeValid(unordered_map> for (auto iter = shape_map.begin(); iter != shape_map.end(); ++iter) { vector shape = iter->second; if (shape.size() < 1) { - ErrorManager::GetInstance().ATCReportErrMessage("E10017"); + ErrorManager::GetInstance().ATCReportErrMessage("E10012"); GELOGE(ge::PARAM_INVALID, "--input_shape's shape size can not be less than 1 when set --dynamic_batch_size."); return false; } @@ -61,14 +82,14 @@ bool CheckDynamicBatchSizeInputShapeValid(unordered_map> } if (size == 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E10043"); + ErrorManager::GetInstance().ATCReportErrMessage("E10031"); GELOGE(ge::PARAM_INVALID, "At least one batch n must be equal to -1 when set --dynamic_batch_size."); return false; } for (char c : dynamic_batch_size) { if (!isdigit(c) && (c != ',') && (c != ' ')) { - ErrorManager::GetInstance().ATCReportErrMessage("E10047", {"value"}, {dynamic_batch_size}); + ErrorManager::GetInstance().ATCReportErrMessage("E10033", {"value"}, {dynamic_batch_size}); GELOGE(ge::PARAM_INVALID, "Input parameter[--dynamic_batch_size]'s value[%s] is invalid.", dynamic_batch_size.c_str()); return false; @@ -169,7 +190,7 @@ Status CheckDynamicBatchSizeOrImageSizeParamValid(std::string &dynamic_batch_siz vector>> user_shape_map; is_dynamic_input = true; if (input_shape.empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"input_shape"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"input_shape"}); GELOGE(ge::PARAM_INVALID, "The input_shape can not be empty in dynamic batchsize scenario."); return ge::PARAM_INVALID; } @@ -200,21 +221,19 @@ bool ParseInputShape(const string &input_shape, unordered_map shape_vec = StringUtils::Split(input_shape, ';'); const int DEFAULT_SHAPE_PAIR_SIZE = 2; for (const auto &shape : shape_vec) { - vector shape_pair_vec = StringUtils::Split(shape, ':'); + vector shape_pair_vec = SplitInputShape(shape); if (shape_pair_vec.size() != DEFAULT_SHAPE_PAIR_SIZE) { - ErrorManager::GetInstance().ATCReportErrMessage("E10010", {"shape"}, {shape}); - GELOGW( - "Input parameter[--input_shape]’s shape is [%s], " - "correct sample is input_name1:n1,c1,h1,w1", - shape.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, + {shape, kSplitError1, kInputShapeSample1}); + GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", + shape.c_str(), kSplitError1, kInputShapeSample1); return false; } if (shape_pair_vec[1].empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10011", {"shape"}, {shape}); - GELOGW( - "Input parameter[--input_shape]’s shape is [%s], can not empty, " - "correct sample is input_name1:n1,c1,h1,w1", - shape.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, + {shape, kEmptyError, kInputShapeSample1}); + GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", + shape.c_str(), kEmptyError, kInputShapeSample1); return false; } @@ -223,34 +242,48 @@ bool ParseInputShape(const string &input_shape, unordered_map caffe_support_input_format = {"NCHW", "ND"}; static std::set tf_support_input_format = {"NCHW", "NHWC", "ND", "NCDHW", "NDHWC"}; static std::set onnx_support_input_format = {"NCHW", "ND"}; +static const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model"; +static const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model"; +static const char *const kONNXFormatSupport = "only support NCHW, ND in ONNX model"; static std::map input_format_str_to_geformat = { {"ND", domi::DOMI_TENSOR_ND}, {"NCHW", domi::DOMI_TENSOR_NCHW}, {"NHWC", domi::DOMI_TENSOR_NHWC}, diff --git a/src/ge/ir_build/ge_ir_build.cc b/src/ge/ir_build/ge_ir_build.cc index 0be75b51..a64591da 100644 --- a/src/ge/ir_build/ge_ir_build.cc +++ b/src/ge/ir_build/ge_ir_build.cc @@ -296,7 +296,6 @@ graphStatus Impl::BuildModel(const Graph &graph, const std::map(model.data.get()), static_cast(model.length)); } - } // namespace ge diff --git a/src/ge/model/ge_model.h b/src/ge/model/ge_model.h index 6305211a..be4b65bc 100644 --- a/src/ge/model/ge_model.h +++ b/src/ge/model/ge_model.h @@ -87,6 +87,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeModel : public AttrHolder uint8_t platform_type_ = {0}; uint32_t model_id_ = INVALID_MODEL_ID; }; -}; // namespace ge +} // namespace ge using GeModelPtr = std::shared_ptr; #endif // GE_MODEL_GE_MODEL_H_ diff --git a/src/ge/offline/main.cc b/src/ge/offline/main.cc index f77f006d..fad4134c 100644 --- a/src/ge/offline/main.cc +++ b/src/ge/offline/main.cc @@ -66,6 +66,10 @@ static bool is_dynamic_input = false; // 310 limited 8G size const char *const kGraphMemoryManagerMallocMaxSize = "8*1024*1024*1024"; +const char *const kModeSupport = + "only support 0(model to framework model), " + "1(framework model to json), 3(only pre-check), 5(pbtxt to json)"; +const char *const kModelToJsonSupport = "only support 0(Caffe) 3(TensorFlow)"; DEFINE_string(model, "", "The model file."); DEFINE_string(output, "", "The output file path&name."); @@ -138,10 +142,6 @@ DEFINE_string(optypelist_for_implmode, "", "Optional; Nodes need use implmode selected in op_select_implmode " "Format:\"node_name1,node_name2\""); -DEFINE_string(head_stream, "0", - "Optional; Is need head stream, default is not need." - "Format: \"0: no head stream; 1: add head stream;\""); - DEFINE_string(singleop, "", "Optional; If set, generate single op model with the given json file."); DEFINE_int32(disable_reuse_memory, 0, "Optional; If set to 1, disable reuse memory when generating if."); @@ -173,7 +173,8 @@ DEFINE_string(dynamic_image_size, "", DEFINE_string(enable_small_channel, "0", "Optional; If set to 1, small channel is enabled."); -DEFINE_bool(enable_compress_weight, false, "Optional; enable compress weight. true: enable; false(default): disable"); +DEFINE_string(enable_compress_weight, "false", + "Optional; enable compress weight. true: enable; false(default): disable"); DEFINE_string(compress_weight_conf, "", "Optional; the config file to compress weight"); @@ -183,6 +184,10 @@ DEFINE_string(log, "default", "Optional; generate atc log. Support debug, info, DEFINE_string(dump_mode, "0", "Optional; generate infershape json,only support 1 , 0."); +DEFINE_int32(op_debug_level, 0, + "Optional; configure debug level of compiler. 0(default): close debug;" + "1: open TBE compiler, export ccec file and TBE instruction mapping file; 2: open ccec compiler"); + class GFlagUtils { public: /** @@ -235,7 +240,7 @@ class GFlagUtils { "\"check_result.json\"\n" " --disable_reuse_memory The switch of reuse memory. Default value is : 0." "0 means reuse memory, 1 means do not reuse memory.\n" - " --input_fp16_nodes Input node datatype is fp16 and format is NCHW. Separate multiple nodes with semicolons " + " --input_fp16_nodes Input node datatype is fp16. Separate multiple nodes with semicolons " "(;)." "Use double quotation marks (\") to enclose each argument." "E.g.: \"node_name1;node_name2\"\n" @@ -255,7 +260,6 @@ class GFlagUtils { " --optypelist_for_implmode Appoint which op to use op_select_implmode, used with op_select_implmode ." "Separate multiple nodes with commas (,). Use double quotation marks (\") to enclose each argument." "E.g.: \"node_name1,node_name2\"\n" - " --head_stream Add head stream. 0(default): disable; 1: enable\n" " --soc_version The soc version. E.g.: \"Ascend310\"\n" " --core_type Set core type AiCore or VectorCore. VectorCore: use vector core. " "Default value is: AiCore\n" @@ -283,7 +287,7 @@ class GFlagUtils { static Status CheckDumpInfershapeJsonFlags() { Status ret = CheckFrameWorkValid(FLAGS_framework, FLAGS_weight); GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "check custom aicpu run so failed!"); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "weight"), + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"), return domi::FAILED, "Input parameter[--weight]'s value[%s] is invalid!", FLAGS_weight.c_str()); return domi::SUCCESS; @@ -292,7 +296,7 @@ class GFlagUtils { static Status CheckFlags() { // No model file information passed in GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_model == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"model"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"model"}); return domi::PARAM_INVALID, "Input parameter[--model]'s value is empty!"); // check param disable_reuse_memory GE_CHK_BOOL_EXEC(ge::CheckDisableReuseMemoryParamValid(to_string(FLAGS_disable_reuse_memory)) == ge::SUCCESS, @@ -304,7 +308,7 @@ class GFlagUtils { return ge::FAILED, "check optypelist_for_implmode and op_select_implmode failed!"); // No output file information passed in GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_mode == GEN_OM_MODEL && FLAGS_output == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"output"}); return domi::PARAM_INVALID, "Input parameter[--output]'s value is empty!"); Status ret = CheckFrameWorkValid(FLAGS_framework, FLAGS_weight); @@ -323,16 +327,16 @@ class GFlagUtils { GELOGI("domi will run with encrypt!"); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_encrypt_key), return domi::FAILED, - "encrypt_key file %s not found!!", FLAGS_encrypt_key.c_str()); + "encrypt_key file not found!!"); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_certificate), return domi::FAILED, - "certificate file %s not found!!", FLAGS_certificate.c_str()); + "certificate file not found!!"); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_hardware_key), return domi::FAILED, - "hardware_key file %s not found!!", FLAGS_hardware_key.c_str()); + "hardware_key file not found!!"); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_private_key), return domi::FAILED, - "private_key file %s not found!!", FLAGS_private_key.c_str()); + "private_key file not found!!"); } else { // No encryption GELOGI("domi will run without encrypt!"); } @@ -341,43 +345,37 @@ class GFlagUtils { /** * Check the validity of the I / O file path */ - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_model, "model"), return domi::FAILED, + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_model, "--model"), return domi::FAILED, "model file %s not found!!", FLAGS_model.c_str()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "weight"), + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"), return domi::FAILED, "weight file %s not found!!", FLAGS_weight.c_str()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_cal_conf != "" && !ge::CheckInputPathValid(FLAGS_cal_conf, "cal_conf"), + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_cal_conf != "" && !ge::CheckInputPathValid(FLAGS_cal_conf, "--cal_conf"), return domi::FAILED, "calibration config file %s not found!!", FLAGS_cal_conf.c_str()); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_op_name_map != "" && !ge::CheckInputPathValid(FLAGS_op_name_map, "op_name_map"), return domi::FAILED, + FLAGS_op_name_map != "" && !ge::CheckInputPathValid(FLAGS_op_name_map, "--op_name_map"), return domi::FAILED, "op config file %s not found!!", FLAGS_op_name_map.c_str()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_head_stream != "" && FLAGS_head_stream != "0" && FLAGS_head_stream != "1", - ErrorManager::GetInstance().ATCReportErrMessage("E10006", {"parameter"}, {"head_stream"}); - return domi::FAILED, "Input parameter[--head_stream] must be 0 or 1!!"); - GE_CHK_BOOL_EXEC(ge::CheckInsertOpConfParamValid(std::string(FLAGS_insert_op_conf)) == ge::SUCCESS, return ge::FAILED, "check insert op conf failed!"); GE_CHK_BOOL_EXEC( - ge::CheckCompressWeightParamValid(FLAGS_enable_compress_weight ? std::string("true") : std::string("false"), - FLAGS_compress_weight_conf) == ge::SUCCESS, + ge::CheckCompressWeightParamValid(FLAGS_enable_compress_weight, FLAGS_compress_weight_conf) == ge::SUCCESS, return ge::FAILED, "check compress weight failed!"); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckOutputPathValid(FLAGS_check_report, "check_report"), return domi::FAILED, + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckOutputPathValid(FLAGS_check_report, "--check_report"), return domi::FAILED, "check_report file %s not found!!", FLAGS_check_report.c_str()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_mode == GEN_OM_MODEL && (!ge::CheckOutputPathValid(FLAGS_output) || !CheckPathWithName(FLAGS_output)), - return domi::FAILED, "output path %s is not valid!!", FLAGS_output.c_str()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_mode == GEN_OM_MODEL && (!ge::CheckOutputPathValid(FLAGS_output, "--output") || + !CheckPathWithName(FLAGS_output)), + return domi::FAILED, "output path %s is not valid!!", FLAGS_output.c_str()); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( FLAGS_save_original_model != "" && FLAGS_save_original_model != "true" && FLAGS_save_original_model != "false", - ErrorManager::GetInstance().ATCReportErrMessage("E10033", {"parameter", "value"}, + ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"parameter", "value"}, {"save_original_model", FLAGS_save_original_model}); return domi::FAILED, "Input parameter[--save_original_model]'s value[%s] must be true or false.", FLAGS_save_original_model.c_str()); @@ -398,18 +396,18 @@ class GFlagUtils { static Status CheckConverJsonParamFlags() { // No model path passed in GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_om == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"om"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"om"}); return domi::PARAM_INVALID, "Input parameter[--om]'s value is empty!!"); // JSON path not passed in GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_json == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"json"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"json"}); return domi::PARAM_INVALID, "Input parameter[--json]'s value is empty!!"); // Check if the model path is valid - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_om, "om"), return domi::PARAM_INVALID, + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_om, "--om"), return domi::PARAM_INVALID, "model file path is invalid: %s.", FLAGS_om.c_str()); // Check whether the JSON path is valid - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckOutputPathValid(FLAGS_json, "om"), return domi::PARAM_INVALID, + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckOutputPathValid(FLAGS_json, "--json"), return domi::PARAM_INVALID, "json file path is invalid: %s.", FLAGS_json.c_str()); return domi::SUCCESS; @@ -446,7 +444,8 @@ class GFlagUtils { if (framework != (int32_t)domi::CAFFE && framework != (int32_t)domi::TENSORFLOW && framework != (int32_t)domi::MINDSPORE && framework != (int32_t)domi::ONNX) { // No framework information was passed in or the entered framework is illegal - ErrorManager::GetInstance().ATCReportErrMessage("E10007", {"parameter"}, {"framework"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10007", {"parameter", "support"}, + {"framework", "0(Caffe) or 1(MindSpore) or 3(TensorFlow)"}); DOMI_LOGE( "Input parameter[--framework] is mandatory and it's value must be: " "0(Caffe) or 1(MindSpore) or 3(TensorFlow)."); @@ -519,31 +518,29 @@ static bool CheckInputFormat() { if (ge::caffe_support_input_format.find(FLAGS_input_format) != ge::caffe_support_input_format.end()) { return true; } - ErrorManager::GetInstance().ATCReportErrMessage("E10031", {"value"}, {FLAGS_input_format}); // only support NCHW ND - GELOGE(ge::FAILED, - "Input parameter[--input_format]'s value[%s] is wrong, " - "only support NCHW, ND in Caffe model.", - FLAGS_input_format.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--input_format", FLAGS_input_format, ge::kCaffeFormatSupport}); + GELOGE(ge::FAILED, "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), + ge::kCaffeFormatSupport); return false; } else if ((FLAGS_framework == static_cast(domi::TENSORFLOW))) { // tf if (ge::tf_support_input_format.find(FLAGS_input_format) != ge::tf_support_input_format.end()) { return true; } - ErrorManager::GetInstance().ATCReportErrMessage("E10032", {"value"}, {FLAGS_input_format}); // only support NCHW NHWC ND NCDHW NDHWC - GELOGE(ge::FAILED, - "Input parameter[--input_format]'s value[%s] is wrong, " - "only support NCHW, NHWC, ND, NCDHW, NDHWC in tf model", - FLAGS_input_format.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--input_format", FLAGS_input_format, ge::kTFFormatSupport}); + GELOGE(ge::FAILED, "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), ge::kTFFormatSupport); return false; } else if (FLAGS_framework == static_cast(domi::ONNX)) { if (ge::onnx_support_input_format.find(FLAGS_input_format) != ge::onnx_support_input_format.end()) { return true; } // only support NCHW ND - GELOGE(ge::FAILED, "Input parameter[--input_format]'s value[%s] is error, Only support NCHW, ND in onnx model", - FLAGS_input_format.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--input_format", FLAGS_input_format, ge::kONNXFormatSupport}); + GELOGE(ge::FAILED, "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), ge::kONNXFormatSupport); return false; } return true; @@ -625,8 +622,7 @@ void LoadModelParserLib(std::string caffe_parser_path) { return; } -void LoadCustomOpLib() { - OpRegistry::Instance()->registrationDatas.clear(); +void LoadCustomOpLib(bool need_load_ops_plugin) { std::string plugin_path; GetCustomOpPath(plugin_path); @@ -642,7 +638,11 @@ void LoadCustomOpLib() { } LoadModelParserLib(caffe_parser_path); - + if (!need_load_ops_plugin) { + GELOGI("No need to load ops plugin so."); + return; + } + OpRegistry::Instance()->registrationDatas.clear(); // load other so files except lib_caffe_parser.so in the plugin so path for (auto elem : fileList) { ge::StringUtils::Trim(elem); @@ -657,17 +657,21 @@ void LoadCustomOpLib() { std::vector registrationDatas = OpRegistry::Instance()->registrationDatas; for (OpRegistrationData reg_data : registrationDatas) { - bool ret = ge::OpRegistrationTbe::Instance()->Finalize(reg_data); - if (ret) { - OpRegistry::Instance()->Register(reg_data); - } + (void)ge::OpRegistrationTbe::Instance()->Finalize(reg_data); + (void)OpRegistry::Instance()->Register(reg_data); } } void SaveCustomCaffeProtoPath() { GELOGI("Enter save custom caffe proto path."); - string customop_path; + std::string path_base = ge::GELib::GetPath(); + GELOGI("path_base is %s", path_base.c_str()); + path_base = path_base.substr(0, path_base.rfind('/')); + path_base = path_base.substr(0, path_base.rfind('/') + 1); + ge::GetParserContext().caffe_proto_path = path_base + "include/proto/"; + + string customop_path; const char *path_env = std::getenv("ASCEND_OPP_PATH"); if (path_env != nullptr) { std::string path = path_env; @@ -676,10 +680,6 @@ void SaveCustomCaffeProtoPath() { ge::GetParserContext().custom_proto_path = customop_path; return; } - std::string path_base = ge::GELib::GetPath(); - GELOGI("path_base is %s", path_base.c_str()); - path_base = path_base.substr(0, path_base.rfind('/')); - path_base = path_base.substr(0, path_base.rfind('/') + 1); customop_path = path_base + "ops/framework/custom/caffe/"; ge::GetParserContext().custom_proto_path = customop_path; return; @@ -723,15 +723,6 @@ Status CreateInputsForInference(const ge::Graph &graph, vector &in return ge::SUCCESS; } -void ChangeStringToBool(std::string &arg_s, bool arg_b) { - if (arg_s == "true") { - arg_b = true; - } else { - arg_b = false; - } - return; -} - domi::Status GenerateInfershapeJson() { if (!CheckInputFormat()) { GELOGE(ge::FAILED, "Check input_format failed"); @@ -740,8 +731,6 @@ domi::Status GenerateInfershapeJson() { Status ret = GFlagUtils::CheckDumpInfershapeJsonFlags(); GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "Check flags failed!"); - // Load custom operator Library - LoadCustomOpLib(); ge::GeGenerator ge_generator; std::map options; ge::Status geRet = ge_generator.Initialize(options); @@ -783,24 +772,25 @@ static Status ConvertModelToJson(int fwk_type, const string &model_file, const s return ret; } - if ((fwk_type != domi::TENSORFLOW) && (fwk_type != domi::CAFFE)) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E10068", {"param", "value", "supports"}, - {"framework", std::to_string(fwk_type), "only support 0(Caffe) 3(TensorFlow)"}); - GELOGE(ge::FAILED, "Input parameter[--framework] is mandatory and it's value must be: 0(Caffe) 3(TensorFlow)."); + if ((fwk_type != domi::TENSORFLOW) && (fwk_type != domi::CAFFE) && (fwk_type != domi::ONNX)) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--framework", std::to_string(fwk_type), kModelToJsonSupport}); + GELOGE(ge::FAILED, "Invalid value for --framework[%d], %s.", fwk_type, kModelToJsonSupport); return ge::FAILED; } - // Since the Caffe model's conversion to JSON file depends on lib_caffe_parser.so, loadcustomoplib is called here. - LoadCustomOpLib(); - if (FLAGS_dump_mode == "0") { + // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_tensorflow_parser.so. + LoadCustomOpLib(false); ret = ge::ConvertFwkModelToJson((domi::FrameworkType)fwk_type, model_file.c_str(), json_file.c_str()); return ret; } else if (FLAGS_dump_mode == "1") { + // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_tensorflow_parser.so and ops plugin so. + LoadCustomOpLib(true); ret = GenerateInfershapeJson(); return ret; } else { + ErrorManager::GetInstance().ATCReportErrMessage("E10006", {"parameter"}, {"dump_mode"}); GELOGE(ge::FAILED, "Input parameter[--dump_mode]'s value must be 1 or 0."); return ge::FAILED; } @@ -831,7 +821,7 @@ domi::Status GenerateModel(std::map &options, std::string output ge::Model load_model = ge::Model("loadmodel", "version2"); auto ret1 = load_model.LoadFromFile(FLAGS_model); if (ret1 != ge::GRAPH_SUCCESS) { - ErrorManager::GetInstance().ATCReportErrMessage("E10056", {"parameter"}, {FLAGS_model}); + ErrorManager::GetInstance().ATCReportErrMessage("E10041", {"parameter"}, {FLAGS_model}); DOMI_LOGE( "Load model from %s failed, please check model file or " "input parameter[--framework] is correct", @@ -934,10 +924,11 @@ static void SetEnvForSingleOp(std::map &options) { options.emplace(ge::OPTYPELIST_FOR_IMPLMODE, FLAGS_optypelist_for_implmode); options.emplace(ge::AUTO_TUNE_MODE, FLAGS_auto_tune_mode); options.emplace(ge::GRAPH_MEMORY_MAX_SIZE, kGraphMemoryManagerMallocMaxSize); + options.emplace(ge::OP_DEBUG_LEVEL, to_string(FLAGS_op_debug_level)); } domi::Status GenerateSingleOp(const std::string &json_file_path) { - if (!FLAGS_output.empty() && !ge::CheckOutputPathValid(FLAGS_output)) { + if (!FLAGS_output.empty() && !ge::CheckOutputPathValid(FLAGS_output, "--output")) { DOMI_LOGE("output path %s is not valid!", FLAGS_output.c_str()); return domi::FAILED; } @@ -1003,7 +994,7 @@ domi::Status GenerateOmModel() { "quotation marks (\") to enclose each argument such as out_nodes, input_shape, dynamic_image_size"); #if !defined(__ANDROID__) && !defined(ANDROID) // Load custom operator Library - LoadCustomOpLib(); + LoadCustomOpLib(true); SaveCustomCaffeProtoPath(); @@ -1041,8 +1032,6 @@ domi::Status GenerateOmModel() { options.insert(std::pair(ge::INPUT_FP16_NODES, FLAGS_input_fp16_nodes)); } - options.insert(std::pair(string(ge::HEAD_STREAM), FLAGS_head_stream)); - options.insert(std::pair(string(ge::AUTO_TUNE_MODE), FLAGS_auto_tune_mode)); options.insert( @@ -1060,7 +1049,7 @@ domi::Status GenerateOmModel() { options.insert(std::pair(string(ge::FUSION_SWITCH_FILE), FLAGS_fusion_switch_file)); - options.insert(std::pair(string(ge::ENABLE_COMPRESS_WEIGHT), FLAGS_enable_compress_weight + options.insert(std::pair(string(ge::ENABLE_COMPRESS_WEIGHT), (FLAGS_enable_compress_weight == "true") ? ge::kEnableCompressWeightTrue : ge::kEnableCompressWeightFalse)); @@ -1075,6 +1064,8 @@ domi::Status GenerateOmModel() { options.insert(std::pair(string(ge::ORIGINAL_MODEL_FILE), FLAGS_output + "_original.om")); } + options.insert(std::pair(string(ge::OP_DEBUG_LEVEL), to_string(FLAGS_op_debug_level))); + // print atc option map ge::PrintOptionMap(options, "atc option"); @@ -1098,8 +1089,8 @@ domi::Status ConvertModelToJson() { return domi::SUCCESS; } -bool CheckRet(domi::Status ret, ge::Status geRet) { - if (ret != domi::SUCCESS || geRet != ge::SUCCESS) { +bool CheckRet(domi::Status ret) { + if (ret != domi::SUCCESS) { if (FLAGS_mode == ONLY_PRE_CHECK) { GELOGW("ATC precheck failed."); } else if (FLAGS_mode == GEN_OM_MODEL) { @@ -1148,7 +1139,7 @@ int init(int argc, char *argv[]) { int ret = -1; const std::set log_level = {"default", "null", "debug", "info", "warning", "error"}; if (log_level.count(FLAGS_log) == 0) { - std::cout << "E10016: invalid value for --log:" << FLAGS_log << ", only support debug, info, warning, error, null" + std::cout << "E10010: invalid value for --log:" << FLAGS_log << ", only support debug, info, warning, error, null" << std::endl; return ret; } @@ -1158,12 +1149,18 @@ int init(int argc, char *argv[]) { return ret; } + std::string path_base = ge::GELib::GetPath(); + ret = ErrorManager::GetInstance().Init(path_base); + if (ret != 0) { + DOMI_LOGE("ErrorManager init fail !"); + return ret; + } + return 0; } int main(int argc, char *argv[]) { Status ret = domi::SUCCESS; - ge::Status geRet = ge::SUCCESS; std::cout << "ATC start working now, please wait for a moment." << std::endl; try { // Initialize @@ -1188,12 +1185,9 @@ int main(int argc, char *argv[]) { GE_CHK_BOOL_EXEC(ConvertPbtxtToJson() == domi::SUCCESS, ret = domi::FAILED; break, "ATC convert pbtxt to json execute failed!!"); } else { - ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"value"}, {std::to_string(FLAGS_mode)}); - DOMI_LOGE( - "Invalid value for --mode[%d], only support " - "0(model to framework model), 1(framework model to json), 3(only pre-check), " - "5(pbtxt to json)!", - FLAGS_mode); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--mode", std::to_string(FLAGS_mode), kModeSupport}); + GELOGE(ge::PARAM_INVALID, "Invalid value for --mode[%d], %s.", FLAGS_mode, kModeSupport); ret = domi::FAILED; break; } @@ -1208,8 +1202,12 @@ int main(int argc, char *argv[]) { std::cout << "ATC run failed, some exceptions occur !" << std::endl; } - if (!CheckRet(ret, geRet)) { + if (!CheckRet(ret)) { std::cout << "ATC run failed, Please check the detail log, Try \'atc --help\' for more information" << std::endl; + int result = ErrorManager::GetInstance().OutputErrMessage(STDOUT_FILENO); + if (result != 0) { + DOMI_LOGE("ErrorManager outputErrMessage fail !"); + } return ret; } else { std::cout << "ATC run success, welcome to the next use." << std::endl; diff --git a/src/ge/offline/single_op_parser.cc b/src/ge/offline/single_op_parser.cc index 4d589565..b8947a65 100644 --- a/src/ge/offline/single_op_parser.cc +++ b/src/ge/offline/single_op_parser.cc @@ -200,13 +200,13 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { for (auto &tensor_desc : op_desc.input_desc) { if (tensor_desc.type == DT_UNDEFINED) { ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"input", std::to_string(index)}); - GELOGE(false, "Input index[%d]'s dataType is invalid", index); + GELOGE(false, "Input's dataType is invalid when the index is %d", index); return false; } if (tensor_desc.format == FORMAT_RESERVED) { ErrorManager::GetInstance().ATCReportErrMessage("E10028", {"input", "index"}, {"input", std::to_string(index)}); - GELOGE(PARAM_INVALID, "Input index[%d]'s format is invalid", index); + GELOGE(PARAM_INVALID, "Input's format is invalid when the index is %d", index); return false; } ++index; @@ -216,13 +216,13 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { for (auto &tensor_desc : op_desc.output_desc) { if (tensor_desc.type == DT_UNDEFINED) { ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"output", std::to_string(index)}); - GELOGE(PARAM_INVALID, "Output[%d] dataType is invalid", index); + GELOGE(PARAM_INVALID, "Output's dataType is invalid when the index is %d", index); return false; } if (tensor_desc.format == FORMAT_RESERVED) { ErrorManager::GetInstance().ATCReportErrMessage("E10028", {"input", "index"}, {"output", std::to_string(index)}); - GELOGE(PARAM_INVALID, "Output[%d] format is invalid", index); + GELOGE(PARAM_INVALID, "Output's format is invalid when the index is %d", index); return false; } ++index; @@ -316,17 +316,15 @@ Status SingleOpParser::ParseSingleOpList(const std::string &file, std::vector &options_const) { GetExternalEnginePath(extern_engine_path); GELOGI("OPTION_EXEC_EXTERN_PLUGIN_PATH=%s.", extern_engine_path.c_str()); + op_tiling_manager_.LoadSo(); + ret = plugin_manager_.LoadSo(extern_engine_path, func_check_list); if (ret == SUCCESS) { initialize_ = options; diff --git a/src/ge/opskernel_manager/ops_kernel_manager.h b/src/ge/opskernel_manager/ops_kernel_manager.h index 8d98ad3f..1d464201 100644 --- a/src/ge/opskernel_manager/ops_kernel_manager.h +++ b/src/ge/opskernel_manager/ops_kernel_manager.h @@ -24,6 +24,7 @@ #include "common/debug/log.h" #include "common/ge/plugin_manager.h" +#include "common/ge/op_tiling_manager.h" #include "common/ge_inner_error_codes.h" #include "common/opskernel/ops_kernel_info_store.h" #include "common/optimizer/graph_optimizer.h" @@ -105,6 +106,7 @@ class OpsKernelManager { Status InitGraphOptimizerPriority(); PluginManager plugin_manager_; + OpTilingManager op_tiling_manager_; // opsKernelInfoStore map ops_kernel_store_{}; // graph_optimizer diff --git a/src/ge/session/inner_session.cc b/src/ge/session/inner_session.cc index 74495e82..b97862e1 100644 --- a/src/ge/session/inner_session.cc +++ b/src/ge/session/inner_session.cc @@ -29,6 +29,34 @@ #include "runtime/mem.h" namespace ge { +namespace { +Status CheckReuseMemoryOption(const std::map &options) { + const int kDecimal = 10; + auto dump_op_env = std::getenv("DUMP_OP"); + int dump_op_flag = (dump_op_env != nullptr) ? std::strtol(dump_op_env, nullptr, kDecimal) : 0; + auto iter = options.find(OPTION_EXEC_DISABLE_REUSED_MEMORY); + if (iter != options.end()) { + if (iter->second == "0") { + GELOGD("%s=0, reuse memory is open", OPTION_EXEC_DISABLE_REUSED_MEMORY); + if (dump_op_flag) { + GELOGW("Will dump incorrect op data with ge option %s=0", OPTION_EXEC_DISABLE_REUSED_MEMORY); + } + } else if (iter->second == "1") { + GELOGD("%s=1, reuse memory is close", OPTION_EXEC_DISABLE_REUSED_MEMORY); + } else { + GELOGE(PARAM_INVALID, "option %s=%s is invalid", OPTION_EXEC_DISABLE_REUSED_MEMORY, iter->second.c_str()); + return FAILED; + } + } else { + if (dump_op_flag) { + GELOGW("Will dump incorrect op data with default reuse memory"); + } + } + + return SUCCESS; +} +} // namespace + static std::mutex mutex_; // BuildGraph and RunGraph use InnerSession::InnerSession(uint64_t session_id, const std::map &options) @@ -39,13 +67,36 @@ Status InnerSession::Initialize() { GELOGW("[InnerSession:%lu] session already initialize.", session_id_); return SUCCESS; } + + // If the global options and the session options are duplicated, the session options is preferred. + auto all_options = options_; + all_options.insert(GetMutableGlobalOptions().begin(), GetMutableGlobalOptions().end()); + + Status ret = CheckReuseMemoryOption(all_options); + if (ret != SUCCESS) { + GELOGE(ret, "[InnerSession:%lu] check reuse memory option failed.", session_id_); + return ret; + } + UpdateThreadContext(std::map{}); GE_CHK_RT_RET(rtSetDevice(GetContext().DeviceId())); - Status ret = graph_manager_.Initialize(options_); + PropertiesManager::Instance().GetDumpProperties(session_id_).InitByOptions(); + + ret = graph_manager_.Initialize(options_); if (ret != SUCCESS) { GELOGE(ret, "[InnerSession:%lu] initialize failed.", session_id_); + PropertiesManager::Instance().RemoveDumpProperties(session_id_); + return ret; + } + + ret = VarManager::Instance(session_id_)->SetMemoryMallocSize(all_options); + if (ret != SUCCESS) { + GELOGE(ret, "failed to set malloc size"); + (void)graph_manager_.Finalize(); + PropertiesManager::Instance().RemoveDumpProperties(session_id_); + GE_CHK_RT(rtDeviceReset(static_cast(GetContext().DeviceId()))); return ret; } @@ -55,6 +106,7 @@ Status InnerSession::Initialize() { ret = VarManager::Instance(session_id_)->Init(version, session_id_, DEFAULT_DEVICE_ID, DEFAULT_JOB_ID); if (ret != SUCCESS) { GELOGE(ret, "failed to init session instance"); + PropertiesManager::Instance().RemoveDumpProperties(session_id_); } init_flag_ = true; return SUCCESS; @@ -78,6 +130,9 @@ Status InnerSession::Finalize() { // release var memory GELOGI("VarManager free var memory."); (void)VarManager::Instance(session_id_)->FreeVarMemory(); + + PropertiesManager::Instance().RemoveDumpProperties(session_id_); + GE_CHK_RT(rtDeviceReset(static_cast(GetContext().DeviceId()))); return ret; @@ -223,6 +278,7 @@ void InnerSession::UpdateThreadContext(const std::map GetThreadLocalContext().SetGlobalOption(GetMutableGlobalOptions()); GetThreadLocalContext().SetSessionOption(options_); GetThreadLocalContext().SetGraphOption(options); + GetContext().SetSessionId(session_id_); } void InnerSession::UpdateThreadContext(uint32_t graph_id) { diff --git a/src/ge/session/omg.cc b/src/ge/session/omg.cc index 4754f9b9..26103063 100644 --- a/src/ge/session/omg.cc +++ b/src/ge/session/omg.cc @@ -65,6 +65,9 @@ namespace ge { namespace { const std::string kGraphDefaultName = "domi_default"; const std::string kScopeIdAttr = "fusion_scope"; +const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\""; +const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8"; +const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes."; } // namespace // When the model is converted to a JSON file, the following operator attributes in the blacklist will be ignored @@ -78,7 +81,7 @@ static bool CheckInputTrueOrFalse(const std::string &s, const std::string &atc_p if ((s == "true") || (s == "false")) { return true; } else { - ErrorManager::GetInstance().ATCReportErrMessage("E10033", {"parameter", "value"}, {atc_param, s}); + ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"parameter", "value"}, {atc_param, s}); GELOGE(PARAM_INVALID, "Input parameter[--%s]'s value[%s] must be true or false.", atc_param.c_str(), s.c_str()); return false; } @@ -97,12 +100,12 @@ static Status CheckInputShapeNode(const ComputeGraphPtr &graph) { std::string node_name = it.first; ge::NodePtr node = graph->FindNode(node_name); if (node == nullptr) { - ErrorManager::GetInstance().ATCReportErrMessage("E10034", {"parameter", "opname"}, {"input_shape", node_name}); + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"input_shape", node_name}); GELOGE(PARAM_INVALID, "Input parameter[--input_shape]'s opname[%s] is not exist in model", node_name.c_str()); return PARAM_INVALID; } if (node->GetType() != DATA) { - ErrorManager::GetInstance().ATCReportErrMessage("E10035", {"parameter", "opname"}, {"input_shape", node_name}); + ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"}, {"input_shape", node_name}); GELOGE(PARAM_INVALID, "Input parameter[--input_shape]'s opname[%s] is not a input opname", node_name.c_str()); return PARAM_INVALID; } @@ -133,18 +136,19 @@ static Status CheckInputFp16Nodes(const ComputeGraphPtr &graph, const string &in for (uint32_t i = 0; i < input_fp16_nodes_vec.size(); ++i) { ge::NodePtr node = graph->FindNode(input_fp16_nodes_vec[i]); if (node == nullptr) { - ErrorManager::GetInstance().ATCReportErrMessage("E10034", {"parameter", "opname"}, + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"input_fp16_nodes", input_fp16_nodes_vec[i]}); - GELOGE(PARAM_INVALID, "Can not find node [%s] in graph, please check input_fp16_nodes param", + GELOGE(PARAM_INVALID, "Input parameter[--input_fp16_nodes]'s opname[%s] is not exist in model", input_fp16_nodes_vec[i].c_str()); return PARAM_INVALID; } auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); if (op_desc->GetType() != DATA) { - ErrorManager::GetInstance().ATCReportErrMessage("E10035", {"parameter", "opname"}, + ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"}, {"input_fp16_nodes", input_fp16_nodes_vec[i]}); - GELOGE(PARAM_INVALID, "input_fp16_nodes: %s is not a input node name", input_fp16_nodes_vec[i].c_str()); + GELOGE(PARAM_INVALID, "Input parameter[--input_fp16_nodes]'s opname[%s] is not a input opname", + input_fp16_nodes_vec[i].c_str()); return PARAM_INVALID; } if (ge::AttrUtils::SetBool(op_desc, "input_fp16", true)) { @@ -302,14 +306,32 @@ Status SetOutFormatAndDataTypeAttr(ge::OpDescPtr op_desc, const ge::Format forma return domi::SUCCESS; } +bool CheckDigitStr(std::string &str) { + for (char c : str) { + if (!isdigit(c)) { + GELOGE(domi::FAILED, "value[%s] is not positive integer", str.c_str()); + return false; + } + } + return true; +} + Status StringToInt(std::string &str, int32_t &value) { try { + if (!CheckDigitStr(str)) { + GELOGE(PARAM_INVALID, "Invalid of digit string: %s ", str.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--output_type", str, "is not positive integer"}); + return PARAM_INVALID; + } value = stoi(str); } catch (std::invalid_argument &) { - GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", str.c_str()); + GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch invalid_argument.", str.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"output_type", str}); return PARAM_INVALID; } catch (std::out_of_range &) { - GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", str.c_str()); + GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch out_of_range.", str.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"output_type", str}); return PARAM_INVALID; } return SUCCESS; @@ -325,8 +347,9 @@ Status VerifyOutputTypeAndOutNodes(std::vector &out_type_vec) { } for (uint32_t i = 0; i < out_type_vec.size(); ++i) { if (out_nodes_info.find(out_type_vec[i]) == out_nodes_info.end()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10059", {"value"}, {out_type_vec[i]}); - GELOGE(domi::FAILED, "Can not find this node (%s) in out_nodes.", out_type_vec[i].c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--output_type", out_type_vec[i], kOutputTypeError}); + GELOGE(domi::FAILED, "Invalid value for --output_type[%s], %s.", out_type_vec[i].c_str(), kOutputTypeError); return domi::FAILED; } } @@ -339,9 +362,9 @@ Status ParseOutputType(const std::string &output_type, std::map node_index_type_v = StringUtils::Split(node, ':'); if (node_index_type_v.size() != 3) { // The size must be 3. - ErrorManager::GetInstance().ATCReportErrMessage("E10058", {"value"}, {node}); - GELOGE(PARAM_INVALID, - "The param of output_type is invalid, the correct format is [opname:index:dtype]," - "while the actual input is %s.", - node.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--output_type", node, kOutputTypeSample}); + GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", node.c_str(), kOutputTypeSample); return domi::FAILED; } ge::DataType tmp_dt; @@ -363,13 +384,15 @@ Status ParseOutputType(const std::string &output_type, std::mapsecond; @@ -396,6 +419,22 @@ Status ParseOutputType(const std::string &output_type, std::mapGetOutputsSize(); + if (index < 0 || index >= out_size) { + GELOGE(domi::FAILED, + "out_node [%s] output index:%d must be smaller " + "than node output size:%d and can not be negative!", + op_desc->GetName().c_str(), index, out_size); + std::string fail_reason = "output index:" + to_string(index) + + " must be smaller than output size:" + to_string(out_size) + " and can not be negative!"; + ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"}, + {"out_nodes", op_desc->GetName(), fail_reason}); + return domi::FAILED; + } + return domi::SUCCESS; +} + Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output) { ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); GE_CHECK_NOTNULL(compute_graph); @@ -404,7 +443,6 @@ Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::vector output_formats = domi::GetContext().output_formats; std::vector> output_nodes_info; std::vector output_nodes_name; - std::map> out_type_index_map; std::map> out_type_dt_map; if (!output_type.empty()) { @@ -423,6 +461,10 @@ Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const } auto op_desc = out_node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); + if (CheckOutNode(op_desc, user_out_nodes[i].second) != SUCCESS) { + GELOGE(domi::FAILED, "Check out node (%s) fail.", user_out_nodes[i].first.c_str()); + return domi::FAILED; + } if (i < output_formats.size()) { if (output_formats[i] == domi::DOMI_TENSOR_NC1HWC0) { GELOGI("The output node [%s] should be set NC1HWC0", user_out_nodes[i].first.c_str()); @@ -445,18 +487,43 @@ Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const if (user_out_nodes.empty()) { for (ge::NodePtr node : compute_graph->GetDirectNode()) { if (!node->GetInDataNodes().empty() && node->GetOutDataNodes().empty()) { - Status ret = GetOutputLeaf(node, output_nodes_info, output_nodes_name); + Status ret = GetOutputLeaf(node, output_nodes_info); GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "find leaf fail."); } } } + GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name); compute_graph->SetGraphOutNodesInfo(output_nodes_info); domi::GetContext().net_out_nodes = output_nodes_name; return domi::SUCCESS; } -Status GetOutputLeaf(NodePtr node, std::vector> &output_nodes_info, - std::vector &output_nodes_name) { +void GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, + std::vector &output_nodes_name) { + output_nodes_name.clear(); + if (domi::GetContext().out_top_names.empty()) { + // tf process, no top name. + for (const auto output_node_info : output_nodes_info) { + std::string node_name = output_node_info.first->GetName(); + int32_t index = output_node_info.second; + output_nodes_name.push_back(node_name + ":" + std::to_string(index)); + } + return; + } + // caffe process, need add top name after node_name:index + for (size_t i = 0; i < output_nodes_info.size(); ++i) { + std::string node_name = output_nodes_info[i].first->GetName(); + int32_t index = output_nodes_info[i].second; + if (i < domi::GetContext().out_top_names.size()) { + output_nodes_name.push_back(node_name + ":" + std::to_string(index) + ":" + domi::GetContext().out_top_names[i]); + } else { + GELOGW("Get top name of node [%s] fail.", node_name.c_str()); + output_nodes_name.push_back(node_name + ":" + std::to_string(index)); + } + } +} + +Status GetOutputLeaf(NodePtr node, std::vector> &output_nodes_info) { ge::OpDescPtr tmpDescPtr = node->GetOpDesc(); if (tmpDescPtr == nullptr) { GELOGE(domi::FAILED, "Get outnode op desc fail."); @@ -466,7 +533,6 @@ Status GetOutputLeaf(NodePtr node, std::vector> if (node->GetType() != NETOUTPUT) { for (size_t index = 0; index < size; ++index) { output_nodes_info.push_back(std::make_pair(node, index)); - output_nodes_name.push_back(node->GetName() + ":" + std::to_string(index)); } } else { const auto in_anchors = node->GetAllInDataAnchors(); @@ -478,7 +544,6 @@ Status GetOutputLeaf(NodePtr node, std::vector> } auto out_node = out_anchor->GetOwnerNode(); output_nodes_info.push_back(std::make_pair(out_node, out_anchor->GetIdx())); - output_nodes_name.push_back(out_node->GetName() + ":" + std::to_string(out_anchor->GetIdx())); } } return SUCCESS; @@ -538,8 +603,9 @@ Status ParseOutNodes(const string &out_nodes) { for (const string &node : nodes_v) { vector key_value_v = StringUtils::Split(node, ':'); if (key_value_v.size() != 2) { // The size must be 2. - ErrorManager::GetInstance().ATCReportErrMessage("E10069", {"param", "value", "supports"}, - {"out_nodes", node, "opname:index"}); + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--out_nodes", node, "the correct format is \"node_name1:0;node_name1:1;node_name2:0\""}); GELOGE(PARAM_INVALID, "The input format of --out_nodes is invalid, the correct format is " "\"node_name1:0;node_name1:1;node_name2:0\", while the actual input is %s.", @@ -548,6 +614,12 @@ Status ParseOutNodes(const string &out_nodes) { } auto iter = domi::GetContext().out_nodes_map.find(key_value_v[0]); // stoi: The method may throw an exception: invalid_argument/out_of_range + if (!CheckDigitStr(key_value_v[1])) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--out_nodes", out_nodes, "is not positive integer"}); + GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s", out_nodes.c_str()); + return PARAM_INVALID; + } int32_t index = stoi(StringUtils::Trim(key_value_v[1])); if (iter != domi::GetContext().out_nodes_map.end()) { iter->second.emplace_back(index); @@ -561,9 +633,11 @@ Status ParseOutNodes(const string &out_nodes) { } } catch (std::invalid_argument &) { GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"out_nodes", out_nodes}); return PARAM_INVALID; } catch (std::out_of_range &) { GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"out_nodes", out_nodes}); return PARAM_INVALID; } @@ -575,7 +649,7 @@ Status ParseOutNodes(const string &out_nodes) { /// @param [in] graph Input network graph /// @return SUCCESS: Input parameters are correct; PARAM_INVALID: Input parameters are incorrect /// -static Status CheckOpNameMap(const ComputeGraphPtr &graph) { +static Status CheckOpNameMap(const ComputeGraphPtr &graph, const std::string &op_conf) { GE_CHECK_NOTNULL(graph); unordered_map graphNodeTypes; for (const NodePtr &node : graph->GetAllNodes()) { @@ -590,7 +664,9 @@ static Status CheckOpNameMap(const ComputeGraphPtr &graph) { GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(propertiesMap.empty(), "op_name_map file is empty, please check file!"); for (auto iter = propertiesMap.begin(); iter != propertiesMap.end(); iter++) { GE_IF_BOOL_EXEC(graphNodeTypes.find(iter->second) == graphNodeTypes.end(), - ErrorManager::GetInstance().ATCReportErrMessage("E10060", {"parameter"}, {"op_name_map"}); + ErrorManager::GetInstance().ATCReportErrMessage( + "E10003", {"parameter", "value", "reason"}, + {"op_name_map", op_conf, "type[" + iter->second + "] is not found in model"}); GELOGE(PARAM_INVALID, "Invalid parameter for op_name_map."); return PARAM_INVALID;); } return SUCCESS; @@ -647,7 +723,8 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::mapCreateModelParser(framework); GE_CHK_BOOL_RET_STATUS(model_parser != nullptr, FAILED, "ATC create model parser ret fail, framework:%d.", framework); return model_parser->ToJson(model_file, json_file); } - ErrorManager::GetInstance().ATCReportErrMessage("E10045", {"parameter"}, {"model"}); + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--framework", std::to_string(framework), "only support 0(Caffe) 3(TensorFlow)"}); GELOGE(PARAM_INVALID, "Input parameter[--framework] is mandatory and it's value must be: 0(Caffe) 3(TensorFlow)."); return PARAM_INVALID; } diff --git a/src/ge/session/session_manager.cc b/src/ge/session/session_manager.cc index c3439b0b..68a8aa70 100644 --- a/src/ge/session/session_manager.cc +++ b/src/ge/session/session_manager.cc @@ -51,11 +51,11 @@ Status SessionManager::Finalize() { return SUCCESS; } -Status SessionManager::SetrtContext(rtContext_t rt_context) { +Status SessionManager::SetRtContext(SessionId session_id, rtContext_t rt_context) { GELOGI("set rt_context RT_CTX_NORMAL_MODE, device id:%u.", GetContext().DeviceId()); GE_CHK_RT_RET(rtCtxCreate(&rt_context, RT_CTX_NORMAL_MODE, static_cast(GetContext().DeviceId()))); GE_CHK_RT_RET(rtCtxSetCurrent(rt_context)); - RtContextUtil::GetInstance().AddrtContext(rt_context); + RtContextUtil::GetInstance().AddRtContext(session_id, rt_context); return SUCCESS; } @@ -85,7 +85,7 @@ Status SessionManager::CreateSession(const std::map &o session_id = next_session_id; // create a context - ret = SetrtContext(rtContext_t()); + ret = SetRtContext(session_id, rtContext_t()); return ret; } @@ -106,7 +106,7 @@ Status SessionManager::DestroySession(SessionId session_id) { } // Unified destruct rt_context - RtContextUtil::GetInstance().DestroyrtContexts(); + RtContextUtil::GetInstance().DestroyRtContexts(session_id); SessionPtr innerSession = it->second; Status ret = innerSession->Finalize(); @@ -300,4 +300,4 @@ bool SessionManager::IsGraphNeedRebuild(SessionId session_id, uint32_t graph_id) } return innerSession->IsGraphNeedRebuild(graph_id); } -}; // namespace ge +} // namespace ge diff --git a/src/ge/session/session_manager.h b/src/ge/session/session_manager.h index 5cce5214..5cdb849f 100644 --- a/src/ge/session/session_manager.h +++ b/src/ge/session/session_manager.h @@ -33,7 +33,6 @@ class SessionManager { friend class GELib; public: - Status SetrtContext(rtContext_t rtContext); /// /// @ingroup ge_session /// @brief create session @@ -163,10 +162,12 @@ class SessionManager { Status GetNextSessionId(SessionId &next_session_id); + Status SetRtContext(SessionId session_id, rtContext_t rtContext); + std::map session_manager_map_; std::mutex mutex_; bool init_flag_ = false; }; -}; // namespace ge +} // namespace ge #endif // GE_SESSION_SESSION_MANAGER_H_ diff --git a/src/ge/single_op/single_op.cc b/src/ge/single_op/single_op.cc index 9578471a..e2d756df 100644 --- a/src/ge/single_op/single_op.cc +++ b/src/ge/single_op/single_op.cc @@ -50,9 +50,13 @@ Status SingleOp::ValidateArgs(const std::vector &inputs, const std:: for (size_t i = 0; i < num_inputs; ++i) { // preventing from read out of bound size_t aligned_size = GetAlignedSize(inputs[i].length); + GELOGI("Input [%zu], aligned_size:%zu, inputs.length:%u, input_sizes_:%u", i, aligned_size, inputs[i].length, + input_sizes_[i]); if (aligned_size < input_sizes_[i]) { - GELOGE(PARAM_INVALID, "Input size mismatch. index = %zu, model expect %zu, but given %zu(after align)", i, - input_sizes_[i], aligned_size); + GELOGE(PARAM_INVALID, + "Input size mismatch. index = %zu, model expect %zu," + " but given %zu(after align)", + i, input_sizes_[i], aligned_size); return PARAM_INVALID; } } @@ -66,9 +70,13 @@ Status SingleOp::ValidateArgs(const std::vector &inputs, const std:: for (size_t i = 0; i < num_outputs; ++i) { // preventing from write out of bound size_t aligned_size = GetAlignedSize(outputs[i].length); + GELOGI("Output [%zu], aligned_size:%zu, outputs.length:%u, output_sizes_:%u", i, aligned_size, outputs[i].length, + output_sizes_[i]); if (aligned_size < output_sizes_[i]) { - GELOGE(PARAM_INVALID, "Output size mismatch. index = %zu, model expect %zu, but given %zu(after align)", i, - output_sizes_[i], aligned_size); + GELOGE(PARAM_INVALID, + "Output size mismatch. index = %zu, model expect %zu," + "but given %zu(after align)", + i, output_sizes_[i], aligned_size); return PARAM_INVALID; } } @@ -81,23 +89,11 @@ Status SingleOp::GetArgs(const std::vector &inputs, const std::vecto if (use_physical_addr_) { for (auto &input : inputs) { auto *addr = reinterpret_cast(input.data); - size_t aligned_size = GetAlignedSize(input.length); - auto ret = ModelUtils::ConvertVirtualAddressToPhysical(addr, aligned_size, addr); - if (ret != SUCCESS) { - GELOGE(ret, "ConvertVirtualAddressToPhysical failed. Arg index = %zu", arg_index); - return ret; - } args_[arg_index++] = reinterpret_cast(addr); } for (auto &output : outputs) { auto *addr = reinterpret_cast(output.data); - size_t aligned_size = GetAlignedSize(output.length); - auto ret = ModelUtils::ConvertVirtualAddressToPhysical(addr, aligned_size, addr); - if (ret != SUCCESS) { - GELOGE(ret, "ConvertVirtualAddressToPhysical failed. Arg index = %zu", arg_index); - return ret; - } args_[arg_index++] = reinterpret_cast(addr); } } else { @@ -117,6 +113,7 @@ Status SingleOp::UpdateArgs(const std::vector &inputs, const std::ve if (ret != SUCCESS) { return ret; } + // update tbe task args size_t num_args = arg_table_.size(); for (size_t i = 0; i < num_args; ++i) { std::vector &ptr_to_arg_in_tasks = arg_table_[i]; @@ -129,18 +126,34 @@ Status SingleOp::UpdateArgs(const std::vector &inputs, const std::ve *arg_addr = args_[i]; } } + // update aicpu_TF or aicpu_CC args for (auto &task : tasks_) { + size_t io_addr_num = args_.size(); if (task->GetOpTaskType() == OP_TASK_AICPU) { - GELOGD("Update aicpu task args"); + GELOGD("Update aicpu_TF task args"); AiCpuTask *task_aicpu = dynamic_cast(task); GE_CHECK_NOTNULL(task_aicpu); - auto *dstIOAddr = const_cast(reinterpret_cast(task_aicpu->GetIOAddr())); - auto rt_ret = rtMemcpyAsync(dstIOAddr, sizeof(uint64_t) * args_.size(), &args_[0], + auto *dst_io_addr = const_cast(reinterpret_cast(task_aicpu->GetIOAddr())); + GE_CHECK_NOTNULL(dst_io_addr); + auto rt_ret = rtMemcpyAsync(dst_io_addr, sizeof(uint64_t) * args_.size(), &args_[0], sizeof(uint64_t) * args_.size(), RT_MEMCPY_HOST_TO_DEVICE_EX, stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "rtMemcpyAsync addresses failed, ret = %d", rt_ret); return RT_FAILED; } + } else if (task->GetOpTaskType() == OP_TASK_AICPUCC) { + GELOGD("Update aicpu_CC task args"); + AiCpuCCTask *task_aicpu_cc = dynamic_cast(task); + GE_CHECK_NOTNULL(task_aicpu_cc); + const uintptr_t *task_io_addr = reinterpret_cast(task_aicpu_cc->GetIOAddr()); + GE_CHECK_NOTNULL(task_io_addr); + auto io_addr = reinterpret_cast(const_cast(task_io_addr)); + for (size_t i = 0; i < io_addr_num; ++i) { + io_addr[i] = reinterpret_cast(args_[i]); + } + } else { + GELOGW("Only TF_kernel aicpu and aicpu_CC are supported, but got %u", task->GetOpTaskType()); + continue; } } return SUCCESS; @@ -164,6 +177,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOp::ExecuteAsync(c return ret; } } + return ret; } diff --git a/src/ge/single_op/single_op_model.cc b/src/ge/single_op/single_op_model.cc index 9decdf75..b72a41fc 100644 --- a/src/ge/single_op/single_op_model.cc +++ b/src/ge/single_op/single_op_model.cc @@ -28,6 +28,7 @@ #include "graph/utils/tensor_utils.h" #include "runtime/rt.h" #include "task/aicpu_task_builder.h" +#include "task/aicpu_kernel_task_builder.h" #include "task/tbe_task_builder.h" using domi::TaskDef; @@ -198,11 +199,6 @@ Status SingleOpModel::SetInputsAndOutputs(SingleOp &single_op) { int arg_index = 0; for (size_t i = 0; i < input_offset_list_.size(); ++i) { auto *addr = model_params_.mem_base + input_offset_list_[i]; - auto ret = ModelUtils::ConvertVirtualAddressToPhysical(addr, input_sizes_[i], addr); - if (ret != SUCCESS) { - GELOGE(ret, "ConvertVirtualAddressToPhysical failed. Input index = %zu", i); - return ret; - } model_params_.addr_mapping_.emplace(reinterpret_cast(addr), arg_index++); single_op.input_sizes_.emplace_back(input_sizes_[i]); single_op.input_addr_list_.emplace_back(addr); @@ -210,11 +206,6 @@ Status SingleOpModel::SetInputsAndOutputs(SingleOp &single_op) { for (size_t i = 0; i < output_offset_list_.size(); ++i) { auto *addr = model_params_.mem_base + output_offset_list_[i]; - auto ret = ModelUtils::ConvertVirtualAddressToPhysical(addr, output_sizes_[i], addr); - if (ret != SUCCESS) { - GELOGE(ret, "ConvertVirtualAddressToPhysical failed. Output index = %zu", i); - return ret; - } model_params_.addr_mapping_.emplace(reinterpret_cast(addr), arg_index++); single_op.output_sizes_.emplace_back(output_sizes_[i]); single_op.output_addr_list_.emplace_back(addr); @@ -234,16 +225,31 @@ Status SingleOpModel::BuildTaskList(SingleOp &single_op) { task_def.DebugString().c_str()); auto task_type = static_cast(task_def.type()); if (task_type == RT_MODEL_TASK_KERNEL) { - GELOGD("Building TBE task"); - OpTask *task = nullptr; - auto ret = BuildKernelTask(task_def.kernel(), single_op, &task); - if (ret != SUCCESS) { - return ret; + const domi::KernelDef &kernel_def = task_def.kernel(); + const auto &context = kernel_def.context(); + auto kernel_type = static_cast(context.kernel_type()); + if (kernel_type == cce::ccKernelType::TE) { + GELOGD("Building TBE task"); + OpTask *task = nullptr; + auto ret = BuildKernelTask(task_def.kernel(), single_op, &task); + if (ret != SUCCESS) { + return ret; + } + single_op.tasks_.emplace_back(task); + } else if (kernel_type == cce::ccKernelType::AI_CPU) { + GELOGD("Building AICPU_CC task"); + OpTask *task = nullptr; + auto ret = BuildCpuKernelTask(task_def.kernel(), &task); + if (ret != SUCCESS) { + return ret; + } + single_op.tasks_.emplace_back(task); + } else { + GELOGE(UNSUPPORTED, "Only TBE kernel and AI_CPU kernek are supported, but got %u", context.kernel_type()); + return UNSUPPORTED; } - - single_op.tasks_.emplace_back(task); } else if (task_type == RT_MODEL_TASK_KERNEL_EX) { - GELOGD("Building AICPU task"); + GELOGD("Building AICPU_TF task"); OpTask *task = nullptr; auto ret = BuildKernelExTask(task_def.kernel_ex(), single_op, &task); if (ret != SUCCESS) { @@ -281,12 +287,6 @@ void SingleOpModel::ParseArgTable(TbeOpTask *task, SingleOp &op) { Status SingleOpModel::BuildKernelTask(const domi::KernelDef &kernel_def, SingleOp &single_op, OpTask **task) { GE_CHECK_NOTNULL(task); const auto &context = kernel_def.context(); - auto kernel_type = static_cast(context.kernel_type()); - if (kernel_type != cce::ccKernelType::TE) { - GELOGE(UNSUPPORTED, "Only TBE kernel is supported, but got %u", context.kernel_type()); - return UNSUPPORTED; - } - auto iter = op_list_.find(context.op_index()); if (iter == op_list_.end()) { GELOGE(INTERNAL_ERROR, "op desc not found. op index = %u", context.op_index()); @@ -323,13 +323,13 @@ Status SingleOpModel::BuildKernelExTask(const domi::KernelExDef &kernel_def, Sin std::unique_ptr aicpu_task(new (std::nothrow) AiCpuTask()); if (aicpu_task == nullptr) { - GELOGE(MEMALLOC_FAILED, "create aicpu op task failed"); + GELOGE(MEMALLOC_FAILED, "create aicpu_TF op task failed"); return MEMALLOC_FAILED; } auto builder = AiCpuTaskBuilder(iter->second, kernel_def); auto ret = builder.BuildTask(*aicpu_task, model_params_); if (ret != SUCCESS) { - GELOGE(ret, "build aicpu op task failed"); + GELOGE(ret, "build aicpu_TF op task failed"); return ret; } @@ -337,6 +337,24 @@ Status SingleOpModel::BuildKernelExTask(const domi::KernelExDef &kernel_def, Sin return SUCCESS; } +Status SingleOpModel::BuildCpuKernelTask(const domi::KernelDef &kernel_def, OpTask **task) { + std::unique_ptr aicpucc_task(new (std::nothrow) AiCpuCCTask()); + if (aicpucc_task == nullptr) { + GELOGE(MEMALLOC_FAILED, "create aicpu_CC op task failed"); + return MEMALLOC_FAILED; + } + + auto builder = AiCpuCCTaskBuilder(kernel_def); + auto ret = builder.BuildTask(*aicpucc_task); + if (ret != SUCCESS) { + GELOGE(ret, "build aicpu_CC op task failed"); + return ret; + } + + *task = aicpucc_task.release(); + return SUCCESS; +} + Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) { auto ret = InitModelMem(resource); if (ret != SUCCESS) { diff --git a/src/ge/single_op/single_op_model.h b/src/ge/single_op/single_op_model.h index 4d8aae30..3b8c2616 100644 --- a/src/ge/single_op/single_op_model.h +++ b/src/ge/single_op/single_op_model.h @@ -64,6 +64,7 @@ class SingleOpModel { Status BuildTaskList(SingleOp &single_op); Status BuildKernelTask(const domi::KernelDef &kernel_def, SingleOp &single_op, OpTask **task); Status BuildKernelExTask(const domi::KernelExDef &kernel_def, SingleOp &single_op, OpTask **task); + Status BuildCpuKernelTask(const domi::KernelDef &kernel_def, OpTask **task); static void ParseOpModelParams(ModelHelper &model_helper, SingleOpModelParam ¶m); void ParseArgTable(TbeOpTask *task, SingleOp &op); diff --git a/src/ge/single_op/task/aicpu_kernel_task_builder.cc b/src/ge/single_op/task/aicpu_kernel_task_builder.cc new file mode 100644 index 00000000..936c7b67 --- /dev/null +++ b/src/ge/single_op/task/aicpu_kernel_task_builder.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "single_op/task/aicpu_kernel_task_builder.h" + +namespace ge { +AiCpuCCTaskBuilder::AiCpuCCTaskBuilder(const domi::KernelDef &kernel_def) : kernel_def_(kernel_def) {} + +Status AiCpuCCTaskBuilder::SetKernelArgs(AiCpuCCTask &task) { + size_t aicpu_arg_size = kernel_def_.args_size(); + if (aicpu_arg_size <= 0) { + GELOGE(RT_FAILED, "aicpu_arg_size is invalid, value = %zu", aicpu_arg_size); + return RT_FAILED; + } + void *aicpu_args = malloc(aicpu_arg_size); + if (aicpu_args == nullptr) { + GELOGE(RT_FAILED, "malloc failed, size = %zu", aicpu_arg_size); + return RT_FAILED; + } + + task.SetKernelArgs(aicpu_args, aicpu_arg_size); + auto err = memcpy_s(aicpu_args, aicpu_arg_size, kernel_def_.args().data(), aicpu_arg_size); + if (err != EOK) { + GELOGE(RT_FAILED, "memcpy_s args failed, size = %zu, err = %d", aicpu_arg_size, err); + return RT_FAILED; + } + + task.SetIoAddr(static_cast(aicpu_args) + sizeof(aicpu::AicpuParamHead)); + return SUCCESS; +} + +Status AiCpuCCTaskBuilder::BuildTask(AiCpuCCTask &task) { + auto ret = SetKernelArgs(task); + if (ret != SUCCESS) { + return ret; + } + const std::string &so_name = kernel_def_.so_name(); + const std::string &kernel_name = kernel_def_.kernel_name(); + task.SetSoName(so_name); + task.SetkernelName(kernel_name); + return SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/src/ge/single_op/task/aicpu_kernel_task_builder.h b/src/ge/single_op/task/aicpu_kernel_task_builder.h new file mode 100644 index 00000000..c445132e --- /dev/null +++ b/src/ge/single_op/task/aicpu_kernel_task_builder.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_SINGLE_OP_TASK_AICPU_KERNEL_TASK_BUILDER_H_ +#define GE_SINGLE_OP_TASK_AICPU_KERNEL_TASK_BUILDER_H_ + +#include +#include "aicpu/common/aicpu_task_struct.h" +#include "single_op/single_op.h" +#include "single_op/single_op_model.h" +#include "runtime/mem.h" + +namespace ge { +class AiCpuCCTaskBuilder { + public: + explicit AiCpuCCTaskBuilder(const domi::KernelDef &kernel_def); + ~AiCpuCCTaskBuilder() = default; + + Status BuildTask(AiCpuCCTask &task); + + private: + Status SetKernelArgs(AiCpuCCTask &task); + const domi::KernelDef &kernel_def_; +}; +} // namespace ge + +#endif // GE_SINGLE_OP_TASK_AICPUCC_TASK_BUILDER_H_ \ No newline at end of file diff --git a/src/ge/single_op/task/aicpu_task_builder.cc b/src/ge/single_op/task/aicpu_task_builder.cc index 1a4c37ca..bc2c76f6 100644 --- a/src/ge/single_op/task/aicpu_task_builder.cc +++ b/src/ge/single_op/task/aicpu_task_builder.cc @@ -129,7 +129,8 @@ Status AiCpuTaskBuilder::BuildTask(ge::AiCpuTask &task, const SingleOpModelParam task.task_info_ = kernel_def_.task_info(); task.workspace_addr_ = ws_addr_vec[0]; + auto debug_info = BuildTaskUtils::GetTaskInfo(op_desc_); + GELOGI("[TASK_INFO] %s %s", task.task_info_.c_str(), debug_info.c_str()); return SUCCESS; } - } // namespace ge diff --git a/src/ge/single_op/task/build_task_utils.cc b/src/ge/single_op/task/build_task_utils.cc index 82b77031..9e97ee57 100644 --- a/src/ge/single_op/task/build_task_utils.cc +++ b/src/ge/single_op/task/build_task_utils.cc @@ -19,7 +19,9 @@ #include "runtime/rt.h" #include "graph/load/new_model_manager/model_utils.h" #include "graph/manager/graph_var_manager.h" +#include "graph/utils/type_utils.h" #include "framework/common/debug/ge_log.h" +#include "framework/common/types.h" namespace ge { namespace { @@ -27,7 +29,7 @@ const uint64_t kSessionId = UINT64_MAX; uint8_t *kVarBase = nullptr; const uint64_t kLogicVarBase = 0; const uint64_t kVarSize = 0; -} +} // namespace std::vector> BuildTaskUtils::GetAddresses(const OpDescPtr &op_desc, const SingleOpModelParam ¶m) { @@ -58,9 +60,46 @@ std::vector BuildTaskUtils::JoinAddresses(const std::vector BuildTaskUtils::GetKernelArgs(const OpDescPtr &op_desc, - const SingleOpModelParam ¶m) { +std::vector BuildTaskUtils::GetKernelArgs(const OpDescPtr &op_desc, const SingleOpModelParam ¶m) { auto addresses = GetAddresses(op_desc, param); return JoinAddresses(addresses); } + +std::string BuildTaskUtils::GetTaskInfo(const OpDescPtr &op_desc) { + std::stringstream ss; + if (op_desc != nullptr) { + auto op_type = op_desc->GetType(); + if (op_type == ge::NETOUTPUT || op_type == ge::DATA) { + return ss.str(); + } + // Conv2D IN[DT_FLOAT16 NC1HWC0[256, 128, 7, 7, 16],DT_FLOAT16 FRACTAL_Z[128, 32, 16, 16]] + // OUT[DT_FLOAT16 NC1HWC0[256, 32, 7, 7, 16]] + ss << op_type << " IN["; + for (uint32_t idx = 0; idx < op_desc->GetInputsSize(); idx++) { + const GeTensorDescPtr &input = op_desc->MutableInputDesc(idx); + ss << TypeUtils::DataTypeToSerialString(input->GetDataType()) << " "; + ss << TypeUtils::FormatToSerialString(input->GetFormat()); + ss << VectorToString(input->GetShape().GetDims()); + if (idx < op_desc->GetInputsSize() - 1) { + ss << ","; + } + } + ss << "] OUT["; + + for (uint32_t idx = 0; idx < op_desc->GetOutputsSize(); idx++) { + const GeTensorDescPtr &output = op_desc->MutableOutputDesc(idx); + ss << TypeUtils::DataTypeToSerialString(output->GetDataType()) << " "; + Format out_format = output->GetFormat(); + const GeShape &out_shape = output->GetShape(); + const auto &dims = out_shape.GetDims(); + ss << TypeUtils::FormatToSerialString(out_format); + ss << VectorToString(dims); + if (idx < op_desc->GetOutputsSize() - 1) { + ss << ","; + } + } + ss << "]\n"; + } + return ss.str(); +} } // namespace ge diff --git a/src/ge/single_op/task/build_task_utils.h b/src/ge/single_op/task/build_task_utils.h index a5030e69..f5885fd2 100644 --- a/src/ge/single_op/task/build_task_utils.h +++ b/src/ge/single_op/task/build_task_utils.h @@ -18,6 +18,7 @@ #define GE_SINGLE_OP_TASK_BUILD_TASK_UTILS_H_ #include +#include #include "graph/op_desc.h" #include "single_op/single_op.h" @@ -31,6 +32,21 @@ class BuildTaskUtils { static std::vector> GetAddresses(const OpDescPtr &op_desc, const SingleOpModelParam ¶m); static std::vector JoinAddresses(const std::vector> &addresses); static std::vector GetKernelArgs(const OpDescPtr &op_desc, const SingleOpModelParam ¶m); + static std::string GetTaskInfo(const OpDescPtr &op_desc); + template + static std::string VectorToString(const std::vector &values) { + std::stringstream ss; + ss << '['; + auto size = values.size(); + for (size_t i = 0; i < size; ++i) { + ss << values[i]; + if (i != size - 1) { + ss << ", "; + } + } + ss << ']'; + return ss.str(); + } }; } // namespace ge #endif // GE_SINGLE_OP_TASK_BUILD_TASK_UTILS_H_ diff --git a/src/ge/single_op/task/op_task.cc b/src/ge/single_op/task/op_task.cc index e93fad71..19e8b6a4 100644 --- a/src/ge/single_op/task/op_task.cc +++ b/src/ge/single_op/task/op_task.cc @@ -16,10 +16,18 @@ #include "single_op/task/op_task.h" +#include +#include + #include "runtime/rt.h" #include "framework/common/debug/ge_log.h" namespace ge { +namespace { +constexpr int kLaunchRetryTimes = 1000; +constexpr int kSleepTime = 10; +} // namespace + void TbeOpTask::SetStubFunc(const std::string &name, const void *stub_func) { this->stub_name_ = name; this->stub_func_ = stub_func; @@ -53,12 +61,20 @@ Status TbeOpTask::LaunchKernel(rtStream_t stream) { GELOGD("To invoke rtKernelLaunch. task = %s, block_dim = %u", this->stub_name_.c_str(), block_dim_); auto *sm_desc = reinterpret_cast(sm_desc_); auto ret = rtKernelLaunch(stub_func_, block_dim_, args_, static_cast(arg_size_), sm_desc, stream); + int retry_times = 0; + while (ret != RT_ERROR_NONE && retry_times < kLaunchRetryTimes) { + retry_times++; + GELOGW("Retry after %d ms, retry_times: %d", kSleepTime, retry_times); + std::this_thread::sleep_for(std::chrono::milliseconds(kSleepTime)); + ret = rtKernelLaunch(stub_func_, block_dim_, args_, arg_size_, sm_desc, stream); + } + if (ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Invoke rtKernelLaunch failed. ret = %d, task = %s", ret, this->stub_name_.c_str()); return RT_FAILED; } - GELOGD("Invoke rtKernelLaunch succeeded. task = %s", this->stub_name_.c_str()); + GELOGI("[TASK_INFO] %s", this->stub_name_.c_str()); return SUCCESS; } @@ -88,8 +104,49 @@ Status AiCpuTask::LaunchKernel(rtStream_t stream) { GELOGE(RT_FAILED, "Invoke rtKernelLaunch failed. ret = %d, task = %s", ret, this->op_type_.c_str()); return RT_FAILED; } + GELOGI("[TASK_INFO] %s", this->task_info_.c_str()); + return SUCCESS; +} + +void AiCpuCCTask::SetKernelArgs(void *args, size_t arg_size) { + args_ = args; + arg_size_ = arg_size; + // the blockdim value is defult "1" for rtCpuKernelLaunch + block_dim_ = 1; +} + +void AiCpuCCTask::SetSoName(const std::string &so_name) { so_name_ = so_name; } + +void AiCpuCCTask::SetkernelName(const std::string &kernel_Name) { kernel_name_ = kernel_Name; } + +void AiCpuCCTask::SetIoAddr(void *io_addr) { io_addr_ = io_addr; } + +const void *AiCpuCCTask::GetIOAddr() const { return io_addr_; } + +const void *AiCpuCCTask::GetArgs() const { return args_; } + +size_t AiCpuCCTask::GetArgSize() const { return arg_size_; } + +AiCpuCCTask::~AiCpuCCTask() { + if (args_ != nullptr) { + free(args_); + args_ = nullptr; + } +} - GELOGD("Invoke rtKernelLaunch succeeded. task = %s", this->op_type_.c_str()); +Status AiCpuCCTask::LaunchKernel(rtStream_t stream) { + GELOGI("To invoke rtCpuKernelLaunch. block_dim = %u, so_name is %s, kernel_name is %s", block_dim_, so_name_.data(), + kernel_name_.data()); + // sm_desc is nullptr, because l2 buffer does not support + auto *sm_desc = reinterpret_cast(sm_desc_); + auto ret = + rtCpuKernelLaunch(static_cast(so_name_.data()), static_cast(kernel_name_.data()), + block_dim_, args_, static_cast(arg_size_), sm_desc, stream); + if (ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Invoke rtCpuKernelLaunch failed. ret = %d", ret); + return RT_FAILED; + } + GELOGD("Invoke rtCpuKernelLaunch succeeded"); return SUCCESS; } } // namespace ge diff --git a/src/ge/single_op/task/op_task.h b/src/ge/single_op/task/op_task.h index 168a71b3..fd4cc96f 100644 --- a/src/ge/single_op/task/op_task.h +++ b/src/ge/single_op/task/op_task.h @@ -28,6 +28,7 @@ namespace ge { enum OpTaskType { OP_TASK_TBE = 0, OP_TASK_AICPU, + OP_TASK_AICPUCC, OP_TASK_INVALID, }; @@ -79,6 +80,34 @@ class AiCpuTask : public OpTask { std::string op_type_; void *io_addr_ = nullptr; }; + +class AiCpuCCTask : public OpTask { + public: + AiCpuCCTask() = default; + ~AiCpuCCTask() override; + AiCpuCCTask(const AiCpuCCTask &) = delete; + AiCpuCCTask &operator=(const AiCpuCCTask &) = delete; + + Status LaunchKernel(rtStream_t stream) override; + OpTaskType GetOpTaskType() override { return OP_TASK_AICPUCC; } + const void *GetIOAddr() const; + const void *GetArgs() const; + void SetKernelArgs(void *args, size_t arg_size); + void SetSoName(const std::string &so_name); + void SetkernelName(const std::string &kernel_Name); + void SetIoAddr(void *io_addr); + size_t GetArgSize() const; + + private: + friend class AiCpuCCTaskBuilder; + std::string so_name_; + std::string kernel_name_; + void *args_ = nullptr; + size_t arg_size_ = 0; + uint32_t block_dim_ = 1; + void *sm_desc_ = nullptr; + void *io_addr_ = nullptr; +}; } // namespace ge #endif // GE_SINGLE_OP_TASK_OP_TASK_H_ diff --git a/src/ge/single_op/task/tbe_task_builder.cc b/src/ge/single_op/task/tbe_task_builder.cc index c0f6877f..a422fb96 100644 --- a/src/ge/single_op/task/tbe_task_builder.cc +++ b/src/ge/single_op/task/tbe_task_builder.cc @@ -290,6 +290,8 @@ Status TbeTaskBuilder::BuildTask(TbeOpTask &task, const SingleOpModelParam ¶ if (ret != SUCCESS) { return ret; } + auto task_info = BuildTaskUtils::GetTaskInfo(op_desc_); + GELOGI("[TASK_INFO] %s %s", stub_name_.c_str(), task_info.c_str()); void *stub_func = nullptr; auto rtRet = rtGetFunctionByName(stub_name_.c_str(), &stub_func); diff --git a/tests/depends/cce/src/op_kernel_registry.cc b/tests/depends/cce/src/op_kernel_registry.cc index 9bb32a31..5ccd1391 100644 --- a/tests/depends/cce/src/op_kernel_registry.cc +++ b/tests/depends/cce/src/op_kernel_registry.cc @@ -1,19 +1,3 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - #include "register/op_kernel_registry.h" namespace ge { diff --git a/third_party/fwkacllib/inc/hccl/base.h b/third_party/fwkacllib/inc/hccl/base.h index 74163baf..1d83d7bf 100644 --- a/third_party/fwkacllib/inc/hccl/base.h +++ b/third_party/fwkacllib/inc/hccl/base.h @@ -102,6 +102,11 @@ typedef enum tagHcclDataType { HCCL_DATA_TYPE_RESERVED /**< reserved */ } hcclDataType_t; +constexpr u32 HCCL_UNIQUE_ID_BYTES = 2060; // 2060: unique id length +using hcclUniqueId = struct hcclUniqueIdDef { + char internal[HCCL_UNIQUE_ID_BYTES]; +}; + const u32 HCCL_MAX_SEGMENT_NUM = 8; // The max number of gradient segments. /** @@ -120,6 +125,12 @@ enum GradSplitForceMode { FORCE_RESERVED /**< reserved */ }; +enum OriginalGraphShapeType { + KNOWN_SHAPE, + UNKNOWN_SHAPE, + SHAPE_RESERVED /**< reserved */ +}; + /** * @brief stream handle. */ diff --git a/third_party/fwkacllib/inc/hccl/hcom.h b/third_party/fwkacllib/inc/hccl/hcom.h index a448d411..19bf4fb3 100644 --- a/third_party/fwkacllib/inc/hccl/hcom.h +++ b/third_party/fwkacllib/inc/hccl/hcom.h @@ -22,7 +22,6 @@ #ifndef HCOM_H_ #define HCOM_H_ -#include #include #ifdef __cplusplus @@ -246,8 +245,9 @@ hcclResult_t hcom_receive(const char *tag, void *outputPtr, u64 count, hcclDataT * @param segmentIdx A list identifying the index of end gradient in each segment. * @return hcclResult_t */ -hcclResult_t hcom_get_split_strategy(const char *group, const struct model_feature *feature, - u32 maxSegmentNum, u32 *segmentNum, u32 *segmentIdx, GradSplitForceMode force = FORCE_NONE); +hcclResult_t hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 maxSegmentNum, + u32 *segmentNum, u32 *segmentIdx, GradSplitForceMode force = FORCE_NONE, + OriginalGraphShapeType shapeType = KNOWN_SHAPE); /** * @brief Set the gradient split strategy with in the group, according to gradient index. diff --git a/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_linux.h b/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_linux.h index ce83d143..6ac8f8f6 100644 --- a/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_linux.h +++ b/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_linux.h @@ -344,6 +344,8 @@ extern INT32 mmRealPath(const CHAR *path, CHAR *realPath, INT32 realPathLen); extern INT32 mmDup2(INT32 oldFd, INT32 newFd); +extern INT32 mmDup(INT32 fd); + extern INT32 mmUnlink(const CHAR *filename); extern INT32 mmChmod(const CHAR *filename, INT32 mode); diff --git a/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_win.h b/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_win.h index ef15f371..68a70c27 100644 --- a/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_win.h +++ b/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_win.h @@ -378,6 +378,7 @@ _declspec(dllexport) INT32 mmGetRealPath(CHAR *path, CHAR *realPath); _declspec(dllexport) INT32 mmRealPath(const CHAR *path, CHAR *realPath, INT32 realPathLen); _declspec(dllexport) INT32 mmDup2(INT32 oldFd, INT32 newFd); +_declspec(dllexport) INT32 mmDup(INT32 fd); _declspec(dllexport) INT32 mmUnlink(const CHAR *filename); _declspec(dllexport) INT32 mmChmod(const CHAR *filename, INT32 mode); _declspec(dllexport) INT32 mmFileno(FILE *stream); diff --git a/third_party/fwkacllib/inc/ops/all_ops.h b/third_party/fwkacllib/inc/ops/all_ops.h index 031e955c..c30bf32b 100644 --- a/third_party/fwkacllib/inc/ops/all_ops.h +++ b/third_party/fwkacllib/inc/ops/all_ops.h @@ -31,7 +31,9 @@ #include "functional_ops.h" #include "get_data_ops.h" #include "hcom_ops.h" +#include "hvd_ops.h" #include "image_ops.h" +#include "internal_ops.h" #include "linalg_ops.h" #include "logging_ops.h" #include "lookup_ops.h" diff --git a/third_party/fwkacllib/inc/ops/array_ops.h b/third_party/fwkacllib/inc/ops/array_ops.h index 0d2a05a3..7c6f9b2c 100644 --- a/third_party/fwkacllib/inc/ops/array_ops.h +++ b/third_party/fwkacllib/inc/ops/array_ops.h @@ -1084,6 +1084,43 @@ REG_OP(TransShape) .ATTR(outShape,ListInt ,{}) .OP_END_FACTORY_REG(TransShape); +/** +*@brief Computes the (possibly normalized) Levenshtein Edit Distance. + +*@par Inputs: +*@li hypothesis_indices: The indices of the hypothesis list SparseTensor.\n +This is an N x R int64 matrix. +*@li hypothesis_shape: The values of the hypothesis list SparseTensor.\n +This is an N-length vector. +*@li hypothesis_shape: The shape of the hypothesis list SparseTensor.\n +This is an R-length vector. +*@li truth_indices: The indices of the truth list SparseTensor.\n +This is an M x R int64 matrix. +*@li truth_shape: The values of the truth list SparseTensor.\n +This is an M-length vector. +*@li truth_shape: The shape of the truth list SparseTensor.\n +This is an R-length vector + +*@par Attributes: +*@li normalize: boolean (if true, edit distances are normalized by length of truth). + +*@par Outputs: +*@li output: A dense float tensor with rank R - 1. + +*@par Third-party framework compatibility +* Compatible with TensorFlow EditDistance operator. +*/ +REG_OP(EditDistance) + .INPUT(hypothesis_indices, TensorType({DT_INT64})) + .INPUT(hypothesis_values, TensorType::BasicType()) + .INPUT(hypothesis_shape, TensorType({DT_INT64})) + .INPUT(truth_indices, TensorType({DT_INT64})) + .INPUT(truth_values, TensorType::BasicType()) + .INPUT(truth_shape, TensorType({DT_INT64})) + .ATTR(normalize, Bool, true) + .OUTPUT(output, TensorType({DT_FLOAT})) + .OP_END_FACTORY_REG(EditDistance) + } // namespace ge #endif // GE_OP_ARRAY_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/ctc_ops.h b/third_party/fwkacllib/inc/ops/ctc_ops.h index 00485a14..74b797f3 100644 --- a/third_party/fwkacllib/inc/ops/ctc_ops.h +++ b/third_party/fwkacllib/inc/ops/ctc_ops.h @@ -50,7 +50,6 @@ If not specified, defaults to true *@par Third-party framework compatibility * Compatible with TensorFlow CTCLoss operator. */ - REG_OP(CTCLoss) .INPUT(inputs, TensorType({DT_FLOAT, DT_DOUBLE})) .INPUT(labels_indices, TensorType({DT_INT64})) @@ -63,6 +62,77 @@ REG_OP(CTCLoss) .ATTR(ignore_longer_outputs_than_inputs, Bool, false) .OP_END_FACTORY_REG(CTCLoss) +/** +*@brief Performs greedy decoding on the logits given in inputs. + +*@par Inputs: +*@li inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. +*@li sequence_length: A vector containing sequence lengths, size `(batch_size)`. + +*@par Attributes: +*@li merge_repeated: If True, merge repeated classes in output. + +*@par Outputs: +*@li decoded_indices: Indices matrix, size `(total_decoded_outputs x 2)`,\n +of a `SparseTensor`. The rows store: [batch, time]. +*@li decoded_values: Values vector, size: `(total_decoded_outputs)`,\n +of a `SparseTensor`. The vector stores the decoded classes. +*@li decoded_shape: Shape vector, size `(2)`, of the decoded SparseTensor.\n +Values are: `[batch_size, max_decoded_length]`. +*@li log_probability: Matrix, size `(batch_size x 1)`, containing sequence\n +log-probabilities. + +*@par Third-party framework compatibility +* Compatible with TensorFlow CTCGreedyDecoder operator. +*/ +REG_OP(CTCGreedyDecoder) + .INPUT(inputs, TensorType({DT_FLOAT, DT_DOUBLE})) + .INPUT(sequence_length, TensorType({DT_INT32})) + .ATTR(merge_repeated, Bool, false) + .OUTPUT(decoded_indices, TensorType({DT_INT64})) + .OUTPUT(decoded_values, TensorType({DT_INT64})) + .OUTPUT(decoded_shape, TensorType({DT_INT64})) + .OUTPUT(log_probability, TensorType({DT_FLOAT, DT_DOUBLE})) + .OP_END_FACTORY_REG(CTCGreedyDecoder) + +/** +*@brief Performs beam search decoding on the logits given in input. + +*@par Inputs: +*@li inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. +*@li sequence_length: A vector containing sequence lengths, size `(batch_size)`. + +*@par Attributes: +*@li merge_repeated: If True, merge repeated classes in output. + +*@par Outputs: +*@li decoded_indices: A list (length: top_paths) of indices matrices. Matrix j,\n +size `(total_decoded_outputs[j] x 2)`, has indices of a\n +`SparseTensor`. The rows store: [batch, time]. +*@li decoded_values: A list (length: top_paths) of values vectors. Vector j,\n +size `(length total_decoded_outputs[j])`, has the values of a\n +`SparseTensor`. The vector stores the decoded classes for beam j. +*@li decoded_shape: A list (length: top_paths) of shape vector. Vector j,\n +size `(2)`, stores the shape of the decoded `SparseTensor[j]`.\n +Its values are: `[batch_size, max_decoded_length[j]]`. +*@li log_probability: A matrix, shaped: `(batch_size x top_paths)`. The\n +sequence log-probabilities. + +*@par Third-party framework compatibility +* Compatible with TensorFlow CTCBeamSearchDecoder operator. +*/ +REG_OP(CTCBeamSearchDecoder) + .INPUT(inputs, TensorType({DT_FLOAT, DT_DOUBLE})) + .INPUT(sequence_length, TensorType({DT_INT32})) + .REQUIRED_ATTR(beam_width, Int) + .REQUIRED_ATTR(top_paths, Int) + .ATTR(merge_repeated, Bool, true) + .DYNAMIC_OUTPUT(decoded_indices, TensorType({DT_INT64})) + .DYNAMIC_OUTPUT(decoded_values, TensorType({DT_INT64})) + .DYNAMIC_OUTPUT(decoded_shape, TensorType({DT_INT64})) + .OUTPUT(log_probability, TensorType({DT_FLOAT, DT_DOUBLE})) + .OP_END_FACTORY_REG(CTCBeamSearchDecoder) + } // namespace ge #endif //GE_OP_CTC_OPS_H \ No newline at end of file diff --git a/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h b/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h index 04e1cea3..378eee38 100644 --- a/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h +++ b/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h @@ -483,9 +483,9 @@ REG_OP(Equal) *x: A Tensor. Must be one of the following types: float16, float32, double, complex64, complex128. *@par Attributes: -*@li base: An optional attribute of type float32, specifying the base gamma. Defaults to "-1". -*@li scale: An optional attribute of type float32, specifying the scale alpha. Defaults to "1". -*@li shift: An optional attribute of type float32, specifying the shift beta. Defaults to "0". +*@li base: An optional attribute of type float32, specifying the base gamma. Defaults to "-1.0". +*@li scale: An optional attribute of type float32, specifying the scale alpha. Defaults to "1.0". +*@li shift: An optional attribute of type float32, specifying the shift beta. Defaults to "0.0". *@par Outputs: *y: A Tensor of the same type as "x". @@ -1016,17 +1016,17 @@ REG_OP(BesselI1e) * y = log_base(shift + scale * x), with "base" > 0. * @par Inputs: -* @li x: A Tensor of type UnaryDataType. +* @li x: A Tensor of type complex64, complex128, float16, float32 or double. * @par Attributes: -* @li base: An optional float32, specifying the base "e". Defaults to "-1" +* @li base: An optional float32, specifying the base "e". Defaults to "-1.0" * @li scale: An optional float32, specifying the scale of input "x". Defaults -* to "1" -* @li shift: An optional float32, specifying the shift. Defaults to "0" +* to "1.0" +* @li shift: An optional float32, specifying the shift. Defaults to "0.0" * @par Outputs: -* y: A Tensor of type UnaryDataType. +* y: A Tensor has same type as "x". * @attention Constraints: * @li "base" is supposed to be greater than 0. Retaining the default @@ -2262,7 +2262,7 @@ REG_OP(ArgMinD) *dtype: The output type, either "int32" or "int64". Defaults to "int64". *@par Outputs: -*y: A multi-dimensional Tensor of type int32, specifying the index with the largest value. The dimension is one less than that of "x". +*y: A multi-dimensional Tensor of type int32 or int64, specifying the index with the largest value. The dimension is one less than that of "x". *@attention Constraints: *@li x: If there are multiple maximum values, the index of the first maximum value is used. @@ -2398,8 +2398,8 @@ REG_OP(ArgMinWithValue) *y: A Tensor. Has the same type and format as "x". *@par Attributes: -*@li N: A required attribute. the number of input x, max size is 32. -*@li model: An optional attribute. Defaults to "1". +*@li N: A required attribute. the number of input x, max size is 32. Type is int. +*@li model: An optional attribute. Type is int. Defaults to "1". * "0": product, "1": sum, "2": max. *@li coeff: A required attribute. Must met all of following rules: * size of "coeff" must be equal to len("x") or is null. @@ -2693,6 +2693,86 @@ REG_OP(AdamApplyOne) .OP_END_FACTORY_REG(AdamApplyOne) /** +*@brief A fusion operator for bert lamb. + +*@par Inputs: +*Eleven inputs, including: +* @li input0: A Tensor. Must be one of the following types: float16, float32. +* @li input1: A Tensor. Must be one of the following types: float16, float32. +* @li input2: A Tensor. Must be one of the following types: float16, float32. +* @li input3: A Tensor. Must be one of the following types: float16, float32. +* @li input4: A Tensor. Must be one of the following types: float16, float32. +* @li mul0_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul1_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul2_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul3_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul4_x: A Tensor. Must be one of the following types: float16, float32. +* @li add2_y: A Tensor. Must be one of the following types: float16, float32. + +*@par Outputs: +*Three outputs, including: +* @li output0: A Tensor. Must be one of the following types: float16, float32. +* @li output1: A Tensor. Must be one of the following types: float16, float32. +* @li output2: A Tensor. Must be one of the following types: float16, float32. + +*/ +REG_OP(AdamApplyOneWithDecayAssign) + .INPUT(input0, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input2, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input3, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input4, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul0_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul1_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul2_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul3_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul4_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(add2_y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output0, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output1, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output2, TensorType({DT_FLOAT16,DT_FLOAT})) + .OP_END_FACTORY_REG(AdamApplyOneWithDecayAssign) + +/** +*@brief A fusion operator for bert lamb. + +*@par Inputs: +*Ten inputs, including: +* @li input0: A Tensor. Must be one of the following types: float16, float32. +* @li input1: A Tensor. Must be one of the following types: float16, float32. +* @li input2: A Tensor. Must be one of the following types: float16, float32. +* @li input3: A Tensor. Must be one of the following types: float16, float32. +* @li input4: A Tensor. Must be one of the following types: float16, float32. +* @li mul0_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul1_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul2_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul3_x: A Tensor. Must be one of the following types: float16, float32. +* @li add2_y: A Tensor. Must be one of the following types: float16, float32. + +*@par Outputs: +*Three outputs, including: +* @li output0: A Tensor. Must be one of the following types: float16, float32. +* @li output1: A Tensor. Must be one of the following types: float16, float32. +* @li output2: A Tensor. Must be one of the following types: float16, float32. + +*/ +REG_OP(AdamApplyOneAssign) + .INPUT(input0, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input2, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input3, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input4, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul0_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul1_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul2_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul3_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(add2_y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output0, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output1, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output2, TensorType({DT_FLOAT16,DT_FLOAT})) + .OP_END_FACTORY_REG(AdamApplyOneAssign) + +/** *@brief Confuse select, maximum, greater and sqrt. *@par Inputs: @@ -3042,6 +3122,22 @@ REG_OP(KLDiv) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .OP_END_FACTORY_REG(KLDiv) +/** +*@brief copy data from x to y.. + +*@par Inputs: +*One inputs, including: +* @li x: A Tensor. Must be one of the following types: float16, float32, int8, uint8, int32, bool. + +*@par Outputs: +*y: A Tensor. Has the same type as "x". + +*@par Third-party framework compatibility +*/ +REG_OP(TensorMove) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8, DT_BOOL})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8, DT_BOOL})) + .OP_END_FACTORY_REG(TensorMove) } // namespace ge diff --git a/third_party/fwkacllib/inc/ops/hvd_ops.h b/third_party/fwkacllib/inc/ops/hvd_ops.h new file mode 100644 index 00000000..09748b8e --- /dev/null +++ b/third_party/fwkacllib/inc/ops/hvd_ops.h @@ -0,0 +1,77 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_OP_HVD_OPS_H_ +#define GE_OP_HVD_OPS_H_ + +#include "graph/operator_reg.h" + +namespace ge { +/** + * @brief Outputs a tensor gathering all input tensors. + * @par Inputs: + * x: A tensor. Must be one of the following types: uint8, int8, uint16, int16, int32, + * int64, float16, bool. + * @par Attributes: + * @li rank_size: A required integer identifying the number of ranks + * participating in the op. + * @par Outputs: + * y: A Tensor. Has the same type as "x". + */ +REG_OP(HorovodAllgather) + // GE not support float64 currently + .INPUT(x, TensorType({DT_UINT8, DT_INT8, DT_UINT16, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_BOOL})) + .OUTPUT(y, TensorType({DT_UINT8, DT_INT8, DT_UINT16, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_BOOL})) + // add rank_size attr + .REQUIRED_ATTR(rank_size, Int) + .OP_END_FACTORY_REG(HorovodAllgather) + +/** + * @brief Outputs a tensor containing the reduction across all input tensors + * passed to op. + * @par Inputs: + * x: A tensor. Must be one of the following types: int32, int64, float16, float32 + * @par Attributes: + * @li reduce_op: A required int identifying the reduction operation to + * perform.The supported operation are: "sum", "max", "min", "prod". + * @par Outputs: + * y: A Tensor. Has the same type as "x". + */ +REG_OP(HorovodAllreduce) + .INPUT(x, TensorType({DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT})) + .OUTPUT(y, TensorType({DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT})) + .REQUIRED_ATTR(reduce_op, Int) + .OP_END_FACTORY_REG(HorovodAllreduce) + +/** + * @brief Broadcasts the input tensor in root rank to all ranks. + * @par Inputs: + * x: A list of dynamic input tensor. Must be one of the following types: + * int8, int32, float16, float32. + * @par Attributes: + * @li root_rank: A required integer identifying the root rank in the op + * input of this rank will be broadcast to other ranks. + * @par Outputs: + * y: A list of dynamic output tensor. Has the same type and length as "x". + */ +REG_OP(HorovodBroadcast) + .INPUT(x, TensorType({DT_UINT8, DT_INT8, DT_UINT16, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_BOOL})) + .OUTPUT(y, TensorType({DT_UINT8, DT_INT8, DT_UINT16, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_BOOL})) + .REQUIRED_ATTR(root_rank, Int) + .OP_END_FACTORY_REG(HorovodBroadcast) + +} // namespace ge +#endif // GE_OP_HVD_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/internal_ops.h b/third_party/fwkacllib/inc/ops/internal_ops.h new file mode 100644 index 00000000..e3caa45f --- /dev/null +++ b/third_party/fwkacllib/inc/ops/internal_ops.h @@ -0,0 +1,48 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_OP_INTERNAL_OPS_H_ +#define GE_OP_INTERNAL_OPS_H_ + +#include "graph/operator_reg.h" +#include "graph/operator.h" + +namespace ge { + +/** +*@brief aicpu assit help op for auxiliary matrix generation. + +*@par Inputs: +*The input is dynamic for attribute func_name \n + +*@par Attributes: +*@li func_name:An required param, for example "topkv2". \n + +*@par Outputs: +*The output is dynamic for attribute func_name. +*/ + +REG_OP(AssistHelp) + .DYNAMIC_INPUT(x, TensorType({ DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, + DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE })) + .DYNAMIC_OUTPUT(y, TensorType({ DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, + DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + . REQUIRED_ATTR (func_name, String) + . OP_END_FACTORY_REG(AssistHelp) + +} // namespace ge + +#endif // GE_OP_INTERNAL_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/math_ops.h b/third_party/fwkacllib/inc/ops/math_ops.h index 5d34804c..b0c35c28 100644 --- a/third_party/fwkacllib/inc/ops/math_ops.h +++ b/third_party/fwkacllib/inc/ops/math_ops.h @@ -29,9 +29,9 @@ namespace ge { * x: A Tensor of type float16 or float32. *@par Attributes: -*@li power: Optional. Defaults to 1.0. -*@li scale: Optional. Defaults to 1.0. -*@li shift: Optional. Defaults to 0.0. +*@li power: Optional. Must be one of the following types: float32. Defaults to 1.0. +*@li scale: Optional. Must be one of the following types: float32. Defaults to 1.0. +*@li shift: Optional. Must be one of the following types: float32. Defaults to 0.0. *@par Outputs: * y: A Tensor. Has the same type and shape as "x". diff --git a/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h b/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h index 29cf0df3..7cfd762f 100644 --- a/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h +++ b/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h @@ -699,6 +699,45 @@ REG_OP(FullyConnection) .OP_END_FACTORY_REG(FullyConnection) /** +*@brief Also known as a "fully-connected-compress" layer, computes an inner product with a set of learned weights, and (optionally) adds biases. + +*@par Inputs: +* Four inputs, including: +*@li x: A Tensor of type uint8, int8. +*@li w: A weight matrix of type int8, int8. +*@li w: A compress index matrix of type int8, int8. +*@li b: A Tensor of type float16, int32, int32. +*@li offset_w: A Tensor of type int8.i + +*@par Attributes: +*@li num_output: Reserved. +*@li transpose: A bool, specifying whether to transpose, either "true" or "false". Defaults to "false". +*@li axis: Reserved. +*@li offset_x: Reserved. + +*@par Outputs: +*y: The result tensor of type int32. + +*@par Third-party framework compatibility +* Compatible with the Caffe operator InnerProduct. + +*@par Quantization supported or not +* Yes +*/ +REG_OP(FullyConnectionCompress) + .INPUT(x, TensorType({DT_UINT8, DT_INT8})) + .INPUT(w, TensorType({DT_INT8})) + .INPUT(comress_index, TensorType({DT_INT8})) + .OPTIONAL_INPUT(b, TensorType({DT_INT32})) + .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) + .OUTPUT(y, TensorType({DT_INT32})) + .REQUIRED_ATTR(num_output, Int) + .ATTR(transpose, Bool, false) + .ATTR(axis, Int, 1) + .ATTR(offset_x, Int, 0) + .OP_END_FACTORY_REG(FullyConnectionCompress) + +/** *@brief Computes the confusion matrix from predictions and labels. *@par Inputs: diff --git a/third_party/fwkacllib/inc/ops/nn_batch_norm_ops.h b/third_party/fwkacllib/inc/ops/nn_batch_norm_ops.h index e8eb4769..39aaa993 100644 --- a/third_party/fwkacllib/inc/ops/nn_batch_norm_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_batch_norm_ops.h @@ -33,12 +33,12 @@ namespace ge { * @li variance: A Tensor. Must be one of the following types: float32. *@par Attributes: -* @li mode: A Tensor. Must be one of the following types: int. -* @li epsilon: A Tensor. Must be one of the following types: float32. -* @li momentum: A Tensor. Must be one of the following types: float32. -* @li is_training: A Tensor. Must be one of the following types: bool. -* @li is_training_fusion: A Tensor. Must be one of the following types: bool. -* @li moving_average_fraction: A Tensor. Must be one of the following types: float32. +* @li mode: A Tensor. Must be one of the following types: int. defaults: 1. +* @li epsilon: A Tensor. Must be one of the following types: float32. Defaults to 0.000001. +* @li momentum: A Tensor. Must be one of the following types: float32. Defaults to 0.9. +* @li is_training: A Tensor. Must be one of the following types: bool. Defaults to true. +* @li is_training_fusion: A Tensor. Must be one of the following types: bool. Defaults to true. +* @li moving_average_fraction: A Tensor. Must be one of the following types: float32. Defaults to 0.00300002098. *@par Outputs: *Three outputs, including: @@ -83,8 +83,8 @@ REG_OP(FusedBatchNorm) * @li save_inv_variance1: A Tensor. Must be one of the following types: float32. *@par Attributes: -* @li epsilon: A Tensor. Must be one of the following types: float32. -* @li momentum: A Tensor. Must be one of the following types: float32. +* @li epsilon: A Tensor. Must be one of the following types: float32. Defaults to 0.0. +* @li momentum: A Tensor. Must be one of the following types: float32. Defaults to 0.0. *@par Outputs: *Three outputs, including: @@ -361,14 +361,14 @@ REG_OP(BatchNormGradExt2) *@par Inputs: *@li x: A 4D or 5D Tensor of type float16 or float32, with format NHWC or NCHW for 4D or NC1HWC0 for 5D. *@li mean: A Tensor of type float32 or float16. Must be 1D if input "x" Specifies the mean used for inference. -*@li variance: A Tensor of type float32 or float16. Must be 1D if input "x" Specifies the variance used for inference. -*@li momentum: A Tensor of type float32 or float16, represents the mean and the variance's scale factor +*@li variance: A Tensor of type float32 or float16 . Must be 1D if input "x" Specifies the variance used for inference. +*@li momentum: A Tensor,represents the mean and the variance's scale factor *@li scale: An optional tensor of type float16 or float32, no use *@li offset: An optional tensor of type float16 or float32, no use *@par Attributes: *@li epsilon: An optional float32, specifying the small value added to variance to avoid dividing by zero. Defaults to "0.00001". *@li use_global_stats: mean inference mode , only can be "True". -*@li mode: An optional attr, not use +*@li mode: An optional input, not use *@par Outputs:\n *@li y: A 4D or 5D Tensor of type float16 or float32 for the normalized "x" */ @@ -391,11 +391,11 @@ REG_OP(BNInference) *@li mean: A Tensor of type float32 or float16. Must be 1D if input "x" Specifies the mean used for inference. *@li variance: A Tensor of type float32 or float16 . Must be 1D if input "x" Specifies the variance used for inference. -*@li momentum: A Tensor of type float32 or float16, the mean and the variance's Scale factor +*@li momentum: An optional float, mean and variance's Scale factor *@par Attributes: *@li epsilon: An optional float32, specifying the small value added to variance to avoid dividing by zero. Defaults to "0.00001". *@li use_global_stats: mean inference mode , only can be "True". -*@li mode: An optional inpout, not use +*@li mode: An optional attr, not use *@par Outputs: *@li alpha: A Tensor of type float16 or float32 for the cpu calculate mean *@li beta: A Tensor of type float16 or float32 for the cpu calculate variance @@ -418,8 +418,8 @@ REG_OP(BnHost) *@par Inputs: *@li x: A 4D or 5D Tensor of type float16 or float32, with format NHWC or NCHW for 4D or NC1HWC0 for 5D. -*@li mean: A Tensor of type float32 or float16. Must be 1D if input "x" Specifies the mean used for inference. -*@li variance: A Tensor of type float32 or float16 . Must be 1D if input "x" Specifies the variance used for inference. +*@li mean: A Tensor of type float32 or float16. Must be 1D if input "x" Specifies the mean used for inference. +*@li variance: A Tensor of type float32 or float16 . Must be 1D if input "x" Specifies the variance used for inference. *@li scale: An optional tensor of type float16 or float32, no use *@li offset: An optional tensor of type float16 or float32, no use *@par Attributes: diff --git a/third_party/fwkacllib/inc/ops/nn_calculation_ops.h b/third_party/fwkacllib/inc/ops/nn_calculation_ops.h index 5d4e6bff..5818e14b 100644 --- a/third_party/fwkacllib/inc/ops/nn_calculation_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_calculation_ops.h @@ -143,31 +143,29 @@ REG_OP(DepthwiseConv2DBackpropFilterD) * @par Inputs: * Three inputs include: \n * @li input_size: 4D shape of input tensor [N, C, H, W] or [N, H, W, C], -* support int32 -* @li filter: 4D filter tensor with shape of [H, W, C, K], support float16, -* float32, double +* support int32, int64 +* @li filter: 4D filter tensor with shape of [H, W, C, K], support float16. * @li out_backprop: 4D tensor with shape [N, C, H, W] or [N, H, W, C]. -* Must be one of the following types: float16, float32, double. +* Must be one of the following types: float16. * @par Attributes: -* @li strides: A required list or tuple. The stride of the sliding window for +* @li strides: A required list or tuple of int32. The stride of the sliding window for * height and width of input "x" of the convolution. * Must be with shape [1, 1, stride_height, stride_width] or [1, stride_height, * stride_width, 1]. -* @li dilations: An optional list or tuple. The dilation factor for each -* dimension of input "x". +* @li dilations: An optional list or tuple of int32. The dilation factor for each +* dimension of input "x". Defaults to "[1, 1, 1, 1]". * If set to k > 1, there will be k-1 skipped cells between each filter element * on that dimension. Must be with shape [1, 1, dilation_height, dilation_width] * or [1, dilation_height, dilation_width, 1]. -* @li pads: A required list or tuple. Padding added to each dimension of the +* @li pads: A required list or tuple of int32. Padding added to each dimension of the * input. * @li data_format: An optional string. Input data format, either "NHWC" or -* "NCHW". +* "NCHW". Defaults to "NHWC". * @par Outputs: * input_grad: Gradient of the deep convolution relative to the input with shape -* [N, C, H, W] or [N, H, W, C] Must be one of the following types: float16, -* float32, double. +* [N, C, H, W] or [N, H, W, C] Must be one of the following types: float16. * @attention Constraints:\n * The feature map is 4D with shape [N, C, Hi, Wi] or [N, Hi, Wi, C], but @@ -259,8 +257,8 @@ REG_OP(DepthwiseConv2DBackpropInputD) *@par Inputs: *Two required inputs and two optional inputs, including: \n -* @li x: A 4D tensor of type float16, with shape [N, C, H, W] or [N, H, W, C] -* @li filter: A 4D tensor of type float16, with shape [H, W, C, K] +* @li x: A 4D tensor of type float16 or int8, with shape [N, C, H, W] or [N, H, W, C] +* @li filter: A 4D tensor of type float16 or int8, with shape [H, W, C, K] * @li bias: An optional tensor of type float16 or int32 * @li offset_w: An optional float16 or int8, used for quantized inference @@ -273,8 +271,8 @@ REG_OP(DepthwiseConv2DBackpropInputD) * dimension of input "x". * If set to k > 1, there will be k-1 skipped cells between each filter element * on that dimension. Must be with shape [1, 1, dilation_height, dilation_width] -* or [1, dilation_height, dilation_width, 1]. -* @li pads: A required list or tuple. Padding added to each dimension of the +* or [1, dilation_height, dilation_width, 1]. Defaults to "[1, 1, 1, 1]". +* @li pads: A required list or tuple of int32. Padding added to each dimension of the * input. * @li data_format: An optional string. Input data format, either "NHWC" or * "NCHW". Defaults to "NHWC". @@ -282,7 +280,7 @@ REG_OP(DepthwiseConv2DBackpropInputD) * Defaults to 0. * @par Outputs: -* y: 4D tensor of type float16, with shape [N, C, H, W] or [N, H, W, C] +* y: 4D tensor of type float16 or int32, with shape [N, C, H, W] or [N, H, W, C] * @attention Constraints:\n * The feature map is 4D with shape [N, C, Hi, Wi] or [N, Hi, Wi, C], but @@ -462,24 +460,24 @@ REG_OP(Conv2DBackpropInputD) * @li x: A Tensor. Must have the same type as "filter". 4D with shape * [batch, out_channels, out_height, out_width]. Gradients with respect * to the output of the convolution. - * @li filter: A Tensor of type float16. + * @li filter: A Tensor of type float16, float32, double or int8. * 4D with shape [out_channels, in_channel, filter_height, filter_width].\n * Two optional inputs: - * @li bias: An optional tensor of type float16 - * @li offset_w: An optional 1D tensor for quantized deconvolution. Reserved.\n + * @li bias: An optional tensor of type float16, float32, int32 or int64. + * @li offset_w: An optional 1D tensor for quantized deconvolution. Type is int8. Reserved.\n *@par Attributes: * Six attributes: * @li strides: A tuple or list of 2 integers. The stride of the sliding window - * for H/W dimension. + * for H/W dimension. Defaults to [1, 1, 1, 1]. * @li pads: A tuple or list of 4 integers. The [top, bottom, left, right] - * padding on the feature map + * padding on the feature map. Defaults to [0, 0, 0, 0]. * @li dilations: A tuple or list of 4 integers. The dilation factor for each * dimension of input. Must be [1, 1, 1, 1]. * @li groups: Number of blocked connections from input channels to - * output channels. - * @li data_format: An optional string from: "NCHW". Defaults to "NCHW".\n + output channels. Defaults to "1". + * @li data_format: An optional string from: "NCHW". Defaults to "NCHW". \n Specify the data format of the input and output data. - * @li offset_x: An optional integer for quantized deconvolution. + * @li offset_x: An optional integer for quantized deconvolution. Defaults to "0". *@par Outputs: * y: A Tensor. Has the same type as "filter". 4D tensor with shape * [batch, channels, height, width]. @@ -577,19 +575,19 @@ REG_OP(Conv2DBackpropFilterD) * * The input and output tensor attributes are listed as follows: * @verbatim - Tensor | x | filter | bias | offset_w | y + |Tensor | x | filter | bias | offset_w | y -----------|---------|---------|---------|----------|-------- - Data Type | float16 | float16 | float16 | _ | float16 - |---------|---------|---------|----------|-------- - | float32 | float32 | float32 | _ | float32 - |---------|---------|---------|----------|-------- - | float64 | float64 | float64 | _ | float64 - |---------|---------|---------|----------|-------- - | int8 | int8 | int32 | int8 | int32 + |Data Type | float16 | float16 | float16 | _ | float16 + | |---------|---------|---------|----------|-------- + | | float32 | float32 | float32 | _ | float32 + | |---------|---------|---------|----------|-------- + | | float64 | float64 | float64 | _ | float64 + | |---------|---------|---------|----------|-------- + | | int8 | int8 | int32 | int8 | int32 -----------|---------|---------|---------|----------|-------- - Format | NCHW | NCHW | ND | ND | NCHW - | NHWC | NHWC | | | NHWC - | | HWCN | | | + |Format | NCHW | NCHW | ND | ND | NCHW + | | NHWC | NHWC | | | NHWC + | | | HWCN | | | @endverbatim * It should be noted that the data types must correspond to each other, but the * format does not need to. @@ -604,10 +602,10 @@ REG_OP(Conv2DBackpropFilterD) * for dilated convolution. Has the same dimension order and value as "strides". * @li groups: Number of blocked connections from input channels to output * channels. Input channels and output channels must both be divisible by -* "groups". Must be set to 1. -* @li offset_x: An optional integer for quantized convolution. +* "groups".Type is int32. Must be set to 1. +* @li offset_x: An optional integer for quantized convolution. Type is int32. Defaults to "0". * @li data_format: An optional string from: "NHWC", "NCHW". Specifying the -* data format of the input and output images. Reserved. +* data format of the input and output images. Type is string. Defaults to "NHWC". Reserved. *@par Outputs: * @li y: A 4D Tensor of output images. @@ -615,23 +613,23 @@ REG_OP(Conv2DBackpropFilterD) *@attention * @li The parameter scope is listed as follows: * @verbatim - Name | Field | Scope + |Name | Field | Scope ------------------|--------------|---------- - Input Image Size | H dimension | [1, 4096] - | W dimension | [1, 4096] + |Input Image Size | H dimension | [1, 4096] + | | W dimension | [1, 4096] ------------------|--------------|---------- - Filter Size | H dimension | [1, 255] - | W dimension | [1, 255] + |Filter Size | H dimension | [1, 255] + | | W dimension | [1, 255] ------------------|--------------|---------- - Stride Size | H dimension | [1, 63] - | W dimension | [1, 63] + |Stride Size | H dimension | [1, 63] + | | W dimension | [1, 63] ------------------|--------------|---------- - Padding Size | top side | [0, 255] - | bottom side | [0, 255] - | left side | [0, 255] - | right side | [0, 255] + |Padding Size | top side | [0, 255] + | | bottom side | [0, 255] + | | left side | [0, 255] + | | right side | [0, 255] ------------------|--------------|---------- - Dilation Size | H dimension | [1, 255] + |Dilation Size | H dimension | [1, 255] | W dimension | [1, 255] @endverbatim @@ -712,8 +710,8 @@ REG_OP(Conv3D) .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) - .ATTR(strides, ListInt, {1, 1, 1, 1, 1}) - .ATTR(pads, ListInt, {0, 0, 0, 0, 0, 0}) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) .ATTR(data_format, String, "NDHWC") .ATTR(dilations, ListInt, {1, 1, 1, 1, 1}) .OP_END_FACTORY_REG(Conv3D) @@ -744,7 +742,7 @@ REG_OP(Conv3DBackpropInput) .INPUT(grads, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .REQUIRED_ATTR(strides, ListInt) - .ATTR(pads, ListInt, {0, 0, 0, 0, 0, 0}) + .REQUIRED_ATTR(pads, ListInt) .ATTR(data_format, String, "NDHWC") .ATTR(dilations, ListInt, {1, 1, 1, 1, 1}) .OP_END_FACTORY_REG(Conv3DBackpropInput) @@ -773,7 +771,7 @@ REG_OP(Conv3DBackpropInputD) .OUTPUT(y, TensorType({DT_FLOAT16})) .REQUIRED_ATTR(input_size, ListInt) .REQUIRED_ATTR(strides, ListInt) - .ATTR(pads, ListInt, {0, 0, 0, 0, 0, 0}) + .REQUIRED_ATTR(pads, ListInt) .ATTR(data_format, String, "NDHWC") .ATTR(dilations, ListInt, {1, 1, 1, 1, 1}) .OP_END_FACTORY_REG(Conv3DBackpropInputD) diff --git a/third_party/fwkacllib/inc/ops/nn_detect_ops.h b/third_party/fwkacllib/inc/ops/nn_detect_ops.h index 5dca8a9d..ceb92f7a 100644 --- a/third_party/fwkacllib/inc/ops/nn_detect_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_detect_ops.h @@ -187,14 +187,15 @@ REG_OP(ROIAlignGrad) *@li features: A 5HD Tensor of type float32 or float16. *@li rois: ROI position. A 2D Tensor of float32 or float16 with shape (N, 5). "N" indicates the number of ROIs, the value "5" indicates the indexes of images where the ROIs are located, * "x0", "y0", "x1", and "y1". -*@li rois_n: An optional input, specifying the number of valid ROIs. This parameter is reserved. +*@li rois_n: An optional input of type int32, specifying the number of valid ROIs. This parameter is reserved. *@par Attributes: -*@li spatial_scale: A required attribute of type float, specifying the scaling ratio of "features" to the original image. -*@li pooled_height: A required attribute of type int, specifying the H dimension. -*@li pooled_width: A required attribute of type int, specifying the W dimension. -*@li sample_num: An optional attribute of type int, specifying the horizontal and vertical sampling frequency of each output. If this attribute is set to "0", +*@li spatial_scale: A required attribute of type float32, specifying the scaling ratio of "features" to the original image. +*@li pooled_height: A required attribute of type int32, specifying the H dimension. +*@li pooled_width: A required attribute of type int32, specifying the W dimension. +*@li sample_num: An optional attribute of type int32, specifying the horizontal and vertical sampling frequency of each output. If this attribute is set to "0", * the sampling frequency is equal to the rounded up value of "rois", which is a floating point number. Defaults to "2". +*@li roi_end_mode: An optional attribute of type int32. Defaults to "1". *@par Outputs: * output: Outputs the feature sample of each ROI position. The format is 5HD Tensor of type float32 or float16. The axis N is the number of input ROIs. Axes H, W, and C are consistent @@ -362,15 +363,15 @@ REG_OP(PSROIPooling) *@li im_info: An ND tensor of type float16 or float32, specifying the Image information. *@li actual_rois_num: An optional NCHW tensor of type int32, specifying the number of valid boxes per batch. *@par Attributes: -*@li batch_rois: An optional int32, specifying the number of images to be predicted. +*@li batch_rois: An optional int32, specifying the number of images to be predicted. Defaults to "1". *@li num_classes: An required int32, specifying the number of classes to be predicted. The value must be greater than 0. *@li score_threshold: An required float32, specifying the threshold for box filtering. The value range is [0.0, 1.0]. *@li iou_threshold: An required float32, specifying the confidence threshold for box filtering, which is the output "obj" of operator Region. The value range is (0.0, 1.0). *@par Outputs: -*@li box: An NCHW tensor of type float16 or float32, describing the information of each output box, including the coordinates, class, and confidence. -Proposal of actual output, with output shape [batch, numBoxes,8], 8 means [x1, y1, x2, y2, score, label, batchID, NULL], the maximum value of numBoxes is 1024. +*@li box: A tensor of type float16 or float32 for proposal of actual output, with output shape [batch, numBoxes,8]. +* 8 means [x1, y1, x2, y2, score, label, batchID, NULL], the maximum value of numBoxes is 1024. That is, take min (the maximum number of input boxes, 1024) -*@li actual_bbox_num: An NCHW tensor of type int32 With shape [bacth, num_classes], specifying the number of output boxes. +*@li actual_bbox_num: A tensor of type int32 With shape [bacth, num_classes], specifying the number of output boxes. *@attention Constraints:\n *@li totalnum < max_rois_num * batch_rois. @@ -414,9 +415,9 @@ REG_OP(FSRDetectionOutput) *@li confidence_threshold: An optional float32, specify the topk filter threshold. Only consider detections with confidence greater than the threshold *@li kernel_name: An optional string, specifying the operator name. Defaults to "ssd_detection_output". *@par Outputs: -*@li out_boxnum: An NCHW tensor of type int32, specifying the number of output boxes. -*@li y: An NCHW tensor of type float16 or float32 with shape [batch,keep_top_k, 8], describing the information of each output box, including the coordinates, -* class, and confidence. In output shape, 8 means (batchID, label(classID), score (class probability), xmin, ymin, xmax, ymax, null) +*@li out_boxnum: A tensor of type int32, specifying the number of output boxes. +*@li y: A tensor of type float16 or float32 with shape [batch,keep_top_k, 8], describing the information of each output box. +* In output shape, 8 means (batchID, label(classID), score (class probability), xmin, ymin, xmax, ymax, null) * It is a custom operator. It has no corresponding operator in Caffe. */ REG_OP(SSDDetectionOutput) @@ -447,10 +448,10 @@ REG_OP(SSDDetectionOutput) *@li boxes: A required int32, specifying the number of anchor boxes. Defaults to "5" for V2 or "3" for V3. *@li coords: An int32, specifying the number of parameters required for locating an object. The value is fixed at "4", corresponding to (x,y,w,h). *@li classes: An int32, specifying the number of prediction classes. Defaults to "80". The value range is [1, 1024]. -*@li yolo_version: A string, specifying the YOLO version, either "V2" or "V3". -*@li softmax: A bool, specifying whether to perform softmax, valid only when "yolo_version = V2". -*@li background: A bool, specifying the operation types of the obj and classes, used in conjunction with "softmax" and valid only when "yolo_version = V2". -*@li softmaxtree: A bool, Fixed to False, defined in Lite, but not used. +*@li yolo_version: A string, specifying the YOLO version, either "V2" or "V3".Defaults to "V3" +*@li softmax: A bool, specifying whether to perform softmax, valid only when "yolo_version = V2". Defaults to "false". +*@li background: A bool, specifying the operation types of the obj and classes, used in conjunction with "softmax" and valid only when "yolo_version = V2". Defaults to "false". +*@li softmaxtree: A bool, Fixed to False, defined in Lite, but not used. Defaults to "false". *@par Outputs: *@li coord_data: A float16 or float32 with shape [N, boxes*coords, ceilx(height*width*2+32, 32)/2], where "ceil" indicates that a detected box is aligned upwards with the second parameter. Specifies the coordinates of a detected box. @@ -501,10 +502,10 @@ and the actual image height and width. *@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "512". * *@par Outputs: -*@li boxout: An NCHW tensor of type float16 or float32 with shape [batch,6,post_nms_topn]. describing the information of each output box, including the coordinates, class, -and confidence. In output shape, 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. -*@li boxoutnum: An NCHW tensor of type int32 with shape [batch,8,1,1], specifying the number of output boxes. It means only the first one of the 8 numbers is valid, -the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 +*@li boxout: A tensor of type float16 or float32 with shape [batch,6,post_nms_topn]. describing the information of each output box, +* In output shape, 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. +*@li boxoutnum: A tensor of type int32 with shape [batch,8,1,1], specifying the number of output boxes. It means only the first one of the 8 numbers is valid, +* the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 * *@attention Constraints:\n *@li This operator applies only to the YOLO v2 network. @@ -561,10 +562,10 @@ and the actual image height and width. *@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "512". * *@par Outputs: -*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. -With shape [batch,6,post_nms_topn], 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. -*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. -With shape [batch,8,1,1], means only the first one of the 8 numbers is valid, the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 +*@li boxout: A tensor of type float16 or float32 with shape [batch,6,post_nms_topn]. describing the information of each output box, +* In output shape, 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. +*@li boxoutnum: A tensor of type int32 with shape [batch,8,1,1], specifying the number of output boxes. It means only the first one of the 8 numbers is valid, +* the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 * *@attention Constraints:\n *@li This operator applies only to the YOLO v2 network. @@ -621,11 +622,11 @@ and the actual image height and width. *@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "512". * *@par Outputs: -*@li boxout: An NCHW tensor of type float16 or float32 with shape [batch,6,post_nms_topn], describing the information of each output box, including the coordinates, class, and confidence. -In output shape, 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. -*@li boxoutnum: An NCHW tensor of type int32 with shape [batch,8,1,1], specifying the number of output boxes. -The output shape means only the first one of the 8 numbers is valid, the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 - +*@li boxout: A tensor of type float16 or float32 with shape [batch,6,post_nms_topn], describing the information of each output box. +* In output shape, 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. +*@li boxoutnum: A tensor of type int32 with shape [batch,8,1,1], specifying the number of output boxes. +* The output shape means only the first one of the 8 numbers is valid, the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 +* *@attention Constraints:\n *@li This operator applies only to the YOLO v3 network. *@li The preceding layer of operator Yolov3DetectionOutput must be three Yolo operators. @@ -688,12 +689,11 @@ and the actual image height and width. *@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "512". * *@par Outputs: -*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. -With shape [batch,6,post_nms_topn], 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. -*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. -With shape [batch,8,1,1], means only the first one of the 8 numbers is valid, the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 +*@li boxout: A tensor of type float16 or float32 with shape [batch,6,post_nms_topn], describing the information of each output box. +* In output shape, 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. +*@li boxoutnum: A tensor of type int32 with shape [batch,8,1,1], specifying the number of output boxes. +* The output shape means only the first one of the 8 numbers is valid, the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 * - *@attention Constraints:\n *@li This operator applies only to the YOLO v3 network. *@li The preceding layer of operator Yolov3DetectionOutput must be three Yolo operators. diff --git a/third_party/fwkacllib/inc/ops/nn_norm_ops.h b/third_party/fwkacllib/inc/ops/nn_norm_ops.h index d4db7cf0..d18a4fa4 100644 --- a/third_party/fwkacllib/inc/ops/nn_norm_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_norm_ops.h @@ -291,8 +291,8 @@ REG_OP(BinaryCrossEntropyGrad) * double. Should be a Variable Tensor. *@par Attributes: -*axes: A list of ints. The dimension softmax would be performed on. Defaults -* to "{-1}". +*axes: A list of int. The dimension softmax would be performed on. Defaults +* to "[-1]". *@par Outputs: *y: A Tensor. Has the same dimensionality and shape as the "x" with values in @@ -632,7 +632,7 @@ REG_OP(DropOutDoMask) * Three inputs, including: *@li x: An ND tensor of type float16 or float32. *@li scale: An ND tensor of type float16 or float32. -*@li bias: An ND tensor of type float16 or float32. +*@li bias: An optional ND tensor of type float16 or float32. *@par Attributes: *@li axis: An optional int32 used to compute the shape of scale and bias input from the online bottoms. Defaults to "1". @@ -679,9 +679,9 @@ REG_OP(Scale) * depth_radius = (local_size - 1) / 2. local_size is the number of channels to sum over (for ACROSS_CHANNELS) * or the side length of the square region to sum over (for WITHIN_CHANNEL). *@li bias: An optional float32. An offset, usually > 0 to avoid dividing by 0. -* Defaults to "1". +* Defaults to "1.0". *@li alpha: An optional float32. A scaling factor, usually positive. -* Defaults to "1". +* Defaults to "1.0". *@li beta: An optional float32. An exponent. Defaults to "0.75" for the caffe framework, Defaults to "0.5" for others. *@li norm_region: An optional string. A mode option. "ACROSS_CHANNELS":0, "WITHIN_CHANNEL":1. Defaults to "ACROSS_CHANNELS". @@ -836,6 +836,56 @@ REG_OP(GroupNorm) .ATTR(num_groups, Int, 2) .OP_END_FACTORY_REG(GroupNorm) +/** +*@brief Performs instance normalization. + +*@par Inputs:\n +* Five inputs, including: (NC1HWC0, supported) +*@li x: A 5D Tensor of type float16 or float32, NC1HWC0. +*@li gamma: A Tensor of type float32. +A 5D Tensor for scaling factor, to scale the normalized x. +*@li beta: A Tensor of type float32. +A 5D Tensor for offset, to shift to the normalized x. +*@li mean: A Tensor of type float32. +A 5D Tensor Specifies the mean used for inference. Reserved. +*@li variance: A Tensor of type float32. +A 5D Tensor Specifies the variance used for inference. Reserved. + +*@par Attributes: +*@li is_training: An optional bool, specifying if the operation is used for \n +training or inference. Defaults to "True". +*@li momentum: An optional float32, \n +the value used for the running_mean and running_var computation. Default: "0.1". +*@li epsilon: An optional float32, specifying the small value added to \n +variance to avoid dividing by zero. Defaults to "0.00001". + +*@par Outputs:\n +* Three outputs, including: (NHWC, NCHW NC1HWC0 supported) +*@li y: A 5D tensor of type float16 or float32 for the normalized "x", \n +*@li batch_mean: A Tensor of type float32. +Specifies the mean of "x". +*@li batch_variance: A Tensor of type float32. +Specifies the variance of "x". + +*@par Third-party framework compatibility +*@li Compatible with the PyTorch operator InstanceNorm. +*/ +REG_OP(InstanceNormV2) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) + .OPTIONAL_INPUT(gamma, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(beta, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(mean, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(variance, TensorType({DT_FLOAT})) + + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(batch_mean, TensorType({DT_FLOAT})) + .OUTPUT(batch_variance, TensorType({DT_FLOAT})) + + .ATTR(is_training, Bool, true) + .ATTR(momentum, Float, 0.1) + .ATTR(epsilon, Float, 0.00001) + .OP_END_FACTORY_REG(InstanceNormV2) + } // namespace ge #endif //GE_OP_NN_NORM_OPS_H diff --git a/third_party/fwkacllib/inc/ops/nn_pooling_ops.h b/third_party/fwkacllib/inc/ops/nn_pooling_ops.h index 5eb11445..693e51d1 100644 --- a/third_party/fwkacllib/inc/ops/nn_pooling_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_pooling_ops.h @@ -102,6 +102,42 @@ REG_OP(AvgPool) .OP_END_FACTORY_REG(AvgPool) /** +*@brief Performs average pooling on the input. + +*@par Inputs: +*x: A 5-D Tensor of shape [batch, depth, height, width, channels] and type float16, float32, double. + +*@par Attributes: +*@li ksize: List of ints that has length 1, 3 or 5. The size of the window for each dimension of the input tensor. +*@li strides:List of ints that has length 1, 3 or 5. The stride of the sliding window for each dimension of the input tensor. +*@li pads: List of ints, implicit zero paddings on both sides of the input. +*@li ceil_mode: When true, will use ceil instead of floor in the formula to compute the output shape. +*@li count_include_pad: When true, will include the zero-padding in the averaging calculation. +*@li divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used. +*@li data_format: A string, format of input data. + +*@par Outputs: +*y: The average pooled output tensor. + +*@attention Constraints: +*@li "ksize" is in the range [1, 255]. "strides" is in the range [1, 63] + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator AvgPool3D. +*/ +REG_OP(AvgPool3D) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT32, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT32, DT_DOUBLE})) + .REQUIRED_ATTR(ksize, ListInt) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(ceil_mode, Bool, false) + .ATTR(count_include_pad, Bool, true) + .ATTR(divisor_override, Int, 0) + .ATTR(data_format, String, "NDHWC") + .OP_END_FACTORY_REG(AvgPool3D) + +/** *@brief Performs max_pool_ext2 on the input. *@par Inputs: @@ -184,17 +220,62 @@ REG_OP(MaxPool) .OP_END_FACTORY_REG(MaxPool) REG_OP(MaxPool3D) - .INPUT(x, TensorType({DT_FLOAT16})) - .OUTPUT(y, TensorType({DT_FLOAT16})) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT32, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT32, DT_DOUBLE})) .REQUIRED_ATTR(ksize, ListInt) .REQUIRED_ATTR(strides, ListInt) .REQUIRED_ATTR(padding, String) .ATTR(pads, ListInt, {0,0,0}) - .ATTR(dilation, ListInt, {0,0,0}) + .ATTR(dilation, ListInt, {1,1,1}) .ATTR(ceil_mode, Int, 0) .ATTR(data_format, String, "NDHWC") .OP_END_FACTORY_REG(MaxPool3D) + +/** +* @brief Computes second-order gradients of the maxpooling3d function. + +* @par Inputs: +* @li orig_x: Original forward input tensor(NDC1HWC0) of type float16 +* @li orig_y: Original forward output tensor(NDC1HWC0) of type float16 +* @li grads: Gradient tensor(NDC1HWC0) of type float16 +* @li assist: Assist tensor(NDC1HWC0) of type float16 + +* @par Attributes: +* @li ksize: A required list or tuple, +* specifying the size of the sliding window. +* @li strides: A required list or tuple, +* specifying the stride of the sliding window. +* @li pads: A required list or tuple +* @li padding: A required string, window sliding mode. Either SAME or VALID. +* @li data_format: An optional string. +* Format of the original input, either NCDHW or NDHWC. Defaults to NDHWC. + +* @attention Constraints: +* @li Only the Ascend 910 platform is supported. +* @li "orig_x" and "grads" must have the same shape. +* @li "orig_y" and "y" must have the same shape. Otherwise, an error is reported. +* @li "orig_x", "orig_y", "grads", and "y" must be NDC1HWC0 tensors. + +* @par Outputs: +* @li y: Result tensor of type float16 + +* @par Third-party framework compatibility +* @li Compatible with the TensorFlow operator MaxPool3DGradGrad. +*/ + +REG_OP(MaxPool3DGradGrad) + .INPUT(orig_x, TensorType::RealNumberType()) + .INPUT(orig_y, TensorType::RealNumberType()) + .INPUT(grads, TensorType::RealNumberType()) + .OUTPUT(y, TensorType::RealNumberType()) + .REQUIRED_ATTR(ksize, ListInt) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(data_format, String, "NDHWC") + .OP_END_FACTORY_REG(MaxPool3DGradGrad) + + /** * @brief Computes gradients of the maxpooling function. @@ -239,9 +320,10 @@ REG_OP(MaxPoolGrad) * @brief Computes second-order gradients of the maxpooling function. * @par Inputs: -* @li x1: Original forward input tensor of type RealNumberType -* @li x2: Original forward output tensor of type RealNumberType -* @li grad: Gradient tensor of type RealNumberType +* @li x1: Original forward input tensor. Supported type:float, double, int32, + * uint8, int16, int8, int64, uint16, half, uint32, uint64. +* @li x2: Has the same type and format as input "x1". +* @li grad:Has the same type and format as input "x1". * @par Attributes: * @li ksize: A required list or tuple, @@ -262,7 +344,7 @@ REG_OP(MaxPoolGrad) * @li Other dimensions of ksize and strides is 1. * @par Outputs: -* @li y: Result tensor of type RealNumberType +* @li y: Has the same type and format as input "x1". * @par Third-party framework compatibility * @li Compatible with the TensorFlow operator MaxPoolGradGrad. @@ -398,18 +480,55 @@ REG_OP(MaxPoolGradWithArgmax) .OP_END_FACTORY_REG(MaxPoolGradWithArgmax) /** +*@brief Performs transform mask to argmax. + +*@par Inputs: +* Two input: +*x: An NC1HWC0 Tensor of type float16. +*mask: An NC1HWC0 Tensor of type uint16. + +*@par Attributes: +*@li ksize: A required list of int8, int16, int32, or int64 values, specifying the size of the window for each dimension of the input tensor. No default value. +*@li strides: A required list of int8, int16, int32, or int64 values, specifying the stride of the sliding window for each dimension of the input tensor. No default value. +*@li padding: A required string. No default value. + +*@par Outputs: +*argmax: An NC1HWC0 Tensor of type int32. + +*@attention Constraints: +*@li "ksize" is a list that has length 4: ksize[0] = 1 or ksize[3] = 1, ksize[1] * ksize[2] <= 255. +*@li "stride is a list that has length 4: strides[0] = 1 or strides[3] = 1, strides[1] <= 63, strides[0] >= 1, strides[2] <= 63, strides[2] >= 1. +*@li "padding" is either "SAME" or "VALID". + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator Mask2Argmax. +*/ +REG_OP(Mask2Argmax) + .INPUT(x, TensorType::RealNumberType()) + .INPUT(mask, TensorType::IndexNumberType()) + .OUTPUT(argmax, TensorType::IndexNumberType()) + .REQUIRED_ATTR(ksize, ListInt) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(padding, String) + .REQUIRED_ATTR(originshape, ListInt) + .OP_END_FACTORY_REG(Mask2Argmax) + +/** * @brief Computes second-order gradients of the maxpooling function. * @par Inputs: -* @li x: Original forward input tensor of type RealNumberType -* @li grad: Gradient tensor of type RealNumberType -* @li argmax: An tensor of type IndexNumberType +* @li x: Original forward input tensor. Supported type: float, double, int32, + * uint8, int16, int8, int64, uint16, half, uint32, uint64. +* @li grad: Gradient tensor. Supported type: float, double, int32, + * uint8, int16, int8, int64, uint16, half, uint32, uint64. +* @li argmax: An tensor of type int32 or int64. * @par Attributes: * @li ksize: A required list, specifying the size of the sliding window. * @li strides: A required list, specifying the stride of the sliding window. * @li padding: A required string, window sliding mode. Either SAME or VALID. * @par Outputs: -* @li y:Result tensor of type RealNumberType +* @li y:Result tensor. Supported type: float, double, int32, + * uint8, int16, int8, int64, uint16, half, uint32, uint64 * @attention Constraints: * @li Only the cloud platform is supported. @@ -531,7 +650,7 @@ REG_OP(MaxPoolGradWithArgmaxCCE) * one input, including: *@li x: A tensor of type float16 or float32. *@par Attributes: -*@li scale: A optional float, scale factor of x. Defaults to "1.0". +*@li scale: A optional float32, scale factor of x. Defaults to "1.0". *@li stride_h: An optional int32, broadcast the axis of h. Defaults to "2". *@li stride_w: An optional int32, broadcast the axis of w. Defaults to "2". *@par Outputs: @@ -749,7 +868,186 @@ REG_OP(DataFormatVecPermute) .ATTR(dst_format, String, "NCHW") .OP_END_FACTORY_REG(DataFormatVecPermute) +/** +* @brief Computes gradients of the MaxPool3D function. +* @par Inputs: +* @li orig_x: A mutable NDC1HWC0 tensor of type float16. +* @li orig_y: A mutable NDC1HWC0 tensor of type float16. +* @li grads: A mutable NDC1HWC0 tensor of type float16. + +* @par Attributes: +* @li ksize: A required tuple or list, specifying the size of the window for +* each dimension of the input tensor. +* @li strides: A required tuple or list, specifying the stride of the sliding +* window for each dimension of the input tensor. +* @li pads: A list of 6 ints. Supports only padding along the D, +* H and W dimensions in sequence of head, tail, top, bottom, left and right. +* to use. +* @li data_format: An optional string, Specify the data format of the input and +* output data. With the default format "NDHWC". + +* @par Outputs: +* y: A mutable tensor. Has the same shape as "orig_x", but type is float32. + +* @par Third-party framework compatibility +* Compatible with the TensorFlow operator MaxPool3DGrad. +*/ +REG_OP(MaxPool3DGrad) + .INPUT(orig_x, TensorType::RealNumberType()) + .INPUT(orig_y, TensorType::RealNumberType()) + .INPUT(grads, TensorType::RealNumberType()) + .OUTPUT(y, TensorType::RealNumberType()) + .REQUIRED_ATTR(ksize, ListInt) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(data_format, String, "NDHWC") + .OP_END_FACTORY_REG(MaxPool3DGrad) + +/** +*@brief Performs AvgPool1D on the input. + +*@par Inputs: +*x: A Tensor. Must be one of the following types: int8, uint8, int16, int32, int64, float16, float32, float64. + +*@par Attributes: +*@li ksize: An required int, specifying the size of the window. +*@li strides: An required int. +*@li pads: A required tuple or list. +*@li ceil_mode: An optional bool. Defaults to False. +*@li count_include_pad: An optional bool. Defaults to False. + +*@par Outputs: +*y: A Tensor. Has the same type as x. + +*@par Third-party framework compatibility +*@li compatible with pytorch AvgPool1D operator. +*/ +REG_OP(AvgPool1D) + .INPUT(x, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .REQUIRED_ATTR(ksize, Int) + .REQUIRED_ATTR(strides, Int) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(ceil_mode, Bool, false) + .ATTR(count_include_pad, Bool, false) + .OP_END_FACTORY_REG(AvgPool1D) + +/** +*@brief Performs AvgPool1D on the input. + +*@par Inputs: +*x: A Tensor. Must be one of the following types: int8, uint8, int16, int32, int64, float16, float32, float64. + +*@par Attributes: +*@li ksize: An required int, specifying the size of the window. +*@li strides: An required int. +*@li pads: A required tuple or list. +*@li ceil_mode: An optional bool. Defaults to False. +*@li count_include_pad: An optional bool. Defaults to False. + +*@par Outputs: +*y: A Tensor. Has the same type as x. + +*@par Third-party framework compatibility +*@li compatible with pytorch AvgPool1D operator. +*/ +REG_OP(AvgPool1DD) + .INPUT(x, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .INPUT(assist_matrix, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .REQUIRED_ATTR(ksize, Int) + .REQUIRED_ATTR(strides, Int) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(ceil_mode, Bool, false) + .ATTR(count_include_pad, Bool, false) + .OP_END_FACTORY_REG(AvgPool1DD) +/** +*@brief Performs max pooling on the input and outputs both max values and indices. + +*@par Inputs: +* One input: +*x: An NC1HWC0 Tensor of type float16. +*@par Attributes: +*@li ksize: A required list of int8, int16, int32, or int64 values, specifying the size of the window for +* each dimension of the input tensor. No default value. +*@li strides: A required list of int8, int16, int32, or int64 values, specifying the stride of the sliding window for +* each dimension of the input tensor. No default value. +*@li pads: A required string. No default value. +*@li dtype: A optional int. default value is 3. +*@li dilation: A optional list of int8, int16, int32, or int64 values. +*@li ceil_mode: A optional bool. default value is false. + +*@par Outputs: +*y: A Tensor. Has the same type and format as input "x". +*argmax: A Tensor. type:uint16, format:NC1HWC0. +*@attention Constraints: +*@li "ksize" is a list that has length 4: ksize[0] = 1 or ksize[3] = 1, ksize[1] * ksize[2] <= 255. +*@li "strides is a list that has length 4: strides[0] = 1 or strides[3] = 1, strides[1] <= 63, strides[0] >= 1, +* strides[2] <= 63, strides[2] >= 1. +*@li "dilation" is a list that has length 4. +*@li "ceil_mode" is a bool, default is false. + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator MaxPoolWithArgmax. +*/ +REG_OP(MaxPoolWithArgmaxV2) + .INPUT(x, TensorType({DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT16})) + .OUTPUT(argmax, TensorType({DT_UINT16})) + .REQUIRED_ATTR(ksize, ListInt) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(dtype, Int, 3) + .ATTR(dilation, ListInt, {1, 1, 1, 1}) + .ATTR(ceil_mode, Bool, false) + .OP_END_FACTORY_REG(MaxPoolWithArgmaxV2) + +/** +*@brief Performs the backpropagation of MaxPoolWithArgmaxV2. + +*@par Inputs: +* Three inputs, including: +*@li x: An NC1HWC0 tensor of type float16. +*@li grad: An NC1HWC0 tensor of type float16. +*@li argmx: An NC1HWC0 tensor of type uint16 or int64. + +*@par Attributes: +*@li ksize: A required list of int8, int16, int32, or int64 values, specifying the size of the window for + * each dimension of the input tensor. No default value. +*@li strides: A required list of int8, int16, int32, or int64 values, specifying the stride of the sliding window for + * each dimension of the input tensor. No default value. +*@li pads: A required string. No default value. +*@li dtype: A optional int. default value is 3. +*@li dilation: A optional list of int8, int16, int32, or int64 values. +*@li ceil_mode: A optional bool. default value is false. + +*@par Outputs: +*y: A Tensor. Has the same type and format as input "x". + +*@attention Constraints: +*@li "ksize" is a list that has length 4: ksize[0] = 1 or ksize[3] = 1, ksize[1] * ksize[2] <= 255. +*@li "strides" is a list that has length 4: strides[0] = 1 or strides[3] = 1 +*@li "dilation" is a list that has length 4. +*@li "ceil_mode" is a bool, default is false. + +*@see max_pool_grad_with_argmaxv2 +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator MaxPoolGradWithArgmaxV2. +*/ + +REG_OP(MaxPoolGradWithArgmaxV2) + .INPUT(x, TensorType({DT_FLOAT16})) + .INPUT(grad, TensorType({DT_FLOAT16})) + .INPUT(argmax, TensorType({DT_UINT16})) + .OUTPUT(y, TensorType({DT_FLOAT16})) + .REQUIRED_ATTR(ksize, ListInt) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(dtype, Int, 3) + .ATTR(dilation, ListInt, {1,1,1,1}) + .ATTR(ceil_mode, Bool, false) + .OP_END_FACTORY_REG(MaxPoolGradWithArgmaxV2) } // namespace ge #endif // GE_OP_NN_POOLING_OPS_H diff --git a/third_party/fwkacllib/inc/ops/nn_training_ops.h b/third_party/fwkacllib/inc/ops/nn_training_ops.h index 1c9aa516..368054f5 100644 --- a/third_party/fwkacllib/inc/ops/nn_training_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_training_ops.h @@ -1508,7 +1508,7 @@ REG_OP(ApplyProximalAdagradD) *@par Attributes: *use_locking: An optional bool. Defaults to "False".\n * If "True", updating of the var and accum tensors will be protected by a lock; \n -* If "False", the behavior is undefined, but may exhibit less contention. +* If "False", the behavior is undefined, but may exhibit less contention. *@par Outputs: *var: A mutable Tensor. Has the same type as "var". diff --git a/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h b/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h index 1405fdb7..a01073cf 100644 --- a/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h +++ b/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h @@ -83,7 +83,7 @@ REG_OP(TanhGrad) *@par Inputs: *One input: -*x: A Tensor. Must be one of the following types: float16, float32, complex64, complex128, int32, int64 +*x: A Tensor. Must be one of the following types: float16, float32, complex64, complex128, double. *@par Outputs: *y: A Tensor. Has the same type as "x". @@ -184,7 +184,7 @@ REG_OP(Relu6Grad) * @brief Compute sigmoid of "x" element-wise. * @par Inputs: -* A Tensor of type UnaryDataType. +* A Tensor of type complex64, complex128, float16, float32 or double. * @par Outputs: * A Tensor. Has the same type as "x". @@ -220,7 +220,7 @@ REG_OP(SigmoidGrad) *if x>0, x+log(1+exp(-x)); otherwise log(1+exp(x)). *@par Inputs: -*x: A Tensor of type float16 or float32. +*x: A Tensor of type double, float16 or float32. *@par Outputs: *y: A tensor. Has the same type and format as input "x". @@ -442,7 +442,7 @@ REG_OP(PReluGrad) *x: A float16, float32 or double, for the input data type. *@par Attributes: -*alpha: A float. Defines at which negative value the ELU saturates. Defaults to "1.0". +*alpha: A float32. Defines at which negative value the ELU saturates. Defaults to "1.0". *@par Outputs: *y: A float16, float32 or double, for the normalized result. diff --git a/third_party/fwkacllib/inc/ops/reduce_ops.h b/third_party/fwkacllib/inc/ops/reduce_ops.h index 8819d2d5..a8aed058 100644 --- a/third_party/fwkacllib/inc/ops/reduce_ops.h +++ b/third_party/fwkacllib/inc/ops/reduce_ops.h @@ -673,7 +673,7 @@ REG_OP(ReduceAnyD) *@par Attributes: *@li operation: An optional int32 from 1(SUM), 2(ASUM), 3(SUMSQ), and 4(MEAN), -*specifying the reduction algorithm. Defaults to 1. +*specifying the reduction algorithm. Defaults to "1". *@li axis: An optional int32, specifying the first axis to reduce. Defaults to "0". *The value range is [-N, N-1], where N is the input tensor rank. *@li coeff: An optional float32, specifying the scale coefficient. Defaults to "1.0". @@ -745,7 +745,190 @@ REG_OP(EuclideanNormD) .ATTR(keep_dims, Bool, false) .OP_END_FACTORY_REG(EuclideanNormD) -} //namespace ge +/** +*@brief Performs instance normalization for inference. + +*@par Inputs:\n +* Five inputs, including: (NC1HWC0 supported) +*@li x: A Tensor of type float16 or float32. +*@li gamma: A [N, C1, 1, 1, C0] Tensor of type float32, for the scaling gamma. +*@li beta: A [N, C1, 1, 1, C0] Tensor of type float32, for the scaling beta. +*@li mean: A [N, C1, 1, 1, C0] ensor of type float32, for the mean. +*@li variance: A [N, C1, 1, 1, C0] Tensor of type float32, for the variance. + +*@par Attributes: +*epsilon: An optional float32, specifying the small value added to variance to avoid dividing by zero. +Defaults to "0.00001". + +*@par Outputs:\n +*y: A Tensor of type float16 or float32 for the normalized "x". +*batch_mean: A Tensor of type float32 for the result mean. +*batch_ variance: A Tensor of type float32 for the result variance. + +*@attention Constraints: +*For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction. +*/ +REG_OP(INInferV2) + .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) + .OPTIONAL_INPUT(gamma, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(beta, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(mean, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(variance, TensorType({DT_FLOAT})) + .ATTR(epsilon, Float, 0.00001) + .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(batch_mean, TensorType({DT_FLOAT})) + .OUTPUT(batch_variance, TensorType({DT_FLOAT})) + .OP_END_FACTORY_REG(INInferV2) + +/** +*@brief Performs reduced instance normalization. + +*@par Inputs:\n +*x: A Tensor of type float16 or float32, with format NC1HWC0. + +*@par Outputs: +*@li sum: A Tensor of type float32 for SUM reduced "x". +*@li square_sum: A Tensor of type float32 for SUMSQ reduced "x". + +*@attention Constraints:\n +* This operator is a InstanceNorm fusion operator for updating the moving averages for training. \n +* This operator is used in conjunction with INTrainingUpdateV2. +*/ +REG_OP(INTrainingReduceV2) + .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(sum, TensorType({DT_FLOAT})) + .OUTPUT(square_sum, TensorType({DT_FLOAT})) + .OP_END_FACTORY_REG(INTrainingReduceV2) + + +/** +*@brief Performs update instance normalization. + +*@par Inputs:\n +* Seven inputs, including: (NC1HWC0supported) +*@li x: A Tensor of type float16 or float32. +*@li sum: A T [N, C1, 1, 1, C0] ensor of type float32 for the output of operator INTrainingReduceV2. +*@li square_sum: A [N, C1, 1, 1, C0] Tensor of type float32 for the output of operator INTrainingReduceV2. +*@li gamma: A [N, C1, 1, 1, C0] Tensor of type float32, for the scaling gamma. +*@li beta: A [N, C1, 1, 1, C0] Tensor of type float32, for the scaling beta. +*@li mean: A [N, C1, 1, 1, C0] Tensor of type float32, for the updated mean. +*@li variance: A [N, C1, 1, 1, C0] Tensor of type float32, for the updated variance. + +*@par Attributes: +*@li momentum: A required float32, specifying the momentum to update mean and var. +*@li epsilon: A required float32, specifying the small value added to variance to avoid dividing by zero. + +*@par Outputs:\n +* Three outputs, including: (NC1HWC0 supported) +*@li y: A Tensor of type float16 or float32, for normalized "x". +*@li batch_mean: A Tensor of type float32, for the updated mean. +*@li batch_variance: A Tensor of type float32, for the updated variance. + +*@attention Constraints: +*@li This operator is a InstanceNorm fusion operator for updating the moving averages for training. \n +* This operator is used in conjunction with INTrainingReduceV2. +*@li For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction. +*/ +REG_OP(INTrainingUpdateV2) + .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(sum, TensorType({DT_FLOAT})) + .INPUT(square_sum, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(gamma, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(beta, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(mean, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(variance, TensorType({DT_FLOAT})) + .ATTR(momentum, Float, 0.1) + .ATTR(epsilon, Float, 0.00001) + .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(batch_mean, TensorType({DT_FLOAT})) + .OUTPUT(batch_variance, TensorType({DT_FLOAT})) + .OP_END_FACTORY_REG(INTrainingUpdateV2) + + +/** +*@brief Performs reduced group normalization. + +*@par Inputs:\n +*x: A Tensor of type float16 or float32, with format NCHW NHWC. + +*@par Outputs: +*@li sum: A Tensor of type float32 for SUM reduced "x". +*@li square_sum: A Tensor of type float32 for SUMSQ reduced "x". + + +*@par Attributes: +*@li num_groups: Int, specifying the num of groups. required, same to GNTrainingUpdate. + +*@attention Constraints:\n +* This operator is a GroupNorm fusion operator for updating the moving averages for training. \n +* This operator is used in conjunction with GNTrainingUpdate. +*/ +REG_OP(GNTrainingReduce) + .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(sum, TensorType({DT_FLOAT})) + .OUTPUT(square_sum, TensorType({DT_FLOAT})) + .ATTR(num_groups, Int, 2) + .OP_END_FACTORY_REG(GNTrainingReduce) + + +/** +*@brief Performs update group normalization. + +*@par Inputs:\n +* Eight inputs, including: (NCHW NHWC supported) +*@li x: A Tensor of type float16 or float32. +*@li sum: A 5D Tensor of type float32, +shape is [N, G, D, 1, 1] for NCHW, [N, 1, 1, G, D] for NHWC +for the output of operator GNTrainingReduce. +*@li square_sum: A 5D Tensor of type float32, +shape is [N, G, D, 1, 1] for NCHW, [N, 1, 1, G, D] for NHWC +for the output of operator GNTrainingReduce. +*@li scale: A 5D Tensor of type float32, +shape is [1, G, D, 1, 1] for NCHW, [1, 1, 1, G, D] for NHWC +is for the scaling gamma. +*@li offset: A 5D Tensor of type float32, +shape is [1, G, D, 1, 1] for NCHW, [1, 1, 1, G, D] for NHWC +for the scaling beta. +*@li mean: A 5D Tensor of type float32, +shape is [N, G, D, 1, 1] for NCHW, [N, 1, 1, G, D] for NHWC +for the updated mean. +*@li variance: A 5D Tensor of type float32, +shape is [N, G, D, 1, 1] for NCHW, [N, 1, 1, G, D] for NHWC +for the updated variance. + + +*@par Attributes: +*@li epsilon: A float32, specifying the small value added to variance to avoid dividing by zero. +*@li num_groups: Int, specifying the num of groups. required, same to GNTrainingReduce + +*@par Outputs:\n +* Three outputs, including: (NC1HWC0 supported) +*@li y: A Tensor of type float16 or float32, for normalized "x". +*@li batch_mean: A Tensor of type float32, for the updated mean. +*@li batch_variance: A Tensor of type float32, for the updated variance. + +*@attention Constraints: +*@li This operator is a InstanceNorm fusion operator for updating the moving averages for training. \n +* This operator is used in conjunction with GNTrainingUpdate. +*@li For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction. +*/ +REG_OP(GNTrainingUpdate) + .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(sum, TensorType({DT_FLOAT})) + .INPUT(square_sum, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(scale, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(offset, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(mean, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(variance, TensorType({DT_FLOAT})) + .ATTR(num_groups, Int, 2) + .ATTR(epsilon, Float, 0.0001) + .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(batch_mean, TensorType({DT_FLOAT})) + .OUTPUT(batch_variance, TensorType({DT_FLOAT})) + .OP_END_FACTORY_REG(GNTrainingUpdate) + +} //namespace ge + #endif /* GE_OP_REDUCE_OPS_H */ diff --git a/third_party/fwkacllib/inc/ops/rnn.h b/third_party/fwkacllib/inc/ops/rnn.h index c4d64b0a..b72d9a79 100644 --- a/third_party/fwkacllib/inc/ops/rnn.h +++ b/third_party/fwkacllib/inc/ops/rnn.h @@ -67,6 +67,13 @@ REG_OP(BasicLSTMCell) .ATTR(activation, String, "tanh") .OP_END_FACTORY_REG(BasicLSTMCell) +REG_OP(DynamicLSTM) + .INPUT(x, TensorType({DT_FLOAT32})) + .INPUT(w, TensorType({DT_FLOAT32})) + .INPUT(b, TensorType({DT_FLOAT32})) + .OUTPUT(output_h, TensorType({DT_FLOAT32})) + .OP_END_FACTORY_REG(DynamicLSTM) + /** *@brief: Basic LSTM Cell backward calculation.Calculate the gradient of input and hidden state. *@par Inputs: @@ -87,7 +94,7 @@ REG_OP(BasicLSTMCellInputGrad) .INPUT(dgate, TensorType({DT_FLOAT16})) .INPUT(w, TensorType({DT_FLOAT16})) .OPTIONAL_INPUT(dropout_mask, TensorType({DT_UINT8})) - .OUTPUT(dxt, TensorType({DT_FLOAT16})) + .OUTPUT(dxt, TensorType({DT_FLOAT16, DT_FLOAT32})) .OUTPUT(dht, TensorType({DT_FLOAT16, DT_FLOAT32})) .ATTR(keep_prob, Float, 1.0) .OP_END_FACTORY_REG(BasicLSTMCellInputGrad) diff --git a/third_party/fwkacllib/inc/ops/selection_ops.h b/third_party/fwkacllib/inc/ops/selection_ops.h index aafcece0..bbe203cd 100644 --- a/third_party/fwkacllib/inc/ops/selection_ops.h +++ b/third_party/fwkacllib/inc/ops/selection_ops.h @@ -89,7 +89,8 @@ REG_OP(RangeD) *@par Inputs: *Two inputs, including: -* @li x: A Tensor of type TensorType::BasicType(). +* @li x: A Tensor. +* Must be one of the following types: float16, float32, double, int64, int32, uint8, uint16, uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, qint16, quint16, qint32. * @li multiples: A 1D Tensor of type int32 or int64. * The length must be the same as the number of dimensions in "input" @@ -496,7 +497,7 @@ REG_OP(UnsortedSegmentSumD) *@par Inputs: * Two inputs, including:\n *@li x: An ND Tensor (up to 8D). \n -*Must be one of the following types: int8, uint8, int16, uint16, int32, int64, bool, float32, double +*Must be one of the following types: int8, uint8, int16, uint16, int32, int64, bool, float16, float32, double, complex64, complex128, string. *@li axis: A 1D Tensor.\n *Must be one of the following types: int32, int64 @@ -1559,14 +1560,14 @@ REG_OP(ProposalD) * If reverse=false: (N, H, W, C)->(N, H/stride, W/stride, C*(stride*stride)) *@par Inputs: -*x: An (N, H, W, C) tensor. All types except double are supported. +*x: An (N, H, W, C) tensor. Type is float16, float32, int8, uint8, int16, uint16, int32, uint32, int64 or uint64.. *@par Attributes: *@li stride: An optional int32, specifying the plane or channel scaling factor. Defaults to "2". *@li reverse: An optional bool, specifying the conversion mode. If "true", depth to space conversion is performed. If "false", space to depth conversion is performed. Defaults to "false". *@par Outputs: -*y: An (N, H, W, C) tensor. All types except double are supported. +*y: An (N, H, W, C) tensor. Has same type as "x". *@attention Constraints: *@li If reverse=true: C/(stride*stride) yields an integer result. If reverse=false: W/stride and H/stride yield integer results. @@ -1593,7 +1594,7 @@ REG_OP(PassThrough) * @li x: A required Tensor. Must be one of the following types: float16, float32, int8, uint8, int16, uint16, int32, uint32,int64, uint64. * @li size: A required Tensor. Must be one of the following types: float16, float32, int8, uint8, int16, uint16, int32, uint32, int64, uint64. *@par Attributes: -*@li axis: A required int32, specifying the first dimension to crop. +*@li axis: A required int32, specifying the first dimension to crop. Defaults to "2". *@li offset: A required array, specifying the shift for all/each dimension to align the cropped bottom with the reference bottom. Must be one of the following types: float16, float32, int8, uint8, int16, uint16, int32, uint32, int64, uint64. *@par Outputs: *y: A required Tensor. Has the same type and shape as "size". diff --git a/third_party/fwkacllib/inc/ops/split_combination_ops.h b/third_party/fwkacllib/inc/ops/split_combination_ops.h index 700d34b7..7e4428d0 100644 --- a/third_party/fwkacllib/inc/ops/split_combination_ops.h +++ b/third_party/fwkacllib/inc/ops/split_combination_ops.h @@ -25,11 +25,11 @@ namespace ge { *@par Inputs: * Two inputs, including: *@li x: An ND Tensor. -*Must be one of the following types: float16, float32, int32, int8, int16, int64, uint8, uint16, uint32, uint64 +*Must be one of the types:float16, float32, double, int64, int32, uint8, uint16, uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, qint16, quint16, qint32. *@li split_dim: Must be the following type:int32. Specifies the dimension along which to split. *@par Attributes: -*num_split: A required int8, int16, int32, or int64. Specifies the number of output tensors. No default value. +*num_split: A required int32. Specifies the number of output tensors. No default value. *@par Outputs: *y: Dynamic output.A list of output tensors. Has the same type and format as "x". @@ -186,6 +186,7 @@ REG_OP(ParallelConcat) *@par Attributes: *concat_dim: A required int8, int16, int32, or int64. Specifies the dimension along which to concatenate. No default value. +*N: An attribute int8, int16, int32, or int64. Specifies the number of elements in "x". Defaults to "1". *@par Outputs: *y: A Tensor. Has the same type and format as "x". @@ -267,7 +268,9 @@ REG_OP(ConcatD) *@par Inputs: * Two inputs, including: *@li x: Dynamic input.An NC1HWC0 or ND Tensor. -*Must be one of the following types: float16, float32, int32, int8, int16, int64, uint8, uint16, uint32, uint64 +*Must be one of the following types: float16, float32, double, int32, +* uint8, int16, int8, complex64, int64, qint8, quint8, qint32, uint16, +* complex128, uint32, uint64, qint16, quint16. *@li concat_dim: An int32, or int64. Specifies the dimension along which to concatenate. *@par Attributes: diff --git a/third_party/fwkacllib/inc/ops/transformation_ops.h b/third_party/fwkacllib/inc/ops/transformation_ops.h index 69951da9..7b8a94f8 100644 --- a/third_party/fwkacllib/inc/ops/transformation_ops.h +++ b/third_party/fwkacllib/inc/ops/transformation_ops.h @@ -94,6 +94,13 @@ REG_OP(Transpose) .OUTPUT(y, TensorType::BasicType()) .OP_END_FACTORY_REG(Transpose) +REG_OP(TransData) + .INPUT(src, TensorType::BasicType()) + .OUTPUT(dst, TensorType::BasicType()) + .REQUIRED_ATTR(src_format, String) + .REQUIRED_ATTR(dst_format, String) + .OP_END_FACTORY_REG(TransData) + /** *@brief Permutes the dimensions according to order.\n The returned tensor's dimension i will correspond to the input dimension order[i]. @@ -102,7 +109,7 @@ REG_OP(Transpose) *x: A Tensor. Must be one of the following types: float16, float32. *@par Attributes: -*order: A permutation of the dimensions of "x".support any axis transformation +*order: A permutation of the dimensions of "x".Type is int32.support any axis transformation.Defaults to "{0}" *@par Outputs: *y: A Tensor. Has the same type as "x". @@ -291,7 +298,7 @@ REG_OP(DepthToSpace) *@brief Permutes data into spatial data blocks and then prunes them. *@par Inputs: -*@li x: A 4D Tensor with format NC1HWC0. +*@li x: A 4D Tensor with format NHWC. *@li crops: A 1D list or tuple of int32 or int64. *Must be one of the following types: float16, float32 @@ -300,7 +307,7 @@ REG_OP(DepthToSpace) *block_size: A required int8, int16, int32, or int64. No default value. *@par Outputs: -*y: A 4D Tensor with format NC1HWC0, +*y: A 4D Tensor with format NHWC, * of type float16 or float32. @@ -365,7 +372,7 @@ REG_OP(BatchToSpaceD) *@par Inputs: * Two inputs, including: -*@li x: An NC1HWC0 Tensor. Must be one of the following types: +*@li x: An NHWC Tensor. Must be one of the following types: * float16, float32, double, int64, int32, uint8, uint16, uint32, uint64, int8, * int16, complex64, complex128, qint8, quint8, qint16, quint16, qint32. *@li paddings: A 2D tensor of type int, specifying the input. @@ -389,7 +396,7 @@ REG_OP(SpaceToBatch) *@brief Outputs a copy of the input tensor where values from the "height" and "width" dimensions are padded and rearranged to the "batch" dimension. *@par Inputs: -*x: An NC1HWC0 Tensor. Must be one of the following types: float16, float32, double, int64, int32, uint8, uint16, uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, qint16, quint16, qint32. +*x: An NHWC Tensor. Must be one of the following types: float16, float32, double, int64, int32, uint8, uint16, uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, qint16, quint16, qint32. *@par Attributes: @@ -598,6 +605,13 @@ REG_OP(Compress) .OUTPUT(compress_index, TensorType({DT_INT8})) .REQUIRED_ATTR(compress_parameters, ListInt) .OP_END_FACTORY_REG(Compress) + +REG_OP(CompressFcOp) + .INPUT(weight, TensorType({DT_INT8})) + .OUTPUT(weight_compress, TensorType({DT_INT8})) + .OUTPUT(compress_index, TensorType({DT_INT8})) + .REQUIRED_ATTR(compress_parameters, ListInt) + .OP_END_FACTORY_REG(CompressFcOp) } // namespace ge #endif // GE_OP_TRANSFORMATION_OPS_H diff --git a/third_party/fwkacllib/inc/register/op_registry.h b/third_party/fwkacllib/inc/register/op_registry.h index 1fcdf9de..1dc14b8b 100644 --- a/third_party/fwkacllib/inc/register/op_registry.h +++ b/third_party/fwkacllib/inc/register/op_registry.h @@ -35,6 +35,7 @@ enum RemoveInputType { OMG_MOVE_TYPE_SCALAR_VALUE, OMG_REMOVE_TYPE_WITH_COND = 1000, OMG_REMOVE_INPUT_WITH_ORIGINAL_TYPE, + OMG_INPUT_REORDER, }; struct RemoveInputConfigure { @@ -43,6 +44,7 @@ struct RemoveInputConfigure { RemoveInputType moveType; bool attrValue = false; std::string originalType; + std::vector input_order; }; class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistry { @@ -57,11 +59,11 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistry { void GetOpTypeByImplyType(std::vector &vec_op_type, const domi::ImplyType &imply_type); - domi::ParseParamFunc GetParseParamFunc(const std::string &op_type); + domi::ParseParamFunc GetParseParamFunc(const std::string &op_type, const std::string &ori_type); - domi::ParseParamByOpFunc GetParseParamByOperatorFunc(const std::string &op_type); + domi::ParseParamByOpFunc GetParseParamByOperatorFunc(const std::string &ori_type); - domi::FusionParseParamFunc GetFusionParseParamFunc(const std::string &op_type); + domi::FusionParseParamFunc GetFusionParseParamFunc(const std::string &op_type, const std::string &ori_type); domi::ParseSubgraphFunc GetParseSubgraphPostFunc(const std::string &op_type); @@ -72,14 +74,13 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistry { bool GetOmTypeByOriOpType(const std::string &ori_optype, std::string &om_type); private: - std::unordered_map> op_ori_optype_map_; std::unordered_map op_run_mode_map_; - std::unordered_map opParseParamsFnMap_; + std::unordered_map op_parse_params_fn_map_; std::unordered_map parse_params_by_op_func_map_; - std::unordered_map fusionOpParseParamsFnMap_; + std::unordered_map fusion_op_parse_params_fn_map_; std::unordered_map op_types_to_parse_subgraph_post_func_; std::unordered_map> remove_input_configure_map_; - std::unordered_map originOpType2OmOpType_; + std::unordered_map origin_type_to_om_type_; }; } // namespace domi #endif // INC_REGISTER_OP_REGISTRY_H_ diff --git a/third_party/fwkacllib/inc/register/op_tiling.h b/third_party/fwkacllib/inc/register/op_tiling.h new file mode 100644 index 00000000..92067a20 --- /dev/null +++ b/third_party/fwkacllib/inc/register/op_tiling.h @@ -0,0 +1,130 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_OP_TILING_H_ +#define INC_OP_TILING_H_ + +#include "external/register/register_types.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/node.h" + +#include +#include +#include +#include +#include +#include +#include "graph/node.h" + +#define REGISTER_OP_TILING_FUNC(optype, opfunc) \ + REGISTER_OP_TILING_FUNC_UNIQ_HELPER(optype, opfunc, __COUNTER__) + +#define REGISTER_OP_TILING_FUNC_UNIQ_HELPER(optype, opfunc, counter) \ + REGISTER_OP_TILING_FUNC_UNIQ(optype, opfunc, counter) + +#define REGISTER_OP_TILING_FUNC_UNIQ(optype, opfunc, counter) \ + static OpTilingInterf g_##optype##TilingInterf##counter(#optype, opfunc) + +namespace optiling { + +enum TensorArgType { + TA_NONE, + TA_SINGLE, + TA_LIST, +}; + + +using ByteBuffer = std::stringstream; + +struct TeOpTensor { + std::vector shape; + std::vector ori_shape; + std::string format; + std::string ori_format; + std::string dtype; + std::map attrs; +}; + + +struct TeOpTensorArg { + TensorArgType arg_type; + std::vector tensor; +}; + +struct OpRunInfo { + uint32_t block_dim; + std::vector workspaces; + ByteBuffer tiling_data; +}; + + +using TeOpAttrArgs = std::vector; +using TeConstTensorData = std::tuple; + +struct TeOpParas { + std::vector inputs; + std::vector outputs; + std::map const_inputs; + TeOpAttrArgs attrs; +}; + +using OpTilingFunc = std::function; + +using OpTilingFuncPtr = bool(*)(const std::string&, const TeOpParas&, const std::string&, OpRunInfo&); + +class FMK_FUNC_HOST_VISIBILITY OpTilingInterf +{ +public: + OpTilingInterf(std::string op_type, OpTilingFunc func); + ~OpTilingInterf() = default; + static std::map &RegisteredOpInterf(); +}; + + +template +ByteBuffer& ByteBufferPut(ByteBuffer &buf, const T &value) +{ + buf.write(reinterpret_cast(&value), sizeof(value)); + buf.flush(); + return buf; +} + +template +ByteBuffer& ByteBufferGet(ByteBuffer &buf, T &value) +{ + buf.read(reinterpret_cast(&value), sizeof(value)); + return buf; +} + +inline size_t ByteBufferGetAll(ByteBuffer &buf, char *dest, size_t dest_len) +{ + size_t nread = 0; + size_t rn = 0; + do { + rn = buf.readsome(dest + nread, dest_len - nread); + nread += rn; + } while (rn > 0 && dest_len > nread); + + return nread; +} + + +extern "C" ge::graphStatus OpParaCalculate(const ge::Node &node, OpRunInfo &run_info); + +} + +#endif // INC_OP_TILING_H_ diff --git a/third_party/fwkacllib/inc/runtime/base.h b/third_party/fwkacllib/inc/runtime/base.h index 49c9de6a..7539a549 100644 --- a/third_party/fwkacllib/inc/runtime/base.h +++ b/third_party/fwkacllib/inc/runtime/base.h @@ -68,6 +68,8 @@ typedef enum tagRtError { RT_ERROR_NO_STREAM_CB_REG = 0x96, // no callback register info for stream RT_ERROR_DATA_DUMP_LOAD_FAILED = 0x97, // data dump load info fail RT_ERROR_CALLBACK_THREAD_UNSUBSTRIBE = 0x98, // callback thread unsubstribe + RT_ERROR_DEBUG_REGISTER_FAILED = 0x99, // debug register fail + RT_ERROR_DEBUG_UNREGISTER_FAILED = 0x9A, // debug unregister fail RT_ERROR_RESERVED } rtError_t; @@ -186,14 +188,6 @@ RTS_API rtError_t rtPeekAtLastError(); /** * @ingroup dvrt_base - * @brief set polling receive mode for task report - * @param [out] NA - * @return RT_ERROR_NONE for ok - */ -RTS_API rtError_t rtSetPollingMode(); - -/** - * @ingroup dvrt_base * @brief register callback for error code * @param [out] NA * @return RT_ERROR_NONE for ok diff --git a/third_party/fwkacllib/inc/runtime/config.h b/third_party/fwkacllib/inc/runtime/config.h index 2e48cc57..3dad53c5 100644 --- a/third_party/fwkacllib/inc/runtime/config.h +++ b/third_party/fwkacllib/inc/runtime/config.h @@ -41,8 +41,7 @@ typedef enum tagRtChipType { CHIP_CLOUD, CHIP_MDC, CHIP_LHISI, - CHIP_OTHER_PHN, - CHIP_OTHER_OLD, + CHIP_DC, CHIP_END, } rtChipType_t; diff --git a/third_party/fwkacllib/inc/runtime/context.h b/third_party/fwkacllib/inc/runtime/context.h index 54621e86..b059268e 100644 --- a/third_party/fwkacllib/inc/runtime/context.h +++ b/third_party/fwkacllib/inc/runtime/context.h @@ -106,16 +106,6 @@ RTS_API rtError_t rtCtxGetCurrent(rtContext_t *ctx); */ RTS_API rtError_t rtCtxGetDevice(int32_t *device); -/** - * @ingroup rt_context - * @brief set ctx run mode: normal or dryrun - * @param [in] ctx: context - * @param [in] enable: set true means enable dryrun mode - * @param [in] flag: reserved - * @return RT_ERROR_NONE for ok - */ -RTS_API rtError_t rtCtxSetDryRun(rtContext_t ctx, rtDryRunFlag_t enable, uint32_t flag); - #ifdef __cplusplus } #endif diff --git a/third_party/fwkacllib/inc/runtime/dev.h b/third_party/fwkacllib/inc/runtime/dev.h index 928f2822..60928202 100644 --- a/third_party/fwkacllib/inc/runtime/dev.h +++ b/third_party/fwkacllib/inc/runtime/dev.h @@ -32,6 +32,7 @@ typedef struct tagRTDeviceInfo { uint32_t ts_cpu_core_num; uint32_t ai_cpu_core_num; uint32_t ai_core_num; + uint32_t ai_core_freq; uint32_t ai_cpu_core_id; uint32_t ai_core_id; uint32_t aicpu_occupy_bitmap; @@ -46,6 +47,13 @@ typedef enum tagRtRunMode { RT_RUN_MODE_RESERVED } rtRunMode; +typedef enum tagRtAicpuDeployType { + AICPU_DEPLOY_CROSS_OS = 0x0, + AICPU_DEPLOY_CROSS_PROCESS = 0x1, + AICPU_DEPLOY_CROSS_THREAD = 0x2, + AICPU_DEPLOY_RESERVED +} rtAicpuDeployType_t; + /** * @ingroup dvrt_dev * @brief get total device number. @@ -62,15 +70,40 @@ RTS_API rtError_t rtGetDeviceCount(int32_t *count); * @return RT_ERROR_DRV_ERR for error */ RTS_API rtError_t rtGetDeviceIDs(uint32_t *devices, uint32_t len); + /** * @ingroup dvrt_dev - * @brief get total device infomation. + * @brief get device infomation. * @param [in] device the device id - * @param [out] info the device info + * @param [in] moduleType module type + typedef enum { + MODULE_TYPE_SYSTEM = 0, system info + MODULE_TYPE_AICPU, aicpu info + MODULE_TYPE_CCPU, ccpu_info + MODULE_TYPE_DCPU, dcpu info + MODULE_TYPE_AICORE, AI CORE info + MODULE_TYPE_TSCPU, tscpu info + MODULE_TYPE_PCIE, PCIE info + } DEV_MODULE_TYPE; + * @param [in] infoType info type + typedef enum { + INFO_TYPE_ENV = 0, + INFO_TYPE_VERSION, + INFO_TYPE_MASTERID, + INFO_TYPE_CORE_NUM, + INFO_TYPE_OS_SCHED, + INFO_TYPE_IN_USED, + INFO_TYPE_ERROR_MAP, + INFO_TYPE_OCCUPY, + INFO_TYPE_ID, + INFO_TYPE_IP, + INFO_TYPE_ENDIAN, + } DEV_INFO_TYPE; + * @param [out] value the device info * @return RT_ERROR_NONE for ok * @return RT_ERROR_NO_DEVICE for can not find any device */ -RTS_API rtError_t rtGetDeviceInfo(int32_t device, rtDeviceInfo_t *info); +RTS_API rtError_t rtGetDeviceInfo(uint32_t deviceId, int32_t moduleType, int32_t infoType, int64_t *value); /** * @ingroup dvrt_dev @@ -132,6 +165,25 @@ RTS_API rtError_t rtDisableP2P(uint32_t devIdDes, uint32_t phyIdSrc); /** * @ingroup dvrt_dev + * @brief get status + * @param [in] devIdDes the logical device id + * @param [in] phyIdSrc the physical device id + * @param [in|out] status status value + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_NO_DEVICE for can not find any device + */ +RTS_API rtError_t rtGetP2PStatus(uint32_t devIdDes, uint32_t phyIdSrc, uint32_t *status); + +/** + * @ingroup dvrt_dev + * @brief get value of current thread + * @param [in|out] pid value of pid + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtDeviceGetBareTgid(uint32_t *pid); + +/** + * @ingroup dvrt_dev * @brief get target device of current thread * @param [in|out] device the device id * @return RT_ERROR_NONE for ok @@ -214,6 +266,15 @@ RTS_API rtError_t rtGetRunMode(rtRunMode *mode); /** * @ingroup dvrt_dev + * @brief get aicpu deploy + * @param [out] aicpu deploy + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_DRV_ERR for can not get aicpu deploy + */ +RTS_API rtError_t rtGetAicpuDeploy(rtAicpuDeployType_t *deplyType); + +/** + * @ingroup dvrt_dev * @brief set chipType * @return RT_ERROR_NONE for ok */ @@ -225,6 +286,17 @@ RTS_API rtError_t rtSetSocVersion(const char *version); * @return RT_ERROR_NONE for ok */ rtError_t rtGetSocVersion(char *version, const uint32_t maxLen); + +/** + * @ingroup dvrt_dev + * @brief get status + * @param [in] devId the logical device id + * @param [in] otherDevId the other logical device id + * @param [in] infoType info type + * @param [in|out] value pair info + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtGetPairDevicesInfo(uint32_t devId, uint32_t otherDevId, int32_t infoType, int64_t *value); #ifdef __cplusplus } #endif diff --git a/third_party/fwkacllib/inc/runtime/rt_model.h b/third_party/fwkacllib/inc/runtime/rt_model.h index 790492fc..5c85a3d7 100644 --- a/third_party/fwkacllib/inc/runtime/rt_model.h +++ b/third_party/fwkacllib/inc/runtime/rt_model.h @@ -65,6 +65,13 @@ typedef enum tagModelQueueFlag { #define EXECUTOR_TS ((uint32_t)0x01) #define EXECUTOR_AICPU ((uint32_t)0x02) +/* + * @ingroup rt_model + * @brief debug flag for kernel exception dump + */ +#define RT_DEBUG_FLAG_AICORE_OVERFLOW (0x1 << 0) +#define RT_DEBUG_FLAG_ATOMIC_ADD_OVERFLOW (0x1 << 1) + /** * @ingroup * @brief the type defination of aicpu model task command @@ -403,6 +410,26 @@ RTS_API rtError_t rtModelBindQueue(rtModel_t model, uint32_t queueId, rtModelQue */ RTS_API rtError_t rtModelGetId(rtModel_t model, uint32_t *modelId); +/* + * @ingroup rt_model + * @brief enable debug for dump overflow exception + * @param [in] addr: ddr address of kernel exception dumpped + * @param [in] model: model handle + * @param [in] flag: debug flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input handle + */ +rtError_t rtDebugRegister(rtModel_t model, uint32_t flag, const void *addr, uint32_t *streamId, uint32_t *taskId); + +/* + * @ingroup rt_model + * @brief disable debug for dump overflow exception + * @param [in] model: model handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input handle + */ +RTS_API rtError_t rtDebugUnRegister(rtModel_t model); + #ifdef __cplusplus } #endif diff --git a/third_party/fwkacllib/inc/toolchain/slog.h b/third_party/fwkacllib/inc/toolchain/slog.h index f77df225..261fe866 100644 --- a/third_party/fwkacllib/inc/toolchain/slog.h +++ b/third_party/fwkacllib/inc/toolchain/slog.h @@ -91,6 +91,10 @@ extern "C" { * max log length */ #define MSG_LENGTH 1024 +#define DEBUG_LOG_MASK (0x00010000) +#define SECURITY_LOG_MASK (0x00100000) +#define RUN_LOG_MASK (0x01000000) +#define OPERATION_LOG_MASK (0x10000000) typedef struct tagDCODE { const char *cName; @@ -169,83 +173,11 @@ enum { PROCMGR, // Process Manager, Base Platform BBOX, AIVECTOR, + TBE, + FV, INVLID_MOUDLE_ID }; -#ifdef MODULE_ID_NAME - -/** - * @ingroup slog - * - * set module id to map - */ -#define SET_MOUDLE_ID_MAP_NAME(x) \ - { #x, x } - -static DCODE g_moduleIdName[] = {SET_MOUDLE_ID_MAP_NAME(SLOG), - SET_MOUDLE_ID_MAP_NAME(IDEDD), - SET_MOUDLE_ID_MAP_NAME(IDEDH), - SET_MOUDLE_ID_MAP_NAME(HCCL), - SET_MOUDLE_ID_MAP_NAME(FMK), - SET_MOUDLE_ID_MAP_NAME(HIAIENGINE), - SET_MOUDLE_ID_MAP_NAME(DVPP), - SET_MOUDLE_ID_MAP_NAME(RUNTIME), - SET_MOUDLE_ID_MAP_NAME(CCE), -#if (OS_TYPE == LINUX) - SET_MOUDLE_ID_MAP_NAME(HDC), -#else - SET_MOUDLE_ID_MAP_NAME(HDCL), -#endif // OS_TYPE - SET_MOUDLE_ID_MAP_NAME(DRV), - SET_MOUDLE_ID_MAP_NAME(MDCFUSION), - SET_MOUDLE_ID_MAP_NAME(MDCLOCATION), - SET_MOUDLE_ID_MAP_NAME(MDCPERCEPTION), - SET_MOUDLE_ID_MAP_NAME(MDCFSM), - SET_MOUDLE_ID_MAP_NAME(MDCCOMMON), - SET_MOUDLE_ID_MAP_NAME(MDCMONITOR), - SET_MOUDLE_ID_MAP_NAME(MDCBSWP), - SET_MOUDLE_ID_MAP_NAME(MDCDEFAULT), - SET_MOUDLE_ID_MAP_NAME(MDCSC), - SET_MOUDLE_ID_MAP_NAME(MDCPNC), - SET_MOUDLE_ID_MAP_NAME(MLL), - SET_MOUDLE_ID_MAP_NAME(DEVMM), - SET_MOUDLE_ID_MAP_NAME(KERNEL), - SET_MOUDLE_ID_MAP_NAME(LIBMEDIA), - SET_MOUDLE_ID_MAP_NAME(CCECPU), - SET_MOUDLE_ID_MAP_NAME(ASCENDDK), - SET_MOUDLE_ID_MAP_NAME(ROS), - SET_MOUDLE_ID_MAP_NAME(HCCP), - SET_MOUDLE_ID_MAP_NAME(ROCE), - SET_MOUDLE_ID_MAP_NAME(TEFUSION), - SET_MOUDLE_ID_MAP_NAME(PROFILING), - SET_MOUDLE_ID_MAP_NAME(DP), - SET_MOUDLE_ID_MAP_NAME(APP), - SET_MOUDLE_ID_MAP_NAME(TS), - SET_MOUDLE_ID_MAP_NAME(TSDUMP), - SET_MOUDLE_ID_MAP_NAME(AICPU), - SET_MOUDLE_ID_MAP_NAME(LP), - SET_MOUDLE_ID_MAP_NAME(TDT), - SET_MOUDLE_ID_MAP_NAME(FE), - SET_MOUDLE_ID_MAP_NAME(MD), - SET_MOUDLE_ID_MAP_NAME(MB), - SET_MOUDLE_ID_MAP_NAME(ME), - SET_MOUDLE_ID_MAP_NAME(IMU), - SET_MOUDLE_ID_MAP_NAME(IMP), - SET_MOUDLE_ID_MAP_NAME(GE), - SET_MOUDLE_ID_MAP_NAME(MDCFUSA), - SET_MOUDLE_ID_MAP_NAME(CAMERA), - SET_MOUDLE_ID_MAP_NAME(ASCENDCL), - SET_MOUDLE_ID_MAP_NAME(TEEOS), - SET_MOUDLE_ID_MAP_NAME(ISP), - SET_MOUDLE_ID_MAP_NAME(SIS), - SET_MOUDLE_ID_MAP_NAME(HSM), - SET_MOUDLE_ID_MAP_NAME(DSS), - SET_MOUDLE_ID_MAP_NAME(PROCMGR), - SET_MOUDLE_ID_MAP_NAME(BBOX), - SET_MOUDLE_ID_MAP_NAME(AIVECTOR), - { NULL, -1 }}; -#endif // MODULE_ID_NAME - #if (OS_TYPE == LINUX) /** * @ingroup slog @@ -386,6 +318,11 @@ extern int CheckLogLevel(int moduleId, int logLevel); DlogWithKVInner(moduleId, level, pstKVArray, kvNum, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ } while (0) +/** + * @ingroup slog + * @brief DlogFlush: flush log buffer to file + */ +void DlogFlush(void); /** * @ingroup slog diff --git a/third_party/prebuild/aarch64/liberror_manager.so b/third_party/prebuild/aarch64/liberror_manager.so new file mode 100755 index 00000000..759d8e30 Binary files /dev/null and b/third_party/prebuild/aarch64/liberror_manager.so differ diff --git a/third_party/prebuild/aarch64/libslog.so b/third_party/prebuild/aarch64/libslog.so new file mode 100755 index 00000000..700fc118 Binary files /dev/null and b/third_party/prebuild/aarch64/libslog.so differ diff --git a/third_party/prebuild/x86_64/liberror_manager.so b/third_party/prebuild/x86_64/liberror_manager.so new file mode 100755 index 00000000..cd9ad8bc Binary files /dev/null and b/third_party/prebuild/x86_64/liberror_manager.so differ diff --git a/third_party/prebuild/x86_64/libslog.so b/third_party/prebuild/x86_64/libslog.so index b476618d..01b75e40 100755 Binary files a/third_party/prebuild/x86_64/libslog.so and b/third_party/prebuild/x86_64/libslog.so differ