@@ -60,6 +60,25 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { | |||||
Status AddGraph(uint32_t graphId, const Graph &graph, const std::map<std::string, std::string> &options); | Status AddGraph(uint32_t graphId, const Graph &graph, const std::map<std::string, std::string> &options); | ||||
/// | /// | ||||
/// @ingroup client | |||||
/// @brief add a copy graph with a specific graphId | |||||
/// @param [in] graphId graph id | |||||
/// @param [in] graph the graph | |||||
/// @return Status result of function | |||||
/// | |||||
Status AddGraphWithCopy(uint32_t graph_id, const Graph &graph); | |||||
/// | |||||
/// @ingroup client | |||||
/// @brief add a copy graph with a specific graphId and graphOptions | |||||
/// @param [in] graphId graph id | |||||
/// @param [in] graph the graph | |||||
/// @param [in] options graph options | |||||
/// @return Status result of function | |||||
/// | |||||
Status AddGraphWithCopy(uint32_t graph_id, const Graph &graph, const std::map<AscendString, AscendString> &options); | |||||
/// | |||||
/// @ingroup ge_graph | /// @ingroup ge_graph | ||||
/// @brief remove a graph of the session with specific session id | /// @brief remove a graph of the session with specific session id | ||||
/// @param [in] graphId graph id | /// @param [in] graphId graph id | ||||
@@ -245,6 +245,12 @@ const std::string INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; | |||||
// 0: close debug; 1: open TBE compiler; 2: open ccec compiler | // 0: close debug; 1: open TBE compiler; 2: open ccec compiler | ||||
const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel"; | const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel"; | ||||
// Configure model bank path | |||||
const std::string MDL_BANK_PATH_FLAG = "ge.mdl_bank_path"; | |||||
// Configure op bank path | |||||
const std::string OP_BANK_PATH_FLAG = "ge.op_bank_path"; | |||||
// Graph run mode | // Graph run mode | ||||
enum GraphRunMode { PREDICTION = 0, TRAIN }; | enum GraphRunMode { PREDICTION = 0, TRAIN }; | ||||
@@ -315,13 +321,28 @@ static const char *const OPTYPELIST_FOR_IMPLMODE = ge::OPTYPELIST_FOR_IMPLMODE.c | |||||
static const char *const DEBUG_DIR = ge::DEBUG_DIR; | static const char *const DEBUG_DIR = ge::DEBUG_DIR; | ||||
static const char *const OP_COMPILER_CACHE_DIR = ge::OP_COMPILER_CACHE_DIR; | static const char *const OP_COMPILER_CACHE_DIR = ge::OP_COMPILER_CACHE_DIR; | ||||
static const char *const OP_COMPILER_CACHE_MODE = ge::OP_COMPILER_CACHE_MODE; | static const char *const OP_COMPILER_CACHE_MODE = ge::OP_COMPILER_CACHE_MODE; | ||||
static const char *const MDL_BANK_PATH_FLAG = ge::MDL_BANK_PATH_FLAG.c_str(); | |||||
static const char *const OP_BANK_PATH_FLAG = ge::OP_BANK_PATH_FLAG.c_str(); | |||||
// for interface: aclgrphBuildModel | // for interface: aclgrphBuildModel | ||||
const std::set<std::string> ir_builder_suppported_options = { | |||||
INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, | |||||
DYNAMIC_BATCH_SIZE, DYNAMIC_IMAGE_SIZE, DYNAMIC_DIMS, | |||||
INSERT_OP_FILE, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, | |||||
AUTO_TUNE_MODE, OUTPUT_TYPE, OUT_NODES, | |||||
INPUT_FP16_NODES, LOG_LEVEL}; | |||||
const std::set<std::string> ir_builder_suppported_options = {INPUT_FORMAT, | |||||
INPUT_SHAPE, | |||||
OP_NAME_MAP, | |||||
DYNAMIC_BATCH_SIZE, | |||||
DYNAMIC_IMAGE_SIZE, | |||||
DYNAMIC_DIMS, | |||||
INSERT_OP_FILE, | |||||
PRECISION_MODE, | |||||
EXEC_DISABLE_REUSED_MEMORY, | |||||
AUTO_TUNE_MODE, | |||||
OUTPUT_TYPE, | |||||
OUT_NODES, | |||||
INPUT_FP16_NODES, | |||||
LOG_LEVEL, | |||||
OP_DEBUG_LEVEL, | |||||
DEBUG_DIR, | |||||
OP_COMPILER_CACHE_DIR, | |||||
OP_COMPILER_CACHE_MODE}; | |||||
// for interface: aclgrphParse | // for interface: aclgrphParse | ||||
const std::set<std::string> ir_parser_suppported_options = {INPUT_FORMAT, | const std::set<std::string> ir_parser_suppported_options = {INPUT_FORMAT, | ||||
@@ -336,7 +357,9 @@ const std::set<std::string> ir_parser_suppported_options = {INPUT_FORMAT, | |||||
OUT_NODES, | OUT_NODES, | ||||
COMPRESS_WEIGHT_CONF, | COMPRESS_WEIGHT_CONF, | ||||
ENABLE_SCOPE_FUSION_PASSES, | ENABLE_SCOPE_FUSION_PASSES, | ||||
LOG_LEVEL}; | |||||
LOG_LEVEL, | |||||
MDL_BANK_PATH_FLAG, | |||||
OP_BANK_PATH_FLAG}; | |||||
// for interface: aclgrphBuildInitialize | // for interface: aclgrphBuildInitialize | ||||
const std::set<std::string> global_options = {CORE_TYPE, | const std::set<std::string> global_options = {CORE_TYPE, | ||||
@@ -31,6 +31,18 @@ class AscendString { | |||||
const char* GetString() const; | const char* GetString() const; | ||||
bool operator<(const AscendString& d) const; | |||||
bool operator>(const AscendString& d) const; | |||||
bool operator<=(const AscendString& d) const; | |||||
bool operator>=(const AscendString& d) const; | |||||
bool operator==(const AscendString& d) const; | |||||
bool operator!=(const AscendString& d) const; | |||||
private: | private: | ||||
std::shared_ptr<std::string> name_; | std::shared_ptr<std::string> name_; | ||||
}; | }; | ||||
@@ -94,6 +94,7 @@ using FusionParseParamFunc = | |||||
std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>; | std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>; | ||||
using FusionParseParamByOpFunc = std::function<domi::Status(const std::vector<ge::Operator> &, ge::Operator &)>; | using FusionParseParamByOpFunc = std::function<domi::Status(const std::vector<ge::Operator> &, ge::Operator &)>; | ||||
using ParseSubgraphFunc = std::function<Status(const std::string &subgraph_name, const ge::Graph &graph)>; | using ParseSubgraphFunc = std::function<Status(const std::string &subgraph_name, const ge::Graph &graph)>; | ||||
using ParseOpToGraphFunc = std::function<Status(const ge::Operator &, ge::Graph &)>; | |||||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | ||||
public: | public: | ||||
@@ -125,6 +126,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||||
OpRegistrationData &InputReorderVector(const vector<int> &input_order); | OpRegistrationData &InputReorderVector(const vector<int> &input_order); | ||||
OpRegistrationData &ParseOpToGraphFn(const ParseOpToGraphFunc &parse_op_to_graph_fn); | |||||
domi::ImplyType GetImplyType() const; | domi::ImplyType GetImplyType() const; | ||||
std::string GetOmOptype() const; | std::string GetOmOptype() const; | ||||
std::set<std::string> GetOriginOpTypeSet() const; | std::set<std::string> GetOriginOpTypeSet() const; | ||||
@@ -134,6 +137,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||||
FusionParseParamFunc GetFusionParseParamFn() const; | FusionParseParamFunc GetFusionParseParamFn() const; | ||||
FusionParseParamByOpFunc GetFusionParseParamByOpFn() const; | FusionParseParamByOpFunc GetFusionParseParamByOpFn() const; | ||||
ParseSubgraphFunc GetParseSubgraphPostFn() const; | ParseSubgraphFunc GetParseSubgraphPostFn() const; | ||||
ParseOpToGraphFunc GetParseOpToGraphFn() const; | |||||
private: | private: | ||||
std::shared_ptr<OpRegistrationDataImpl> impl_; | std::shared_ptr<OpRegistrationDataImpl> impl_; | ||||
@@ -18,10 +18,12 @@ | |||||
#define INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ | #define INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ | ||||
#include <string> | #include <string> | ||||
#include <sstream> | |||||
#include "runtime/rt.h" | #include "runtime/rt.h" | ||||
#include "common/string_util.h" | #include "common/string_util.h" | ||||
#include "common/util.h" | #include "common/util.h" | ||||
#include "common/util/error_manager/error_manager.h" | |||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "ge/ge_api_error_codes.h" | #include "ge/ge_api_error_codes.h" | ||||
@@ -253,4 +255,29 @@ | |||||
exec_expr1; \ | exec_expr1; \ | ||||
} | } | ||||
#define GE_ERRORLOG_AND_ERRORMSG(_status, errormsg) \ | |||||
{ \ | |||||
GELOGE(_status, "%s", errormsg); \ | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {errormsg}); \ | |||||
} | |||||
#define GE_CHK_LOG_AND_ERRORMSG(expr, _status, errormsg) \ | |||||
do { \ | |||||
bool b = (expr); \ | |||||
if (!b) { \ | |||||
GELOGE(_status, "%s", errormsg); \ | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {errormsg}); \ | |||||
return _status; \ | |||||
} \ | |||||
} while (0) | |||||
template <typename T> | |||||
std::string FmtToStr(const T &t) { | |||||
std::string fmt; | |||||
std::stringstream st; | |||||
st << "[" << t << "]"; | |||||
fmt = st.str(); | |||||
return fmt; | |||||
} | |||||
#endif // INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ | #endif // INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ |
@@ -70,6 +70,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFIL | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_STOP_VALUE; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_STOP_VALUE; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::map<std::string, std::string> PROFILE_COMPONENT_MAP; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::map<std::string, std::string> PROFILE_COMPONENT_MAP; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_CONFIG; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_CONFIG; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_MODEL_ID; | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASKS; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASKS; | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR; | ||||
@@ -270,6 +270,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||||
static ge::Status ReleaseSingleOpResource(void *stream); | static ge::Status ReleaseSingleOpResource(void *stream); | ||||
static ge::Status GetDeviceIdByModelId(uint32_t model_id, uint32_t &device_id); | |||||
ge::Status GetBatchInfoSize(uint32_t model_id, size_t &shape_count); | ge::Status GetBatchInfoSize(uint32_t model_id, size_t &shape_count); | ||||
ge::Status GetOrigInputInfo(uint32_t model_id, uint32_t index, OriginInputInfo &orig_input_info); | ge::Status GetOrigInputInfo(uint32_t model_id, uint32_t index, OriginInputInfo &orig_input_info); | ||||
ge::Status GetAllAippInputOutputDims(uint32_t model_id, uint32_t index, std::vector<InputOutputDims> &input_dims, | ge::Status GetAllAippInputOutputDims(uint32_t model_id, uint32_t index, std::vector<InputOutputDims> &input_dims, | ||||
@@ -1115,6 +1115,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYN | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_DATATYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_DATATYPE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_FORMAT; | ||||
// atc user def dtype&format | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_OUTPUT_NODES; | |||||
// for fusion op plugin | // for fusion op plugin | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE; | ||||
@@ -42,6 +42,7 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"detail/*.cc" | "detail/*.cc" | ||||
"debug/*.cc" | "debug/*.cc" | ||||
"option/*.cc" | "option/*.cc" | ||||
"transformer/src/*cc" | |||||
) | ) | ||||
# include directories | # include directories | ||||
@@ -30,4 +30,66 @@ const char* AscendString::GetString() const { | |||||
return (*name_).c_str(); | return (*name_).c_str(); | ||||
} | } | ||||
bool AscendString::operator<(const AscendString& d) const { | |||||
if (name_ == nullptr && d.name_ == nullptr) { | |||||
return false; | |||||
} else if (name_ == nullptr) { | |||||
return true; | |||||
} else if (d.name_ == nullptr) { | |||||
return false; | |||||
} | |||||
return (*name_ < *(d.name_)); | |||||
} | |||||
bool AscendString::operator>(const AscendString& d) const { | |||||
if (name_ == nullptr && d.name_ == nullptr) { | |||||
return false; | |||||
} else if (name_ == nullptr) { | |||||
return false; | |||||
} else if (d.name_ == nullptr) { | |||||
return true; | |||||
} | |||||
return (*name_ > *(d.name_)); | |||||
} | |||||
bool AscendString::operator==(const AscendString& d) const { | |||||
if (name_ == nullptr && d.name_ == nullptr) { | |||||
return true; | |||||
} else if (name_ == nullptr) { | |||||
return false; | |||||
} else if (d.name_ == nullptr) { | |||||
return false; | |||||
} | |||||
return (*name_ == *(d.name_)); | |||||
} | |||||
bool AscendString::operator<=(const AscendString& d) const { | |||||
if (name_ == nullptr) { | |||||
return true; | |||||
} else if (d.name_ == nullptr) { | |||||
return false; | |||||
} | |||||
return (*name_ <= *(d.name_)); | |||||
} | |||||
bool AscendString::operator>=(const AscendString& d) const { | |||||
if (d.name_ == nullptr) { | |||||
return true; | |||||
} else if (name_ == nullptr) { | |||||
return false; | |||||
} | |||||
return (*name_ >= *(d.name_)); | |||||
} | |||||
bool AscendString::operator!=(const AscendString& d) const { | |||||
if (name_ == nullptr && d.name_ == nullptr) { | |||||
return false; | |||||
} else if (name_ == nullptr) { | |||||
return true; | |||||
} else if (d.name_ == nullptr) { | |||||
return true; | |||||
} | |||||
return (*name_ != *(d.name_)); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -384,12 +384,15 @@ void FormatRefiner::RefreshOriginFormatOfAnchor(std::vector<ge::NodePtr> &anchor | |||||
continue; | continue; | ||||
} | } | ||||
for (const auto &input_desc : node->GetOpDesc()->GetAllInputsDescPtr()) { | for (const auto &input_desc : node->GetOpDesc()->GetAllInputsDescPtr()) { | ||||
if (input_desc != nullptr) { | |||||
// single op support private format set, its origin format should not be override | |||||
auto ori_format = input_desc->GetOriginFormat(); | |||||
if (input_desc != nullptr && (ori_format == FORMAT_ND || ori_format == FORMAT_RESERVED)) { | |||||
input_desc->SetOriginFormat(input_desc->GetFormat()); | input_desc->SetOriginFormat(input_desc->GetFormat()); | ||||
} | } | ||||
} | } | ||||
for (const auto &output_desc : node->GetOpDesc()->GetAllOutputsDescPtr()) { | for (const auto &output_desc : node->GetOpDesc()->GetAllOutputsDescPtr()) { | ||||
if (output_desc != nullptr) { | |||||
auto ori_format = output_desc->GetOriginFormat(); | |||||
if (output_desc != nullptr && (ori_format == FORMAT_ND || ori_format == FORMAT_RESERVED)) { | |||||
output_desc->SetOriginFormat(output_desc->GetFormat()); | output_desc->SetOriginFormat(output_desc->GetFormat()); | ||||
} | } | ||||
} | } | ||||
@@ -1078,6 +1078,9 @@ const std::string ATTR_NAME_DYNAMIC_INPUT_END = "_dynamic_input_index_end"; | |||||
const std::string ATTR_ATC_USER_DEFINE_DATATYPE = "_user_defined_data_type"; | const std::string ATTR_ATC_USER_DEFINE_DATATYPE = "_user_defined_data_type"; | ||||
const std::string ATTR_ATC_USER_DEFINE_FORMAT = "_user_defined_format"; | const std::string ATTR_ATC_USER_DEFINE_FORMAT = "_user_defined_format"; | ||||
// atc user def dtype&format | |||||
const std::string ATTR_ATC_USER_DEFINE_OUTPUT_NODES = "_user_defined_output_nodes"; | |||||
// for fusion op plugin | // for fusion op plugin | ||||
const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type"; | const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type"; | ||||
@@ -46,6 +46,10 @@ COMMON_LOCAL_SRC_FILES := \ | |||||
option/ge_local_context.cc \ | option/ge_local_context.cc \ | ||||
./runtime_inference_context.cc \ | ./runtime_inference_context.cc \ | ||||
./utils/node_utils.cc \ | ./utils/node_utils.cc \ | ||||
../third_party/transformer/src/axis_util.cpp \ | |||||
../third_party/transformer/src/transfer_shape_according_to_format.cpp \ | |||||
./utils/transformer_utils.cc \ | |||||
COMMON_LOCAL_C_INCLUDES := \ | COMMON_LOCAL_C_INCLUDES := \ | ||||
proto/om.proto \ | proto/om.proto \ | ||||
@@ -57,13 +61,19 @@ COMMON_LOCAL_C_INCLUDES := \ | |||||
proto/op_mapping_info.proto \ | proto/op_mapping_info.proto \ | ||||
proto/dump_task.proto \ | proto/dump_task.proto \ | ||||
inc \ | inc \ | ||||
metadef/inc \ | |||||
graphengine/inc \ | |||||
inc/external \ | inc/external \ | ||||
inc/external/graph \ | |||||
inc/graph \ | |||||
inc/common \ | |||||
common \ | |||||
common/graph \ | |||||
metadef/inc/external \ | |||||
graphengine/inc/external \ | |||||
metadef/inc/external/graph \ | |||||
metadef/inc/graph \ | |||||
metadef/inc/common \ | |||||
metadef \ | |||||
metadef/graph \ | |||||
third_party/protobuf/include \ | third_party/protobuf/include \ | ||||
$(TOPDIR)metadef/third_party \ | |||||
$(TOPDIR)metadef/third_party/transformer/inc \ | |||||
libc_sec/include \ | libc_sec/include \ | ||||
ops/built-in/op_proto/inc \ | ops/built-in/op_proto/inc \ | ||||
cann/ops/built-in/op_proto/inc \ | cann/ops/built-in/op_proto/inc \ | ||||
@@ -27,6 +27,7 @@ | |||||
#include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
#include "graph/utils/ge_ir_utils.h" | #include "graph/utils/ge_ir_utils.h" | ||||
#include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
#include "graph/utils/transformer_utils.h" | |||||
#include "proto/ge_ir.pb.h" | #include "proto/ge_ir.pb.h" | ||||
using std::make_pair; | using std::make_pair; | ||||
@@ -1301,11 +1302,24 @@ graphStatus OpDesc::CallInferFunc(Operator &op) { | |||||
return GRAPH_PARAM_INVALID; | return GRAPH_PARAM_INVALID; | ||||
} | } | ||||
} | } | ||||
std::unique_ptr<NodeShapeTransUtils> transformer(new (std::nothrow) NodeShapeTransUtils(shared_from_this())); | |||||
if (transformer == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "Memory alloc failed"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
if (!transformer->CatchFormatAndShape()) { | |||||
GELOGE(GRAPH_FAILED, "catch format and shape info failed!"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
graphStatus graph_status = (graphStatus)infer_func_(op); | graphStatus graph_status = (graphStatus)infer_func_(op); | ||||
if (graph_status != GRAPH_SUCCESS) { | if (graph_status != GRAPH_SUCCESS) { | ||||
GELOGE(GRAPH_FAILED, "%s call infer func. ret: %u", GetName().c_str(), graph_status); | GELOGE(GRAPH_FAILED, "%s call infer func. ret: %u", GetName().c_str(), graph_status); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
if (!transformer->UpdateFormatAndShape()) { | |||||
GELOGE(GRAPH_FAILED, "catch format and shape info failed!"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
graphStatus OpDesc::CallInferFormatFunc(Operator &op) { | graphStatus OpDesc::CallInferFormatFunc(Operator &op) { | ||||
@@ -1425,7 +1425,10 @@ class GraphBuilderImpl { | |||||
const string name = node->GetName(); | const string name = node->GetName(); | ||||
for (auto &name_idx : op_impl->op_desc_->GetSubgraphNameIndexes()) { | for (auto &name_idx : op_impl->op_desc_->GetSubgraphNameIndexes()) { | ||||
const SubgraphBuilder &builder = op_impl->GetSubgraphBuilder(name_idx.first); | const SubgraphBuilder &builder = op_impl->GetSubgraphBuilder(name_idx.first); | ||||
GE_CHK_BOOL_EXEC(builder != nullptr, return GRAPH_FAILED, "Node: %s, Get builder failed.", name.c_str()); | |||||
if (builder == nullptr) { | |||||
GELOGW("Node: %s, Has no builder.", name.c_str()); | |||||
continue; | |||||
} | |||||
Graph graph = builder(); // Build subgraph from user define builder. | Graph graph = builder(); // Build subgraph from user define builder. | ||||
const ComputeGraphPtr &subgraph = GraphUtils::GetComputeGraph(graph); | const ComputeGraphPtr &subgraph = GraphUtils::GetComputeGraph(graph); | ||||
@@ -26,6 +26,7 @@ | |||||
#include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
#include "debug/ge_op_types.h" | #include "debug/ge_op_types.h" | ||||
#include "debug/ge_util.h" | |||||
#include "external/graph/operator.h" | #include "external/graph/operator.h" | ||||
#include "external/graph/operator_factory.h" | #include "external/graph/operator_factory.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
@@ -41,7 +42,6 @@ const uint32_t kWhileBodySubGraphIdx = 1; | |||||
graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) { | graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) { | ||||
GELOGD("Enter reverse brush while body subgraph process!"); | GELOGD("Enter reverse brush while body subgraph process!"); | ||||
auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx); | auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx); | ||||
if (sub_graph_body == nullptr) { | if (sub_graph_body == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "Get while body graph failed!"); | GELOGE(GRAPH_FAILED, "Get while body graph failed!"); | ||||
@@ -661,10 +661,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferSh | |||||
if (!is_unknown_graph) { | if (!is_unknown_graph) { | ||||
auto inference_context = CreateInferenceContext(context_map, node); | auto inference_context = CreateInferenceContext(context_map, node); | ||||
if (inference_context == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "inference context is null"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
GE_CHECK_NOTNULL(inference_context); | |||||
GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); | GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); | ||||
op.SetInferenceContext(inference_context); | op.SetInferenceContext(inference_context); | ||||
} | } | ||||
@@ -678,8 +675,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferSh | |||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
for (const auto &out_anchor : node->GetAllOutDataAnchors()) { | for (const auto &out_anchor : node->GetAllOutDataAnchors()) { | ||||
auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); | auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); | ||||
ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size())); | |||||
output_tensor->SetOriginShape(output_tensor->GetShape()); | |||||
if (output_tensor->MutableShape().GetDims().empty()) { | |||||
output_tensor->SetOriginShape(output_tensor->GetShape()); | |||||
} | |||||
ge::TensorUtils::SetRealDimCnt(*output_tensor, | |||||
static_cast<uint32_t>(output_tensor->GetOriginShape().GetDims().size())); | |||||
output_tensor->SetOriginDataType(output_tensor->GetDataType()); | output_tensor->SetOriginDataType(output_tensor->GetDataType()); | ||||
GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", | GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", | ||||
@@ -0,0 +1,144 @@ | |||||
/** | |||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file axis_util.h | |||||
* \brief get the axis value | |||||
*/ | |||||
#ifndef COMMON_UTILS_TRANSFER_AXIS_UTIL_H_ | |||||
#define COMMON_UTILS_TRANSFER_AXIS_UTIL_H_ | |||||
#include <memory.h> | |||||
#include <functional> | |||||
#include <vector> | |||||
#include "external/graph/ge_error_codes.h" | |||||
#include "external/graph/types.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
namespace common { | |||||
namespace transformer { | |||||
const int32_t DIM_DEFAULT_SIZE = 4; | |||||
const uint32_t NCHW_DIMENSION_NUM = 4; | |||||
const int32_t AXIS_NCHW_DIM_N = 0; | |||||
const int32_t AXIS_NCHW_DIM_C = 1; | |||||
const int32_t AXIS_NCHW_DIM_H = 2; | |||||
const int32_t AXIS_NCHW_DIM_W = 3; | |||||
const int32_t AXIS_NHWC_DIM_N = 0; | |||||
const int32_t AXIS_NHWC_DIM_H = 1; | |||||
const int32_t AXIS_NHWC_DIM_W = 2; | |||||
const int32_t AXIS_NHWC_DIM_C = 3; | |||||
const int32_t AXIS_NC1HWC0_DIM_N = 0; | |||||
const int32_t AXIS_NC1HWC0_DIM_C1 = 1; | |||||
const int32_t AXIS_NC1HWC0_DIM_C0 = 4; | |||||
const int32_t AXIS_NC1HWC0_DIM_H = 2; | |||||
const int32_t AXIS_NC1HWC0_DIM_W = 3; | |||||
const int32_t AXIS_HWCN_DIM_H = 0; | |||||
const int32_t AXIS_HWCN_DIM_W = 1; | |||||
const int32_t AXIS_HWCN_DIM_C = 2; | |||||
const int32_t AXIS_HWCN_DIM_N = 3; | |||||
const int32_t AXIS_C1HWNCoC0_DIM_C1 = 0; | |||||
const int32_t AXIS_C1HWNCoC0_DIM_H = 1; | |||||
const int32_t AXIS_C1HWNCoC0_DIM_W = 2; | |||||
const int32_t AXIS_C1HWNCoC0_DIM_N = 3; | |||||
const int32_t AXIS_C1HWNCoC0_DIM_Co = 4; | |||||
const int32_t AXIS_C1HWNCoC0_DIM_C0 = 5; | |||||
#define CHECK_NOTNULL(val) \ | |||||
do { \ | |||||
if ((val) == nullptr) { \ | |||||
GELOGE(GRAPH_FAILED, "[ERROR]Parameter[%s] must not be null.", #val); \ | |||||
return false; \ | |||||
} \ | |||||
} while (0) | |||||
#define CHECK(cond, log_func, return_expr) \ | |||||
do { \ | |||||
if (cond) { \ | |||||
log_func; \ | |||||
return_expr; \ | |||||
} \ | |||||
} while (0) | |||||
enum AxisValueType { | |||||
AXIS_N = 0, | |||||
AXIS_C = 1, | |||||
AXIS_H = 2, | |||||
AXIS_W = 3, | |||||
AXIS_C1 = 4, | |||||
AXIS_C0 = 5, | |||||
AXIS_Co = 6, | |||||
AXIS_D = 7, | |||||
AXIS_BOTTOM = 8 | |||||
}; | |||||
int64_t DivisionCeiling(int64_t dividend, int64_t divisor); | |||||
/* Axis value is arranged as {N,C,H,W,C1,C0,...} */ | |||||
/* The first parameter is old shape's dimension, | |||||
* second is c0 and third is axis value. */ | |||||
using GetAxisValueInfoByFormat = | |||||
std::function<bool(const std::vector<int64_t>&, const uint32_t&, std::vector<int64_t>&, std::vector<int64_t>&)>; | |||||
using GetAxisValueInfoByFormatPtr = std::shared_ptr<GetAxisValueInfoByFormat>; | |||||
class AxisUtil { | |||||
public: | |||||
AxisUtil(); | |||||
~AxisUtil(){}; | |||||
bool GetAxisValueByOriginFormat(const ge::Format& format, const std::vector<int64_t>& dimVec, const uint32_t& c0, | |||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
bool HasAxisValueFunc(const ge::Format& format); | |||||
private: | |||||
static bool CheckParams(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
static bool GetAxisValueByNCHW(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
static bool GetAxisValueByNHWC(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
static bool GetAxisValueByNC1HWC0(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
static bool GetAxisValueByFz(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
static bool GetAxisValueByHWCN(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
static bool GetAxisValueByND(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
static bool GetAxisValueByC1HWNCoC0(const std::vector<int64_t>& originalDimVec, const uint32_t& c0, | |||||
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue); | |||||
/* map of GetAxisValueInfoByFormat, get axis value by different original | |||||
* formats. */ | |||||
std::map<ge::Format, GetAxisValueInfoByFormatPtr> getAxisValueFuncMap; | |||||
}; | |||||
} // namespace transformer | |||||
} // namespace common | |||||
#endif // COMMON_UTILS_TRANSFER_AXIS_UTIL_H_ |
@@ -0,0 +1,122 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file transfer_shape_according_to_format.h | |||||
* \brief set shape according to original format and current format | |||||
*/ | |||||
#ifndef COMMON_UTILS_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ | |||||
#define COMMON_UTILS_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ | |||||
#include "transformer/inc/axis_util.h" | |||||
#include <memory.h> | |||||
#include <functional> | |||||
#include <vector> | |||||
#include "graph/types.h" | |||||
#include "graph/utils/op_desc_utils.h" | |||||
namespace common { | |||||
namespace transformer { | |||||
enum OpImplType { | |||||
EN_IMPL_CUSTOM_CONSTANT_CCE = 0, // custom constant op | |||||
EN_IMPL_CUSTOM_TIK, // custom tik op | |||||
EN_IMPL_CUSTOM_TBE, // custom tbe op | |||||
EN_IMPL_HW_CONSTANT_CCE, // Huawei built-in constant op | |||||
EN_IMPL_HW_GENERAL_CCE, // Huawei built-in cce op | |||||
EN_IMPL_HW_TIK, // Huawei built-in tik op | |||||
EN_IMPL_HW_TBE, // Huawei built-in tbe op | |||||
EN_IMPL_RL, // RL op | |||||
EN_IMPL_PLUGIN_TBE, // Huawei built-in tbe plugin op | |||||
EN_IMPL_VECTOR_CORE_HW_TBE, // Huawei built-in tbe op | |||||
EN_IMPL_VECTOR_CORE_CUSTOM_TBE, // custom tbe op | |||||
EN_IMPL_NON_PERSISTENT_CUSTOM_TBE, // custom tbe op | |||||
EN_RESERVED // reserved value | |||||
}; | |||||
const uint32_t SHAPE_NUMBER_16 = 16; | |||||
const uint32_t SHAPE_NUMBER_32 = 32; | |||||
const uint32_t SHAPE_DIM_VALUE_C04 = 4; | |||||
const uint32_t NI = 16; | |||||
const uint32_t MINUS_VALUE_ONE = 1; | |||||
const uint32_t MINUS_VALUE_TWO = 2; | |||||
const uint32_t SIZE_OF_CN = 2; | |||||
const uint32_t MINIMUM_NZ_SHAPE_DIM_NUM = 2; | |||||
/* The first parameter is axis value, second is new shape and third is | |||||
* op implementation type. */ | |||||
using GetNewShapeByAxisValueAndFormat = | |||||
std::function<bool(vector<int64_t> &, const int64_t &, vector<int64_t> &, vector<int64_t> &)>; | |||||
using GetNewShapeByAxisValueAndFormatPtr = std::shared_ptr<GetNewShapeByAxisValueAndFormat>; | |||||
struct ShapeAndFormatInfo { | |||||
const std::vector<int64_t> &oldShape; | |||||
std::vector<int64_t> &newShape; | |||||
const ge::Format &oldFormat; | |||||
const ge::Format &newFormat; | |||||
const ge::DataType ¤tDataType; | |||||
const int64_t &opImplType; | |||||
}; | |||||
using ShapeAndFormat = struct ShapeAndFormatInfo; | |||||
class ShapeTransferAccordingToFormat { | |||||
public: | |||||
ShapeTransferAccordingToFormat(); | |||||
~ShapeTransferAccordingToFormat(){}; | |||||
ShapeTransferAccordingToFormat(const ShapeTransferAccordingToFormat &) = delete; | |||||
ShapeTransferAccordingToFormat &operator=(const ShapeTransferAccordingToFormat &) = delete; | |||||
bool GetShapeAccordingToFormat(ShapeAndFormat &inputAndOutputInfo, int64_t *c = nullptr); | |||||
/* ----------Below is the function of getting new shape---------------------- */ | |||||
static bool GetNCHWShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType, | |||||
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue); | |||||
static bool GetNHWCShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType, | |||||
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue); | |||||
static bool GetNC1HWC0ShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType, | |||||
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue); | |||||
static bool GetFzShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType, | |||||
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue); | |||||
static bool GetHWCNShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType, | |||||
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue); | |||||
static bool GetC1HWNCoC0ShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType, | |||||
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue); | |||||
static bool GetNzShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType, | |||||
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue); | |||||
private: | |||||
/* map of GetAxisValueInfoByFormat, get axis value by different original | |||||
* formats. */ | |||||
std::map<ge::Format, GetNewShapeByAxisValueAndFormatPtr> getNewShapeFuncMap; | |||||
std::map<ge::DataType, uint32_t> mapOfDtypeAndC0; | |||||
}; | |||||
} // namespace transformer | |||||
} // namespace common | |||||
#endif // COMMON_UTILS_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ |
@@ -0,0 +1,198 @@ | |||||
/** | |||||
* Copyright 2019 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file axis_util.cpp | |||||
* \brief get the axis value | |||||
*/ | |||||
#include "transformer/inc/axis_util.h" | |||||
#include "graph/types.h" | |||||
namespace common { | |||||
namespace transformer { | |||||
using namespace ge; | |||||
using namespace std; | |||||
AxisUtil::AxisUtil() { | |||||
getAxisValueFuncMap = {{FORMAT_NCHW, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNCHW)}, | |||||
{FORMAT_NHWC, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNHWC)}, | |||||
{FORMAT_NC1HWC0, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNC1HWC0)}, | |||||
{FORMAT_HWCN, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByHWCN)}, | |||||
{FORMAT_ND, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByND)}, | |||||
{FORMAT_C1HWNCoC0, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByC1HWNCoC0)}}; | |||||
} | |||||
int64_t DivisionCeiling(int64_t dividend, int64_t divisor) { | |||||
if (divisor == 0) { | |||||
return 0; | |||||
} else { | |||||
return (dividend + divisor - 1) / divisor; | |||||
} | |||||
} | |||||
bool AxisUtil::GetAxisValueByOriginFormat(const Format &format, const vector<int64_t> &dimVec, const uint32_t &c0, | |||||
vector<int64_t> &axisValue, vector<int64_t> &ndValue) { | |||||
auto iterGetAxisFunc = getAxisValueFuncMap.find(format); | |||||
if (iterGetAxisFunc == getAxisValueFuncMap.end()) { | |||||
GELOGI("Can not get axis value of old format %u!", format); | |||||
return false; | |||||
} | |||||
GetAxisValueInfoByFormatPtr getAxisFunc = iterGetAxisFunc->second; | |||||
CHECK_NOTNULL(getAxisFunc); | |||||
return (*getAxisFunc)(dimVec, c0, axisValue, ndValue); | |||||
} | |||||
bool AxisUtil::HasAxisValueFunc(const Format &format) { | |||||
auto iterGetAxisFunc = getAxisValueFuncMap.find(format); | |||||
if (iterGetAxisFunc == getAxisValueFuncMap.end()) { | |||||
GELOGI("Can not get axis value of format %u!", format); | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
bool AxisUtil::CheckParams(const vector<int64_t> &originalDimVec, const uint32_t &c0, vector<int64_t> &axisValue, | |||||
vector<int64_t> &ndValue) { | |||||
ndValue = originalDimVec; | |||||
auto dimSize = originalDimVec.size(); | |||||
if (dimSize < DIM_DEFAULT_SIZE) { | |||||
/* Before this funcion, we should call function PadDimensionTo4. */ | |||||
GELOGI("Dimension size %zu is invalid.", dimSize); | |||||
return false; | |||||
} | |||||
if (c0 == 0) { | |||||
GELOGE(GRAPH_FAILED, "[ERROR]c0 is zero!"); | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
bool AxisUtil::GetAxisValueByND(const vector<int64_t> &originalDimVec, const uint32_t &c0, vector<int64_t> &axisValue, | |||||
vector<int64_t> &ndValue) { | |||||
CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true); | |||||
CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true); | |||||
ndValue = originalDimVec; | |||||
/* To differentiate the input datatype of int8 and others */ | |||||
axisValue[AXIS_C0] = c0; | |||||
if (originalDimVec.size() == NCHW_DIMENSION_NUM) { | |||||
axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N]; | |||||
axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C]; | |||||
axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H]; | |||||
axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W]; | |||||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0); | |||||
axisValue[AXIS_Co] = c0; | |||||
} | |||||
return true; | |||||
} | |||||
bool AxisUtil::GetAxisValueByNCHW(const vector<int64_t> &originalDimVec, const uint32_t &c0, vector<int64_t> &axisValue, | |||||
vector<int64_t> &ndValue) { | |||||
CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true); | |||||
CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true); | |||||
/* C0 Must be set for case ND or 2D-NCHW to NZ */ | |||||
axisValue[AXIS_C0] = c0; | |||||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGE(GRAPH_FAILED,"[ERROR]Parameter is invalid!"), | |||||
return false); | |||||
axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N]; | |||||
axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C]; | |||||
axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H]; | |||||
axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W]; | |||||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0); | |||||
axisValue[AXIS_Co] = c0; | |||||
return true; | |||||
} | |||||
bool AxisUtil::GetAxisValueByNHWC(const vector<int64_t> &originalDimVec, const uint32_t &c0, vector<int64_t> &axisValue, | |||||
vector<int64_t> &ndValue) { | |||||
CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true); | |||||
CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true); | |||||
/* C0 Must be set for case ND or 2D-NHWC to NZ */ | |||||
axisValue[AXIS_C0] = c0; | |||||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGE(GRAPH_FAILED, "[ERROR]Parameter is invalid!"), | |||||
return false); | |||||
axisValue[AXIS_N] = originalDimVec[AXIS_NHWC_DIM_N]; | |||||
axisValue[AXIS_C] = originalDimVec[AXIS_NHWC_DIM_C]; | |||||
axisValue[AXIS_H] = originalDimVec[AXIS_NHWC_DIM_H]; | |||||
axisValue[AXIS_W] = originalDimVec[AXIS_NHWC_DIM_W]; | |||||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NHWC_DIM_C], (int64_t)c0); | |||||
axisValue[AXIS_Co] = c0; | |||||
return true; | |||||
} | |||||
bool AxisUtil::GetAxisValueByNC1HWC0(const vector<int64_t> &originalDimVec, const uint32_t &c0, | |||||
vector<int64_t> &axisValue, vector<int64_t> &ndValue) { | |||||
CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true); | |||||
CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true); | |||||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGE(GRAPH_FAILED,"[ERROR]Parameter is invalid!"), | |||||
return false); | |||||
auto dimSize = originalDimVec.size(); | |||||
if (dimSize == DIM_DEFAULT_SIZE + 1) { | |||||
axisValue[AXIS_C1] = originalDimVec[AXIS_NC1HWC0_DIM_C1]; | |||||
axisValue[AXIS_C0] = originalDimVec[AXIS_NC1HWC0_DIM_C0]; | |||||
axisValue[AXIS_C] = axisValue[AXIS_C1] * axisValue[AXIS_C0]; | |||||
} else { | |||||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0); | |||||
axisValue[AXIS_C0] = c0; | |||||
axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C]; | |||||
} | |||||
axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N]; | |||||
axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H]; | |||||
axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W]; | |||||
return true; | |||||
} | |||||
bool AxisUtil::GetAxisValueByHWCN(const vector<int64_t> &originalDimVec, const uint32_t &c0, vector<int64_t> &axisValue, | |||||
vector<int64_t> &ndValue) { | |||||
CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true); | |||||
CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true); | |||||
/* C0 Must be set for case ND or 2D-NHWC to NZ */ | |||||
axisValue[AXIS_C0] = c0; | |||||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGE(GRAPH_FAILED, "[ERROR]Parameter is invalid!"), | |||||
return false); | |||||
axisValue[AXIS_N] = originalDimVec[AXIS_HWCN_DIM_N]; | |||||
axisValue[AXIS_C] = originalDimVec[AXIS_HWCN_DIM_C]; | |||||
axisValue[AXIS_H] = originalDimVec[AXIS_HWCN_DIM_H]; | |||||
axisValue[AXIS_W] = originalDimVec[AXIS_HWCN_DIM_W]; | |||||
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_HWCN_DIM_C], (int64_t)c0); | |||||
axisValue[AXIS_Co] = c0; | |||||
return true; | |||||
} | |||||
bool AxisUtil::GetAxisValueByC1HWNCoC0(const vector<int64_t> &originalDimVec, const uint32_t &c0, | |||||
vector<int64_t> &axisValue, vector<int64_t> &ndValue) { | |||||
CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true); | |||||
CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true); | |||||
/* C0 Must be set for case ND or 2D-NHWC to NZ */ | |||||
axisValue[AXIS_C0] = c0; | |||||
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGE(GRAPH_FAILED, "[ERROR]Parameter is invalid!"), | |||||
return false); | |||||
axisValue[AXIS_N] = originalDimVec[AXIS_C1HWNCoC0_DIM_N]; | |||||
axisValue[AXIS_C] = originalDimVec[AXIS_C1HWNCoC0_DIM_C1] * c0; | |||||
axisValue[AXIS_H] = originalDimVec[AXIS_C1HWNCoC0_DIM_H]; | |||||
axisValue[AXIS_W] = originalDimVec[AXIS_C1HWNCoC0_DIM_W]; | |||||
axisValue[AXIS_C1] = originalDimVec[AXIS_C1HWNCoC0_DIM_C1]; | |||||
axisValue[AXIS_Co] = originalDimVec[AXIS_C1HWNCoC0_DIM_Co]; | |||||
return true; | |||||
} | |||||
} // namespace transformer | |||||
} // namespace common |
@@ -0,0 +1,242 @@ | |||||
/** | |||||
* Copyright 2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file transfer_shape_according_to_format.cpp | |||||
* \brief set shape according to original format and current format | |||||
*/ | |||||
#include "transformer/inc/transfer_shape_according_to_format.h" | |||||
namespace common { | |||||
namespace transformer { | |||||
using namespace ge; | |||||
using namespace std; | |||||
ShapeTransferAccordingToFormat::ShapeTransferAccordingToFormat(void) { | |||||
getNewShapeFuncMap = { | |||||
{ge::FORMAT_NCHW, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNCHWShapeByAxisValue)}, | |||||
{ge::FORMAT_NHWC, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNHWCShapeByAxisValue)}, | |||||
{ge::FORMAT_NC1HWC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNC1HWC0ShapeByAxisValue)}, | |||||
{ge::FORMAT_FRACTAL_Z, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetFzShapeByAxisValue)}, | |||||
{ge::FORMAT_HWCN, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetHWCNShapeByAxisValue)}, | |||||
{ge::FORMAT_C1HWNCoC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetC1HWNCoC0ShapeByAxisValue)}, | |||||
{ge::FORMAT_FRACTAL_NZ, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNzShapeByAxisValue)}}; | |||||
mapOfDtypeAndC0 = { | |||||
{ge::DT_FLOAT16, SHAPE_NUMBER_16}, {ge::DT_FLOAT, SHAPE_NUMBER_16}, {ge::DT_INT8, SHAPE_NUMBER_32}, | |||||
{ge::DT_INT16, SHAPE_NUMBER_16}, {ge::DT_INT32, SHAPE_NUMBER_16}, {ge::DT_INT64, SHAPE_NUMBER_16}, | |||||
{ge::DT_UINT8, SHAPE_NUMBER_16}, {ge::DT_UINT16, SHAPE_NUMBER_32}, {ge::DT_UINT32, SHAPE_NUMBER_16}, | |||||
{ge::DT_UINT64, SHAPE_NUMBER_16}, {ge::DT_BOOL, SHAPE_NUMBER_16}}; | |||||
} | |||||
bool ShapeTransferAccordingToFormat::GetNCHWShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType, | |||||
const vector<int64_t>& axisValue, | |||||
const vector<int64_t>& ndValue) { | |||||
CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true); | |||||
/* axisValue is initialized as a size 6 vector. */ | |||||
newShape.push_back(axisValue[AXIS_N]); | |||||
newShape.push_back(axisValue[AXIS_C]); | |||||
newShape.push_back(axisValue[AXIS_H]); | |||||
newShape.push_back(axisValue[AXIS_W]); | |||||
return true; | |||||
} | |||||
bool ShapeTransferAccordingToFormat::GetNHWCShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType, | |||||
const vector<int64_t>& axisValue, | |||||
const vector<int64_t>& ndValue) { | |||||
CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true); | |||||
/* axisValue is initialized as a size 6 vector. */ | |||||
newShape.push_back(axisValue[AXIS_N]); | |||||
newShape.push_back(axisValue[AXIS_H]); | |||||
newShape.push_back(axisValue[AXIS_W]); | |||||
newShape.push_back(axisValue[AXIS_C]); | |||||
return true; | |||||
} | |||||
bool ShapeTransferAccordingToFormat::GetNC1HWC0ShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType, | |||||
const vector<int64_t>& axisValue, | |||||
const vector<int64_t>& ndValue) { | |||||
CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true); | |||||
/* axisValue is initialized as a size 6 vector. */ | |||||
if (implType == EN_IMPL_HW_TBE || implType == EN_IMPL_CUSTOM_TBE || implType == EN_IMPL_NON_PERSISTENT_CUSTOM_TBE) { | |||||
newShape.push_back(axisValue[AXIS_N]); | |||||
newShape.push_back(axisValue[AXIS_C1]); | |||||
newShape.push_back(axisValue[AXIS_H]); | |||||
newShape.push_back(axisValue[AXIS_W]); | |||||
newShape.push_back(axisValue[AXIS_C0]); | |||||
} else { | |||||
newShape.push_back(axisValue[AXIS_N]); | |||||
newShape.push_back(axisValue[AXIS_C]); | |||||
newShape.push_back(axisValue[AXIS_H]); | |||||
newShape.push_back(axisValue[AXIS_W]); | |||||
} | |||||
return true; | |||||
} | |||||
bool ShapeTransferAccordingToFormat::GetFzShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType, | |||||
const vector<int64_t>& axisValue, | |||||
const vector<int64_t>& ndValue) { | |||||
CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true); | |||||
/* axisValue is initialized as a size 6 vector. */ | |||||
if (ndValue.size() == SIZE_OF_CN) { | |||||
auto sizeOfOriginalVec = ndValue.size(); | |||||
newShape = ndValue; | |||||
/* sizeOfOriginalVec - 1 mean the last value of original vec | |||||
* sizeOfOriginalVec - 2 mean the second last value of original vec */ | |||||
newShape[sizeOfOriginalVec - MINUS_VALUE_ONE] = | |||||
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], SHAPE_NUMBER_16); | |||||
newShape[sizeOfOriginalVec - MINUS_VALUE_TWO] = | |||||
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], axisValue[AXIS_C0]); | |||||
newShape.push_back(SHAPE_NUMBER_16); | |||||
newShape.push_back(axisValue[AXIS_C0]); | |||||
} else { | |||||
if (implType == EN_IMPL_HW_TBE || implType == EN_IMPL_CUSTOM_TBE || implType == EN_IMPL_NON_PERSISTENT_CUSTOM_TBE) { | |||||
int64_t hwc1 = axisValue[AXIS_C1] * axisValue[AXIS_H] * axisValue[AXIS_W]; | |||||
newShape.push_back(hwc1); | |||||
newShape.push_back(DivisionCeiling(axisValue[AXIS_N], NI)); | |||||
newShape.push_back(NI); | |||||
newShape.push_back(axisValue[AXIS_C0]); | |||||
} else { | |||||
newShape.push_back(axisValue[AXIS_N]); | |||||
newShape.push_back(axisValue[AXIS_C]); | |||||
newShape.push_back(axisValue[AXIS_H]); | |||||
newShape.push_back(axisValue[AXIS_W]); | |||||
} | |||||
} | |||||
return true; | |||||
} | |||||
bool ShapeTransferAccordingToFormat::GetHWCNShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType, | |||||
const vector<int64_t>& axisValue, | |||||
const vector<int64_t>& ndValue) { | |||||
CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true); | |||||
/* axisValue is initialized as a size 6 vector. */ | |||||
newShape.push_back(axisValue[AXIS_H]); | |||||
newShape.push_back(axisValue[AXIS_W]); | |||||
newShape.push_back(axisValue[AXIS_C]); | |||||
newShape.push_back(axisValue[AXIS_N]); | |||||
return true; | |||||
} | |||||
bool ShapeTransferAccordingToFormat::GetC1HWNCoC0ShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType, | |||||
const vector<int64_t>& axisValue, | |||||
const vector<int64_t>& ndValue) { | |||||
CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true); | |||||
/* axisValue is initialized as a size 6 vector. */ | |||||
newShape.push_back(axisValue[AXIS_C1]); | |||||
newShape.push_back(axisValue[AXIS_H]); | |||||
newShape.push_back(axisValue[AXIS_W]); | |||||
newShape.push_back(axisValue[AXIS_N]); | |||||
newShape.push_back(axisValue[AXIS_Co]); | |||||
newShape.push_back(axisValue[AXIS_C0]); | |||||
return true; | |||||
} | |||||
bool ShapeTransferAccordingToFormat::GetNzShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType, | |||||
const vector<int64_t>& axisValue, | |||||
const vector<int64_t>& ndValue) { | |||||
CHECK(ndValue.empty(), GELOGD("ndValue is empty!"), return true); | |||||
CHECK(axisValue.empty() || axisValue.size() <= AXIS_C0, | |||||
GELOGD("AxisValue is empty or its size %zu <= AXIS_C0[%u]", axisValue.size(), AXIS_C0), return true); | |||||
uint32_t sizeOfOriginalVec = ndValue.size(); | |||||
if (sizeOfOriginalVec < MINIMUM_NZ_SHAPE_DIM_NUM) { | |||||
GELOGD("ndValue's dim num is less than 2!"); | |||||
return true; | |||||
} | |||||
/* axisValue is initialized as a size 6 vector. */ | |||||
newShape = ndValue; | |||||
/* sizeOfOriginalVec - 1 mean the last value of original vec | |||||
* sizeOfOriginalVec - 2 mean the second last value of original vec */ | |||||
newShape[sizeOfOriginalVec - MINUS_VALUE_ONE] = | |||||
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], (int64_t)SHAPE_NUMBER_16); | |||||
newShape[sizeOfOriginalVec - MINUS_VALUE_TWO] = | |||||
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], axisValue[AXIS_C0]); | |||||
newShape.push_back(SHAPE_NUMBER_16); | |||||
newShape.push_back(axisValue[AXIS_C0]); | |||||
return true; | |||||
} | |||||
bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(ShapeAndFormat& shapeAndFormatInfo, int64_t* c) { | |||||
/* The default new shape is old shape */ | |||||
shapeAndFormatInfo.newShape = shapeAndFormatInfo.oldShape; | |||||
if (shapeAndFormatInfo.oldFormat >= ge::FORMAT_RESERVED || shapeAndFormatInfo.newFormat >= ge::FORMAT_RESERVED) { | |||||
GELOGE(GRAPH_FAILED, "Old format %u or new format %u is invalid!", shapeAndFormatInfo.oldFormat, | |||||
shapeAndFormatInfo.newFormat); | |||||
return false; | |||||
} | |||||
if (shapeAndFormatInfo.currentDataType >= ge::DT_UNDEFINED) { | |||||
GELOGE(GRAPH_FAILED, "currentDataType %u is invalid!", shapeAndFormatInfo.currentDataType); | |||||
return false; | |||||
} | |||||
AxisUtil* axisutil_object = new AxisUtil(); | |||||
if (!axisutil_object->HasAxisValueFunc(shapeAndFormatInfo.oldFormat)) { | |||||
delete axisutil_object; | |||||
return true; | |||||
} | |||||
auto iterGetNewShapeFunc = getNewShapeFuncMap.find(shapeAndFormatInfo.newFormat); | |||||
if (iterGetNewShapeFunc == getNewShapeFuncMap.end()) { | |||||
GELOGD("Can not get new shape of new format %u!", shapeAndFormatInfo.newFormat); | |||||
delete axisutil_object; | |||||
return true; | |||||
} | |||||
GELOGD("Original format %u, new format %u", shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.newFormat); | |||||
GetNewShapeByAxisValueAndFormatPtr getNewShapeFunc = iterGetNewShapeFunc->second; | |||||
CHECK_NOTNULL(getNewShapeFunc); | |||||
std::vector<int64_t> axisValue; | |||||
for (uint32_t i = 0; i < AXIS_BOTTOM; i++) { | |||||
axisValue.push_back(1); | |||||
} | |||||
std::vector<int64_t> ndValue; | |||||
uint32_t c0; | |||||
if (mapOfDtypeAndC0.empty()) { | |||||
c0 = SHAPE_NUMBER_16; | |||||
} else { | |||||
auto iterGetC0 = mapOfDtypeAndC0.find(shapeAndFormatInfo.currentDataType); | |||||
if (iterGetC0 == mapOfDtypeAndC0.end()) { | |||||
GELOGE(GRAPH_FAILED, "Dtype is not support."); | |||||
delete axisutil_object; | |||||
return true; | |||||
} | |||||
c0 = iterGetC0->second; | |||||
} | |||||
// The value of C0 should be 4 while format is 5HD-4 or FRAZ-4 | |||||
if (shapeAndFormatInfo.newFormat == ge::FORMAT_NC1HWC0_C04) { | |||||
c0 = SHAPE_DIM_VALUE_C04; | |||||
} | |||||
bool status = axisutil_object->GetAxisValueByOriginFormat( | |||||
shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.oldShape, c0, axisValue, ndValue); | |||||
if (status != true && shapeAndFormatInfo.newFormat != ge::FORMAT_FRACTAL_NZ) { | |||||
delete axisutil_object; | |||||
return true; | |||||
} | |||||
delete axisutil_object; | |||||
shapeAndFormatInfo.newShape.clear(); | |||||
(*getNewShapeFunc)(shapeAndFormatInfo.newShape, shapeAndFormatInfo.opImplType, axisValue, ndValue); | |||||
if (c != nullptr) { | |||||
*c = axisValue[AXIS_C]; | |||||
} | |||||
return true; | |||||
} | |||||
} // namespace transformer | |||||
} // namespace common |
@@ -0,0 +1,160 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "transformer_utils.h" | |||||
#include "external/ge/ge_api_types.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "graph/utils/type_utils.h" | |||||
namespace ge { | |||||
bool NodeShapeTransUtils::CatchFormatAndShape() { | |||||
inputs_ = op_desc_->GetAllInputName(); | |||||
outputs_ = op_desc_->GetAllOutputName(); | |||||
for (auto &ele : inputs_) { | |||||
auto tensor_desc_input = op_desc_->MutableInputDesc(ele.first); | |||||
if (tensor_desc_input == nullptr) { | |||||
continue; | |||||
} | |||||
auto format = tensor_desc_input->GetFormat(); | |||||
auto ori_format = tensor_desc_input->GetOriginFormat(); | |||||
if (format == ori_format) { | |||||
GELOGD("Node is %s, input tensor name is %s. ori format: %s, format: %s is same! No need to catch format&shape!", | |||||
op_desc_->GetName().c_str(), ele.first.c_str(), TypeUtils::FormatToSerialString(ori_format).c_str(), | |||||
TypeUtils::FormatToSerialString(format).c_str()); | |||||
continue; | |||||
} | |||||
map_format_in_.insert(std::pair<std::string, Format>(ele.first, format)); | |||||
map_ori_format_in_.insert(std::pair<std::string, Format>(ele.first, ori_format)); | |||||
map_dtype_in_.insert(std::pair<std::string, DataType>(ele.first, tensor_desc_input->GetDataType())); | |||||
tensor_desc_input->SetFormat(ori_format); | |||||
tensor_desc_input->SetShape(tensor_desc_input->GetOriginShape()); | |||||
} | |||||
for (auto &ele : outputs_) { | |||||
auto tensor_desc_output = op_desc_->MutableOutputDesc(ele.first); | |||||
if (tensor_desc_output == nullptr) { | |||||
continue; | |||||
} | |||||
auto format = tensor_desc_output->GetFormat(); | |||||
auto ori_format = tensor_desc_output->GetOriginFormat(); | |||||
if (format == ori_format) { | |||||
GELOGD("Node is %s, output tensor name is %s. ori format: %s, format: %s is same! No need to catch format&shape!", | |||||
op_desc_->GetName().c_str(), ele.first.c_str(), TypeUtils::FormatToSerialString(ori_format).c_str(), | |||||
TypeUtils::FormatToSerialString(format).c_str()); | |||||
continue; | |||||
} | |||||
map_format_out_.insert(std::pair<std::string, Format>(ele.first, format)); | |||||
map_ori_format_out_.insert(std::pair<std::string, Format>(ele.first, ori_format)); | |||||
map_dtype_out_.insert(std::pair<std::string, DataType>(ele.first, tensor_desc_output->GetDataType())); | |||||
if (format == ori_format) { | |||||
continue; | |||||
} | |||||
tensor_desc_output->SetFormat(ori_format); | |||||
} | |||||
return true; | |||||
} | |||||
bool NodeShapeTransUtils::UpdateFormatAndShape() { | |||||
for (auto &ele : inputs_) { | |||||
auto tensor_desc_input = op_desc_->MutableInputDesc(ele.first); | |||||
if (tensor_desc_input == nullptr) { | |||||
continue; | |||||
} | |||||
// if can not find saved info, it says format and origin format is same when catched | |||||
if (map_format_in_.find(ele.first) == map_format_in_.end()) { | |||||
GELOGD("Node is [%s], input tensor name [%s] is not been catched.Skip update action for it!", | |||||
op_desc_->GetName().c_str(), ele.first.c_str()); | |||||
tensor_desc_input->SetOriginFormat(tensor_desc_input->GetFormat()); | |||||
tensor_desc_input->SetOriginShape(tensor_desc_input->GetShape()); | |||||
continue; | |||||
} | |||||
auto ori_format = tensor_desc_input->GetFormat(); | |||||
auto ori_shape = tensor_desc_input->GetShape(); | |||||
auto curr_format = map_format_in_[ele.first]; | |||||
if (ori_format == curr_format) { | |||||
continue; | |||||
} | |||||
std::unique_ptr<common::transformer::ShapeTransferAccordingToFormat> shape_transfer( | |||||
new (std::nothrow) common::transformer::ShapeTransferAccordingToFormat()); | |||||
if (shape_transfer == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "Memory alloc failed"); | |||||
return false; | |||||
} | |||||
std::vector<int64_t> ori_shape_dims = ori_shape.GetDims(); | |||||
std::vector<int64_t> out_dims; | |||||
ge::DataType dtype = map_dtype_in_[ele.first]; | |||||
common::transformer::ShapeAndFormat shape_and_format_info{ | |||||
ori_shape_dims, out_dims, ori_format, curr_format, dtype, common::transformer::EN_IMPL_CUSTOM_TBE}; | |||||
shape_transfer->GetShapeAccordingToFormat(shape_and_format_info); | |||||
tensor_desc_input->SetFormat(curr_format); | |||||
tensor_desc_input->SetShape(GeShape(out_dims)); | |||||
} | |||||
for (auto &ele : outputs_) { | |||||
auto tensor_desc_output = op_desc_->MutableOutputDesc(ele.first); | |||||
if (tensor_desc_output == nullptr) { | |||||
continue; | |||||
} | |||||
// if can not find saved info, it says format and origin format is same when catched | |||||
if (map_ori_format_out_.find(ele.first) == map_ori_format_out_.end()) { | |||||
GELOGD("Node is [%s], input tensor name [%s] is not been catched.Skip update action for it!", | |||||
op_desc_->GetName().c_str(), ele.first.c_str()); | |||||
tensor_desc_output->SetOriginFormat(tensor_desc_output->GetFormat()); | |||||
tensor_desc_output->SetOriginShape(tensor_desc_output->GetShape()); | |||||
continue; | |||||
} | |||||
auto ori_shape = tensor_desc_output->GetShape(); | |||||
auto curr_format = tensor_desc_output->GetFormat(); | |||||
if (curr_format != map_ori_format_out_[ele.first]) { | |||||
GELOGE(GRAPH_FAILED, "Node is %s, out tensor name is %s. format: %s, recorded origin format: %s is not same", | |||||
op_desc_->GetName().c_str(), ele.first.c_str(), TypeUtils::FormatToSerialString(curr_format).c_str(), | |||||
TypeUtils::FormatToSerialString(map_ori_format_out_[ele.first]).c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
tensor_desc_output->SetOriginShape(ori_shape); | |||||
auto saved_format = map_format_out_[ele.first]; | |||||
if (curr_format == saved_format) { | |||||
GELOGD("Nodeis %s, out tensor name is %s. ori format: %s, recorded format: %s is same! No need to transfer", | |||||
op_desc_->GetName().c_str(), ele.first.c_str(), TypeUtils::FormatToSerialString(curr_format).c_str(), | |||||
TypeUtils::FormatToSerialString(saved_format).c_str()); | |||||
continue; | |||||
} | |||||
tensor_desc_output->SetFormat(saved_format); | |||||
std::unique_ptr<common::transformer::ShapeTransferAccordingToFormat> shape_transfer( | |||||
new (std::nothrow) common::transformer::ShapeTransferAccordingToFormat()); | |||||
if (shape_transfer == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "Memory alloc failed"); | |||||
return false; | |||||
} | |||||
std::vector<int64_t> ori_shape_dims = ori_shape.GetDims(); | |||||
std::vector<int64_t> out_dims; | |||||
ge::DataType dtype = tensor_desc_output->GetDataType(); | |||||
common::transformer::ShapeAndFormat shape_and_format_info{ | |||||
ori_shape_dims, out_dims, curr_format, saved_format, dtype, common::transformer::EN_IMPL_CUSTOM_TBE}; | |||||
shape_transfer->GetShapeAccordingToFormat(shape_and_format_info); | |||||
tensor_desc_output->SetShape(GeShape(out_dims)); | |||||
GELOGD("Node is %s, out tensor name is %s. Update format and shape success,ori format: %s, format: %s", | |||||
op_desc_->GetName().c_str(), ele.first.c_str(), TypeUtils::FormatToSerialString(curr_format).c_str(), | |||||
TypeUtils::FormatToSerialString(saved_format).c_str()); | |||||
} | |||||
GELOGD("Node is %s. Update format and shape success", op_desc_->GetName().c_str()); | |||||
return true; | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,50 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef COMMON_GRAPH_UTILS_TRANSFORMER_UTILS_H_ | |||||
#define COMMON_GRAPH_UTILS_TRANSFORMER_UTILS_H_ | |||||
#include <string> | |||||
#include <map> | |||||
#include "external/graph/types.h" | |||||
#include "graph/op_desc.h" | |||||
#include "graph/ge_tensor.h" | |||||
#include "transformer/inc/transfer_shape_according_to_format.h" | |||||
namespace ge { | |||||
class NodeShapeTransUtils { | |||||
public: | |||||
bool CatchFormatAndShape(); | |||||
bool UpdateFormatAndShape(); | |||||
explicit NodeShapeTransUtils(OpDescPtr op_desc) : op_desc_(op_desc) {} | |||||
~NodeShapeTransUtils() {} | |||||
private: | |||||
std::map<std::string, Format> map_format_in_; | |||||
std::map<std::string, Format> map_ori_format_in_; | |||||
std::map<std::string, DataType> map_dtype_in_; | |||||
std::map<std::string, Format> map_format_out_; | |||||
std::map<std::string, Format> map_ori_format_out_; | |||||
std::map<std::string, DataType> map_dtype_out_; | |||||
std::map<std::string, uint32_t> inputs_; | |||||
std::map<std::string, uint32_t> outputs_; | |||||
OpDescPtr op_desc_; | |||||
}; | |||||
} // namespace ge | |||||
#endif // COMMON_GRAPH_UTILS_TRANSFORMER_UTILS_H_ |
@@ -260,6 +260,33 @@ Status Session::AddGraph(uint32_t graph_id, const Graph &graph, const std::map<s | |||||
return ret; | return ret; | ||||
} | } | ||||
Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph) { | |||||
std::map<AscendString, AscendString> options; | |||||
return AddGraphWithCopy(graph_id, graph, options); | |||||
} | |||||
Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph, | |||||
const std::map<AscendString, AscendString> &options) { | |||||
GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_); | |||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "AddGraph failed in Session."); | |||||
return FAILED; | |||||
} | |||||
std::map<std::string, std::string> str_options; | |||||
for (auto it = options.begin(); it != options.end(); ++it) { | |||||
str_options.insert({it->first.GetString(), it->second.GetString()}); | |||||
} | |||||
GELOGD("Adding graph to session"); | |||||
Status ret = instance_ptr->SessionManagerObj().AddGraphWithCopy(sessionId_, graph_id, graph, str_options); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "AddGraph failed in Session."); | |||||
return FAILED; | |||||
} | |||||
GELOGD("AddGraph finished in Session."); | |||||
return ret; | |||||
} | |||||
Status Session::RemoveGraph(uint32_t graph_id) { | Status Session::RemoveGraph(uint32_t graph_id) { | ||||
GELOGT(TRACE_INIT, "Session RemoveGraph start"); | GELOGT(TRACE_INIT, "Session RemoveGraph start"); | ||||
@@ -24,6 +24,7 @@ | |||||
#include "common/fp16_t.h" | #include "common/fp16_t.h" | ||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/debug/log.h" | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "securec.h" | #include "securec.h" | ||||
@@ -123,21 +124,25 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
std::pair<DataType, DataType> trans_info(args.src_data_type, args.dst_data_type); | std::pair<DataType, DataType> trans_info(args.src_data_type, args.dst_data_type); | ||||
auto iter = trans_mode_map.find(trans_info); | auto iter = trans_mode_map.find(trans_info); | ||||
if (iter == trans_mode_map.end()) { | if (iter == trans_mode_map.end()) { | ||||
GELOGE(PARAM_INVALID, "Trans data type from %s to %s is not supported.", | |||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); | |||||
std::string error = "Failed to trans data from datatype " + | |||||
FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " + | |||||
FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + " , it is not supported."; | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
auto trans_mode = iter->second; | auto trans_mode = iter->second; | ||||
int size = GetSizeByDataType(args.dst_data_type); | int size = GetSizeByDataType(args.dst_data_type); | ||||
if (size <= 0) { | if (size <= 0) { | ||||
GELOGE(PARAM_INVALID, "Failed to calc size from data type %s", | |||||
TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); | |||||
std::string error = "Failed to calc size from data type" + | |||||
FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + ", it is not supported."; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
if (args.src_data_size > static_cast<size_t>(SIZE_MAX / size)) { | if (args.src_data_size > static_cast<size_t>(SIZE_MAX / size)) { | ||||
GELOGE(PARAM_INVALID, "args.src_data_size %zu or data type size %d too big.", args.src_data_size, size); | |||||
std::string error = | |||||
"args.src_data_size" + FmtToStr(args.src_data_size) + " or data type size" + FmtToStr(size) + " is too big"; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
size_t total_size = static_cast<size_t>(args.src_data_size * size); | size_t total_size = static_cast<size_t>(args.src_data_size * size); | ||||
@@ -154,9 +159,11 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
} | } | ||||
if (CastKernel(args, dst.get(), args.src_data_size, trans_mode) != SUCCESS) { | if (CastKernel(args, dst.get(), args.src_data_size, trans_mode) != SUCCESS) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to cast data from %s to %s, data size %zu", | |||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str(), args.src_data_size); | |||||
std::string error = "Failed to cast data from datatype " + | |||||
FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " + | |||||
FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + ", data size is " + | |||||
FmtToStr(std::to_string(args.src_data_size)); | |||||
GE_ERRORLOG_AND_ERRORMSG(INTERNAL_ERROR, error.c_str()); | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
result.data = dst; | result.data = dst; | ||||
@@ -22,6 +22,7 @@ | |||||
#include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
#include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/debug/log.h" | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -35,14 +36,16 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||||
auto src_shape = args.src_shape; | auto src_shape = args.src_shape; | ||||
auto dst_shape = args.dst_shape; | auto dst_shape = args.dst_shape; | ||||
if (args.src_format != FORMAT_C1HWNCoC0 || args.dst_format != FORMAT_HWCN) { | if (args.src_format != FORMAT_C1HWNCoC0 || args.dst_format != FORMAT_HWCN) { | ||||
GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
std::string error = "Dose not support trans format from " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
GELOGE(UNSUPPORTED, "Failed to trans shape from NC1HWNCoC0 to HWCN, invalid data type %s", | |||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | |||||
std::string error = "Failed to trans shape from NC1HWNCoC0 to HWCN, invalid data type" + | |||||
FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)); | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
if (!CheckShapeValid(src_shape, kC1hwncoc0DimsNum)) { | if (!CheckShapeValid(src_shape, kC1hwncoc0DimsNum)) { | ||||
@@ -58,8 +61,9 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||||
src_shape.at(kC1hwncoc0H) != dst_shape.at(kHwcnH) || src_shape.at(kC1hwncoc0W) != dst_shape.at(kHwcnW) || | src_shape.at(kC1hwncoc0H) != dst_shape.at(kHwcnH) || src_shape.at(kC1hwncoc0W) != dst_shape.at(kHwcnW) || | ||||
src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != cube_size || | src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != cube_size || | ||||
src_shape.at(kC1hwncoc0C0) != cube_size) { | src_shape.at(kC1hwncoc0C0) != cube_size) { | ||||
GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | |||||
ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | |||||
std::string error = "Failed to check relationship between src and dst shape, src shape" + | |||||
FmtToStr(ShapeToString(src_shape)) + ", dst shape" + FmtToStr(ShapeToString(dst_shape)); | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -149,11 +149,7 @@ Status FormatTransferDhwcnFractalZ3D::TransFormat(const TransArgs &args, TransRe | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
return ret; | return ret; | ||||
} | } | ||||
if (!args.dst_shape.empty() && args.dst_shape != expect_shape) { | |||||
GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
ShapeToString(expect_shape).c_str()); | |||||
if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -150,11 +150,7 @@ Status FormatTransferDhwncFractalZ3DTranspose::TransFormat(const TransArgs &args | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
return ret; | return ret; | ||||
} | } | ||||
if (!args.dst_shape.empty() && args.dst_shape != expect_shape) { | |||||
GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
ShapeToString(expect_shape).c_str()); | |||||
if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -22,6 +22,7 @@ | |||||
#include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
#include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/debug/log.h" | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -39,8 +40,9 @@ bool CheckShape(Format format, const ShapeVector &shape) { | |||||
case FORMAT_NHWC: | case FORMAT_NHWC: | ||||
return CheckShapeValid(shape, kDimSize4D); | return CheckShapeValid(shape, kDimSize4D); | ||||
default: | default: | ||||
GELOGE(PARAM_INVALID, "Trans format between %s and FORMAT_FRACTAL_NZ is not supported.", | |||||
TypeUtils::FormatToSerialString(format).c_str()); | |||||
std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + | |||||
" and FORMAT_FRACTAL_NZ is not supported."; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
@@ -103,11 +105,7 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
if (args.src_shape != expect_src_shape) { | |||||
GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, invalid relationship between src shape %s and dst %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||||
ShapeToString(args.dst_shape).c_str()); | |||||
if (!IsTransShapeSrcCorrect(args, expect_src_shape)) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -279,11 +277,7 @@ Status FormatTransferFractalNz::TransFormat(const TransArgs &args, TransResult & | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
return ret; | return ret; | ||||
} | } | ||||
if (args.dst_shape != expect_shape) { | |||||
GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
ShapeToString(expect_shape).c_str()); | |||||
if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
return TransFormatFromNdToFracNz(args, result, hw_shape); | return TransFormatFromNdToFracNz(args, result, hw_shape); | ||||
@@ -23,6 +23,7 @@ | |||||
#include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
#include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/debug/log.h" | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -159,8 +160,9 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||||
ret = memset_s(dst.get() + offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | ret = memset_s(dst.get() + offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | ||||
} else { | } else { | ||||
if (protected_size < size) { | if (protected_size < size) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld", | |||||
protected_size, size); | |||||
std::string error = "Failed to operate the dst memory, protected_size is " + FmtToStr(protected_size) + | |||||
" and size is " + FmtToStr(size); | |||||
GE_ERRORLOG_AND_ERRORMSG(INTERNAL_ERROR, error.c_str()); | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
char *dst_data = reinterpret_cast<char *>(dst.get() + offset); | char *dst_data = reinterpret_cast<char *>(dst.get() + offset); | ||||
@@ -345,11 +347,7 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
return ret; | return ret; | ||||
} | } | ||||
if (!args.dst_shape.empty() && args.dst_shape != expect_shape) { | |||||
GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
ShapeToString(expect_shape).c_str()); | |||||
if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -22,6 +22,7 @@ | |||||
#include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
#include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/debug/log.h" | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -39,8 +40,9 @@ bool CheckShape(Format format, const ShapeVector &shape) { | |||||
case FORMAT_NHWC: | case FORMAT_NHWC: | ||||
return CheckShapeValid(shape, kDimSize4D); | return CheckShapeValid(shape, kDimSize4D); | ||||
default: | default: | ||||
GELOGE(PARAM_INVALID, "Not support trans format between %s and FORMAT_FRACTAL_ZZ.", | |||||
TypeUtils::FormatToSerialString(format).c_str()); | |||||
std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + | |||||
" and FORMAT_FRACTAL_ZZ is not supported."; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
@@ -103,12 +105,7 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
if (args.src_shape != expect_src_shape) { | |||||
GELOGE(PARAM_INVALID, | |||||
"Failed to trans format from %s to %s, invalid relationship between src shape %s and dst shape %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.src_shape).c_str(), | |||||
ShapeToString(args.dst_shape).c_str()); | |||||
if (!IsTransShapeSrcCorrect(args, expect_src_shape)) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -289,11 +286,7 @@ Status FormatTransferFractalZz::TransFormat(const TransArgs &args, TransResult & | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
return ret; | return ret; | ||||
} | } | ||||
if (args.dst_shape != expect_shape) { | |||||
GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
ShapeToString(expect_shape).c_str()); | |||||
if (!IsTransShapeDstCorrect(args, expect_shape)) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
return TransFormatFromNdToFracZz(args, result, hw_shape); | return TransFormatFromNdToFracZz(args, result, hw_shape); | ||||
@@ -22,6 +22,7 @@ | |||||
#include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
#include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/debug/log.h" | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -33,9 +34,10 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { | |||||
auto src_shape = args.src_shape; | auto src_shape = args.src_shape; | ||||
auto dst_shape = args.dst_shape; | auto dst_shape = args.dst_shape; | ||||
if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_HWCN) { | if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_HWCN) { | ||||
GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
std::string error = "Dose not support trans format from " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
@@ -59,8 +61,9 @@ Status CheckArgsForFracZToHwcn(const TransArgs &args) { | |||||
int64_t n0 = Ceil(dst_shape.at(kHwcnN), static_cast<int64_t>(kNiSize)); | int64_t n0 = Ceil(dst_shape.at(kHwcnN), static_cast<int64_t>(kNiSize)); | ||||
if (src_shape.at(kFracZHWC1) != dst_shape.at(kHwcnH) * dst_shape.at(kHwcnW) * c1 || src_shape.at(kFracZC0) != c0 || | if (src_shape.at(kFracZHWC1) != dst_shape.at(kHwcnH) * dst_shape.at(kHwcnW) * c1 || src_shape.at(kFracZC0) != c0 || | ||||
src_shape.at(kFracZNi) != kNiSize || src_shape.at(kFracZN0) != n0) { | src_shape.at(kFracZNi) != kNiSize || src_shape.at(kFracZN0) != n0) { | ||||
GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | |||||
ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | |||||
std::string error = "Failed to check relationship between src shape" + FmtToStr(ShapeToString(src_shape)) + | |||||
" and dst shape" + FmtToStr(ShapeToString(dst_shape)); | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -22,6 +22,7 @@ | |||||
#include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
#include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/debug/log.h" | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -33,9 +34,10 @@ Status CheckArgsForFracZToNchw(const TransArgs &args) { | |||||
auto src_shape = args.src_shape; | auto src_shape = args.src_shape; | ||||
auto dst_shape = args.dst_shape; | auto dst_shape = args.dst_shape; | ||||
if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_NCHW) { | if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_NCHW) { | ||||
GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
std::string error = "Dose not support trans format from " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
@@ -22,6 +22,7 @@ | |||||
#include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
#include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/debug/log.h" | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -33,9 +34,10 @@ Status CheckArgsForFracZToNhwc(const TransArgs &args) { | |||||
auto src_shape = args.src_shape; | auto src_shape = args.src_shape; | ||||
auto dst_shape = args.dst_shape; | auto dst_shape = args.dst_shape; | ||||
if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_NHWC) { | if (args.src_format != FORMAT_FRACTAL_Z || args.dst_format != FORMAT_NHWC) { | ||||
GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
std::string error = "Dose not support trans format from " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
@@ -22,6 +22,7 @@ | |||||
#include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
#include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/debug/log.h" | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -50,9 +51,10 @@ Status TransShapeHwcnToC1hwncoc0(const DataType &data_type, const std::vector<in | |||||
Status CheckArgsForHwcnToC1hwncoc0(const TransArgs &args) { | Status CheckArgsForHwcnToC1hwncoc0(const TransArgs &args) { | ||||
if (args.src_format != FORMAT_HWCN || args.dst_format != FORMAT_C1HWNCoC0) { | if (args.src_format != FORMAT_HWCN || args.dst_format != FORMAT_C1HWNCoC0) { | ||||
GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
std::string error = "Dose not support trans format from " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
@@ -22,6 +22,7 @@ | |||||
#include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
#include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/debug/log.h" | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -33,9 +34,10 @@ Status CheckArgsForNc1hwc0ToNchw(const TransArgs &args) { | |||||
auto src_shape = args.src_shape; | auto src_shape = args.src_shape; | ||||
auto dst_shape = args.dst_shape; | auto dst_shape = args.dst_shape; | ||||
if (args.src_format != FORMAT_NC1HWC0 || args.dst_format != FORMAT_NCHW) { | if (args.src_format != FORMAT_NC1HWC0 || args.dst_format != FORMAT_NCHW) { | ||||
GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
std::string error = "Dose not support trans format from " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
@@ -22,6 +22,7 @@ | |||||
#include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
#include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/debug/log.h" | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -33,9 +34,10 @@ Status CheckArgsForNc1hwc0ToNhwc(const TransArgs &args) { | |||||
auto src_shape = args.src_shape; | auto src_shape = args.src_shape; | ||||
auto dst_shape = args.dst_shape; | auto dst_shape = args.dst_shape; | ||||
if (args.src_format != FORMAT_NC1HWC0 || args.dst_format != FORMAT_NHWC) { | if (args.src_format != FORMAT_NC1HWC0 || args.dst_format != FORMAT_NHWC) { | ||||
GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
std::string error = "Dose not support trans format from " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
@@ -280,11 +280,7 @@ Status FormatTransferNchwToFZC04::TransFormat(const TransArgs &args, TransResult | |||||
return ret; | return ret; | ||||
} | } | ||||
if (!args_tmp.dst_shape.empty() && args_tmp.dst_shape != expect_shape) { | |||||
GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, the dst shape %s is invalid, expect %s", | |||||
TypeUtils::FormatToSerialString(args_tmp.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args_tmp.dst_format).c_str(), ShapeToString(args_tmp.dst_shape).c_str(), | |||||
ShapeToString(expect_shape).c_str()); | |||||
if (!IsTransShapeDstCorrect(args_tmp, expect_shape)) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -22,6 +22,7 @@ | |||||
#include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
#include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/debug/log.h" | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -53,9 +54,10 @@ Status TransShapeNchwToNc1hwc0(const std::vector<int64_t> &src_shape, DataType d | |||||
Status CheckArgsForNchwToNc1hwc0(const TransArgs &args) { | Status CheckArgsForNchwToNc1hwc0(const TransArgs &args) { | ||||
if (args.src_format != FORMAT_NCHW || args.dst_format != FORMAT_NC1HWC0) { | if (args.src_format != FORMAT_NCHW || args.dst_format != FORMAT_NC1HWC0) { | ||||
GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
std::string error = "Dose not support trans format from " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
std::vector<int64_t> expect_5d_shape; | std::vector<int64_t> expect_5d_shape; | ||||
@@ -22,6 +22,7 @@ | |||||
#include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
#include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/debug/log.h" | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -51,9 +52,10 @@ Status TransShapeNhwcToNc1hwc0(const std::vector<int64_t> &src_shape, DataType d | |||||
Status CheckArgsForNhwcToNc1hwc0(const TransArgs &args) { | Status CheckArgsForNhwcToNc1hwc0(const TransArgs &args) { | ||||
if (args.src_format != FORMAT_NHWC || args.dst_format != FORMAT_NC1HWC0) { | if (args.src_format != FORMAT_NHWC || args.dst_format != FORMAT_NC1HWC0) { | ||||
GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
std::string error = "Dose not support trans format from " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
if (!CheckDataTypeSupported(args.src_data_type)) { | if (!CheckDataTypeSupported(args.src_data_type)) { | ||||
@@ -48,28 +48,31 @@ std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{ | |||||
bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) { | bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) { | ||||
if (src_shape.empty()) { | if (src_shape.empty()) { | ||||
std::string error = "Failed to transpose, empty src shape"; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
GELOGE(PARAM_INVALID, "Failed to transpose, empty src shape"); | GELOGE(PARAM_INVALID, "Failed to transpose, empty src shape"); | ||||
return false; | return false; | ||||
} | } | ||||
for (auto dim : src_shape) { | for (auto dim : src_shape) { | ||||
if (dim < 0) { | if (dim < 0) { | ||||
GELOGE(PARAM_INVALID, "Failed to transpose, negative dim in src shape %s", ShapeToString(src_shape).c_str()); | |||||
std::string error = "Failed to transpose, negative dim in src shape " + FmtToStr(ShapeToString(src_shape)); | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
if (perm_arg.size() != src_shape.size()) { | if (perm_arg.size() != src_shape.size()) { | ||||
GELOGE(PARAM_INVALID, | |||||
"Failed to transpose, the size of src shape(%zu) and" | |||||
" perm arg(%zu) are different", | |||||
src_shape.size(), perm_arg.size()); | |||||
std::string error = "Failed to transpose, the size of src shape" + FmtToStr(src_shape.size()) + " and perm arg" + | |||||
FmtToStr(perm_arg.size()) + " are different"; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
std::vector<int64_t> exists(perm_arg.size()); | std::vector<int64_t> exists(perm_arg.size()); | ||||
for (auto perm : perm_arg) { | for (auto perm : perm_arg) { | ||||
if (perm < 0 || static_cast<size_t>(perm) >= perm_arg.size() || ++exists[perm] > 1) { | if (perm < 0 || static_cast<size_t>(perm) >= perm_arg.size() || ++exists[perm] > 1) { | ||||
GELOGE(PARAM_INVALID, "Failed to transpose, duplicated perm arg %ld, perm arg %s", perm, | |||||
JoinToString(perm_arg).c_str()); | |||||
std::string error = | |||||
"Failed to transpose, duplicated perm arg " + FmtToStr(perm) + ", perm arg " + FmtToStr(JoinToString(perm_arg)); | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
@@ -192,9 +195,10 @@ Status TransposeWithShapeCheck(const uint8_t *data, const std::vector<int64_t> & | |||||
} | } | ||||
auto expected_shape = TransShapeByPerm(src_shape, perm_arg); | auto expected_shape = TransShapeByPerm(src_shape, perm_arg); | ||||
if (dst_shape != expected_shape) { | if (dst_shape != expected_shape) { | ||||
GELOGE(PARAM_INVALID, "Failed to trans axis for perm_arg %s, invalid dst shape %s, expect %s", | |||||
ShapeToString(perm_arg).c_str(), ShapeToString(dst_shape).c_str(), ShapeToString(expected_shape).c_str()); | |||||
return PARAM_INVALID; | |||||
std::string error = "Failed to trans axis for perm_arg" + FmtToStr(ShapeToString(perm_arg)) + | |||||
", invalid dst shape" + FmtToStr(ShapeToString(dst_shape)) + ", expect" + | |||||
FmtToStr(ShapeToString(expected_shape)); | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
} | } | ||||
return Transpose(data, src_shape, src_data_type, perm_arg, result); | return Transpose(data, src_shape, src_data_type, perm_arg, result); | ||||
@@ -203,14 +207,18 @@ Status TransposeWithShapeCheck(const uint8_t *data, const std::vector<int64_t> & | |||||
Status GetPermByForamt(Format src_format, Format dst_format, std::vector<int64_t> &perm) { | Status GetPermByForamt(Format src_format, Format dst_format, std::vector<int64_t> &perm) { | ||||
auto dst_iter = perm_args.find(src_format); | auto dst_iter = perm_args.find(src_format); | ||||
if (dst_iter == perm_args.end()) { | if (dst_iter == perm_args.end()) { | ||||
GELOGE(UNSUPPORTED, "Failed to trans shape, do not support transpose from format %s to %s", | |||||
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); | |||||
std::string error = "Failed to trans shape, do not support transpose from format " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(src_format)) + " to " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(dst_format)); | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
auto iter = dst_iter->second.find(dst_format); | auto iter = dst_iter->second.find(dst_format); | ||||
if (iter == dst_iter->second.end()) { | if (iter == dst_iter->second.end()) { | ||||
GELOGE(UNSUPPORTED, "Failed to trans shape, do not support transpose from format %s to %s", | |||||
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); | |||||
std::string error = "Failed to trans shape, do not support transpose from format " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(src_format)) + " to " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(dst_format)); | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
perm = iter->second; | perm = iter->second; | ||||
@@ -223,11 +231,7 @@ Status FormatTransferTranspose::TransFormat(const TransArgs &args, TransResult & | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
return ret; | return ret; | ||||
} | } | ||||
if (args.dst_shape != expected_shape) { | |||||
GELOGE(PARAM_INVALID, "Failed to trans format from %s to %s, invalid dst shape %s, expect %s", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str(), ShapeToString(args.dst_shape).c_str(), | |||||
ShapeToString(expected_shape).c_str()); | |||||
if (!IsTransShapeDstCorrect(args, expected_shape)) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -26,6 +26,7 @@ | |||||
#include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/debug/log.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
@@ -34,9 +35,10 @@ namespace formats { | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArgs &args, TransResult &result) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArgs &args, TransResult &result) { | ||||
auto transfer = BuildFormatTransfer(args); | auto transfer = BuildFormatTransfer(args); | ||||
if (transfer == nullptr) { | if (transfer == nullptr) { | ||||
GELOGE(UNSUPPORTED, "Failed to trans data from format %s to %s, unsupport now", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
std::string error = "Failed to trans data from format " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
@@ -58,9 +60,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransShape(Format src_form | |||||
args.dst_format = dst_format; | args.dst_format = dst_format; | ||||
auto transfer = BuildFormatTransfer(args); | auto transfer = BuildFormatTransfer(args); | ||||
if (transfer == nullptr) { | if (transfer == nullptr) { | ||||
GELOGE(UNSUPPORTED, "Failed to trans data from format %s to %s, unsupport now", | |||||
TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | |||||
std::string error = "Failed to trans data from format " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " + | |||||
FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)); | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
@@ -70,9 +73,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransShape(Format src_form | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransDataType(const CastArgs &args, TransResult &result) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransDataType(const CastArgs &args, TransResult &result) { | ||||
auto transfer = BuildDataTypeTransfer(args); | auto transfer = BuildDataTypeTransfer(args); | ||||
if (transfer == nullptr) { | if (transfer == nullptr) { | ||||
GELOGE(UNSUPPORTED, "Failed to trans data from datatype %s to %s, unsupport now", | |||||
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); | |||||
std::string error = "Failed to trans data from datatype " + | |||||
FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " + | |||||
FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)); | |||||
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
@@ -20,6 +20,7 @@ | |||||
#include "common/formats/utils/formats_definitions.h" | #include "common/formats/utils/formats_definitions.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/debug/log.h" | |||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
@@ -29,8 +30,9 @@ int64_t GetCubeSizeByDataType(DataType data_type) { | |||||
// Current cube does not support 4 bytes and longer data | // Current cube does not support 4 bytes and longer data | ||||
auto size = GetSizeByDataType(data_type); | auto size = GetSizeByDataType(data_type); | ||||
if (size <= 0) { | if (size <= 0) { | ||||
GELOGE(PARAM_INVALID, "Failed to get cube size, the data type %s is invalid", | |||||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
std::string error = "Failed to get cube size, the data type " + | |||||
FmtToStr(TypeUtils::DataTypeToSerialString(data_type)) + " is invalid"; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return -1; | return -1; | ||||
} else if (size == 1) { | } else if (size == 1) { | ||||
return kCubeSize * 2; // 32 bytes cube size | return kCubeSize * 2; // 32 bytes cube size | ||||
@@ -57,7 +59,8 @@ int64_t GetItemNumByShape(const std::vector<int64_t> &shape) { | |||||
bool CheckShapeValid(const std::vector<int64_t> &shape, const int64_t expect_dims) { | bool CheckShapeValid(const std::vector<int64_t> &shape, const int64_t expect_dims) { | ||||
if (expect_dims <= 0 || shape.size() != static_cast<size_t>(expect_dims)) { | if (expect_dims <= 0 || shape.size() != static_cast<size_t>(expect_dims)) { | ||||
GELOGE(PARAM_INVALID, "Invalid shape, dims num %zu, expect %ld", shape.size(), expect_dims); | |||||
std::string error = "Invalid shape, dims num " + FmtToStr(shape.size()) + ", expect " + FmtToStr(expect_dims); | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
return IsShapeValid(shape); | return IsShapeValid(shape); | ||||
@@ -70,11 +73,13 @@ bool IsShapeValid(const std::vector<int64_t> &shape) { | |||||
int64_t num = 1; | int64_t num = 1; | ||||
for (auto dim : shape) { | for (auto dim : shape) { | ||||
if (dim < 0) { | if (dim < 0) { | ||||
GELOGE(PARAM_INVALID, "Invalid negative dim in the shape %s", ShapeToString(shape).c_str()); | |||||
std::string error = "Invalid negative dims in the shape " + FmtToStr(ShapeToString(shape)); | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
if (dim != 0 && kShapeItemNumMAX / dim < num) { | if (dim != 0 && kShapeItemNumMAX / dim < num) { | ||||
GELOGE(PARAM_INVALID, "Shape overflow, the total count should be less than %ld!", kShapeItemNumMAX); | |||||
std::string error = "Shape overflow, the total count should be less than " + FmtToStr(kShapeItemNumMAX); | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
num *= dim; | num *= dim; | ||||
@@ -94,5 +99,29 @@ bool IsShapeEqual(const GeShape &src, const GeShape &dst) { | |||||
} | } | ||||
return true; | return true; | ||||
} | } | ||||
bool IsTransShapeSrcCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape) { | |||||
if (args.src_shape != expect_shape) { | |||||
std::string error = "Failed to trans format from" + FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + | |||||
" to " + FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)) + | |||||
", invalid relationship between src shape " + FmtToStr(ShapeToString(args.src_shape)) + | |||||
" and dst " + FmtToStr(ShapeToString(args.dst_shape)); | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
bool IsTransShapeDstCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape) { | |||||
if (!args.dst_shape.empty() && args.dst_shape != expect_shape) { | |||||
std::string error = "Failed to trans format from " + FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + | |||||
" to " + FmtToStr(TypeUtils::FormatToSerialString(args.dst_format)) + ", the dst shape" + | |||||
FmtToStr(ShapeToString(args.dst_shape)) + " is invalid, expect" + | |||||
FmtToStr(ShapeToString(expect_shape)); | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
} // namespace formats | } // namespace formats | ||||
} // namespace ge | } // namespace ge |
@@ -23,6 +23,7 @@ | |||||
#include <vector> | #include <vector> | ||||
#include "external/graph/types.h" | #include "external/graph/types.h" | ||||
#include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
#include "register/register_format_transfer.h" | |||||
namespace ge { | namespace ge { | ||||
namespace formats { | namespace formats { | ||||
@@ -61,6 +62,10 @@ bool IsShapeValid(const std::vector<int64_t> &shape); | |||||
bool IsShapeEqual(const GeShape &src, const GeShape &dst); | bool IsShapeEqual(const GeShape &src, const GeShape &dst); | ||||
bool IsTransShapeSrcCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape); | |||||
bool IsTransShapeDstCorrect(const TransArgs &args, std::vector<int64_t> &expect_shape); | |||||
template <typename T> | template <typename T> | ||||
T Ceil(T n1, T n2) { | T Ceil(T n1, T n2) { | ||||
if (n1 == 0) { | if (n1 == 0) { | ||||
@@ -61,17 +61,18 @@ GE_COMMON_LOCAL_C_INCLUDES := \ | |||||
proto/tensorflow/types.proto \ | proto/tensorflow/types.proto \ | ||||
proto/tensorflow/resource_handle.proto \ | proto/tensorflow/resource_handle.proto \ | ||||
$(TOPDIR)inc \ | $(TOPDIR)inc \ | ||||
$(TOPDIR)metadef/inc \ | |||||
$(TOPDIR)graphengine/inc \ | |||||
$(TOPDIR)inc/external \ | $(TOPDIR)inc/external \ | ||||
$(TOPDIR)inc/external/graph \ | |||||
$(TOPDIR)inc/framework \ | |||||
$(TOPDIR)inc/common/util \ | |||||
$(TOPDIR)metadef/inc/external \ | |||||
$(TOPDIR)graphengine/inc/external \ | |||||
$(TOPDIR)metadef/inc/external/graph \ | |||||
$(TOPDIR)graphengine/inc/framework \ | |||||
$(TOPDIR)metadef/inc/common/util \ | |||||
$(TOPDIR)libc_sec/include \ | $(TOPDIR)libc_sec/include \ | ||||
$(TOPDIR)third_party/json/include \ | $(TOPDIR)third_party/json/include \ | ||||
$(TOPDIR)third_party/protobuf/include \ | $(TOPDIR)third_party/protobuf/include \ | ||||
$(TOPDIR)third_party/openssl/include/x86/include \ | $(TOPDIR)third_party/openssl/include/x86/include \ | ||||
$(TOPDIR)framework/domi \ | |||||
$(TOPDIR)framework/domi/common \ | |||||
$(TOPDIR)framework/domi/common/op \ | |||||
$(TOPDIR)graphengine/ge \ | $(TOPDIR)graphengine/ge \ | ||||
$(TOPDIR)graphengine/ge/common \ | $(TOPDIR)graphengine/ge/common \ | ||||
$(TOPDIR)graphengine/ge/common/op \ | $(TOPDIR)graphengine/ge/common/op \ | ||||
@@ -21,6 +21,7 @@ | |||||
#include "framework/common/string_util.h" | #include "framework/common/string_util.h" | ||||
#include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
#include "runtime/base.h" | #include "runtime/base.h" | ||||
#include "graph/load/new_model_manager/davinci_model.h" | |||||
namespace { | namespace { | ||||
const char *const kJobID = "jobID"; | const char *const kJobID = "jobID"; | ||||
@@ -39,10 +40,12 @@ const std::string kConfigNumsdev = "devNums"; | |||||
const std::string kConfigDevIdList = "devIdList"; | const std::string kConfigDevIdList = "devIdList"; | ||||
const std::string kProfStart = "prof_start"; | const std::string kProfStart = "prof_start"; | ||||
const std::string kProfStop = "prof_stop"; | const std::string kProfStop = "prof_stop"; | ||||
const std::string kProfModelSubscribe = "prof_model_subscribe"; | |||||
const std::string kProfModelUnsubscribe = "prof_model_cancel_subscribe"; | |||||
} // namespace | } // namespace | ||||
namespace ge { | namespace ge { | ||||
ProfilingManager::ProfilingManager() {} | |||||
ProfilingManager::ProfilingManager() : subscribe_count_(0) {} | |||||
ProfilingManager::~ProfilingManager() {} | ProfilingManager::~ProfilingManager() {} | ||||
@@ -54,6 +57,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager &ProfilingMana | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::Init(const Options &options) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::Init(const Options &options) { | ||||
#ifdef DAVINCI_SUPPORT_PROFILING | #ifdef DAVINCI_SUPPORT_PROFILING | ||||
vector<int32_t>().swap(device_id_); | vector<int32_t>().swap(device_id_); | ||||
subscribe_count_ = 0; | |||||
job_id_ = options.job_id; | job_id_ = options.job_id; | ||||
GELOGI("ProfilingManager::Init job_id:%s", job_id_.c_str()); | GELOGI("ProfilingManager::Init job_id:%s", job_id_.c_str()); | ||||
@@ -380,7 +384,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProf | |||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingTaskDescInfo( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingTaskDescInfo( | ||||
const std::vector<TaskDescInfo> &task_desc_info, const int32_t &device_id) { | |||||
uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info, const int32_t &device_id) { | |||||
#ifdef DAVINCI_SUPPORT_PROFILING | #ifdef DAVINCI_SUPPORT_PROFILING | ||||
Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); | Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); | ||||
if (reporter == nullptr) { | if (reporter == nullptr) { | ||||
@@ -400,6 +404,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin | |||||
.append(std::to_string(task_id)) | .append(std::to_string(task_id)) | ||||
.append(" ") | .append(" ") | ||||
.append(std::to_string(stream_id)) | .append(std::to_string(stream_id)) | ||||
.append(" ") | |||||
.append(std::to_string(model_id)) | |||||
.append("\n")); | .append("\n")); | ||||
Msprof::Engine::ReporterData reporter_data{}; | Msprof::Engine::ReporterData reporter_data{}; | ||||
@@ -424,7 +430,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin | |||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingGraphDescInfo( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ProfilingGraphDescInfo( | ||||
const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info, const int32_t &device_id) { | |||||
uint32_t model_id, const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info, const int32_t &device_id) { | |||||
#ifdef DAVINCI_SUPPORT_PROFILING | #ifdef DAVINCI_SUPPORT_PROFILING | ||||
Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); | Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); | ||||
GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return;); | GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return;); | ||||
@@ -482,6 +488,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin | |||||
data.append("\""); | data.append("\""); | ||||
} | } | ||||
data.append(" model_id:").append(std::to_string(model_id)); | |||||
data.append("\n"); | data.append("\n"); | ||||
Msprof::Engine::ReporterData reporter_data{}; | Msprof::Engine::ReporterData reporter_data{}; | ||||
@@ -536,7 +544,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::PluginUn | |||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportProfilingData( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportProfilingData( | ||||
const std::vector<TaskDescInfo> &task_desc_info, const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info) { | |||||
uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info, | |||||
const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info, bool check_device) { | |||||
#ifdef DAVINCI_SUPPORT_PROFILING | #ifdef DAVINCI_SUPPORT_PROFILING | ||||
int32_t logic_device_id = 0; | int32_t logic_device_id = 0; | ||||
rtError_t rt_ret = rtGetDevice(&logic_device_id); | rtError_t rt_ret = rtGetDevice(&logic_device_id); | ||||
@@ -545,7 +554,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportPr | |||||
return; | return; | ||||
} | } | ||||
GELOGI("current logic_device_id:%d", logic_device_id); | GELOGI("current logic_device_id:%d", logic_device_id); | ||||
if (!is_acl_api_mode_) { | |||||
if (check_device) { | |||||
auto ret = std::find(device_id_.begin(), device_id_.end(), logic_device_id); | auto ret = std::find(device_id_.begin(), device_id_.end(), logic_device_id); | ||||
if (ret == device_id_.end()) { | if (ret == device_id_.end()) { | ||||
GELOGE(FAILED, "get valid phy_device_id failed, profiling report failed."); | GELOGE(FAILED, "get valid phy_device_id failed, profiling report failed."); | ||||
@@ -553,9 +562,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::ReportPr | |||||
} | } | ||||
} | } | ||||
GELOGI("start ProfilingTaskDescInfo."); | GELOGI("start ProfilingTaskDescInfo."); | ||||
ProfilingTaskDescInfo(task_desc_info, logic_device_id); | |||||
ProfilingTaskDescInfo(model_id, task_desc_info, logic_device_id); | |||||
GELOGI("start ProfilingGraphDescInfo."); | GELOGI("start ProfilingGraphDescInfo."); | ||||
ProfilingGraphDescInfo(compute_graph_desc_info, logic_device_id); | |||||
ProfilingGraphDescInfo(model_id, compute_graph_desc_info, logic_device_id); | |||||
GELOGI("Report profiling data for GE end."); | GELOGI("Report profiling data for GE end."); | ||||
#endif | #endif | ||||
} | } | ||||
@@ -573,6 +582,102 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t ProfilingManager::GetP | |||||
return module; | return module; | ||||
} | } | ||||
void ProfilingManager::UpdateSubscribeDeviceModuleMap(std::string prof_type, uint32_t device_id, uint64_t module) { | |||||
#ifdef DAVINCI_SUPPORT_PROFILING | |||||
if (prof_type == kProfModelSubscribe) { | |||||
if (subs_dev_module_.find(device_id) != subs_dev_module_.end()) { | |||||
subs_dev_module_[device_id].subscribe_count++; | |||||
} else { | |||||
DeviceSubsInfo dev_info; | |||||
dev_info.module = module; | |||||
dev_info.subscribe_count = 1; | |||||
subs_dev_module_[device_id] = dev_info; | |||||
} | |||||
} else if (prof_type == kProfModelUnsubscribe) { | |||||
if (subs_dev_module_.find(device_id) != subs_dev_module_.end()) { | |||||
if (subs_dev_module_[device_id].subscribe_count > 0) { | |||||
subs_dev_module_[device_id].subscribe_count--; | |||||
} | |||||
} | |||||
} else { | |||||
GELOGI("No need to update device_id module map."); | |||||
} | |||||
#endif | |||||
} | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfModelSubscribe(uint64_t module, | |||||
void *model) { | |||||
#ifdef DAVINCI_SUPPORT_PROFILING | |||||
std::lock_guard<std::mutex> lock(mutex_); | |||||
uint64_t model_load_mask = module & PROF_MODEL_LOAD_MASK; | |||||
if ((subscribe_count_ == 0) && (model_load_mask == PROF_MODEL_LOAD_MASK)) { | |||||
// register framework to profiling | |||||
int32_t result = Msprof::Engine::Init(GE_PROFILING_MODULE, &engine_); | |||||
if (result != SUCCESS) { | |||||
GELOGE(FAILED, "Register profiling engine failed."); | |||||
return FAILED; | |||||
} | |||||
GELOGI("Prof subscribe: model load profiling on."); | |||||
} | |||||
subscribe_count_++; | |||||
auto davinci_model = static_cast<DavinciModel *>(model); | |||||
int32_t device_num = 1; | |||||
uint32_t device[1]; | |||||
device[0] = davinci_model->GetDeviceId(); | |||||
rtError_t rt_ret = rtProfilerStart(module, device_num, device); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(FAILED, "Runtime profiler start failed."); | |||||
return FAILED; | |||||
} | |||||
UpdateSubscribeDeviceModuleMap(kProfModelSubscribe, device[0], module); | |||||
// Report profiling data | |||||
Status p_ret = davinci_model->ReportProfilingData(false); | |||||
if (p_ret != SUCCESS) { | |||||
GELOGE(p_ret, "Report profiling data failed."); | |||||
return p_ret; | |||||
} | |||||
#endif | |||||
return SUCCESS; | |||||
} | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfModelUnsubscribe(void *model) { | |||||
#ifdef DAVINCI_SUPPORT_PROFILING | |||||
std::lock_guard<std::mutex> lock(mutex_); | |||||
if (subscribe_count_ == 0) { | |||||
GELOGW("The profiler has not been subscribed, you do not need to cannel the subscription."); | |||||
return SUCCESS; | |||||
} | |||||
auto davinci_model = static_cast<DavinciModel *>(model); | |||||
int32_t dev_num = 1; | |||||
uint32_t device[1]; | |||||
device[0] = davinci_model->GetDeviceId(); | |||||
auto iter = subs_dev_module_.find(device[0]); | |||||
if (iter != subs_dev_module_.end()) { | |||||
if (subs_dev_module_[device[0]].subscribe_count == 1) { | |||||
rtError_t rt_ret = rtProfilerStop(subs_dev_module_[device[0]].module, dev_num, device); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(FAILED, "Runtime profiler stop failed."); | |||||
return FAILED; | |||||
} | |||||
} | |||||
UpdateSubscribeDeviceModuleMap(kProfModelUnsubscribe, device[0], subs_dev_module_[device[0]].module); | |||||
} | |||||
subscribe_count_--; | |||||
if (subscribe_count_ == 0) { | |||||
int32_t ret = Msprof::Engine::UnInit(GE_PROFILING_MODULE); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Profiling plugin uninit failed, ret:%d", ret); | |||||
return ret; | |||||
} | |||||
} | |||||
#endif | |||||
return SUCCESS; | |||||
} | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfInit(uint64_t module) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfInit(uint64_t module) { | ||||
#ifdef DAVINCI_SUPPORT_PROFILING | #ifdef DAVINCI_SUPPORT_PROFILING | ||||
std::lock_guard<std::mutex> lock(mutex_); | std::lock_guard<std::mutex> lock(mutex_); | ||||
@@ -740,6 +845,7 @@ ProfilingManager::ProfStartProfiling(uint64_t module, const std::map<std::string | |||||
device_id_ptr[i] = static_cast<uint32_t>(device_list[i]); | device_id_ptr[i] = static_cast<uint32_t>(device_list[i]); | ||||
} | } | ||||
GELOGI("Runtime config param: 0x%llx, device num: %d.", module, device_num); | GELOGI("Runtime config param: 0x%llx, device num: %d.", module, device_num); | ||||
rtError_t rt_ret = rtProfilerStart(module, device_num, device_id_ptr.get()); | rtError_t rt_ret = rtProfilerStart(module, device_num, device_id_ptr.get()); | ||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
GELOGE(FAILED, "Runtime profiler config proc failed."); | GELOGE(FAILED, "Runtime profiler config proc failed."); | ||||
@@ -39,6 +39,10 @@ namespace { | |||||
const std::string GE_PROFILING_MODULE = "Framework"; | const std::string GE_PROFILING_MODULE = "Framework"; | ||||
} // namespace | } // namespace | ||||
namespace ge { | namespace ge { | ||||
struct DeviceSubsInfo { | |||||
uint64_t module; | |||||
uint32_t subscribe_count; | |||||
}; | |||||
// register Plugin | // register Plugin | ||||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY PluginImpl : public Msprof::Engine::PluginIntf { | class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY PluginImpl : public Msprof::Engine::PluginIntf { | ||||
public: | public: | ||||
@@ -73,6 +77,9 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { | |||||
ge::Status InitFromOptions(const Options &options); | ge::Status InitFromOptions(const Options &options); | ||||
ge::Status InitFromAclCfg(const std::string &config); | ge::Status InitFromAclCfg(const std::string &config); | ||||
ge::Status StartProfiling(int32_t iter, int32_t device_id); | ge::Status StartProfiling(int32_t iter, int32_t device_id); | ||||
void UpdateSubscribeDeviceModuleMap(std::string prof_type, uint32_t device_id, uint64_t module); | |||||
ge::Status ProfModelSubscribe(uint64_t module, void *model); | |||||
ge::Status ProfModelUnsubscribe(void *model); | |||||
ge::Status ProfInit(uint64_t module); | ge::Status ProfInit(uint64_t module); | ||||
ge::Status ProfFinalize(); | ge::Status ProfFinalize(); | ||||
ge::Status ProfStartProfiling(uint64_t module, const std::map<std::string, std::string> &config_para); | ge::Status ProfStartProfiling(uint64_t module, const std::map<std::string, std::string> &config_para); | ||||
@@ -84,13 +91,15 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { | |||||
bool ProfilingModelLoadOn() const { return is_load_profiling_; } | bool ProfilingModelLoadOn() const { return is_load_profiling_; } | ||||
bool ProfilingModelExecuteOn() const; | bool ProfilingModelExecuteOn() const; | ||||
bool ProfilingOn() const { return is_load_profiling_ && is_execute_profiling_; } // only used by command pattern | bool ProfilingOn() const { return is_load_profiling_ && is_execute_profiling_; } // only used by command pattern | ||||
bool IsAclApiMode() const { return is_acl_api_mode_; } | |||||
int32_t GetOpTraceIterNum() const { return op_trace_iter_num_; } | int32_t GetOpTraceIterNum() const { return op_trace_iter_num_; } | ||||
void ReportProfilingData(const std::vector<TaskDescInfo> &task_desc_info, | |||||
const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info); | |||||
void ReportProfilingData(uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info, | |||||
const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info, bool check_device); | |||||
void Report(const int32_t &device_id, const string &data, Msprof::Engine::Reporter &reporter, | void Report(const int32_t &device_id, const string &data, Msprof::Engine::Reporter &reporter, | ||||
Msprof::Engine::ReporterData &reporter_data); | Msprof::Engine::ReporterData &reporter_data); | ||||
void ProfilingTaskDescInfo(const std::vector<TaskDescInfo> &task_desc_info, const int32_t &device_id); | |||||
void ProfilingGraphDescInfo(const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info, | |||||
void ProfilingTaskDescInfo(uint32_t model_id, const std::vector<TaskDescInfo> &task_desc_info, | |||||
const int32_t &device_id); | |||||
void ProfilingGraphDescInfo(uint32_t model_id, const std::vector<ComputeGraphDescInfo> &compute_graph_desc_info, | |||||
const int32_t &device_id); | const int32_t &device_id); | ||||
void SetProfilingConfig(const string &profiling_cfg); | void SetProfilingConfig(const string &profiling_cfg); | ||||
vector<int32_t> GetProfilingDeviceId() const { return device_id_; } | vector<int32_t> GetProfilingDeviceId() const { return device_id_; } | ||||
@@ -121,7 +130,9 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { | |||||
string system_trace_conf_; | string system_trace_conf_; | ||||
string task_trace_conf_; | string task_trace_conf_; | ||||
const ProfilingEngineImpl engine_; | const ProfilingEngineImpl engine_; | ||||
map<int32_t, uint64_t> device_id_module_map_; // key: device_id, value: profiling on module | |||||
map<int32_t, uint64_t> device_id_module_map_; // key: device_id, value: profiling on module | |||||
map<uint32_t, DeviceSubsInfo> subs_dev_module_; // key: device_id, value: profiling on module | |||||
uint32_t subscribe_count_; | |||||
std::mutex mutex_; | std::mutex mutex_; | ||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -54,6 +54,7 @@ const std::map<std::string, std::string> PROFILE_COMPONENT_MAP{ | |||||
{"runtime", RTS_PROFILE}, | {"runtime", RTS_PROFILE}, | ||||
}; | }; | ||||
const std::string PROFILE_CONFIG = "config"; | const std::string PROFILE_CONFIG = "config"; | ||||
const std::string PROFILE_MODEL_ID = "modelId"; | |||||
REGISTER_OPTYPE_DEFINE(DATA, "Data"); | REGISTER_OPTYPE_DEFINE(DATA, "Data"); | ||||
REGISTER_OPTYPE_DEFINE(AIPPDATA, "AippData"); | REGISTER_OPTYPE_DEFINE(AIPPDATA, "AippData"); | ||||
@@ -1060,6 +1060,19 @@ Status GeExecutor::ReleaseSingleOpResource(void *stream) { | |||||
return SingleOpManager::GetInstance().ReleaseResource(stream); | return SingleOpManager::GetInstance().ReleaseResource(stream); | ||||
} | } | ||||
Status GeExecutor::GetDeviceIdByModelId(uint32_t model_id, uint32_t &device_id) { | |||||
auto model_manager = ModelManager::GetInstance(); | |||||
GE_CHECK_NOTNULL(model_manager); | |||||
auto davinci_model = model_manager->GetModel(model_id); | |||||
if (davinci_model == nullptr) { | |||||
GELOGE(FAILED, "Model id: %d is invaild or model is not loaded.", model_id); | |||||
return FAILED; | |||||
} | |||||
device_id = davinci_model->GetDeviceId(); | |||||
return SUCCESS; | |||||
} | |||||
Status GeExecutor::GetBatchInfoSize(uint32_t model_id, size_t &shape_count) { | Status GeExecutor::GetBatchInfoSize(uint32_t model_id, size_t &shape_count) { | ||||
std::vector<std::vector<int64_t>> batch_info; | std::vector<std::vector<int64_t>> batch_info; | ||||
int32_t dynamic_type = static_cast<int32_t>(FIXED); | int32_t dynamic_type = static_cast<int32_t>(FIXED); | ||||
@@ -72,9 +72,13 @@ local_ge_executor_c_include := \ | |||||
proto/task.proto \ | proto/task.proto \ | ||||
proto/om.proto \ | proto/om.proto \ | ||||
$(TOPDIR)inc/external \ | $(TOPDIR)inc/external \ | ||||
$(TOPDIR)inc/external/graph \ | |||||
$(TOPDIR)inc/framework \ | |||||
$(TOPDIR)metadef/inc/external \ | |||||
$(TOPDIR)graphengine/inc/external \ | |||||
$(TOPDIR)metadef/inc/external/graph \ | |||||
$(TOPDIR)graphengine/inc/framework \ | |||||
$(TOPDIR)inc \ | $(TOPDIR)inc \ | ||||
$(TOPDIR)metadef/inc \ | |||||
$(TOPDIR)graphengine/inc \ | |||||
$(LOCAL_PATH)/../ \ | $(LOCAL_PATH)/../ \ | ||||
$(TOPDIR)graphengine/ge \ | $(TOPDIR)graphengine/ge \ | ||||
$(TOPDIR)libc_sec/include \ | $(TOPDIR)libc_sec/include \ | ||||
@@ -287,11 +287,15 @@ COMMON_LOCAL_C_INCLUDES := \ | |||||
proto/tensorflow/versions.proto \ | proto/tensorflow/versions.proto \ | ||||
$(LOCAL_PATH) ./ \ | $(LOCAL_PATH) ./ \ | ||||
$(TOPDIR)inc \ | $(TOPDIR)inc \ | ||||
$(TOPDIR)metadef/inc \ | |||||
$(TOPDIR)graphengine/inc \ | |||||
$(TOPDIR)inc/external \ | $(TOPDIR)inc/external \ | ||||
$(TOPDIR)inc/external/graph \ | |||||
$(TOPDIR)inc/framework \ | |||||
$(TOPDIR)inc/framework/common \ | |||||
$(TOPDIR)inc/common \ | |||||
$(TOPDIR)metadef/inc/external \ | |||||
$(TOPDIR)graphengine/inc/external \ | |||||
$(TOPDIR)metadef/inc/external/graph \ | |||||
$(TOPDIR)graphengine/inc/framework \ | |||||
$(TOPDIR)graphengine/inc/framework/common \ | |||||
$(TOPDIR)metadef/inc/common \ | |||||
$(TOPDIR)inc/runtime \ | $(TOPDIR)inc/runtime \ | ||||
$(TOPDIR)libc_sec/include \ | $(TOPDIR)libc_sec/include \ | ||||
$(TOPDIR)ops/built-in/op_proto/inc \ | $(TOPDIR)ops/built-in/op_proto/inc \ | ||||
@@ -301,7 +305,7 @@ COMMON_LOCAL_C_INCLUDES := \ | |||||
third_party/opencv/include \ | third_party/opencv/include \ | ||||
ANALYZER_LOCAL_INCLUDES := \ | ANALYZER_LOCAL_INCLUDES := \ | ||||
$(TOPDIR)framework/domi/analyzer \ | |||||
$(TOPDIR)graphengine/ge/analyzer \ | |||||
NEW_OMG_HOST_SRC_FILES := \ | NEW_OMG_HOST_SRC_FILES := \ | ||||
graph/preprocess/insert_op/util_insert_aipp_op.cc \ | graph/preprocess/insert_op/util_insert_aipp_op.cc \ | ||||
@@ -341,15 +345,18 @@ DEVICE_LOCAL_C_INCLUDES := \ | |||||
proto/tensorflow/versions.proto \ | proto/tensorflow/versions.proto \ | ||||
$(LOCAL_PATH) ./ \ | $(LOCAL_PATH) ./ \ | ||||
$(TOPDIR)inc \ | $(TOPDIR)inc \ | ||||
$(TOPDIR)metadef/inc \ | |||||
$(TOPDIR)graphengine/inc \ | |||||
$(TOPDIR)libc_sec/include \ | $(TOPDIR)libc_sec/include \ | ||||
$(TOPDIR)inc/external \ | $(TOPDIR)inc/external \ | ||||
$(TOPDIR)inc/external/graph \ | |||||
$(TOPDIR)inc/common/util \ | |||||
$(TOPDIR)inc/framework \ | |||||
$(TOPDIR)inc/framework/common \ | |||||
$(TOPDIR)metadef/inc/external \ | |||||
$(TOPDIR)graphengine/inc/external \ | |||||
$(TOPDIR)metadef/inc/external/graph \ | |||||
$(TOPDIR)metadef/inc/common/util \ | |||||
$(TOPDIR)graphengine/inc/framework \ | |||||
$(TOPDIR)graphengine/inc/framework/common \ | |||||
$(TOPDIR)inc/runtime \ | $(TOPDIR)inc/runtime \ | ||||
$(TOPDIR)ops/built-in/op_proto/inc \ | $(TOPDIR)ops/built-in/op_proto/inc \ | ||||
$(TOPDIR)framework/domi \ | |||||
$(TOPDIR)graphengine/ge \ | $(TOPDIR)graphengine/ge \ | ||||
$(TOPDIR)toolchain/ide/ide-daemon/external \ | $(TOPDIR)toolchain/ide/ide-daemon/external \ | ||||
third_party/json/include \ | third_party/json/include \ | ||||
@@ -17,12 +17,15 @@ ops_kernel_builder_src_files := ops_kernel_store/ge_local_ops_kernel_builder.cc | |||||
local_lib_inc_path := proto/task.proto \ | local_lib_inc_path := proto/task.proto \ | ||||
${LOCAL_PATH} \ | ${LOCAL_PATH} \ | ||||
${TOPDIR}inc \ | ${TOPDIR}inc \ | ||||
${TOPDIR}metadef/inc \ | |||||
${TOPDIR}graphengine/inc \ | |||||
${TOPDIR}inc/external \ | ${TOPDIR}inc/external \ | ||||
${TOPDIR}inc/external/graph \ | |||||
${TOPDIR}metadef/inc/external \ | |||||
${TOPDIR}graphengine/inc/external \ | |||||
${TOPDIR}metadef/inc/external/graph \ | |||||
$(TOPDIR)libc_sec/include \ | $(TOPDIR)libc_sec/include \ | ||||
${TOPDIR}third_party/protobuf/include \ | ${TOPDIR}third_party/protobuf/include \ | ||||
${TOPDIR}inc/framework \ | |||||
$(TOPDIR)framework/domi \ | |||||
${TOPDIR}graphengine/inc/framework \ | |||||
$(TOPDIR)graphengine/ge \ | $(TOPDIR)graphengine/ge \ | ||||
#compiler for host | #compiler for host | ||||
@@ -300,6 +300,8 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
hybrid/hybrid_davinci_model.cc \ | hybrid/hybrid_davinci_model.cc \ | ||||
executor/ge_executor.cc \ | executor/ge_executor.cc \ | ||||
analyzer/analyzer.cc \ | analyzer/analyzer.cc \ | ||||
ir_build/ge_ir_build.cc \ | |||||
ir_build/atc_ir_common.cc \ | |||||
LIBCLIENT_LOCAL_SRC_FILES := \ | LIBCLIENT_LOCAL_SRC_FILES := \ | ||||
proto/ge_api.proto \ | proto/ge_api.proto \ | ||||
@@ -311,16 +313,20 @@ RUNNER_LOCAL_C_INCLUDES := \ | |||||
$(LOCAL_PATH)/../ \ | $(LOCAL_PATH)/../ \ | ||||
$(LOCAL_PATH)/../../ \ | $(LOCAL_PATH)/../../ \ | ||||
$(TOPDIR)inc \ | $(TOPDIR)inc \ | ||||
$(TOPDIR)inc/common \ | |||||
$(TOPDIR)metadef/inc \ | |||||
$(TOPDIR)graphengine/inc \ | |||||
$(TOPDIR)metadef/inc/common \ | |||||
$(TOPDIR)inc/external \ | $(TOPDIR)inc/external \ | ||||
$(TOPDIR)inc/external/graph \ | |||||
$(TOPDIR)inc/framework \ | |||||
$(TOPDIR)inc/framework/common \ | |||||
$(TOPDIR)inc/graph \ | |||||
$(TOPDIR)metadef/inc/external \ | |||||
$(TOPDIR)graphengine/inc/external \ | |||||
$(TOPDIR)metadef/inc/external/graph \ | |||||
$(TOPDIR)graphengine/inc/external/ge \ | |||||
$(TOPDIR)graphengine/inc/framework \ | |||||
$(TOPDIR)graphengine/inc/framework/common \ | |||||
$(TOPDIR)metadef/inc/graph \ | |||||
$(TOPDIR)inc/runtime \ | $(TOPDIR)inc/runtime \ | ||||
$(TOPDIR)libc_sec/include \ | $(TOPDIR)libc_sec/include \ | ||||
$(TOPDIR)ops/built-in/op_proto/inc \ | $(TOPDIR)ops/built-in/op_proto/inc \ | ||||
$(TOPDIR)framework/domi/analyzer \ | |||||
$(TOPDIR)graphengine/ge/analyzer \ | $(TOPDIR)graphengine/ge/analyzer \ | ||||
$(TOPDIR)toolchain/ide/ide-daemon/external \ | $(TOPDIR)toolchain/ide/ide-daemon/external \ | ||||
proto/fwk_adapter.proto \ | proto/fwk_adapter.proto \ | ||||
@@ -403,6 +409,7 @@ LOCAL_C_INCLUDES := $(RUNNER_LOCAL_C_INCLUDES) | |||||
LOCAL_SRC_FILES := ../../out/ge/lib64/stub/ge_api.cc \ | LOCAL_SRC_FILES := ../../out/ge/lib64/stub/ge_api.cc \ | ||||
../../out/ge/lib64/stub/ge_prof.cc \ | ../../out/ge/lib64/stub/ge_prof.cc \ | ||||
../../out/ge/lib64/stub/ge_ir_build.cc \ | |||||
LOCAL_SHARED_LIBRARIES := | LOCAL_SHARED_LIBRARIES := | ||||
@@ -26,7 +26,7 @@ | |||||
namespace ge { | namespace ge { | ||||
LabelAllocator::LabelAllocator(const ComputeGraphPtr &graph) : compute_graph_(graph) {} | LabelAllocator::LabelAllocator(const ComputeGraphPtr &graph) : compute_graph_(graph) {} | ||||
Status LabelAllocator::AssignFunctionalLabels(uint32_t &label_index) { | |||||
Status LabelAllocator::AssignFunctionalLabels() { | |||||
if (compute_graph_ == nullptr) { | if (compute_graph_ == nullptr) { | ||||
GELOGE(INTERNAL_ERROR, "ComputeGraph not set, Assign labels failed."); | GELOGE(INTERNAL_ERROR, "ComputeGraph not set, Assign labels failed."); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
@@ -42,7 +42,7 @@ Status LabelAllocator::AssignFunctionalLabels(uint32_t &label_index) { | |||||
} | } | ||||
// Add label for functional op. | // Add label for functional op. | ||||
label_index = 0; | |||||
uint32_t label_index = 0; | |||||
for (auto node : functional_nodes) { | for (auto node : functional_nodes) { | ||||
LabelMakerPtr maker = LabelMakerFactory::Instance().Create(node->GetType(), compute_graph_, node); | LabelMakerPtr maker = LabelMakerFactory::Instance().Create(node->GetType(), compute_graph_, node); | ||||
if (maker == nullptr) { | if (maker == nullptr) { | ||||
@@ -56,7 +56,8 @@ Status LabelAllocator::AssignFunctionalLabels(uint32_t &label_index) { | |||||
} | } | ||||
} | } | ||||
GELOGI("AssignFunctionalLabels success."); | |||||
(void)AttrUtils::SetInt(*compute_graph_, ATTR_MODEL_LABEL_NUM, label_index); | |||||
GELOGI("AssignFunctionalLabels success, Num: %u.", label_index); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -66,13 +67,29 @@ bool LabelAllocator::CollectFunctionalNode(ComputeGraphPtr &graph, std::set<Node | |||||
return false; | return false; | ||||
} | } | ||||
NodePtr parent = graph->GetParentNode(); | |||||
if (parent == nullptr) { | |||||
GELOGE(INTERNAL_ERROR, "ComputeGraph owner not set: %s.", graph->GetName().c_str()); | |||||
if (graph->GetGraphUnknownFlag()) { | |||||
GELOGD("Graph[%s] is unknown graph, skip label allocator.", graph->GetName().c_str()); | |||||
return true; | |||||
} | |||||
NodePtr func_node = graph->GetParentNode(); | |||||
if (func_node == nullptr) { | |||||
GELOGE(INTERNAL_ERROR, "Parent functional node not set: %s.", graph->GetName().c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
(void)functional_nodes.insert(parent); // unique functional node. | |||||
ComputeGraphPtr owner_graph = func_node->GetOwnerComputeGraph(); | |||||
if (owner_graph == nullptr) { | |||||
GELOGE(INTERNAL_ERROR, "ComputeGraph owner not set: %s.", func_node->GetName().c_str()); | |||||
return false; | |||||
} | |||||
if (owner_graph->GetGraphUnknownFlag()) { | |||||
GELOGD("Graph[%s] is unknown graph, skip label allocator.", owner_graph->GetName().c_str()); | |||||
return true; | |||||
} | |||||
(void)functional_nodes.insert(func_node); // unique functional node. | |||||
return true; | return true; | ||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -28,7 +28,7 @@ class LabelAllocator { | |||||
explicit LabelAllocator(const ComputeGraphPtr &graph); | explicit LabelAllocator(const ComputeGraphPtr &graph); | ||||
~LabelAllocator() = default; | ~LabelAllocator() = default; | ||||
Status AssignFunctionalLabels(uint32_t &label_index); | |||||
Status AssignFunctionalLabels(); | |||||
private: | private: | ||||
bool CollectFunctionalNode(ComputeGraphPtr &graph, std::set<NodePtr> &functional_nodes); | bool CollectFunctionalNode(ComputeGraphPtr &graph, std::set<NodePtr> &functional_nodes); | ||||
@@ -348,7 +348,11 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr | |||||
auto compute_graph = subgraph->subgraph_info.GetSubGraph(); | auto compute_graph = subgraph->subgraph_info.GetSubGraph(); | ||||
for (NodePtr &node : compute_graph->GetDirectNode()) { | for (NodePtr &node : compute_graph->GetDirectNode()) { | ||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | GE_CHECK_NOTNULL(node->GetOpDesc()); | ||||
if (IsEngineSkip(*subgraph) && node->GetInNodes().empty()) { | |||||
if (node->GetOpDesc()->HasAttr(ATTR_NAME_RTS_LABEL_NODE)) { | |||||
node->GetOpDesc()->SetStreamId(context.default_stream); | |||||
GELOGD("Node %s of type %s in subgraph %s is assigned parent stream %ld (engine: %s).", node->GetName().c_str(), | |||||
node->GetType().c_str(), subgraph->name.c_str(), context.default_stream, engine_name.c_str()); | |||||
} else if (IsEngineSkip(*subgraph) && node->GetInNodes().empty()) { | |||||
GELOGD("Node %s of type %s in subgraph %s doesn't need to assign a stream (engine: %s).", | GELOGD("Node %s of type %s in subgraph %s doesn't need to assign a stream (engine: %s).", | ||||
node->GetName().c_str(), node->GetType().c_str(), subgraph->name.c_str(), engine_name.c_str()); | node->GetName().c_str(), node->GetType().c_str(), subgraph->name.c_str(), engine_name.c_str()); | ||||
} else { | } else { | ||||
@@ -885,6 +885,15 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, | |||||
GELOGI("Unreusable block."); | GELOGI("Unreusable block."); | ||||
continue; | continue; | ||||
} | } | ||||
std::string batch_label; | |||||
if (reusable_block->IsSameLabel(batch_label)) { | |||||
std::string op_label; | |||||
(void)ge::AttrUtils::GetStr(node_op_desc, ATTR_NAME_BATCH_LABEL, op_label); | |||||
if (batch_label != op_label) { | |||||
GELOGI("label diff, op name %s", node_op_desc->GetName().c_str()); | |||||
continue; | |||||
} | |||||
} | |||||
// A node can reuse blocks of the same stream and preorder streams | // A node can reuse blocks of the same stream and preorder streams | ||||
if (CanReuseBySize(reusable_block_counts_, *reusable_block, block_size, real_size, continuous)) { | if (CanReuseBySize(reusable_block_counts_, *reusable_block, block_size, real_size, continuous)) { | ||||
@@ -11,12 +11,15 @@ local_lib_src_files := memory_assigner.cc \ | |||||
local_lib_inc_path := ${LOCAL_PATH} \ | local_lib_inc_path := ${LOCAL_PATH} \ | ||||
${TOPDIR}inc \ | ${TOPDIR}inc \ | ||||
${TOPDIR}metadef/inc \ | |||||
${TOPDIR}graphengine/inc \ | |||||
${TOPDIR}inc/external \ | ${TOPDIR}inc/external \ | ||||
${TOPDIR}inc/external/graph \ | |||||
${TOPDIR}metadef/inc/external \ | |||||
${TOPDIR}graphengine/inc/external \ | |||||
${TOPDIR}metadef/inc/external/graph \ | |||||
$(TOPDIR)libc_sec/include \ | $(TOPDIR)libc_sec/include \ | ||||
${TOPDIR}third_party/protobuf/include \ | ${TOPDIR}third_party/protobuf/include \ | ||||
${TOPDIR}inc/framework \ | |||||
$(TOPDIR)framework/domi \ | |||||
${TOPDIR}graphengine/inc/framework \ | |||||
$(TOPDIR)graphengine/ge \ | $(TOPDIR)graphengine/ge \ | ||||
#compiler for host | #compiler for host | ||||
@@ -24,7 +24,6 @@ | |||||
#include "graph/anchor.h" | #include "graph/anchor.h" | ||||
#include "graph/attr_value.h" | #include "graph/attr_value.h" | ||||
#include "graph/buffer.h" | #include "graph/buffer.h" | ||||
#include "graph/build/label_allocator.h" | |||||
#include "graph/build/stream_allocator.h" | #include "graph/build/stream_allocator.h" | ||||
#include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
#include "graph/common/ge_call_wrapper.h" | #include "graph/common/ge_call_wrapper.h" | ||||
@@ -43,7 +42,6 @@ | |||||
#include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "graph/passes/memcpy_addr_async_pass.h" | |||||
#include "init/gelib.h" | #include "init/gelib.h" | ||||
#include "memory/memory_assigner.h" | #include "memory/memory_assigner.h" | ||||
#include "omg/version.h" | #include "omg/version.h" | ||||
@@ -419,6 +417,14 @@ Status ModelBuilder::BuildModelDef(ge::Model &model) { | |||||
return FAILED); | return FAILED); | ||||
GELOGI("For model, max_mem_offset_: %zu, p2p_mem_size: %zu, zero_copy_mem_size_: %zu", max_mem_offset_, | GELOGI("For model, max_mem_offset_: %zu, p2p_mem_size: %zu, zero_copy_mem_size_: %zu", max_mem_offset_, | ||||
p2p_mem_offset_, zero_copy_mem_size_); | p2p_mem_offset_, zero_copy_mem_size_); | ||||
string fp_ceiling_mode; | |||||
if (ge::GetContext().GetOption("ge.fpCeilingMode", fp_ceiling_mode) == SUCCESS) { | |||||
if (!ge::AttrUtils::SetStr(&model, ATTR_FP_CEILING_MODE, fp_ceiling_mode)) { | |||||
GELOGE(FAILED, "Failed to set attr ATTR_FP_CEILING_MODE"); | |||||
return FAILED; | |||||
} | |||||
GELOGI("Set attr ATTR_FP_CEILING_MODE to model, value is %s.", fp_ceiling_mode.c_str()); | |||||
} | |||||
string ge_core_type; | string ge_core_type; | ||||
Status ret = ge::GetContext().GetOption(kCoreType, ge_core_type); | Status ret = ge::GetContext().GetOption(kCoreType, ge_core_type); | ||||
@@ -695,25 +701,8 @@ Status ModelBuilder::BuildModelForGetTask(ge::Model &model) { | |||||
GE_TIMESTAMP_END(AssignLogicalStreams, "GraphBuilder::AssignLogicalStreams"); | GE_TIMESTAMP_END(AssignLogicalStreams, "GraphBuilder::AssignLogicalStreams"); | ||||
// Assign functional op labels. | // Assign functional op labels. | ||||
GE_TIMESTAMP_START(AssignFunctionalLabels); | |||||
LabelAllocator label_allocator(compute_graph_); | |||||
GE_CHK_STATUS_RET(label_allocator.AssignFunctionalLabels(label_num_), "Assign label failed."); | |||||
GE_TIMESTAMP_END(AssignFunctionalLabels, "ModelBuilder::AssignFunctionalLabels"); | |||||
// Add memcpy_addr_async node. | |||||
rtFeatureType_t feature_type = FEATURE_TYPE_MEMCPY; | |||||
int32_t feature_info = MEMCPY_INFO_SUPPORT_ZEROCOPY; | |||||
int64_t value = 0; | |||||
rtError_t rt_ret = rtGetRtCapability(feature_type, feature_info, &value); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "rtGetRtCapability failed."); | |||||
return RT_FAILED; | |||||
} else { | |||||
GE_TIMESTAMP_START(AddMemcpyAddrAsyncNode); | |||||
MemcpyAddrAsyncPass memcpy_addr; | |||||
GE_CHK_STATUS_RET(memcpy_addr.Run(compute_graph_), "Add memcpy_addr_async node failed."); | |||||
GE_TIMESTAMP_END(AddMemcpyAddrAsyncNode, "MemcpyAddrAsyncPass::Run."); | |||||
} | |||||
auto root_graph = GraphUtils::FindRootGraph(compute_graph_); | |||||
(void)AttrUtils::GetInt(*root_graph, ATTR_MODEL_LABEL_NUM, label_num_); | |||||
GE_TIMESTAMP_START(AssignMemory); | GE_TIMESTAMP_START(AssignMemory); | ||||
MemoryAssigner mem_assigner(compute_graph_); | MemoryAssigner mem_assigner(compute_graph_); | ||||
@@ -80,4 +80,13 @@ bool TransOpUtil::CheckPrecisionLoss(const ge::NodePtr &src_node) { | |||||
} | } | ||||
return true; | return true; | ||||
} | } | ||||
std::string TransOpUtil::TransopMapToString() { | |||||
std::string buffer; | |||||
for (auto &key : Instance().transop_index_map_) { | |||||
buffer += key.first + " "; | |||||
} | |||||
return buffer; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -35,6 +35,8 @@ class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY TransOpUtil { | |||||
static bool CheckPrecisionLoss(const NodePtr &src_node); | static bool CheckPrecisionLoss(const NodePtr &src_node); | ||||
static std::string TransopMapToString(); | |||||
private: | private: | ||||
TransOpUtil(); | TransOpUtil(); | ||||
@@ -23,75 +23,65 @@ | |||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
namespace { | |||||
const int64_t kInvalidStreamId = -1; | |||||
} // namespace | |||||
namespace ge { | namespace ge { | ||||
/** | /** | ||||
* @ingroup ge | * @ingroup ge | ||||
* @brief Set stream id for head node. | |||||
* @brief Link node to graph head. | |||||
* @param [in] graph: graph for add node. | * @param [in] graph: graph for add node. | ||||
* @param [in] op_desc: OpDesc for set logical stream id. | |||||
* @param [in] node: Node add to graph head. | |||||
* @return: void | * @return: void | ||||
*/ | */ | ||||
void LabelMaker::SetStreamIdEnter(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { | |||||
int64_t stream_id = kInvalidStreamId; | |||||
const auto &node_list = graph->GetDirectNode(); | |||||
for (size_t i = 0; i < node_list.size(); ++i) { | |||||
const auto &node = node_list.at(i); | |||||
GE_CHECK_NOTNULL_EXEC(node, continue); | |||||
void LabelMaker::LinkToGraphHead(const ComputeGraphPtr &graph, const NodePtr &node) { | |||||
static const std::set<std::string> non_calc_types = {DATA, CONSTANT, CONSTANTOP, VARIABLE}; | |||||
for (auto &n : graph->GetDirectNode()) { | |||||
if (non_calc_types.count(n->GetType()) > 0) { | |||||
continue; | |||||
} | |||||
stream_id = node->GetOpDesc()->GetStreamId(); | |||||
if (stream_id != kInvalidStreamId) { | |||||
break; | |||||
const auto nodes = n->GetInDataNodes(); | |||||
if (nodes.empty()) { | |||||
continue; | |||||
} | } | ||||
} | |||||
GELOGI("SetStreamId: Node %s assign stream is %ld.", op_desc->GetName().c_str(), stream_id); | |||||
op_desc->SetStreamId(stream_id); | |||||
} | |||||
bool is_head_node = true; | |||||
for (auto &in_node : nodes) { | |||||
if (non_calc_types.count(in_node->GetType()) == 0) { | |||||
is_head_node = false; | |||||
break; | |||||
} | |||||
} | |||||
/** | |||||
* @ingroup ge | |||||
* @brief Set stream id for tail node. | |||||
* @param [in] graph: graph for add node. | |||||
* @param [in] op_desc: OpDesc for set logical stream id. | |||||
* @return: void | |||||
*/ | |||||
void LabelMaker::SetStreamIdLeave(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { | |||||
int64_t stream_id = kInvalidStreamId; | |||||
const auto &node_list = graph->GetDirectNode(); | |||||
for (size_t i = node_list.size(); i > 0; --i) { | |||||
const auto &node = node_list.at(i - 1); // i from list size, need shift 1. | |||||
GE_CHECK_NOTNULL_EXEC(node, continue); | |||||
if (!is_head_node) { | |||||
continue; | |||||
} | |||||
stream_id = node->GetOpDesc()->GetStreamId(); | |||||
if (stream_id != kInvalidStreamId) { | |||||
break; | |||||
if (GraphUtils::AddEdge(node->GetOutControlAnchor(), n->GetInControlAnchor()) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Add ctrl edge from %s to %s failed.", node->GetName().c_str(), n->GetName().c_str()); | |||||
} | } | ||||
} | } | ||||
GELOGI("SetStreamId: Node %s assign stream is %ld.", op_desc->GetName().c_str(), stream_id); | |||||
op_desc->SetStreamId(stream_id); | |||||
} | } | ||||
/** | /** | ||||
* @ingroup ge | * @ingroup ge | ||||
* @brief Set stream id for parent node. | |||||
* @brief Link node to graph tail. | |||||
* @param [in] graph: graph for add node. | * @param [in] graph: graph for add node. | ||||
* @param [in] op_desc: OpDesc for set logical stream id. | |||||
* @param [in] node: Node add to graph tail. | |||||
* @return: void | * @return: void | ||||
*/ | */ | ||||
void LabelMaker::SetStreamIdOwner(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { | |||||
int64_t stream_id = kInvalidStreamId; | |||||
const auto &node = graph->GetParentNode(); | |||||
if (node != nullptr) { | |||||
stream_id = node->GetOpDesc()->GetStreamId(); | |||||
} | |||||
void LabelMaker::LinkToGraphTail(const ComputeGraphPtr &graph, const NodePtr &node) { | |||||
auto tail = graph->FindFirstNodeMatchType(NETOUTPUT); | |||||
while (tail != nullptr) { | |||||
auto nodes = tail->GetOutControlNodes(); | |||||
if (!nodes.empty()) { | |||||
tail = nodes.at(0); | |||||
continue; | |||||
} | |||||
GELOGI("SetStreamId: Node %s assign stream is %ld.", op_desc->GetName().c_str(), stream_id); | |||||
op_desc->SetStreamId(stream_id); | |||||
if (GraphUtils::AddEdge(tail->GetOutControlAnchor(), node->GetInControlAnchor()) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Add ctrl edge from %s to %s failed.", tail->GetName().c_str(), node->GetName().c_str()); | |||||
} | |||||
return; | |||||
} | |||||
} | } | ||||
/** | /** | ||||
@@ -112,7 +102,7 @@ NodePtr LabelMaker::AddStreamActive(const ComputeGraphPtr &graph, const std::str | |||||
OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMACTIVE); | OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMACTIVE); | ||||
GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
SetStreamIdOwner(graph, op_desc); | |||||
(void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true); | |||||
GELOGI("StreamActive: Create node %s.", op_desc->GetName().c_str()); | GELOGI("StreamActive: Create node %s.", op_desc->GetName().c_str()); | ||||
vector<uint32_t> active_streams; | vector<uint32_t> active_streams; | ||||
@@ -122,6 +112,7 @@ NodePtr LabelMaker::AddStreamActive(const ComputeGraphPtr &graph, const std::str | |||||
NodePtr stream_active = graph->AddNodeFront(op_desc); | NodePtr stream_active = graph->AddNodeFront(op_desc); | ||||
GE_CHECK_NOTNULL_EXEC(stream_active, return nullptr); | GE_CHECK_NOTNULL_EXEC(stream_active, return nullptr); | ||||
LinkToGraphHead(graph, stream_active); | |||||
return stream_active; | return stream_active; | ||||
} | } | ||||
@@ -146,7 +137,7 @@ NodePtr LabelMaker::AddLabelSetEnter(const ComputeGraphPtr &graph, const std::st | |||||
OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSET); | OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSET); | ||||
GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
SetStreamIdOwner(graph, op_desc); | |||||
(void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true); | |||||
GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str()); | GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str()); | ||||
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | ||||
@@ -173,19 +164,9 @@ NodePtr LabelMaker::AddLabelSetEnter(const ComputeGraphPtr &graph, const std::st | |||||
NodePtr LabelMaker::AddLabelSetLeave(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) { | NodePtr LabelMaker::AddLabelSetLeave(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) { | ||||
GE_CHECK_NOTNULL_EXEC(graph, return nullptr); | GE_CHECK_NOTNULL_EXEC(graph, return nullptr); | ||||
const auto &node_list = graph->GetDirectNode(); | |||||
auto it = node_list.end(); | |||||
if (it == node_list.begin()) { | |||||
GELOGE(INTERNAL_ERROR, "LabelSet: Graph %s node is empty.", graph->GetName().c_str()); | |||||
return nullptr; | |||||
} | |||||
--it; | |||||
const NodePtr &node = *it; | |||||
GE_CHECK_NOTNULL_EXEC(node, return nullptr); | |||||
OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSET); | OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSET); | ||||
GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
SetStreamIdOwner(graph, op_desc); | |||||
(void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true); | |||||
GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str()); | GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str()); | ||||
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | ||||
@@ -194,11 +175,7 @@ NodePtr LabelMaker::AddLabelSetLeave(const ComputeGraphPtr &graph, const std::st | |||||
GE_CHECK_NOTNULL_EXEC(label_set, return nullptr); | GE_CHECK_NOTNULL_EXEC(label_set, return nullptr); | ||||
// Link control edge to graph tail. | // Link control edge to graph tail. | ||||
if (GraphUtils::AddEdge(node->GetOutControlAnchor(), label_set->GetInControlAnchor()) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "LabelSet: Add ctrl edge to %s failed.", node->GetName().c_str()); | |||||
return nullptr; | |||||
} | |||||
LinkToGraphTail(graph, label_set); | |||||
return label_set; | return label_set; | ||||
} | } | ||||
@@ -222,7 +199,7 @@ NodePtr LabelMaker::AddLabelGotoEnter(const ComputeGraphPtr &graph, const std::s | |||||
OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTOEX); | OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTOEX); | ||||
GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
SetStreamIdOwner(graph, op_desc); | |||||
(void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true); | |||||
GELOGI("LabelGoto: Create node %s.", op_desc->GetName().c_str()); | GELOGI("LabelGoto: Create node %s.", op_desc->GetName().c_str()); | ||||
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | ||||
@@ -246,32 +223,17 @@ NodePtr LabelMaker::AddLabelGotoEnter(const ComputeGraphPtr &graph, const std::s | |||||
NodePtr LabelMaker::AddLabelGotoLeave(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) { | NodePtr LabelMaker::AddLabelGotoLeave(const ComputeGraphPtr &graph, const std::string &name, uint32_t index) { | ||||
GE_CHECK_NOTNULL_EXEC(graph, return nullptr); | GE_CHECK_NOTNULL_EXEC(graph, return nullptr); | ||||
const auto &node_list = graph->GetDirectNode(); | |||||
auto it = node_list.end(); | |||||
if (it == node_list.begin()) { | |||||
GELOGE(INTERNAL_ERROR, "LabelGoto: Graph %s node is empty.", graph->GetName().c_str()); | |||||
return nullptr; | |||||
} | |||||
--it; | |||||
const NodePtr &node = *it; | |||||
GE_CHECK_NOTNULL_EXEC(node, return nullptr); | |||||
OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTOEX); | OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTOEX); | ||||
GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
SetStreamIdLeave(graph, op_desc); | |||||
(void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true); | |||||
GELOGI("LabelGoto: Create node %s.", op_desc->GetName().c_str()); | GELOGI("LabelGoto: Create node %s.", op_desc->GetName().c_str()); | ||||
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | ||||
NodePtr label_goto = graph->AddNode(op_desc); | NodePtr label_goto = graph->AddNode(op_desc); | ||||
GE_CHECK_NOTNULL_EXEC(label_goto, return nullptr); | GE_CHECK_NOTNULL_EXEC(label_goto, return nullptr); | ||||
SetStreamIdOwner(graph, op_desc); | |||||
// Link control edge to graph tail. | // Link control edge to graph tail. | ||||
if (GraphUtils::AddEdge(node->GetOutControlAnchor(), label_goto->GetInControlAnchor()) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "LabelGoto: Add ctrl edge to %s failed.", node->GetName().c_str()); | |||||
return nullptr; | |||||
} | |||||
LinkToGraphTail(graph, label_goto); | |||||
return label_goto; | return label_goto; | ||||
} | } | ||||
@@ -297,7 +259,7 @@ NodePtr LabelMaker::AddLabelSwitchEnter(const ComputeGraphPtr &graph, const std: | |||||
OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSWITCHBYINDEX); | OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSWITCHBYINDEX); | ||||
GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
SetStreamIdOwner(graph, op_desc); | |||||
(void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true); | |||||
GELOGI("LabelSwitchByIndex: Create node %s.", op_desc->GetName().c_str()); | GELOGI("LabelSwitchByIndex: Create node %s.", op_desc->GetName().c_str()); | ||||
if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) { | if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) { | ||||
@@ -332,19 +294,9 @@ NodePtr LabelMaker::AddLabelSwitchLeave(const ComputeGraphPtr &graph, const std: | |||||
const std::vector<uint32_t> &labels) { | const std::vector<uint32_t> &labels) { | ||||
GE_CHECK_NOTNULL_EXEC(graph, return nullptr); | GE_CHECK_NOTNULL_EXEC(graph, return nullptr); | ||||
const auto &node_list = graph->GetDirectNode(); | |||||
auto it = node_list.end(); | |||||
if (it == node_list.begin()) { | |||||
GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Graph %s node is empty.", graph->GetName().c_str()); | |||||
return nullptr; | |||||
} | |||||
--it; | |||||
const NodePtr &node = *it; | |||||
GE_CHECK_NOTNULL_EXEC(node, return nullptr); | |||||
OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSWITCHBYINDEX); | OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELSWITCHBYINDEX); | ||||
GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
SetStreamIdOwner(graph, op_desc); | |||||
(void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, true); | |||||
GELOGI("LabelSwitchByIndex: Create node %s.", op_desc->GetName().c_str()); | GELOGI("LabelSwitchByIndex: Create node %s.", op_desc->GetName().c_str()); | ||||
if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) { | if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) { | ||||
@@ -361,11 +313,7 @@ NodePtr LabelMaker::AddLabelSwitchLeave(const ComputeGraphPtr &graph, const std: | |||||
GE_CHECK_NOTNULL_EXEC(label_switch, return nullptr); | GE_CHECK_NOTNULL_EXEC(label_switch, return nullptr); | ||||
// Link control edge to graph tail. | // Link control edge to graph tail. | ||||
if (GraphUtils::AddEdge(node->GetOutControlAnchor(), label_switch->GetInControlAnchor()) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "LabelSwitchByIndex: Add ctrl edge to %s failed.", node->GetName().c_str()); | |||||
return nullptr; | |||||
} | |||||
LinkToGraphTail(graph, label_switch); | |||||
return label_switch; | return label_switch; | ||||
} | } | ||||
@@ -385,7 +333,6 @@ NodePtr LabelMaker::AddLabelSwitchIndex(const ComputeGraphPtr &graph, const std: | |||||
OpDescPtr op_desc = MakeShared<OpDesc>(name, DATA); | OpDescPtr op_desc = MakeShared<OpDesc>(name, DATA); | ||||
GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
op_desc->SetStreamId(kInvalidStreamId); | |||||
GELOGI("Data: Create node %s.", op_desc->GetName().c_str()); | GELOGI("Data: Create node %s.", op_desc->GetName().c_str()); | ||||
if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) { | if (op_desc->AddInputDesc(desc) != GRAPH_SUCCESS) { | ||||
@@ -60,9 +60,8 @@ class LabelMaker { | |||||
ComputeGraphPtr parent_graph_; | ComputeGraphPtr parent_graph_; | ||||
private: | private: | ||||
void SetStreamIdEnter(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); | |||||
void SetStreamIdLeave(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); | |||||
void SetStreamIdOwner(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); | |||||
void LinkToGraphHead(const ComputeGraphPtr &graph, const NodePtr &node); | |||||
void LinkToGraphTail(const ComputeGraphPtr &graph, const NodePtr &node); | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GRAPH_PASSES_LABEL_MAKER_H_ | #endif // GE_GRAPH_PASSES_LABEL_MAKER_H_ |
@@ -86,6 +86,7 @@ class DataDumper { | |||||
void SetDumpProperties(const DumpProperties &dump_properties) { dump_properties_ = dump_properties; } | void SetDumpProperties(const DumpProperties &dump_properties) { dump_properties_ = dump_properties; } | ||||
const DumpProperties &GetDumpProperties() const { return dump_properties_; } | const DumpProperties &GetDumpProperties() const { return dump_properties_; } | ||||
bool GetOpDescInfo(uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info) const; | bool GetOpDescInfo(uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info) const; | ||||
const std::vector<OpDescInfo> &GetAllOpDescInfo() const { return op_desc_info_; } | |||||
// Dump exception info | // Dump exception info | ||||
Status DumpExceptionInput(const OpDescInfo &op_desc_info, const string &dump_file); | Status DumpExceptionInput(const OpDescInfo &op_desc_info, const string &dump_file); | ||||
@@ -88,6 +88,9 @@ const uint32_t kDataMemAlignSizeCompare = 64; | |||||
const uint32_t kDumpL1FusionOpMByteSize = 2 * 1024 * 1024; | const uint32_t kDumpL1FusionOpMByteSize = 2 * 1024 * 1024; | ||||
const uint32_t kDumpFlagOfL1Fusion = 0; | const uint32_t kDumpFlagOfL1Fusion = 0; | ||||
const char *const kDefaultBatchLable = "Batch_default"; | const char *const kDefaultBatchLable = "Batch_default"; | ||||
const int32_t kInvalidStream = -1; | |||||
const uint32_t kEndOfSequence = 0x0704000a; | |||||
const uint32_t kEndOfSequenceNew = 507005; | |||||
inline bool IsDataOp(const std::string &node_type) { | inline bool IsDataOp(const std::string &node_type) { | ||||
return node_type == DATA_TYPE || node_type == AIPP_DATA_TYPE || node_type == ANN_DATA_TYPE; | return node_type == DATA_TYPE || node_type == AIPP_DATA_TYPE || node_type == ANN_DATA_TYPE; | ||||
@@ -259,7 +262,6 @@ Status DavinciModel::Assign(const GeModelPtr &ge_model) { | |||||
/// | /// | ||||
void DavinciModel::Shrink() { | void DavinciModel::Shrink() { | ||||
ge_model_.reset(); // delete object. | ge_model_.reset(); // delete object. | ||||
op_list_.clear(); | |||||
} | } | ||||
Status DavinciModel::InitModelMem(void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) { | Status DavinciModel::InitModelMem(void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) { | ||||
@@ -612,7 +614,9 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size | |||||
GE_DISMISS_GUARD(stream); | GE_DISMISS_GUARD(stream); | ||||
stream_list_.push_back(stream); | stream_list_.push_back(stream); | ||||
GELOGD("Stream index:%u, stream:%p.", i, stream); | |||||
int32_t rt_stream_id = kInvalidStream; | |||||
(void)rtGetStreamId(stream, &rt_stream_id); | |||||
GELOGI("Logical stream index:%u, stream:%p, rtstream: %d.", i, stream, rt_stream_id); | |||||
} | } | ||||
for (uint32_t i = 0; i < EventNum(); i++) { | for (uint32_t i = 0; i < EventNum(); i++) { | ||||
@@ -654,18 +658,6 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size | |||||
GE_IF_BOOL_EXEC(IsBroadCastOpData(node), | GE_IF_BOOL_EXEC(IsBroadCastOpData(node), | ||||
(void)ge::AttrUtils::SetStr(op_desc, VAR_ATTR_VAR_IS_BROADCAST, "var_is_restore");); | (void)ge::AttrUtils::SetStr(op_desc, VAR_ATTR_VAR_IS_BROADCAST, "var_is_restore");); | ||||
} | } | ||||
// for profiling | |||||
op_name_map_ = compute_graph->GetGraphOpName(); | |||||
vector<string> op_name; | |||||
GE_IF_BOOL_EXEC(ge::AttrUtils::GetListStr(ge_model_, ATTR_MODEL_TASK_INDEX_OP_NAME, op_name), | |||||
GELOGI("get str of task_index_op_name")); | |||||
if (op_name_map_.empty()) { | |||||
for (size_t idx = 0; idx < op_name.size(); idx++) { | |||||
op_name_map_[idx] = op_name[idx]; | |||||
} | |||||
GELOGI("Infer profiling: op_name_size(%zu)", op_name.size()); | |||||
} | |||||
GE_CHK_STATUS_RET(InitNodes(compute_graph), "Init nodes failed"); | GE_CHK_STATUS_RET(InitNodes(compute_graph), "Init nodes failed"); | ||||
@@ -677,7 +669,9 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size | |||||
auto all_dump_model = GetDumpProperties().GetAllDumpModel(); | auto all_dump_model = GetDumpProperties().GetAllDumpModel(); | ||||
bool findByOmName = all_dump_model.find(om_name_) != all_dump_model.end(); | bool findByOmName = all_dump_model.find(om_name_) != all_dump_model.end(); | ||||
bool findByModelName = all_dump_model.find(name_) != all_dump_model.end(); | bool findByModelName = all_dump_model.find(name_) != all_dump_model.end(); | ||||
if (all_dump_model.find(ge::DUMP_ALL_MODEL) != all_dump_model.end() || findByOmName || findByModelName) { | |||||
bool dump_l1fusion_op = | |||||
(all_dump_model.find(ge::DUMP_ALL_MODEL) != all_dump_model.end()) || findByOmName || findByModelName; | |||||
if (dump_l1fusion_op) { | |||||
// malloc 2M for dump l1fusion op | // malloc 2M for dump l1fusion op | ||||
GE_CHK_RT_RET(rtMalloc(&l1_fusion_addr_, kDumpL1FusionOpMByteSize, RT_MEMORY_DDR)); | GE_CHK_RT_RET(rtMalloc(&l1_fusion_addr_, kDumpL1FusionOpMByteSize, RT_MEMORY_DDR)); | ||||
@@ -691,16 +685,21 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size | |||||
need_destroy_aicpu_kernel_ = IsAicpuKernelConnectSpecifiedLayer(); | need_destroy_aicpu_kernel_ = IsAicpuKernelConnectSpecifiedLayer(); | ||||
(void)ge::AttrUtils::GetListStr(ge_model_, ATTR_MODEL_OUT_NODES_NAME, out_node_name_); | (void)ge::AttrUtils::GetListStr(ge_model_, ATTR_MODEL_OUT_NODES_NAME, out_node_name_); | ||||
string fp_ceiling_mode; | |||||
if (ge::AttrUtils::GetStr(ge_model_, ATTR_FP_CEILING_MODE, fp_ceiling_mode)) { | |||||
GELOGI("Get attr ATTR_FP_CEILING_MODE from model, value is %s.", fp_ceiling_mode.c_str()); | |||||
// mode 0: Do not perform saturation processing. By default, IEEE754 is used. | |||||
GE_CHK_RT_RET(rtSetCtxINFMode((fp_ceiling_mode != "0"))); | |||||
} | |||||
// collect profiling for ge | // collect profiling for ge | ||||
if (ProfilingManager::Instance().ProfilingModelLoadOn()) { | |||||
std::vector<ComputeGraphDescInfo> compute_graph_desc_info; | |||||
Status ret1 = GetComputeGraphInfo(compute_graph, compute_graph_desc_info); | |||||
if (ret1 != SUCCESS) { | |||||
GELOGE(ret1, "GetComputeGraphInfo failed."); | |||||
return ret1; | |||||
auto &profiling_manager = ProfilingManager::Instance(); | |||||
if (profiling_manager.ProfilingModelLoadOn()) { | |||||
Status p_ret = ReportProfilingData(!profiling_manager.IsAclApiMode()); | |||||
if (p_ret != SUCCESS) { | |||||
GELOGE(p_ret, "Report profiling data failed."); | |||||
return p_ret; | |||||
} | } | ||||
ProfilingManager::Instance().ReportProfilingData(GetTaskDescInfo(), compute_graph_desc_info); | |||||
GE_CHK_STATUS(SinkModelProfile(), "Sink model profile failed."); | |||||
} | } | ||||
Shrink(); | Shrink(); | ||||
@@ -708,6 +707,20 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size | |||||
return ret; | return ret; | ||||
} | } | ||||
Status DavinciModel::ReportProfilingData(bool check_device) { | |||||
std::vector<ComputeGraphDescInfo> compute_graph_desc_info; | |||||
Status ret = GetComputeGraphInfo(compute_graph_desc_info); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "GetComputeGraphInfo failed."); | |||||
return ret; | |||||
} | |||||
ProfilingManager::Instance().ReportProfilingData(model_id_, GetTaskDescInfo(), compute_graph_desc_info, check_device); | |||||
GE_CHK_STATUS(SinkModelProfile(), "Sink model profiler failed."); | |||||
op_list_.clear(); | |||||
return SUCCESS; | |||||
} | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Travel all nodes and determine if destruction is required. | /// @brief Travel all nodes and determine if destruction is required. | ||||
@@ -2572,7 +2585,7 @@ void *DavinciModel::Run(DavinciModel *model) { | |||||
GE_TIMESTAMP_START(rtStreamSynchronize); | GE_TIMESTAMP_START(rtStreamSynchronize); | ||||
GELOGI("rtStreamSynchronize start."); | GELOGI("rtStreamSynchronize start."); | ||||
rt_ret = rtStreamSynchronize(model->rt_model_stream_); | rt_ret = rtStreamSynchronize(model->rt_model_stream_); | ||||
if (rt_ret == RT_ERROR_END_OF_SEQUENCE) { | |||||
if (rt_ret == kEndOfSequence || rt_ret == kEndOfSequenceNew) { | |||||
seq_end_flag = true; | seq_end_flag = true; | ||||
} | } | ||||
GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
@@ -2901,34 +2914,25 @@ Status DavinciModel::DistributeTask() { | |||||
SaveDumpTask(task->GetTaskID(), task->GetStreamId(), op, task->GetDumpArgs()); | SaveDumpTask(task->GetTaskID(), task->GetStreamId(), op, task->GetDumpArgs()); | ||||
} | } | ||||
} | } | ||||
// get op_name by task_index | |||||
if (task->GetCtx() != nullptr) { | |||||
auto iter = op_name_map_.find(task_index); | |||||
if (iter == op_name_map_.end()) { | |||||
continue; | |||||
} | |||||
// else task index is found in op_name_map_ | |||||
TaskDescInfo task_desc_info; | |||||
string op_name = op_name_map_[task_index]; | |||||
if (!om_name_.empty()) { | |||||
task_desc_info.model_name = om_name_; | |||||
} else { | |||||
task_desc_info.model_name = name_; | |||||
} | |||||
task_desc_info.op_name = op_name; | |||||
task_desc_info.block_dim = model_task_def->task(task_index).kernel().block_dim(); | |||||
task_desc_info.task_id = task->GetTaskID(); | |||||
task_desc_info.stream_id = task->GetStreamId(); | |||||
task_desc_info_.emplace_back(task_desc_info); | |||||
if (flag) { | |||||
if (task->GetSktTaskID() != 0xFFFFFFFF) { | |||||
TaskDescInfo task_desc_info; | |||||
string op_name = "super_kernel_" + to_string(task_index); | |||||
task_desc_info.op_name = op_name; | |||||
task_desc_info.task_id = task->GetSktTaskID(); | |||||
task_desc_info_.emplace_back(task_desc_info); | |||||
} | |||||
// Load task info for profiling | |||||
TaskDescInfo task_desc_info; | |||||
if (!om_name_.empty()) { | |||||
task_desc_info.model_name = om_name_; | |||||
} else { | |||||
task_desc_info.model_name = name_; | |||||
} | |||||
task_desc_info.op_name = op->GetName(); | |||||
task_desc_info.block_dim = model_task_def->task(task_index).kernel().block_dim(); | |||||
task_desc_info.task_id = task->GetTaskID(); | |||||
task_desc_info.stream_id = task->GetStreamId(); | |||||
task_desc_info_.emplace_back(task_desc_info); | |||||
if (flag) { | |||||
if (task->GetSktTaskID() != 0xFFFFFFFF) { | |||||
TaskDescInfo task_desc_info; | |||||
string op_name = "super_kernel_" + to_string(task_index); | |||||
task_desc_info.op_name = op_name; | |||||
task_desc_info.task_id = task->GetSktTaskID(); | |||||
task_desc_info_.emplace_back(task_desc_info); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -3818,50 +3822,31 @@ void DavinciModel::SaveHcclFollowStream(int64_t main_stream_id, rtStream_t strea | |||||
main_follow_stream_mapping_[main_stream_id].emplace_back(stream); | main_follow_stream_mapping_[main_stream_id].emplace_back(stream); | ||||
} | } | ||||
Status DavinciModel::GetComputeGraphInfo(const ComputeGraphPtr &graph, vector<ComputeGraphDescInfo> &graph_desc_info) { | |||||
Status DavinciModel::GetComputeGraphInfo(vector<ComputeGraphDescInfo> &graph_desc_info) { | |||||
GELOGI("GetComputeGraphInfo start."); | GELOGI("GetComputeGraphInfo start."); | ||||
for (auto &node : graph->GetAllNodes()) { | |||||
auto &all_op_desc = data_dumper_.GetAllOpDescInfo(); | |||||
for (auto &op_desc : all_op_desc) { | |||||
ComputeGraphDescInfo compute_graph_info; | ComputeGraphDescInfo compute_graph_info; | ||||
auto op_desc = node->GetOpDesc(); | |||||
if (op_desc == nullptr) { | |||||
GELOGE(PARAM_INVALID, "op_desc is nullptr."); | |||||
return PARAM_INVALID; | |||||
if (!om_name_.empty()) { | |||||
compute_graph_info.model_name = om_name_; | |||||
} else { | |||||
compute_graph_info.model_name = name_; | |||||
} | } | ||||
compute_graph_info.op_name = op_desc.op_name; | |||||
compute_graph_info.op_type = op_desc.op_type; | |||||
compute_graph_info.input_format = op_desc.input_format; | |||||
compute_graph_info.input_shape = op_desc.input_shape; | |||||
compute_graph_info.input_data_type = op_desc.input_data_type; | |||||
compute_graph_info.output_format = op_desc.output_format; | |||||
compute_graph_info.output_shape = op_desc.output_shape; | |||||
compute_graph_info.output_data_type = op_desc.output_data_type; | |||||
auto op_mode = static_cast<uint32_t>(domi::ImplyType::INVALID); | |||||
if (AttrUtils::GetInt(op_desc, ATTR_NAME_IMPLY_TYPE, op_mode) && | |||||
op_mode == static_cast<uint32_t>(domi::ImplyType::TVM)) { | |||||
if (!om_name_.empty()) { | |||||
compute_graph_info.model_name = om_name_; | |||||
} else { | |||||
compute_graph_info.model_name = name_; | |||||
} | |||||
compute_graph_info.op_name = op_desc->GetName(); | |||||
compute_graph_info.op_type = op_desc->GetType(); | |||||
for (size_t i = 0; i < op_desc->GetAllInputsSize(); ++i) { | |||||
GeTensorDescPtr input_desc = op_desc->MutableInputDesc(i); | |||||
if (input_desc == nullptr) { | |||||
continue; | |||||
} | |||||
compute_graph_info.input_format.emplace_back(input_desc->GetFormat()); | |||||
compute_graph_info.input_shape.emplace_back(input_desc->GetShape().GetDims()); | |||||
compute_graph_info.input_data_type.emplace_back(input_desc->GetDataType()); | |||||
} | |||||
for (size_t j = 0; j < op_desc->GetOutputsSize(); ++j) { | |||||
GeTensorDesc output_desc = op_desc->GetOutputDesc(j); | |||||
compute_graph_info.output_format.emplace_back(output_desc.GetFormat()); | |||||
compute_graph_info.output_shape.emplace_back(output_desc.GetShape().GetDims()); | |||||
compute_graph_info.output_data_type.emplace_back(output_desc.GetDataType()); | |||||
} | |||||
graph_desc_info.emplace_back(compute_graph_info); | |||||
} | |||||
graph_desc_info.emplace_back(compute_graph_info); | |||||
} | } | ||||
GELOGI("GetComputeGraphInfo end."); | GELOGI("GetComputeGraphInfo end."); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void DavinciModel::SetTotalFixedAddrsSize(string tensor_name, int64_t fix_addr_size) { | void DavinciModel::SetTotalFixedAddrsSize(string tensor_name, int64_t fix_addr_size) { | ||||
if (tensor_name_to_fixed_addr_size_.find(tensor_name) == tensor_name_to_fixed_addr_size_.end()) { | if (tensor_name_to_fixed_addr_size_.find(tensor_name) == tensor_name_to_fixed_addr_size_.end()) { | ||||
tensor_name_to_fixed_addr_size_[tensor_name] = total_fixed_addr_size_; | tensor_name_to_fixed_addr_size_[tensor_name] = total_fixed_addr_size_; | ||||
@@ -439,6 +439,8 @@ class DavinciModel { | |||||
Status SinkTimeProfile(const InputData ¤t_data); | Status SinkTimeProfile(const InputData ¤t_data); | ||||
Status ReportProfilingData(bool check_device = true); | |||||
void SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr &op, uint32_t task_id, uint32_t stream_id) { | void SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr &op, uint32_t task_id, uint32_t stream_id) { | ||||
data_dumper_.SaveDumpOpInfo(model_param, op, task_id, stream_id); | data_dumper_.SaveDumpOpInfo(model_param, op, task_id, stream_id); | ||||
} | } | ||||
@@ -828,7 +830,7 @@ class DavinciModel { | |||||
Status TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id); | Status TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id); | ||||
// get desc info of graph for profiling | // get desc info of graph for profiling | ||||
Status GetComputeGraphInfo(const ComputeGraphPtr &graph, vector<ComputeGraphDescInfo> &graph_desc_info); | |||||
Status GetComputeGraphInfo(vector<ComputeGraphDescInfo> &graph_desc_info); | |||||
void SetDataDumperArgs(const ComputeGraphPtr &compute_graph); | void SetDataDumperArgs(const ComputeGraphPtr &compute_graph); | ||||
@@ -947,7 +949,6 @@ class DavinciModel { | |||||
std::map<std::string, uint32_t> used_tbe_handle_map_; | std::map<std::string, uint32_t> used_tbe_handle_map_; | ||||
// for profiling task and graph info | // for profiling task and graph info | ||||
std::map<uint32_t, std::string> op_name_map_; | |||||
std::vector<TaskDescInfo> task_desc_info_; | std::vector<TaskDescInfo> task_desc_info_; | ||||
int64_t maxDumpOpNum_; | int64_t maxDumpOpNum_; | ||||
@@ -43,6 +43,8 @@ const std::string kCmdTypeProfInit = "prof_init"; | |||||
const std::string kCmdTypeProfFinalize = "prof_finalize"; | const std::string kCmdTypeProfFinalize = "prof_finalize"; | ||||
const std::string kCmdTypeProfStart = "prof_start"; | const std::string kCmdTypeProfStart = "prof_start"; | ||||
const std::string kCmdTypeProfStop = "prof_stop"; | const std::string kCmdTypeProfStop = "prof_stop"; | ||||
const std::string kCmdTypeProfModelSubscribe = "prof_model_subscribe"; | |||||
const std::string kCmdTypeProfModelUnsubscribe = "prof_model_cancel_subscribe"; | |||||
const char *const kBatchLoadBuf = "batchLoadsoFrombuf"; | const char *const kBatchLoadBuf = "batchLoadsoFrombuf"; | ||||
const char *const kDeleteCustOp = "deleteCustOp"; | const char *const kDeleteCustOp = "deleteCustOp"; | ||||
struct CustAicpuSoBuf { | struct CustAicpuSoBuf { | ||||
@@ -334,11 +336,9 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge | |||||
GELOGI("Parse model %u success.", model_id); | GELOGI("Parse model %u success.", model_id); | ||||
if (ProfilingManager::Instance().ProfilingModelLoadOn()) { | |||||
davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 + | |||||
timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond | |||||
davinci_model->SetProfileTime(MODEL_LOAD_END); | |||||
} | |||||
davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 + | |||||
timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond | |||||
davinci_model->SetProfileTime(MODEL_LOAD_END); | |||||
} while (0); | } while (0); | ||||
GE_CHK_RT(rtDeviceReset(static_cast<int32_t>(GetContext().DeviceId()))); | GE_CHK_RT(rtDeviceReset(static_cast<int32_t>(GetContext().DeviceId()))); | ||||
@@ -562,10 +562,15 @@ Status ModelManager::Stop(uint32_t model_id) { | |||||
/// | /// | ||||
Status ModelManager::HandleCommand(const Command &command) { | Status ModelManager::HandleCommand(const Command &command) { | ||||
static const std::map<std::string, std::function<uint32_t(const Command &)>> cmds = { | static const std::map<std::string, std::function<uint32_t(const Command &)>> cmds = { | ||||
{kCmdTypeProfile, HandleProfileCommand}, {kCmdTypeDump, HandleDumpCommand}, | |||||
{kCmdTypeProfiling, HandleAclProfilingCommand}, {kCmdTypeProfInit, HandleProfInitCommand}, | |||||
{kCmdTypeProfFinalize, HandleProfFinalizeCommand}, {kCmdTypeProfStart, HandleProfStartCommand}, | |||||
{kCmdTypeProfStop, HandleProfStopCommand}}; | |||||
{kCmdTypeProfile, HandleProfileCommand}, | |||||
{kCmdTypeDump, HandleDumpCommand}, | |||||
{kCmdTypeProfiling, HandleAclProfilingCommand}, | |||||
{kCmdTypeProfInit, HandleProfInitCommand}, | |||||
{kCmdTypeProfFinalize, HandleProfFinalizeCommand}, | |||||
{kCmdTypeProfStart, HandleProfStartCommand}, | |||||
{kCmdTypeProfStop, HandleProfStopCommand}, | |||||
{kCmdTypeProfModelSubscribe, HandleProfModelSubscribeCommand}, | |||||
{kCmdTypeProfModelUnsubscribe, HandleProfModelUnsubscribeCommand}}; | |||||
auto iter = cmds.find(command.cmd_type); | auto iter = cmds.find(command.cmd_type); | ||||
if (iter == cmds.end()) { | if (iter == cmds.end()) { | ||||
@@ -591,6 +596,76 @@ Status ModelManager::HandleAclProfilingCommand(const Command &command) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status ModelManager::GetModelByCmd(const Command &command, std::shared_ptr<DavinciModel> &davinci_model) { | |||||
if (command.cmd_params.size() < kCmdParSize) { | |||||
GELOGE(PARAM_INVALID, "When the cmd_type is '%s', the size of cmd_params must larger than 2.", | |||||
command.cmd_type.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
std::string map_key = command.cmd_params[0]; | |||||
std::string value = command.cmd_params[1]; | |||||
if (map_key == PROFILE_MODEL_ID) { | |||||
int32_t model_id = 0; | |||||
try { | |||||
model_id = std::stoi(value); | |||||
} catch (std::invalid_argument &) { | |||||
GELOGE(PARAM_INVALID, "Model id: %s is invalid.", value.c_str()); | |||||
return PARAM_INVALID; | |||||
} catch (std::out_of_range &) { | |||||
GELOGE(PARAM_INVALID, "Model id: %s is out of range.", value.c_str()); | |||||
return PARAM_INVALID; | |||||
} catch (...) { | |||||
GELOGE(FAILED, "Model id: %s cannot change to int.", value.c_str()); | |||||
return FAILED; | |||||
} | |||||
auto model_manager = ModelManager::GetInstance(); | |||||
GE_CHECK_NOTNULL(model_manager); | |||||
davinci_model = model_manager->GetModel(static_cast<uint32_t>(model_id)); | |||||
if (davinci_model == nullptr) { | |||||
GELOGE(FAILED, "Model id: %d is invaild or model is not loaded.", model_id); | |||||
return FAILED; | |||||
} | |||||
} else { | |||||
GELOGE(FAILED, "The model_id parameter is not found in the command."); | |||||
return FAILED; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status ModelManager::HandleProfModelSubscribeCommand(const Command &command) { | |||||
std::shared_ptr<DavinciModel> davinci_model = nullptr; | |||||
Status ret = GetModelByCmd(command, davinci_model); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | |||||
if (ProfilingManager::Instance().ProfModelSubscribe(command.module_index, static_cast<void *>(davinci_model.get())) != | |||||
SUCCESS) { | |||||
GELOGE(FAILED, "Handle prof model subscribe failed."); | |||||
return FAILED; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status ModelManager::HandleProfModelUnsubscribeCommand(const Command &command) { | |||||
std::shared_ptr<DavinciModel> davinci_model = nullptr; | |||||
Status ret = GetModelByCmd(command, davinci_model); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | |||||
if (ProfilingManager::Instance().ProfModelUnsubscribe(static_cast<void *>(davinci_model.get())) != SUCCESS) { | |||||
GELOGE(FAILED, "Handle prof model unsubscribe failed."); | |||||
return FAILED; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status ModelManager::HandleProfInitCommand(const Command &command) { | Status ModelManager::HandleProfInitCommand(const Command &command) { | ||||
uint64_t module_index = command.module_index; | uint64_t module_index = command.module_index; | ||||
if (ProfilingManager::Instance().ProfInit(module_index) != SUCCESS) { | if (ProfilingManager::Instance().ProfInit(module_index) != SUCCESS) { | ||||
@@ -973,11 +1048,9 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model | |||||
GELOGI("Parse model %u success.", model_id); | GELOGI("Parse model %u success.", model_id); | ||||
if (ProfilingManager::Instance().ProfilingModelLoadOn()) { | |||||
davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 + | |||||
timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond | |||||
davinci_model->SetProfileTime(MODEL_LOAD_END); | |||||
} | |||||
davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 + | |||||
timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond | |||||
davinci_model->SetProfileTime(MODEL_LOAD_END); | |||||
GE_IF_BOOL_EXEC(ret == SUCCESS, device_count++); | GE_IF_BOOL_EXEC(ret == SUCCESS, device_count++); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -158,10 +158,14 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
static ge::Status HandleAclProfilingCommand(const Command &command); | static ge::Status HandleAclProfilingCommand(const Command &command); | ||||
static ge::Status HandleProfileCommand(const Command &command); | static ge::Status HandleProfileCommand(const Command &command); | ||||
static ge::Status HandleDumpCommand(const Command &command); | static ge::Status HandleDumpCommand(const Command &command); | ||||
static ge::Status HandleProfModelSubscribeCommand(const Command &command); | |||||
static ge::Status HandleProfModelUnsubscribeCommand(const Command &command); | |||||
static ge::Status HandleProfInitCommand(const Command &command); | static ge::Status HandleProfInitCommand(const Command &command); | ||||
static ge::Status HandleProfFinalizeCommand(const Command &command); | static ge::Status HandleProfFinalizeCommand(const Command &command); | ||||
static ge::Status HandleProfStartCommand(const Command &command); | static ge::Status HandleProfStartCommand(const Command &command); | ||||
static ge::Status HandleProfStopCommand(const Command &command); | static ge::Status HandleProfStopCommand(const Command &command); | ||||
static ge::Status GetModelByCmd(const Command &command, std::shared_ptr<DavinciModel> &davinci_model); | |||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
/// @brief get model memory usage | /// @brief get model memory usage | ||||
@@ -45,7 +45,7 @@ Status ZeroCopyTask::SetTaskArgsOffset(uintptr_t addr, size_t offset) { | |||||
if (it == task_addr_offset_.end()) { | if (it == task_addr_offset_.end()) { | ||||
task_addr_offset_[addr] = {offset}; | task_addr_offset_[addr] = {offset}; | ||||
} else { | } else { | ||||
it->second.push_back(offset); | |||||
it->second.insert(offset); | |||||
} | } | ||||
GELOGI("[ZCPY] %s set task, virtual_addr: 0x%lx, args_addr: %p, size: %zu, offset: %zu", name_.c_str(), addr, | GELOGI("[ZCPY] %s set task, virtual_addr: 0x%lx, args_addr: %p, size: %zu, offset: %zu", name_.c_str(), addr, | ||||
@@ -99,7 +99,7 @@ class ZeroCopyTask { | |||||
bool is_updated_; | bool is_updated_; | ||||
string batch_label_; | string batch_label_; | ||||
// <address from Op, {offset in args}> | // <address from Op, {offset in args}> | ||||
map<uintptr_t, vector<size_t>> task_addr_offset_; | |||||
map<uintptr_t, set<size_t>> task_addr_offset_; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_ZERO_COPY_TASK_H_ | #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_ZERO_COPY_TASK_H_ |
@@ -100,6 +100,8 @@ | |||||
#include "graph/passes/subgraph_const_migration_pass.h" | #include "graph/passes/subgraph_const_migration_pass.h" | ||||
#include "graph/passes/unused_args_clean_pass.h" | #include "graph/passes/unused_args_clean_pass.h" | ||||
#include "graph/passes/global_step_insert_pass.h" | #include "graph/passes/global_step_insert_pass.h" | ||||
#include "graph/passes/memcpy_addr_async_pass.h" | |||||
#include "graph/build/label_allocator.h" | |||||
#include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "graph/graph_util.h" | #include "graph/graph_util.h" | ||||
@@ -131,6 +133,22 @@ bool IsTailingOptimization() { | |||||
GELOGW("OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION not set, use BFSTopologicalSorting by default."); | GELOGW("OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION not set, use BFSTopologicalSorting by default."); | ||||
return false; | return false; | ||||
} | } | ||||
ge::Status CheckFpCeilingMode() { | |||||
static const std::unordered_set<std::string> kValidFpCeilingMode = {"0", "1", "2"}; | |||||
string mode; | |||||
auto ret = ge::GetContext().GetOption("ge.fpCeilingMode", mode); | |||||
if (ret == ge::GRAPH_SUCCESS) { | |||||
if (kValidFpCeilingMode.count(mode) == 0) { | |||||
GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "The fp_ceiling_mode %s is invalid, options are 0, 1, and 2.", mode.c_str()); | |||||
return ge::GE_GRAPH_OPTIONS_INVALID; | |||||
} | |||||
GELOGI("The parameter fp_ceiling_mode is set to %s.", mode.c_str()); | |||||
return ge::SUCCESS; | |||||
} | |||||
GELOGW("The parameter fp_ceiling_mode is not set."); | |||||
return ge::SUCCESS; | |||||
} | |||||
} // namespace | } // namespace | ||||
namespace ge { | namespace ge { | ||||
@@ -162,6 +180,12 @@ Status GraphManager::Initialize(const std::map<string, string> &options) { | |||||
return ret; | return ret; | ||||
} | } | ||||
ret = CheckFpCeilingMode(); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[Initialize] Check fp-ceiling-mode options failed."); | |||||
return ret; | |||||
} | |||||
ret = graph_context_->Initialize(options); | ret = graph_context_->Initialize(options); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "[Initialize] GraphContext initialize failed."); | GELOGE(ret, "[Initialize] GraphContext initialize failed."); | ||||
@@ -325,6 +349,78 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GraphManager::AddGraphWithCopy(const GraphId &graph_id, const Graph &graph, | |||||
const std::map<std::string, std::string> &options, | |||||
const OmgContext &omg_context) { | |||||
if (HasGraphNode(graph_id)) { | |||||
GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST, "[GraphManager] graph exists, graph_id = %u.", graph_id); | |||||
return GE_GRAPH_GRAPH_ALREADY_EXIST; | |||||
} | |||||
auto compute_graph = GraphUtils::GetComputeGraph(graph); | |||||
if (compute_graph != nullptr) { | |||||
compute_graph->SetGraphID(graph_id); | |||||
bool graph_has_been_added = false; | |||||
if (AttrUtils::GetBool(*compute_graph, ATTR_NAME_GRAPH_HAS_BEEN_ADDED, graph_has_been_added) && | |||||
graph_has_been_added) { | |||||
GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST, "[GraphManager] same graph object can not be added again, graph_id = %u.", | |||||
graph_id); | |||||
return GE_GRAPH_GRAPH_ALREADY_EXIST; | |||||
} | |||||
} else { | |||||
GELOGE(FAILED, "compute graph is null"); | |||||
return FAILED; | |||||
} | |||||
std::vector<NodePtr> input_nodes; | |||||
std::vector<NodePtr> output_nodes; | |||||
auto new_compute_graph = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes); | |||||
std::string session_graph_id; | |||||
if (!AttrUtils::GetStr(*new_compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id) || | |||||
session_graph_id.empty()) { | |||||
session_graph_id = "-1_" + to_string(graph_id); | |||||
if (!AttrUtils::SetStr(*new_compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id)) { | |||||
GELOGW("Set attribute of compute graph failed."); | |||||
} | |||||
for (auto &subgraph : new_compute_graph->GetAllSubgraphs()) { | |||||
(void)AttrUtils::SetStr(*subgraph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id); | |||||
} | |||||
GELOGW("Get graph session_graph_id attr failed, set session id to default value: [0]"); | |||||
} | |||||
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id); | |||||
if (graph_node == nullptr) { | |||||
GELOGE(FAILED, "GraphNode make shared failed"); | |||||
return FAILED; | |||||
} | |||||
std::shared_ptr<Graph> graph_ptr = GraphUtils::CreateGraphPtrFromComputeGraph(new_compute_graph); | |||||
if (graph_ptr == nullptr) { | |||||
GELOGE(FAILED, "GraphPtr make shared failed"); | |||||
return FAILED; | |||||
} | |||||
graph_node->SetGraph(graph_ptr); | |||||
graph_node->SetOptions(options); | |||||
AddGraphNode(graph_id, graph_node); | |||||
AddLocalOmgContext(graph_id, omg_context); | |||||
if (!options_.output_datatype.empty()) { | |||||
GetLocalOmgContext().output_type = options_.output_datatype; | |||||
} | |||||
CompilerStages &stages = GetCompilerStages(graph_id); | |||||
stages.preparer.SetOptions(options_); | |||||
Status status = stages.optimizer.SetOptions(options_); | |||||
if (status != SUCCESS) { | |||||
GELOGE(status, "Graph optimizer set options failed."); | |||||
return status; | |||||
} | |||||
stages.builder.SetOptions(options_); | |||||
var_acc_ctrl_.AddGraph(graph_id, new_compute_graph); | |||||
GELOGI("[GraphManager] add graph success, graph_id = %u.", graph_id); | |||||
return SUCCESS; | |||||
} | |||||
Status GraphManager::MergeSubGraph(ComputeGraphPtr &compute_graph, const ge::ComputeGraphPtr &original_compute_graph, | Status GraphManager::MergeSubGraph(ComputeGraphPtr &compute_graph, const ge::ComputeGraphPtr &original_compute_graph, | ||||
GraphId root_graph_id) { | GraphId root_graph_id) { | ||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | ||||
@@ -625,6 +721,13 @@ Status GraphManager::PreRunAfterOptimizeSubGraph(const GraphNodePtr &graph_node, | |||||
GM_RUN_AND_DUMP_PERF("OptimizeGraphBeforeBuildForRts", | GM_RUN_AND_DUMP_PERF("OptimizeGraphBeforeBuildForRts", | ||||
GetCompilerStages(graph_node->GetGraphId()).optimizer.OptimizeGraphBeforeBuildForRts, | GetCompilerStages(graph_node->GetGraphId()).optimizer.OptimizeGraphBeforeBuildForRts, | ||||
compute_graph); | compute_graph); | ||||
Status ret = compute_graph->TopologicalSorting(); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Graph topological sort failed, ret:%d.", ret); | |||||
return ret; | |||||
} | |||||
GM_RUN_AND_DUMP_PERF("Build", Build, graph_node, compute_graph, ge_root_model, session_id); | GM_RUN_AND_DUMP_PERF("Build", Build, graph_node, compute_graph, ge_root_model, session_id); | ||||
GELOGI("PreRun:PreRunAfterOptimizeSubGraph success."); | GELOGI("PreRun:PreRunAfterOptimizeSubGraph success."); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -2170,6 +2273,18 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { | |||||
return ret; | return ret; | ||||
} | } | ||||
// Assign functional op labels. | |||||
GE_TIMESTAMP_START(AssignFunctionalLabels); | |||||
LabelAllocator label_allocator(compute_graph); | |||||
GE_CHK_STATUS_RET(label_allocator.AssignFunctionalLabels(), "Assign label failed."); | |||||
GE_TIMESTAMP_END(AssignFunctionalLabels, "ModelBuilder::AssignFunctionalLabels"); | |||||
// Add memcpy addr asynchronous node. | |||||
GE_TIMESTAMP_START(AddMemcpyAddrAsyncNode); | |||||
MemcpyAddrAsyncPass memcpy_addr; | |||||
GE_CHK_STATUS_RET(memcpy_addr.Run(compute_graph), "Add memcpy_addr_async node failed."); | |||||
GE_TIMESTAMP_END(AddMemcpyAddrAsyncNode, "MemcpyAddrAsyncPass::Run."); | |||||
// After while sub graph handle, mark all node rw type | // After while sub graph handle, mark all node rw type | ||||
auto result = GetCompilerStages(compute_graph->GetGraphID()).optimizer.HandleMemoryRWConflict(compute_graph); | auto result = GetCompilerStages(compute_graph->GetGraphID()).optimizer.HandleMemoryRWConflict(compute_graph); | ||||
if (result != SUCCESS) { | if (result != SUCCESS) { | ||||
@@ -2180,11 +2295,6 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { | |||||
ChangeConstTypeWhenTraining(compute_graph); | ChangeConstTypeWhenTraining(compute_graph); | ||||
ret = compute_graph->TopologicalSorting(); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Graph topological sort failed, ret:%d.", ret); | |||||
return ret; | |||||
} | |||||
GELOGI("End optimize after merge sub graph."); | GELOGI("End optimize after merge sub graph."); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -75,6 +75,16 @@ class GraphManager { | |||||
/// | /// | ||||
/// @ingroup ge_graph | /// @ingroup ge_graph | ||||
/// @brief add a copy graph | |||||
/// @param [in] graph_id graph id | |||||
/// @param [out] Graph output graph | |||||
/// @return Status result of function | |||||
/// | |||||
Status AddGraphWithCopy(const GraphId &graph_id, const Graph &graph, | |||||
const std::map<std::string, std::string> &options, const OmgContext &omg_context); | |||||
/// | |||||
/// @ingroup ge_graph | |||||
/// @brief remove specific graph | /// @brief remove specific graph | ||||
/// @param [in] graph_id graph id | /// @param [in] graph_id graph id | ||||
/// @return Status result of function | /// @return Status result of function | ||||
@@ -202,7 +202,7 @@ Status RdmaPoolAllocator::GetBaseAddr(uint64_t &base_addr, uint64_t &mem_size) { | |||||
GELOGE(INTERNAL_ERROR, "Rdma base addr is nullptr."); | GELOGE(INTERNAL_ERROR, "Rdma base addr is nullptr."); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
base_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(rdma_base_addr_)); | |||||
base_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(rdma_base_addr_)); | |||||
mem_size = rdma_mem_size_; | mem_size = rdma_mem_size_; | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -681,6 +681,11 @@ Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) { | |||||
} | } | ||||
// 2.loop all node, including node in subgraph and handle memory rw conflict | // 2.loop all node, including node in subgraph and handle memory rw conflict | ||||
for (auto &node : compute_graph->GetAllNodes()) { | for (auto &node : compute_graph->GetAllNodes()) { | ||||
// ignore while subgraph node | |||||
const auto parent_node = node->GetOwnerComputeGraph()->GetParentNode(); | |||||
if ((parent_node != nullptr) && (kWhileOpTypes.count(parent_node->GetType()) > 0)) { | |||||
continue; | |||||
} | |||||
// ignore data / netoutput of subgraph | // ignore data / netoutput of subgraph | ||||
if (node->GetType() == DATA && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) { | if (node->GetType() == DATA && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) { | ||||
continue; | continue; | ||||
@@ -534,6 +534,7 @@ Status ge::GraphPartitioner::Initialize(ge::ComputeGraphPtr compute_graph) { | |||||
} | } | ||||
const NodeEngineMap *node_engine_map = graph_info_.engine_placer_.GetNodeEngineMap(); | const NodeEngineMap *node_engine_map = graph_info_.engine_placer_.GetNodeEngineMap(); | ||||
size_t temp_index = 0; | size_t temp_index = 0; | ||||
// travese nodes by topo order one by one | |||||
for (const auto &node : compute_graph->GetDirectNode()) { | for (const auto &node : compute_graph->GetDirectNode()) { | ||||
std::string temp_stream; | std::string temp_stream; | ||||
// node opdesc has been checked before | // node opdesc has been checked before | ||||
@@ -558,9 +559,21 @@ Status ge::GraphPartitioner::Initialize(ge::ComputeGraphPtr compute_graph) { | |||||
} | } | ||||
new_cluster->nodes_.push_back(node); | new_cluster->nodes_.push_back(node); | ||||
if (!HasNoInput(node)) { | if (!HasNoInput(node)) { | ||||
auto node_id = node->GetOpDesc()->GetId(); | |||||
for (const auto &parent : node->GetInAllNodes()) { | for (const auto &parent : node->GetInAllNodes()) { | ||||
new_cluster->in_clu_.insert(graph_info_.node_2_cluster_.at(parent)->index_); | |||||
graph_info_.node_2_cluster_.at(parent)->out_clu_.insert(temp_index); | |||||
auto parent_id = parent->GetOpDesc()->GetId(); | |||||
if (parent_id < node_id) { | |||||
auto iter = graph_info_.node_2_cluster_.find(parent); | |||||
if (iter == graph_info_.node_2_cluster_.end()) { | |||||
GELOGE(FAILED, | |||||
"[GraphPartitioner]: node[%s]id[%ld]'s parent_node[%s]id[%ld]" | |||||
"should make cluster in advance", | |||||
node->GetOpDesc()->GetName().c_str(), node_id, parent->GetOpDesc()->GetName().c_str(), parent_id); | |||||
return FAILED; | |||||
} | |||||
new_cluster->in_clu_.insert(iter->second->index_); | |||||
iter->second->out_clu_.insert(temp_index); | |||||
} | |||||
} | } | ||||
} | } | ||||
graph_info_.node_2_cluster_[node] = new_cluster; | graph_info_.node_2_cluster_[node] = new_cluster; | ||||
@@ -588,7 +601,7 @@ Status ge::GraphPartitioner::AddPartitionsToGraphNode(vector<ge::SubGraphInfoPtr | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
auto &engine_name = graph_info_.partitions_.at(sub_graph); | auto &engine_name = graph_info_.partitions_.at(sub_graph); | ||||
GE_DUMP(sub_graph, sub_graph->GetName()); | |||||
GE_DUMP(sub_graph, sub_graph->GetName() + "_" + mode_2_str_[graph_info_.mode_]); | |||||
if (!session_graph_id.empty()) { | if (!session_graph_id.empty()) { | ||||
GE_IF_BOOL_EXEC(!AttrUtils::SetStr(sub_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id), | GE_IF_BOOL_EXEC(!AttrUtils::SetStr(sub_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id), | ||||
GELOGW("SetStr ATTR_NAME_SESSION_GRAPH_ID failed");) | GELOGW("SetStr ATTR_NAME_SESSION_GRAPH_ID failed");) | ||||
@@ -176,6 +176,8 @@ class GraphPartitioner { | |||||
Graph2InputNodesSubGraphInfo graph_2_input_subgraph_; | Graph2InputNodesSubGraphInfo graph_2_input_subgraph_; | ||||
GraphPartitionInfo graph_info_; | GraphPartitionInfo graph_info_; | ||||
uint32_t partition_times_; // times of call partition | uint32_t partition_times_; // times of call partition | ||||
std::map<Mode, std::string> mode_2_str_ = { | |||||
{kPartitioning, "Partitioning"}, {kSecondPartitioning, "SecondPartitioning"}, {kMerging, "Merging"}}; | |||||
friend class GraphManager; | friend class GraphManager; | ||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -22,7 +22,7 @@ namespace ge { | |||||
Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | ||||
for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
auto node_type = NodeUtils::GetNodeType(*node); | auto node_type = NodeUtils::GetNodeType(*node); | ||||
if (node_type == SWITCH || node_type == REFSWITCH || node_type == SWITCHN) { | |||||
if (node_type == SWITCH || node_type == SWITCHN) { | |||||
GELOGD("Mark format agnostic and continuous for switch node %s", node->GetName().c_str()); | GELOGD("Mark format agnostic and continuous for switch node %s", node->GetName().c_str()); | ||||
const OpDescPtr op_desc = node->GetOpDesc(); | const OpDescPtr op_desc = node->GetOpDesc(); | ||||
const GeTensorDescPtr op_tensor = op_desc->MutableInputDesc(0); | const GeTensorDescPtr op_tensor = op_desc->MutableInputDesc(0); | ||||
@@ -38,10 +38,15 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | |||||
if (node_type == IDENTITY) { | if (node_type == IDENTITY) { | ||||
GELOGD("Mark format agnostic for identity node %s", node->GetName().c_str()); | GELOGD("Mark format agnostic for identity node %s", node->GetName().c_str()); | ||||
AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | ||||
continue; | |||||
} | |||||
if (node_type == REFMERGE || node_type == REFSWITCH) { | |||||
GELOGD("Mark format agnostic for regmerge and refswitch node %s", node->GetName().c_str()); | |||||
AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | |||||
AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector<int64_t>({1})); | AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector<int64_t>({1})); | ||||
continue; | continue; | ||||
} | } | ||||
if (node_type == MERGE || node_type == REFMERGE) { | |||||
if (node_type == MERGE) { | |||||
GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str()); | GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str()); | ||||
const OpDescPtr op_desc = node->GetOpDesc(); | const OpDescPtr op_desc = node->GetOpDesc(); | ||||
const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0); | const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0); | ||||
@@ -25,6 +25,18 @@ | |||||
namespace ge { | namespace ge { | ||||
Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { | Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { | ||||
GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
if (graph->GetGraphUnknownFlag()) { | |||||
GELOGD("Graph[%s] is unknown graph, skip.", graph->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
int64_t value = 0; | |||||
rtError_t rt_ret = rtGetRtCapability(FEATURE_TYPE_MEMCPY, MEMCPY_INFO_SUPPORT_ZEROCOPY, &value); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "rtGetRtCapability failed, error=0x%x.", rt_ret); | |||||
return RT_FAILED; | |||||
} | |||||
for (auto &node : graph->GetAllNodes()) { | for (auto &node : graph->GetAllNodes()) { | ||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
GE_IF_BOOL_EXEC(op_desc == nullptr, continue); | GE_IF_BOOL_EXEC(op_desc == nullptr, continue); | ||||
@@ -193,9 +205,10 @@ NodePtr MemcpyAddrAsyncPass::CreateMemcpyAddrAsyncNode(const ComputeGraphPtr &gr | |||||
const OutDataAnchorPtr &out_data_anchor, | const OutDataAnchorPtr &out_data_anchor, | ||||
const NodePtr &out_of_user_data) { | const NodePtr &out_of_user_data) { | ||||
GELOGD("Start CreateMemcpyAddrAsyncNode."); | GELOGD("Start CreateMemcpyAddrAsyncNode."); | ||||
static uint32_t new_node_index = 0; | |||||
OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); | OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); | ||||
GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "Op_desc of pre node is invalid."); | GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "Op_desc of pre node is invalid."); | ||||
std::string node_name = pre_op_desc->GetName() + "_" + MEMCPYADDRASYNC; | |||||
std::string node_name = pre_op_desc->GetName() + "_" + MEMCPYADDRASYNC + "_" + std::to_string(new_node_index++); | |||||
OpDescPtr op_desc = MakeShared<OpDesc>(node_name, MEMCPYADDRASYNC); | OpDescPtr op_desc = MakeShared<OpDesc>(node_name, MEMCPYADDRASYNC); | ||||
GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
@@ -210,9 +223,18 @@ NodePtr MemcpyAddrAsyncPass::CreateMemcpyAddrAsyncNode(const ComputeGraphPtr &gr | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
int64_t stream_id = out_of_user_data->GetOpDesc()->GetStreamId(); | |||||
op_desc->SetStreamId(stream_id); | |||||
GELOGI("SetStreamId: Node %s assign stream is %ld.", op_desc->GetName().c_str(), stream_id); | |||||
string stream_label; | |||||
if (AttrUtils::GetStr(out_of_user_data->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) { | |||||
(void)AttrUtils::SetStr(op_desc, ATTR_NAME_STREAM_LABEL, stream_label); | |||||
GELOGD("Node %s set stream label: %s", op_desc->GetName().c_str(), stream_label.c_str()); | |||||
} | |||||
bool rts_label_node = false; | |||||
if (AttrUtils::GetBool(out_of_user_data->GetOpDesc(), ATTR_NAME_RTS_LABEL_NODE, rts_label_node)) { | |||||
(void)AttrUtils::SetBool(op_desc, ATTR_NAME_RTS_LABEL_NODE, rts_label_node); | |||||
GELOGD("Node %s set rts label node attribute", op_desc->GetName().c_str()); | |||||
} | |||||
bool labeled_input = false; | bool labeled_input = false; | ||||
(void)ge::AttrUtils::GetBool(out_of_user_data->GetOpDesc(), ATTR_NAME_NODE_CONNECT_INPUT, labeled_input); | (void)ge::AttrUtils::GetBool(out_of_user_data->GetOpDesc(), ATTR_NAME_NODE_CONNECT_INPUT, labeled_input); | ||||
if (labeled_input) { | if (labeled_input) { | ||||
@@ -79,6 +79,13 @@ Status MergePass::Run(NodePtr &node) { | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
} | } | ||||
auto in_node = in_data_nodes.at(0); | |||||
if (IsMergeInputNeedOptimized(in_node)) { | |||||
if (IsolateAndDeleteNode(in_node, {0}) != SUCCESS) { | |||||
GELOGE(FAILED, "Isolate and delete node %s failed.", in_node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
return IsolateAndDeleteNode(node, merge_io_map); | return IsolateAndDeleteNode(node, merge_io_map); | ||||
} | } | ||||
default: { | default: { | ||||
@@ -173,4 +180,27 @@ Status MergePass::CreateConstByValue(NodePtr &node, int value_index, OpDescPtr & | |||||
GE_CHK_STATUS_RET(op_desc->AddOutputDesc(original_out_tensor_desc), "add out put desc failed"); | GE_CHK_STATUS_RET(op_desc->AddOutputDesc(original_out_tensor_desc), "add out put desc failed"); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
bool MergePass::IsMergeInputNeedOptimized(NodePtr &node) const { | |||||
if (node == nullptr) { | |||||
return false; | |||||
} | |||||
// node is not inserted by MergeInputMemcpyPass | |||||
if ((node->GetType() != MEMCPYASYNC) && (node->GetType() != MEMCPYADDRASYNC)) { | |||||
return false; | |||||
} | |||||
if (node->GetInDataNodes().size() != 1) { | |||||
return false; | |||||
} | |||||
auto in_node = node->GetInDataNodes().at(0); | |||||
if (in_node == nullptr) { | |||||
return false; | |||||
} | |||||
// in_node may be global_step var | |||||
if ((in_node->GetType() == VARIABLE) || (in_node->GetType() == VARIABLEV2)) { | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -28,6 +28,7 @@ class MergePass : public BaseNodePass { | |||||
bool IsNeedChangeIndexToConstant(NodePtr &node) const; | bool IsNeedChangeIndexToConstant(NodePtr &node) const; | ||||
Status ChangeIndexToConstant(NodePtr &node, int &value_index); | Status ChangeIndexToConstant(NodePtr &node, int &value_index); | ||||
Status CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &op_desc); | Status CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &op_desc); | ||||
bool IsMergeInputNeedOptimized(NodePtr &node) const; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GRAPH_PASSES_MERGE_PASS_H_ | #endif // GE_GRAPH_PASSES_MERGE_PASS_H_ |
@@ -103,6 +103,12 @@ Status NetOutputPass::GetOutputNode(const ge::ComputeGraphPtr &graph, std::vecto | |||||
GELOGI("user set out node [%s] is found in user def targets, out node is prio!", ele.first->GetName().c_str()); | GELOGI("user set out node [%s] is found in user def targets, out node is prio!", ele.first->GetName().c_str()); | ||||
targets_.erase(iter); | targets_.erase(iter); | ||||
} | } | ||||
auto op_desc = ele.first->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if (op_desc->HasAttr(ATTR_ATC_USER_DEFINE_OUTPUT_NODES)) { | |||||
is_user_define_ouput_nodes = true; | |||||
} | |||||
output_nodes_info.push_back({ele.first, ele.second, -1}); | output_nodes_info.push_back({ele.first, ele.second, -1}); | ||||
} | } | ||||
GELOGI("Output node set by user or leaf node, size:%zu.", output_nodes_info.size()); | GELOGI("Output node set by user or leaf node, size:%zu.", output_nodes_info.size()); | ||||
@@ -414,7 +420,7 @@ Status NetOutputPass::ProcessWithNetoutput(const ge::ComputeGraphPtr &graph, con | |||||
Status NetOutputPass::AddCtrlEdgesBetweenLeafAndNetOutput(const ge::ComputeGraphPtr &graph, | Status NetOutputPass::AddCtrlEdgesBetweenLeafAndNetOutput(const ge::ComputeGraphPtr &graph, | ||||
const ge::NodePtr &net_out_node) { | const ge::NodePtr &net_out_node) { | ||||
GE_CHECK_NOTNULL(net_out_node); | GE_CHECK_NOTNULL(net_out_node); | ||||
if (!GetLocalOmgContext().user_out_nodes.empty()) { | |||||
if (!GetLocalOmgContext().user_out_nodes.empty() || is_user_define_ouput_nodes) { | |||||
GELOGI("No need to add ctrl edge to netoutput because user out nodes have been set."); | GELOGI("No need to add ctrl edge to netoutput because user out nodes have been set."); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -220,6 +220,7 @@ class NetOutputPass : public GraphPass { | |||||
bool is_include_special_node_ = false; | bool is_include_special_node_ = false; | ||||
std::set<NodePtr> targets_; | std::set<NodePtr> targets_; | ||||
friend class ReUpdateNetOutputPass; | friend class ReUpdateNetOutputPass; | ||||
bool is_user_define_ouput_nodes = false; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GRAPH_PASSES_NET_OUTPUT_PASS_H_ | #endif // GE_GRAPH_PASSES_NET_OUTPUT_PASS_H_ |
@@ -173,14 +173,17 @@ Status NextIterationPass::FindWhileGroups() { | |||||
NodePtr next_node = nullptr; | NodePtr next_node = nullptr; | ||||
if (FindTargetNode(out_node, NEXTITERATION, true, batch_label, next_node) != SUCCESS) { | if (FindTargetNode(out_node, NEXTITERATION, true, batch_label, next_node) != SUCCESS) { | ||||
GELOGE(INTERNAL_ERROR, "Get NextIteration node failed."); | |||||
GELOGE(INTERNAL_ERROR, | |||||
"Get NextIteration node failed: inputs of Merge should be Enter/NextIteration, current_Merge=%s", | |||||
out_node->GetName().c_str()); | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
batch_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node)); | batch_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node)); | ||||
NodePtr switch_node = nullptr; | NodePtr switch_node = nullptr; | ||||
if (FindTargetNode(out_node, SWITCH, false, batch_label, switch_node) != SUCCESS) { | if (FindTargetNode(out_node, SWITCH, false, batch_label, switch_node) != SUCCESS) { | ||||
GELOGE(INTERNAL_ERROR, "Get Switch node failed."); | |||||
GELOGE(INTERNAL_ERROR, "Get Switch node failed: output of Merge should be Switch, current_Merge=%s", | |||||
out_node->GetName().c_str()); | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
if (switch_node == nullptr) { | if (switch_node == nullptr) { | ||||
@@ -189,7 +192,9 @@ Status NextIterationPass::FindWhileGroups() { | |||||
NodePtr loop_cond = nullptr; | NodePtr loop_cond = nullptr; | ||||
if (FindTargetNode(switch_node, LOOPCOND, true, batch_label, loop_cond) != SUCCESS) { | if (FindTargetNode(switch_node, LOOPCOND, true, batch_label, loop_cond) != SUCCESS) { | ||||
GELOGE(INTERNAL_ERROR, "Get LoopCond node failed."); | |||||
GELOGE(INTERNAL_ERROR, | |||||
"Get LoopCond node failed: pred input of Switch should be LoopCond, current_Switch=%s", | |||||
switch_node->GetName().c_str()); | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
if (batch_iter.second->loop_cond == nullptr) { | if (batch_iter.second->loop_cond == nullptr) { | ||||
@@ -217,6 +217,9 @@ NodePtr CreateTransNode(const std::string &name, const std::string &node_type, c | |||||
auto index = TransOpUtil::GetTransOpDataIndex(node_type); | auto index = TransOpUtil::GetTransOpDataIndex(node_type); | ||||
if (index < 0) { | if (index < 0) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage( | |||||
"E19025", {"situation", "reason"}, | |||||
{"The trans node type[" + node_type + "]", "it must be " + TransOpUtil::TransopMapToString()}); | |||||
GELOGE(INTERNAL_ERROR, "The trans node type %s does not exists", node_type.c_str()); | GELOGE(INTERNAL_ERROR, "The trans node type %s does not exists", node_type.c_str()); | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
@@ -385,6 +388,8 @@ Status RecoverTransRoadForVar(const NodePtr &var, const VarTransRoad &road) { | |||||
auto trans_name = var->GetName() + "_trans_" + std::to_string(index++); | auto trans_name = var->GetName() + "_trans_" + std::to_string(index++); | ||||
auto ret = RecoverOneTransNodeForVar(trans_name, *iter, last_node, last_node); | auto ret = RecoverOneTransNodeForVar(trans_name, *iter, last_node, last_node); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E15001", {"variable", "index", "type"}, | |||||
{var->GetName(), std::to_string(index), iter->node_type}); | |||||
GELOGE(INTERNAL_ERROR, "Failed to recover trans node for variable %s, index %d, type %s", var->GetName().c_str(), | GELOGE(INTERNAL_ERROR, "Failed to recover trans node for variable %s, index %d, type %s", var->GetName().c_str(), | ||||
index, iter->node_type.c_str()); | index, iter->node_type.c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
@@ -417,6 +422,8 @@ Status RecoverTransRoadForVarRef(const std::set<NodePtr> &nodes, const VarTransR | |||||
auto trans_name = var->GetName() + "_trans_" + std::to_string(index++); | auto trans_name = var->GetName() + "_trans_" + std::to_string(index++); | ||||
auto ret = RecoverOneTransNodeForVarRef(trans_name, *iter, last_node, last_node); | auto ret = RecoverOneTransNodeForVarRef(trans_name, *iter, last_node, last_node); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E15001", {"variable", "index", "type"}, | |||||
{var->GetName(), std::to_string(index), iter->node_type}); | |||||
GELOGE(INTERNAL_ERROR, "Failed to recover trans node for variable %s, index %d, type %s", | GELOGE(INTERNAL_ERROR, "Failed to recover trans node for variable %s, index %d, type %s", | ||||
var->GetName().c_str(), index, iter->node_type.c_str()); | var->GetName().c_str(), index, iter->node_type.c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
@@ -570,6 +577,8 @@ Status CheckIfDynamicBatchScene(NodePtr &data_node, bool &is_dynamic_batch, Node | |||||
std::string related_node_name; | std::string related_node_name; | ||||
if (AttrUtils::GetStr(data_node->GetOpDesc(), kMbatchSwitchnName, related_node_name)) { | if (AttrUtils::GetStr(data_node->GetOpDesc(), kMbatchSwitchnName, related_node_name)) { | ||||
if (related_node_name.empty()) { | if (related_node_name.empty()) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E15002", {"opname", "value", "reason"}, | |||||
{data_node->GetName(), "flag", "but the value is empty"}); | |||||
GELOGE(INTERNAL_ERROR, "The data node %s has switchn node flag, but the value is empty", | GELOGE(INTERNAL_ERROR, "The data node %s has switchn node flag, but the value is empty", | ||||
data_node->GetName().c_str()); | data_node->GetName().c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
@@ -581,6 +590,9 @@ Status CheckIfDynamicBatchScene(NodePtr &data_node, bool &is_dynamic_batch, Node | |||||
} | } | ||||
} | } | ||||
if (switchn_node == nullptr) { | if (switchn_node == nullptr) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage( | |||||
"E15002", {"opname", "value", "reason"}, | |||||
{data_node->GetName(), related_node_name, "but can not find it on the graph"}); | |||||
GELOGE(INTERNAL_ERROR, "The data node %s has switchn node %s, but can not find it on the graph", | GELOGE(INTERNAL_ERROR, "The data node %s has switchn node %s, but can not find it on the graph", | ||||
data_node->GetName().c_str(), related_node_name.c_str()); | data_node->GetName().c_str(), related_node_name.c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
@@ -681,6 +693,10 @@ Status ProcessInputNC1HWC0DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, No | |||||
ge::GeShape old_shape = input->GetShape(); | ge::GeShape old_shape = input->GetShape(); | ||||
bool support = ((old_format == FORMAT_NC1HWC0) || (old_format == FORMAT_NCHW) || (old_format == FORMAT_NHWC)); | bool support = ((old_format == FORMAT_NC1HWC0) || (old_format == FORMAT_NCHW) || (old_format == FORMAT_NHWC)); | ||||
if (!support) { | if (!support) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage( | |||||
"E19014", {"opname", "value", "reason"}, | |||||
{op_desc->GetName(), "format[" + TypeUtils::FormatToSerialString(old_format) + "]", | |||||
"only support FORMAT_NC1HWC0,FORMAT_NCHW,FORMAT_NHWC"}); | |||||
GELOGE(INTERNAL_ERROR, "The format [%s] is unsupported", TypeUtils::FormatToSerialString(old_format).c_str()); | GELOGE(INTERNAL_ERROR, "The format [%s] is unsupported", TypeUtils::FormatToSerialString(old_format).c_str()); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -761,6 +777,8 @@ Status GetStorageFormatAndShape(OpDescPtr &op_desc, const GeTensorDescPtr &tenso | |||||
op_desc->GetName().c_str(), TypeUtils::FormatToSerialString(storage_format).c_str(), | op_desc->GetName().c_str(), TypeUtils::FormatToSerialString(storage_format).c_str(), | ||||
formats::JoinToString(storage_shape).c_str()); | formats::JoinToString(storage_shape).c_str()); | ||||
} else { | } else { | ||||
ErrorManager::GetInstance().ATCReportErrMessage( | |||||
"15003", {"opname", "format"}, {op_desc->GetName(), TypeUtils::FormatToSerialString(storage_format)}); | |||||
GELOGE(PARAM_INVALID, | GELOGE(PARAM_INVALID, | ||||
"Update node by storage format failed, storage_shape not set. " | "Update node by storage format failed, storage_shape not set. " | ||||
"node: [%s], storage_format [%s]", | "node: [%s], storage_format [%s]", | ||||
@@ -900,9 +918,14 @@ Status ProcessNetoutputNodeDynShape(NodePtr &node) { | |||||
// check if is_output_adjust_hw_layout is set | // check if is_output_adjust_hw_layout is set | ||||
if (NeedUpdateFormatByOutputTypeParm(op_desc, index)) { | if (NeedUpdateFormatByOutputTypeParm(op_desc, index)) { | ||||
if ((old_format != FORMAT_NCHW) && (old_format != FORMAT_NHWC) && (old_format != FORMAT_NC1HWC0)) { | if ((old_format != FORMAT_NCHW) && (old_format != FORMAT_NHWC) && (old_format != FORMAT_NC1HWC0)) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage( | |||||
"E19014", {"opname", "value", "reason"}, | |||||
{op_desc->GetName(), "format[" + TypeUtils::FormatToSerialString(old_format) + "]", | |||||
"only support FORMAT_NC1HWC0,FORMAT_NCHW,FORMAT_NHWC"}); | |||||
GELOGE(INTERNAL_ERROR, "Format is not one of NCHW, NHWC, NC1HWC0."); | GELOGE(INTERNAL_ERROR, "Format is not one of NCHW, NHWC, NC1HWC0."); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
GeTensorDesc old_desc(old_shape, old_format, old_dtype); | GeTensorDesc old_desc(old_shape, old_format, old_dtype); | ||||
if (ProcessNetoutputNodeFp16Nc1hwc0DynShape(old_desc, net_output_input_desc, src_node) != SUCCESS) { | if (ProcessNetoutputNodeFp16Nc1hwc0DynShape(old_desc, net_output_input_desc, src_node) != SUCCESS) { | ||||
GELOGE(INTERNAL_ERROR, "Process netoutput fp16 nc1hwc0."); | GELOGE(INTERNAL_ERROR, "Process netoutput fp16 nc1hwc0."); | ||||
@@ -1035,6 +1058,9 @@ Status GraphPrepare::CheckRefInputNode(const NodePtr &node, const std::string &i | |||||
} | } | ||||
bool is_acceptable = (acceptable_types.find(input_type) != acceptable_types.end()); | bool is_acceptable = (acceptable_types.find(input_type) != acceptable_types.end()); | ||||
if (!is_acceptable) { | if (!is_acceptable) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage( | |||||
"E15005", {"opname", "optype", "opname1", "optype1"}, | |||||
{op_desc->GetName(), node->GetType(), input_op_desc->GetName(), input_op_desc->GetType()}); | |||||
GELOGE(PARAM_INVALID, "The ref input of ref node %s[%s] must be ref node or variable, but %s[%s]isn't.", | GELOGE(PARAM_INVALID, "The ref input of ref node %s[%s] must be ref node or variable, but %s[%s]isn't.", | ||||
node->GetName().c_str(), node->GetType().c_str(), input_op_desc->GetName().c_str(), | node->GetName().c_str(), node->GetType().c_str(), input_op_desc->GetName().c_str(), | ||||
input_op_desc->GetType().c_str()); | input_op_desc->GetType().c_str()); | ||||
@@ -1127,6 +1153,9 @@ Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input) { | |||||
} | } | ||||
if ((index < 0) || (static_cast<size_t>(index) >= user_input.size())) { | if ((index < 0) || (static_cast<size_t>(index) >= user_input.size())) { | ||||
std::string situation = "data op index[" + std::to_string(index) + "]"; | |||||
std::string reason = "it must less than user_input size[" + std::to_string(user_input.size()) + "]"; | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19025", {"situation", "reason"}, {situation, reason}); | |||||
GELOGE(PARAM_INVALID, "user_input size = %zu, graph data op index = %ld.", user_input.size(), index); | GELOGE(PARAM_INVALID, "user_input size = %zu, graph data op index = %ld.", user_input.size(), index); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -1139,6 +1168,11 @@ Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input) { | |||||
if (need_check_internal_format) { | if (need_check_internal_format) { | ||||
bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format); | bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format); | ||||
if (is_internal) { | if (is_internal) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage( | |||||
"E19025", {"situation", "reason"}, | |||||
{"Input format[" + TypeUtils::FormatToSerialString(format) + "] or origin_format[" + | |||||
TypeUtils::FormatToSerialString(origin_format) + "]", | |||||
"it is not support"}); | |||||
GELOGE(PARAM_INVALID, "Input format %s or origin_format %s is not support.", | GELOGE(PARAM_INVALID, "Input format %s or origin_format %s is not support.", | ||||
TypeUtils::FormatToSerialString(format).c_str(), | TypeUtils::FormatToSerialString(format).c_str(), | ||||
TypeUtils::FormatToSerialString(origin_format).c_str()); | TypeUtils::FormatToSerialString(origin_format).c_str()); | ||||
@@ -1150,6 +1184,9 @@ Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input) { | |||||
uint32_t length = 1; | uint32_t length = 1; | ||||
bool type_ret = TypeUtils::GetDataTypeLength(data_type, length); | bool type_ret = TypeUtils::GetDataTypeLength(data_type, length); | ||||
if (!type_ret) { | if (!type_ret) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage( | |||||
"E19025", {"situation", "reason"}, | |||||
{"Input datatype[" + TypeUtils::DataTypeToSerialString(data_type) + "]", "it is not support"}); | |||||
GELOGE(PARAM_INVALID, "Input datatype %s is not support.", | GELOGE(PARAM_INVALID, "Input datatype %s is not support.", | ||||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
return FAILED; | return FAILED; | ||||
@@ -1164,6 +1201,10 @@ Status GraphPrepare::UpdateInput(const std::vector<GeTensor> &user_input) { | |||||
return FAILED); | return FAILED); | ||||
bool size_check = (size != 0 && shape_size != size); | bool size_check = (size != 0 && shape_size != size); | ||||
if (size_check) { | if (size_check) { | ||||
std::string situation = | |||||
"input data size[" + std::to_string(size) + "] and shape_size[" + std::to_string(size) + "]"; | |||||
std::string reason = "because size != 0 and shape_size != size"; | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19025", {"situation", "reason"}, {situation, reason}); | |||||
GELOGE(PARAM_INVALID, "input data size =%ld, shape_size =%ld.", size, shape_size); | GELOGE(PARAM_INVALID, "input data size =%ld, shape_size =%ld.", size, shape_size); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -1503,6 +1544,9 @@ Status GraphPrepare::VerifyConstOp(const NodePtr &node) { | |||||
uint32_t length = 1; | uint32_t length = 1; | ||||
bool type_ret = TypeUtils::GetDataTypeLength(data_type, length); | bool type_ret = TypeUtils::GetDataTypeLength(data_type, length); | ||||
if (!type_ret) { | if (!type_ret) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage( | |||||
"E19025", {"situation", "reason"}, | |||||
{"Input datatype[" + TypeUtils::DataTypeToSerialString(data_type) + "]", "it is not support"}); | |||||
GELOGE(PARAM_INVALID, "Input datatype %s is not support.", TypeUtils::DataTypeToSerialString(data_type).c_str()); | GELOGE(PARAM_INVALID, "Input datatype %s is not support.", TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -1512,14 +1556,20 @@ Status GraphPrepare::VerifyConstOp(const NodePtr &node) { | |||||
if (shape_size == 0) { | if (shape_size == 0) { | ||||
if (ge_tensor_desc.GetShape().GetDims().size() == 0) { | if (ge_tensor_desc.GetShape().GetDims().size() == 0) { | ||||
// shape = [], means it's a sclar tensor. | // shape = [], means it's a sclar tensor. | ||||
GE_CHK_BOOL_EXEC(data_size / length == 1, return PARAM_INVALID, "Const is invalid scalar tensor."); | |||||
GE_CHK_BOOL_EXEC(data_size / length == 1, ErrorManager::GetInstance().ATCReportErrMessage( | |||||
"E10043", {"reason"}, {"Const is invalid scalar tensor."}); | |||||
return PARAM_INVALID, "Const is invalid scalar tensor."); | |||||
} else { | } else { | ||||
// shape = [x, y, 0,...], means it's a vector tensor that value is []. | // shape = [x, y, 0,...], means it's a vector tensor that value is []. | ||||
GE_CHK_BOOL_EXEC(data_size == 0, return PARAM_INVALID, "Const is invalid vector scalar."); | |||||
GE_CHK_BOOL_EXEC(data_size == 0, ErrorManager::GetInstance().ATCReportErrMessage( | |||||
"E10043", {"reason"}, {"Const is invalid vector scalar."}); | |||||
return PARAM_INVALID, "Const is invalid vector scalar."); | |||||
} | } | ||||
} else { | } else { | ||||
GE_CHK_BOOL_EXEC(data_size == static_cast<size_t>(shape_size * length) && data_size != 0, return PARAM_INVALID, | |||||
"Const input data size is not equal with tensor desc shape"); | |||||
GE_CHK_BOOL_EXEC(data_size == static_cast<size_t>(shape_size * length) && data_size != 0, | |||||
ErrorManager::GetInstance().ATCReportErrMessage( | |||||
"E10043", {"reason"}, {"Const input data size is not equal with tensor desc shape"}); | |||||
return PARAM_INVALID, "Const input data size is not equal with tensor desc shape"); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -1543,6 +1593,9 @@ Status GraphPrepare::CheckUserInput(const std::vector<GeTensor> &user_input) { | |||||
return GE_GRAPH_INIT_FAILED; | return GE_GRAPH_INIT_FAILED; | ||||
} | } | ||||
if ((index < 0) || (static_cast<size_t>(index) >= user_input.size())) { | if ((index < 0) || (static_cast<size_t>(index) >= user_input.size())) { | ||||
std::string situation = "data op index[" + std::to_string(index) + "]"; | |||||
std::string reason = "it must less than user_input size[" + std::to_string(user_input.size()) + "]"; | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19025", {"situation", "reason"}, {situation, reason}); | |||||
GELOGE(GE_GRAPH_INIT_FAILED, "user_input size:%zu, data op index:%ld.", user_input.size(), index); | GELOGE(GE_GRAPH_INIT_FAILED, "user_input size:%zu, data op index:%ld.", user_input.size(), index); | ||||
return GE_GRAPH_INIT_FAILED; | return GE_GRAPH_INIT_FAILED; | ||||
} | } | ||||
@@ -1550,6 +1603,10 @@ Status GraphPrepare::CheckUserInput(const std::vector<GeTensor> &user_input) { | |||||
for (size_t i = 0; i < desc.GetShape().GetDimNum(); ++i) { | for (size_t i = 0; i < desc.GetShape().GetDimNum(); ++i) { | ||||
if (desc.GetShape().GetDim(i) < 0) { | if (desc.GetShape().GetDim(i) < 0) { | ||||
std::string situation = | |||||
"data dim[" + std::to_string(i) + "][" + std::to_string(desc.GetShape().GetDim(i)) + "]"; | |||||
std::string reason = "it need >= 0"; | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19025", {"situation", "reason"}, {situation, reason}); | |||||
GELOGE(GE_GRAPH_INIT_FAILED, "data dim %zu is not supported, need >= 0, real:%ld.", i, | GELOGE(GE_GRAPH_INIT_FAILED, "data dim %zu is not supported, need >= 0, real:%ld.", i, | ||||
desc.GetShape().GetDim(i)); | desc.GetShape().GetDim(i)); | ||||
return GE_GRAPH_INIT_FAILED; | return GE_GRAPH_INIT_FAILED; | ||||
@@ -53,16 +53,6 @@ | |||||
} \ | } \ | ||||
} while (0) | } while (0) | ||||
#define AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(expr, _status, errormsg) \ | |||||
do { \ | |||||
bool b = (expr); \ | |||||
if (!b) { \ | |||||
GELOGE(_status, errormsg); \ | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {errormsg}); \ | |||||
return _status; \ | |||||
} \ | |||||
} while (0) | |||||
namespace { | namespace { | ||||
const int32_t DEFAULT_MATRIX_R0C0_YUV2RGB = 298; | const int32_t DEFAULT_MATRIX_R0C0_YUV2RGB = 298; | ||||
const int32_t DEFAULT_MATRIX_R0C1_YUV2RGB = 0; | const int32_t DEFAULT_MATRIX_R0C1_YUV2RGB = 0; | ||||
@@ -317,9 +307,8 @@ NodePtr AippOp::FindDataByIndex(const ComputeGraphPtr &graph, int rank) { | |||||
} | } | ||||
return node; | return node; | ||||
} | } | ||||
GELOGE(PARAM_INVALID, "Can not find the data node by index %d", rank); | |||||
string errormsg = "Can not find the data node by aipp parameter related_input_rank " + to_string(rank); | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {errormsg}); | |||||
string error_msg = "Can not find the data node by aipp parameter related_input_rank " + to_string(rank); | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error_msg.c_str()); | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
Status AippOp::GetAndCheckTarget(const ComputeGraphPtr &graph, int rank, NodePtr &target, | Status AippOp::GetAndCheckTarget(const ComputeGraphPtr &graph, int rank, NodePtr &target, | ||||
@@ -364,10 +353,10 @@ Status AippOp::GetAndCheckTarget(const ComputeGraphPtr &graph, int rank, NodePtr | |||||
} | } | ||||
if (!edge_indexes.empty() && (*edge_indexes.rbegin() >= data_node->GetOutDataNodes().size())) { | if (!edge_indexes.empty() && (*edge_indexes.rbegin() >= data_node->GetOutDataNodes().size())) { | ||||
GELOGE(PARAM_INVALID, "input_edge_idx %u should smaller than out edge size of target input %zu", | |||||
*edge_indexes.rbegin(), data_node->GetOutDataNodes().size()); | |||||
string errormsg = "The aipp parameter input_edge_idx should be smaller than the target input's outnodes."; | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {errormsg}); | |||||
string error_msg = "The aipp parameter input_edge_idx[" + std::to_string(*edge_indexes.rbegin()) + | |||||
"] should be smaller than the target input[" + | |||||
std::to_string(data_node->GetOutDataNodes().size()) + "]'s outnodes."; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error_msg.c_str()); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
target = data_node; | target = data_node; | ||||
@@ -442,8 +431,7 @@ Status AippOp::ConvertRelatedInputNameToRank() { | |||||
string error_msg = "Top name " + related_input_name + | string error_msg = "Top name " + related_input_name + | ||||
"convert rank failed, Please" | "convert rank failed, Please" | ||||
" ensure top name in aipp config is the top name of data node."; | " ensure top name in aipp config is the top name of data node."; | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {error_msg}); | |||||
GELOGE(PARAM_INVALID, "Top name[%s] converts rank failed.", related_input_name.c_str()); | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error_msg.c_str()); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -539,87 +527,87 @@ Status AippOp::SetDefaultParams() { | |||||
Status AippOp::ValidateParams() { | Status AippOp::ValidateParams() { | ||||
GE_CHECK_NOTNULL(aipp_params_); | GE_CHECK_NOTNULL(aipp_params_); | ||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->aipp_mode() != domi::AippOpParams::undefined, PARAM_INVALID, | |||||
"When insert AIPP op, aipp_mode must be configured as static or dynamic "); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->var_reci_chn_0_size() <= 1, PARAM_INVALID, | |||||
"The parameter var_reci_chn_0 can not be configed repeatedly"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->var_reci_chn_1_size() <= 1, PARAM_INVALID, | |||||
"The parameter var_reci_chn_1 can not be configed repeatedly"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->var_reci_chn_2_size() <= 1, PARAM_INVALID, | |||||
"The parameter var_reci_chn_2 can not be configed repeatedly"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->var_reci_chn_3_size() <= 1, PARAM_INVALID, | |||||
"The parameter var_reci_chn_3 can not be configed repeatedly"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->matrix_r0c0_size() <= 1, PARAM_INVALID, | |||||
"The parameter matrix_r0c0 can not be configed repeatedly"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->matrix_r0c1_size() <= 1, PARAM_INVALID, | |||||
"The parameter matrix_r0c1 can not be configed repeatedly"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->matrix_r0c2_size() <= 1, PARAM_INVALID, | |||||
"The parameter matrix_r0c2 can not be configed repeatedly"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->matrix_r1c0_size() <= 1, PARAM_INVALID, | |||||
"The parameter matrix_r1c0 can not be configed repeatedly"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->matrix_r1c1_size() <= 1, PARAM_INVALID, | |||||
"The parameter matrix_r1c1 can not be configed repeatedly"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->matrix_r1c2_size() <= 1, PARAM_INVALID, | |||||
"The parameter matrix_r1c2 can not be configed repeatedly"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->matrix_r2c0_size() <= 1, PARAM_INVALID, | |||||
"The parameter matrix_r2c0 can not be configed repeatedly"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->matrix_r2c1_size() <= 1, PARAM_INVALID, | |||||
"The parameter matrix_r2c1 can not be configed repeatedly"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->matrix_r2c2_size() <= 1, PARAM_INVALID, | |||||
"The parameter matrix_r2c2 can not be configed repeatedly"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->output_bias_0_size() <= 1, PARAM_INVALID, | |||||
"The parameter output_bias_0 can not be configed repeatedly"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->output_bias_1_size() <= 1, PARAM_INVALID, | |||||
"The parameter output_bias_1 can not be configed repeatedly"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->output_bias_2_size() <= 1, PARAM_INVALID, | |||||
"The parameter output_bias_2 can not be configed repeatedly"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->input_bias_0_size() <= 1, PARAM_INVALID, | |||||
"The parameter input_bias_0 can not be configed repeatedly"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->input_bias_1_size() <= 1, PARAM_INVALID, | |||||
"The parameter input_bias_1 can not be configed repeatedly"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->input_bias_2_size() <= 1, PARAM_INVALID, | |||||
"The parameter input_bias_2 can not be configed repeatedly"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->input_edge_idx_size() <= 1, PARAM_INVALID, | |||||
"The parameter input_edge_idx can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->aipp_mode() != domi::AippOpParams::undefined, PARAM_INVALID, | |||||
"When insert AIPP op, aipp_mode must be configured as static or dynamic "); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->var_reci_chn_0_size() <= 1, PARAM_INVALID, | |||||
"The parameter var_reci_chn_0 can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->var_reci_chn_1_size() <= 1, PARAM_INVALID, | |||||
"The parameter var_reci_chn_1 can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->var_reci_chn_2_size() <= 1, PARAM_INVALID, | |||||
"The parameter var_reci_chn_2 can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->var_reci_chn_3_size() <= 1, PARAM_INVALID, | |||||
"The parameter var_reci_chn_3 can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->matrix_r0c0_size() <= 1, PARAM_INVALID, | |||||
"The parameter matrix_r0c0 can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->matrix_r0c1_size() <= 1, PARAM_INVALID, | |||||
"The parameter matrix_r0c1 can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->matrix_r0c2_size() <= 1, PARAM_INVALID, | |||||
"The parameter matrix_r0c2 can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->matrix_r1c0_size() <= 1, PARAM_INVALID, | |||||
"The parameter matrix_r1c0 can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->matrix_r1c1_size() <= 1, PARAM_INVALID, | |||||
"The parameter matrix_r1c1 can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->matrix_r1c2_size() <= 1, PARAM_INVALID, | |||||
"The parameter matrix_r1c2 can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->matrix_r2c0_size() <= 1, PARAM_INVALID, | |||||
"The parameter matrix_r2c0 can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->matrix_r2c1_size() <= 1, PARAM_INVALID, | |||||
"The parameter matrix_r2c1 can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->matrix_r2c2_size() <= 1, PARAM_INVALID, | |||||
"The parameter matrix_r2c2 can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->output_bias_0_size() <= 1, PARAM_INVALID, | |||||
"The parameter output_bias_0 can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->output_bias_1_size() <= 1, PARAM_INVALID, | |||||
"The parameter output_bias_1 can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->output_bias_2_size() <= 1, PARAM_INVALID, | |||||
"The parameter output_bias_2 can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->input_bias_0_size() <= 1, PARAM_INVALID, | |||||
"The parameter input_bias_0 can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->input_bias_1_size() <= 1, PARAM_INVALID, | |||||
"The parameter input_bias_1 can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->input_bias_2_size() <= 1, PARAM_INVALID, | |||||
"The parameter input_bias_2 can not be configed repeatedly"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->input_edge_idx_size() <= 1, PARAM_INVALID, | |||||
"The parameter input_edge_idx can not be configed repeatedly"); | |||||
const domi::AippOpParams::AippMode aipp_mode = aipp_params_->aipp_mode(); | const domi::AippOpParams::AippMode aipp_mode = aipp_params_->aipp_mode(); | ||||
if (aipp_mode == domi::AippOpParams::dynamic) { | if (aipp_mode == domi::AippOpParams::dynamic) { | ||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG( | |||||
GE_CHK_LOG_AND_ERRORMSG( | |||||
aipp_params_->max_src_image_size() > 0, PARAM_INVALID, | aipp_params_->max_src_image_size() > 0, PARAM_INVALID, | ||||
"For dynamic AIPP params, max_src_image_size must be set which number should be greater than 0"); | "For dynamic AIPP params, max_src_image_size must be set which number should be greater than 0"); | ||||
} else { | } else { | ||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->input_format() != domi::AippOpParams::UNDEFINED, PARAM_INVALID, | |||||
"Input format of AIPP conf is undefined"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->src_image_size_w() >= 0, PARAM_INVALID, | |||||
"Src_image_size_w must not be configed smaller than 0"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->src_image_size_h() >= 0, PARAM_INVALID, | |||||
"Src_image_size_h must not be configed smaller than 0"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->load_start_pos_w() >= 0, PARAM_INVALID, | |||||
"Load_start_pos_w must not be configed smaller than 0"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->load_start_pos_h() >= 0, PARAM_INVALID, | |||||
"Load_start_pos_h must not be configed smaller than 0"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->crop_size_w() >= 0, PARAM_INVALID, | |||||
"Crop_size_w must not be configed smaller than 0"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->resize_output_w() >= 0, PARAM_INVALID, | |||||
"Resize_output_w must not be configed smaller than 0"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->resize_output_h() >= 0, PARAM_INVALID, | |||||
"Resize_output_h must not be configed smaller than 0"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->left_padding_size() >= 0, PARAM_INVALID, | |||||
"Left_padding_size must not be configed smaller than 0"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->right_padding_size() >= 0, PARAM_INVALID, | |||||
"Right_padding_size must not be configed smaller than 0"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->top_padding_size() >= 0, PARAM_INVALID, | |||||
"Top_padding_size must not be configed smaller than 0"); | |||||
AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->bottom_padding_size() >= 0, PARAM_INVALID, | |||||
"Bottom_padding_size must not be configed smaller than 0"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->input_format() != domi::AippOpParams::UNDEFINED, PARAM_INVALID, | |||||
"Input format of AIPP conf is undefined"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->src_image_size_w() >= 0, PARAM_INVALID, | |||||
"Src_image_size_w must not be configed smaller than 0"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->src_image_size_h() >= 0, PARAM_INVALID, | |||||
"Src_image_size_h must not be configed smaller than 0"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->load_start_pos_w() >= 0, PARAM_INVALID, | |||||
"Load_start_pos_w must not be configed smaller than 0"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->load_start_pos_h() >= 0, PARAM_INVALID, | |||||
"Load_start_pos_h must not be configed smaller than 0"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->crop_size_w() >= 0, PARAM_INVALID, | |||||
"Crop_size_w must not be configed smaller than 0"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->resize_output_w() >= 0, PARAM_INVALID, | |||||
"Resize_output_w must not be configed smaller than 0"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->resize_output_h() >= 0, PARAM_INVALID, | |||||
"Resize_output_h must not be configed smaller than 0"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->left_padding_size() >= 0, PARAM_INVALID, | |||||
"Left_padding_size must not be configed smaller than 0"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->right_padding_size() >= 0, PARAM_INVALID, | |||||
"Right_padding_size must not be configed smaller than 0"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->top_padding_size() >= 0, PARAM_INVALID, | |||||
"Top_padding_size must not be configed smaller than 0"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aipp_params_->bottom_padding_size() >= 0, PARAM_INVALID, | |||||
"Bottom_padding_size must not be configed smaller than 0"); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -792,17 +780,20 @@ Status AippOp::CreateAippData(const NodePtr &aipp_node) { | |||||
int64_t batch_count = -1; | int64_t batch_count = -1; | ||||
if (GetDataDimN(data_node, ori_data_format, batch_count) != ge::SUCCESS) { | if (GetDataDimN(data_node, ori_data_format, batch_count) != ge::SUCCESS) { | ||||
GELOGE(PARAM_INVALID, "Get data_node dims and transfer to nchw_dims failed!"); | |||||
string error_msg = "Get data_node dims and transfer to nchw_dims failed!"; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error_msg.c_str()); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
if (batch_count <= 0) { | if (batch_count <= 0) { | ||||
GELOGE(PARAM_INVALID, "Batch count %ld is invalid", batch_count); | |||||
string error_msg = "Batch count[" + std::to_string(batch_count) + "] is invalid, it must positive."; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error_msg.c_str()); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
int64_t max_dynamic_aipp_size = CalcMaxSize(batch_count); | int64_t max_dynamic_aipp_size = CalcMaxSize(batch_count); | ||||
if (max_dynamic_aipp_size < 0) { | if (max_dynamic_aipp_size < 0) { | ||||
GELOGE(PARAM_INVALID, "The dynamic aipp size is not positive."); | |||||
string error_msg = "The dynamic aipp size is not positive"; | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error_msg.c_str()); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -40,8 +40,6 @@ using domi::AippOpParams; | |||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
const char *const kMbatchSwitchnName = "mbatch-switch-name"; | const char *const kMbatchSwitchnName = "mbatch-switch-name"; | ||||
const int64_t kFormatAgnosticSwitch = 1; | |||||
const int64_t kFormatDependInputIndex = 1; | |||||
} // namespace | } // namespace | ||||
static void ConvertShape2Nhwc(Format &format, vector<int64_t> &shape_vec) { | static void ConvertShape2Nhwc(Format &format, vector<int64_t> &shape_vec) { | ||||
if ((format == FORMAT_NHWC) || (shape_vec.size() != static_cast<size_t>(NORMAL_TENSOR_SIZE))) { | if ((format == FORMAT_NHWC) || (shape_vec.size() != static_cast<size_t>(NORMAL_TENSOR_SIZE))) { | ||||
@@ -127,20 +125,14 @@ Status InsertNewOpUtil::CheckInputNamePositionNotRepeat() { | |||||
string error_msg = | string error_msg = | ||||
"Can not both set related_input_name and related_input_rank!" | "Can not both set related_input_name and related_input_rank!" | ||||
" Please ensure param is the same with the first aipp config(related_input_name)."; | " Please ensure param is the same with the first aipp config(related_input_name)."; | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {error_msg}); | |||||
GELOGE(PARAM_INVALID, | |||||
"Can not both set related_input_rank and related_input_name!" | |||||
" Please ensure param is the same with the first aipp config(related_input_name)."); | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error_msg.c_str()); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
if (item->related_input_name() == another_item->related_input_name()) { | if (item->related_input_name() == another_item->related_input_name()) { | ||||
string error_msg = | string error_msg = | ||||
"Can not insert aipp to the same postion! Please ensure related_input_name" | "Can not insert aipp to the same postion! Please ensure related_input_name" | ||||
" param is different in different aipp config."; | " param is different in different aipp config."; | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {error_msg}); | |||||
GELOGE(PARAM_INVALID, | |||||
"Can not insert aipp op to the same postion! Please ensure related_input_rank param " | |||||
"is different in different aipp config."); | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error_msg.c_str()); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
} | } | ||||
@@ -161,20 +153,14 @@ Status InsertNewOpUtil::CheckInputRankPositionNoRepeat() { | |||||
string error_msg = | string error_msg = | ||||
"Can not both set related_input_rank and related_input_name!" | "Can not both set related_input_rank and related_input_name!" | ||||
" Please ensure param is the same with the first aipp config(related_input_rank)."; | " Please ensure param is the same with the first aipp config(related_input_rank)."; | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {error_msg}); | |||||
GELOGE(PARAM_INVALID, | |||||
"Can not both set related_input_rank and related_input_name!" | |||||
" Please ensure param is the same with the first aipp config(related_input_rank)."); | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error_msg.c_str()); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
if (item->related_input_rank() == another_item->related_input_rank()) { | if (item->related_input_rank() == another_item->related_input_rank()) { | ||||
string error_msg = | string error_msg = | ||||
"Can not insert aipp to the same postion! Please ensure related_input_rank" | "Can not insert aipp to the same postion! Please ensure related_input_rank" | ||||
" param is different in different aipp config."; | " param is different in different aipp config."; | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {error_msg}); | |||||
GELOGE(PARAM_INVALID, | |||||
"Can not insert aipp op to the same postion! Please ensure related_input_rank param " | |||||
"is different in different aipp config."); | |||||
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error_msg.c_str()); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
} | } | ||||
@@ -229,9 +215,9 @@ Status InsertNewOpUtil::CheckGraph(const ComputeGraphPtr &graph) { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
GE_CHK_BOOL_RET_STATUS((aippNodes.size() == 0) || (aippNodes.size() == next_nodes_cnt), PARAM_INVALID, | |||||
"Can not config part of outputs of Data node to support AIPP, config all " | |||||
"of the outputs of Data to support AIPP, or config none of them"); | |||||
GE_CHK_LOG_AND_ERRORMSG((aippNodes.size() == 0) || (aippNodes.size() == next_nodes_cnt), PARAM_INVALID, | |||||
"Can not config part of outputs of Data node to support AIPP, config all " | |||||
"of the outputs of Data to support AIPP, or config none of them"); | |||||
std::unique_ptr<domi::AippOpParams> aippParams(new (std::nothrow) domi::AippOpParams()); | std::unique_ptr<domi::AippOpParams> aippParams(new (std::nothrow) domi::AippOpParams()); | ||||
GE_CHECK_NOTNULL(aippParams); | GE_CHECK_NOTNULL(aippParams); | ||||
@@ -243,15 +229,16 @@ Status InsertNewOpUtil::CheckGraph(const ComputeGraphPtr &graph) { | |||||
GE_CHK_STATUS(GetAippParams(currAippParam, aippNodes[i])); | GE_CHK_STATUS(GetAippParams(currAippParam, aippNodes[i])); | ||||
if (aippMode == domi::AippOpParams::static_) { | if (aippMode == domi::AippOpParams::static_) { | ||||
GE_CHK_BOOL_RET_STATUS(aippParams->input_format() == currAippParam->input_format(), PARAM_INVALID, | |||||
"The input_format of all aipp_ops after one Data should be the same"); | |||||
GE_CHK_BOOL_RET_STATUS(aippParams->src_image_size_w() == currAippParam->src_image_size_w(), PARAM_INVALID, | |||||
"The src_image_size_w of all aipp_ops after one Data should be the same"); | |||||
GE_CHK_BOOL_RET_STATUS(aippParams->src_image_size_h() == currAippParam->src_image_size_h(), PARAM_INVALID, | |||||
"The src_image_size_h of all aipp_ops after one Data should be the same"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aippParams->input_format() == currAippParam->input_format(), PARAM_INVALID, | |||||
"The input_format of all aipp_ops after one Data should be the same"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aippParams->src_image_size_w() == currAippParam->src_image_size_w(), PARAM_INVALID, | |||||
"The src_image_size_w of all aipp_ops after one Data should be the same"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aippParams->src_image_size_h() == currAippParam->src_image_size_h(), PARAM_INVALID, | |||||
"The src_image_size_h of all aipp_ops after one Data should be the same"); | |||||
} else { | } else { | ||||
GE_CHK_BOOL_RET_STATUS(aippParams->max_src_image_size() == currAippParam->max_src_image_size(), PARAM_INVALID, | |||||
"The max_src_image_size of all aipp_ops after one Data should be the same"); | |||||
GE_CHK_LOG_AND_ERRORMSG(aippParams->max_src_image_size() == currAippParam->max_src_image_size(), | |||||
PARAM_INVALID, | |||||
"The max_src_image_size of all aipp_ops after one Data should be the same"); | |||||
} | } | ||||
}); | }); | ||||
} | } | ||||
@@ -271,23 +258,6 @@ Status InsertNewOpUtil::GetAippParams(const std::unique_ptr<domi::AippOpParams> | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status InsertNewOpUtil::AddFormatAgnosticAttrToSwitchn(const NodePtr &aipp_node) { | |||||
GE_CHECK_NOTNULL(aipp_node); | |||||
auto next_nodes = aipp_node->GetOutDataNodes(); | |||||
for (const auto next_node : next_nodes) { | |||||
GE_CHECK_NOTNULL(next_node); | |||||
auto op_desc = next_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if (op_desc->GetType() == SWITCHN) { | |||||
GELOGI("Find switchn node [%s] after aipp [%s]", op_desc->GetName().c_str(), aipp_node->GetName().c_str()); | |||||
(void)AttrUtils::SetInt(op_desc, "_format_agnostic", kFormatAgnosticSwitch); | |||||
(void)AttrUtils::SetListInt(op_desc, "_format_agnostic_except_input", | |||||
std::vector<int64_t>({kFormatDependInputIndex})); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status InsertNewOpUtil::UpdateDataNodeByAipp(const ComputeGraphPtr &graph) { | Status InsertNewOpUtil::UpdateDataNodeByAipp(const ComputeGraphPtr &graph) { | ||||
std::map<std::string, NodePtr> switchn_names_to_data; | std::map<std::string, NodePtr> switchn_names_to_data; | ||||
std::set<NodePtr> updated_switchn; | std::set<NodePtr> updated_switchn; | ||||
@@ -302,9 +272,6 @@ Status InsertNewOpUtil::UpdateDataNodeByAipp(const ComputeGraphPtr &graph) { | |||||
} | } | ||||
if (node->GetType() == AIPP) { | if (node->GetType() == AIPP) { | ||||
GE_RETURN_IF_ERROR(UpdatePrevNodeByAipp(node, updated_switchn)); | GE_RETURN_IF_ERROR(UpdatePrevNodeByAipp(node, updated_switchn)); | ||||
// In dynamic batch/HW and dynamic aipp scend, switchn should be set format agnostic, otherwise transdata maybe | |||||
// inserted between aipp and switchn which introduce performance and memory increase problem. | |||||
GE_RETURN_IF_ERROR(AddFormatAgnosticAttrToSwitchn(node)); | |||||
} | } | ||||
if (node->GetType() == CASE && node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) { | if (node->GetType() == CASE && node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) { | ||||
multbatch_case = node; | multbatch_case = node; | ||||
@@ -314,7 +281,8 @@ Status InsertNewOpUtil::UpdateDataNodeByAipp(const ComputeGraphPtr &graph) { | |||||
for (auto &switchn : updated_switchn) { | for (auto &switchn : updated_switchn) { | ||||
auto data_iter = switchn_names_to_data.find(switchn->GetName()); | auto data_iter = switchn_names_to_data.find(switchn->GetName()); | ||||
if (data_iter == switchn_names_to_data.end()) { | if (data_iter == switchn_names_to_data.end()) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to find relative data node by switchn %s", switchn->GetName().c_str()); | |||||
string error_msg = "Failed to find relative data node by switchn[" + switchn->GetName() + "]"; | |||||
GE_ERRORLOG_AND_ERRORMSG(INTERNAL_ERROR, error_msg.c_str()); | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
GE_RETURN_IF_ERROR(UpdateDataBySwitchN(switchn, data_iter->second)); | GE_RETURN_IF_ERROR(UpdateDataBySwitchN(switchn, data_iter->second)); | ||||
@@ -501,7 +469,8 @@ Status InsertNewOpUtil::UpdateDataBySwitchN(const NodePtr &switchn, const NodePt | |||||
} | } | ||||
} | } | ||||
if (max_index >= switchn->GetOpDesc()->GetOutputsSize()) { | if (max_index >= switchn->GetOpDesc()->GetOutputsSize()) { | ||||
GELOGE(INTERNAL_ERROR, "No max size found from switchn node %s", switchn->GetName().c_str()); | |||||
string error_msg = "No max size found from switchn node[" + switchn->GetName() + "]"; | |||||
GE_ERRORLOG_AND_ERRORMSG(INTERNAL_ERROR, error_msg.c_str()); | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
auto output_desc = switchn->GetOpDesc()->MutableOutputDesc(max_index); | auto output_desc = switchn->GetOpDesc()->MutableOutputDesc(max_index); | ||||
@@ -68,7 +68,6 @@ class InsertNewOpUtil { | |||||
void UpdateMultiBatchInputDims(const OpDescPtr &data_opdesc, Format &old_format); | void UpdateMultiBatchInputDims(const OpDescPtr &data_opdesc, Format &old_format); | ||||
Status UpdatePrevNodeByAipp(NodePtr &node, std::set<NodePtr> &switchns); | Status UpdatePrevNodeByAipp(NodePtr &node, std::set<NodePtr> &switchns); | ||||
Status UpdateDataBySwitchN(const NodePtr &switchn, const NodePtr &data); | Status UpdateDataBySwitchN(const NodePtr &switchn, const NodePtr &data); | ||||
Status AddFormatAgnosticAttrToSwitchn(const NodePtr &aipp_node); | |||||
Status GetDataRelatedNode(NodePtr &node, std::map<NodePtr, std::set<NodePtr>> &data_next_node_map); | Status GetDataRelatedNode(NodePtr &node, std::map<NodePtr, std::set<NodePtr>> &data_next_node_map); | ||||
Status GetAllAipps(const NodePtr &data_node, const NodePtr &node, std::vector<NodePtr> &aipps); | Status GetAllAipps(const NodePtr &data_node, const NodePtr &node, std::vector<NodePtr> &aipps); | ||||
Status GetInputOutputInfo(NodePtr &data_node, NodePtr &aipp_node, std::string &input, std::string &output); | Status GetInputOutputInfo(NodePtr &data_node, NodePtr &aipp_node, std::string &input, std::string &output); | ||||
@@ -593,6 +593,8 @@ Status MultiBatchGraphCopyer::CheckCopyResult(const std::vector<NodePtr> &start_ | |||||
} | } | ||||
auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims(); | auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims(); | ||||
if (!IsAllDimsPositive(dims)) { | if (!IsAllDimsPositive(dims)) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E15004", {"opname", "shape"}, | |||||
{node->GetName(), formats::ShapeToString(dims)}); | |||||
GELOGE(INTERNAL_ERROR, "Failed to copy multi batch graph, the node %s still has unknown shape %s", | GELOGE(INTERNAL_ERROR, "Failed to copy multi batch graph, the node %s still has unknown shape %s", | ||||
node->GetName().c_str(), formats::ShapeToString(dims).c_str()); | node->GetName().c_str(), formats::ShapeToString(dims).c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
@@ -1023,6 +1025,13 @@ Status MultiBatchGraphCopyer::InsertIdentityAfterSwitchN() { | |||||
} | } | ||||
Status ProcessMultiBatch(ComputeGraphPtr &graph) { | Status ProcessMultiBatch(ComputeGraphPtr &graph) { | ||||
const char *multi_batch_with_case = std::getenv("MULTI_BATCH_WITH_CASE"); | |||||
if (multi_batch_with_case != nullptr) { | |||||
PassManager pass_manager; | |||||
GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); | |||||
return pass_manager.Run(graph); | |||||
} | |||||
std::vector<std::vector<int64_t>> shapes; | std::vector<std::vector<int64_t>> shapes; | ||||
if (!InitDynamicParams(shapes)) { | if (!InitDynamicParams(shapes)) { | ||||
GELOGD("There is no multi-batch options, no need to process multi-batch copy"); | GELOGD("There is no multi-batch options, no need to process multi-batch copy"); | ||||
@@ -124,6 +124,8 @@ Status ParserDataToDynmaicInfo(const vector<vector<int64_t>> &shapes, | |||||
auto tmp_index = cur_data_index; | auto tmp_index = cur_data_index; | ||||
for (size_t i = 0; i < static_cast<size_t>(dynamic_dims_num); ++i) { | for (size_t i = 0; i < static_cast<size_t>(dynamic_dims_num); ++i) { | ||||
if (tmp_index >= dynamic_gear_info.size()) { | if (tmp_index >= dynamic_gear_info.size()) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E10045", {"name", "shape"}, | |||||
{data_name, formats::JoinToString(data_shape)}); | |||||
GELOGE(PARAM_INVALID, "Data: %s shape: %s make dynamic dims overflow", data_name.c_str(), | GELOGE(PARAM_INVALID, "Data: %s shape: %s make dynamic dims overflow", data_name.c_str(), | ||||
formats::JoinToString(data_shape).c_str()); | formats::JoinToString(data_shape).c_str()); | ||||
return FAILED; | return FAILED; | ||||
@@ -131,6 +133,8 @@ Status ParserDataToDynmaicInfo(const vector<vector<int64_t>> &shapes, | |||||
one_gear.push_back(dynamic_gear_info[tmp_index++]); | one_gear.push_back(dynamic_gear_info[tmp_index++]); | ||||
} | } | ||||
} else { | } else { | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E10046", {"name", "shape"}, | |||||
{data_name, formats::JoinToString(data_shape)}); | |||||
GELOGE(PARAM_INVALID, "Dynamic dims num of data: %s shape: %s can not be more than one gear dynamic info size", | GELOGE(PARAM_INVALID, "Dynamic dims num of data: %s shape: %s can not be more than one gear dynamic info size", | ||||
data_name.c_str(), formats::JoinToString(data_shape).c_str()); | data_name.c_str(), formats::JoinToString(data_shape).c_str()); | ||||
return FAILED; | return FAILED; | ||||
@@ -9,12 +9,15 @@ local_lib_src_files := engine/host_cpu_engine.cc \ | |||||
local_lib_inc_path := proto/task.proto \ | local_lib_inc_path := proto/task.proto \ | ||||
${LOCAL_PATH} \ | ${LOCAL_PATH} \ | ||||
${TOPDIR}inc \ | ${TOPDIR}inc \ | ||||
${TOPDIR}metadef/inc \ | |||||
${TOPDIR}graphengine/inc \ | |||||
${TOPDIR}inc/external \ | ${TOPDIR}inc/external \ | ||||
${TOPDIR}inc/external/graph \ | |||||
${TOPDIR}metadef/inc/external \ | |||||
${TOPDIR}graphengine/inc/external \ | |||||
${TOPDIR}metadef/inc/external/graph \ | |||||
$(TOPDIR)libc_sec/include \ | $(TOPDIR)libc_sec/include \ | ||||
${TOPDIR}third_party/protobuf/include \ | ${TOPDIR}third_party/protobuf/include \ | ||||
${TOPDIR}inc/framework \ | |||||
$(TOPDIR)framework/domi \ | |||||
${TOPDIR}graphengine/inc/framework \ | |||||
$(TOPDIR)graphengine/ge \ | $(TOPDIR)graphengine/ge \ | ||||
#compiler for host | #compiler for host | ||||
@@ -100,7 +100,9 @@ Status SliceKernel::Compute(const OpDescPtr attr, const std::vector<ConstGeTenso | |||||
} | } | ||||
// construct tensorDesc | // construct tensorDesc | ||||
ge::GeShape output_shape(output_dims); | ge::GeShape output_shape(output_dims); | ||||
GeTensorDesc output_tensor_desc(output_shape, FORMAT_NCHW, data_type); | |||||
auto attr_output_tensor_desc = attr->GetOutputDesc(0); | |||||
GeTensorDesc output_tensor_desc(attr_output_tensor_desc); | |||||
output_tensor_desc.SetShape(output_shape); | |||||
GeTensorPtr output_ptr = MakeShared<GeTensor>(output_tensor_desc); | GeTensorPtr output_ptr = MakeShared<GeTensor>(output_tensor_desc); | ||||
if (output_ptr == nullptr) { | if (output_ptr == nullptr) { | ||||
GELOGW("make_shared ge::GeTensor failed, node name %s.", attr->GetName().c_str()); | GELOGW("make_shared ge::GeTensor failed, node name %s.", attr->GetName().c_str()); | ||||
@@ -45,16 +45,9 @@ NpuMemoryAllocator *NpuMemoryAllocator::GetAllocator() { | |||||
NpuMemoryAllocator::NpuMemoryAllocator(uint32_t device_id) : device_id_(device_id) {} | NpuMemoryAllocator::NpuMemoryAllocator(uint32_t device_id) : device_id_(device_id) {} | ||||
void *NpuMemoryAllocator::Allocate(std::size_t size, AllocationAttr *attr) { | void *NpuMemoryAllocator::Allocate(std::size_t size, AllocationAttr *attr) { | ||||
void *try_reuse_addr = nullptr; | |||||
size_t allocate_size = size; | size_t allocate_size = size; | ||||
MemStorageType mem_type = HBM; | MemStorageType mem_type = HBM; | ||||
if (attr != nullptr) { | if (attr != nullptr) { | ||||
try_reuse_addr = attr->try_reuse_addr_; | |||||
if (attr->padding_ != 0) { | |||||
// padding up to multiple of attr->padding, and add extra attr->padding_ | |||||
allocate_size = (size + 2 * attr->padding_ - 1) / attr->padding_ * attr->padding_; | |||||
GELOGD("Padding size %ld by %d. final size = %zu.", size, attr->padding_, allocate_size); | |||||
} | |||||
mem_type = attr->mem_type_; | mem_type = attr->mem_type_; | ||||
} | } | ||||
@@ -69,6 +62,17 @@ void *NpuMemoryAllocator::Allocate(std::size_t size, AllocationAttr *attr) { | |||||
} else if (mem_type == HOST_DDR) { | } else if (mem_type == HOST_DDR) { | ||||
buffer = malloc(allocate_size); | buffer = malloc(allocate_size); | ||||
} else { | } else { | ||||
void *try_reuse_addr = nullptr; | |||||
int padding = kDefaultPadding; | |||||
if (attr != nullptr) { | |||||
try_reuse_addr = attr->try_reuse_addr_; | |||||
if (attr->padding_ > 0) { | |||||
padding = attr->padding_; | |||||
} | |||||
} | |||||
// padding up to multiple of padding, and add extra padding | |||||
allocate_size = (size + 2 * padding - 1) / padding * padding; | |||||
GELOGD("Padding size %ld by %d. final size = %zu.", size, padding, allocate_size); | |||||
buffer = MemManager::Instance() | buffer = MemManager::Instance() | ||||
.CachingInstance(RT_MEMORY_HBM) | .CachingInstance(RT_MEMORY_HBM) | ||||
.Malloc(allocate_size, reinterpret_cast<uint8_t *>(try_reuse_addr), device_id_); | .Malloc(allocate_size, reinterpret_cast<uint8_t *>(try_reuse_addr), device_id_); | ||||
@@ -105,8 +105,10 @@ Status NodeDoneCallback::PrepareConstInputs(const NodeItem &node_item) { | |||||
vector<uint8_t> host_buffer(static_cast<unsigned long>(tensor_size)); | vector<uint8_t> host_buffer(static_cast<unsigned long>(tensor_size)); | ||||
GELOGD("[%s] To cache output[%d] to host, size = %zu", node_item.NodeName().c_str(), output_idx, | GELOGD("[%s] To cache output[%d] to host, size = %zu", node_item.NodeName().c_str(), output_idx, | ||||
output_tensor->GetSize()); | output_tensor->GetSize()); | ||||
GE_CHK_RT_RET( | |||||
rtMemcpy(host_buffer.data(), tensor_size, output_tensor->GetData(), tensor_size, RT_MEMCPY_DEVICE_TO_HOST)); | |||||
if (tensor_size > 0) { | |||||
GE_CHK_RT_RET( | |||||
rtMemcpy(host_buffer.data(), tensor_size, output_tensor->GetData(), tensor_size, RT_MEMCPY_DEVICE_TO_HOST)); | |||||
} | |||||
tensor.SetData(std::move(host_buffer)); | tensor.SetData(std::move(host_buffer)); | ||||
string session_id = std::to_string(context_->GetSessionId()); | string session_id = std::to_string(context_->GetSessionId()); | ||||
RuntimeInferenceContext *runtime_infer_ctx = nullptr; | RuntimeInferenceContext *runtime_infer_ctx = nullptr; | ||||
@@ -234,7 +236,9 @@ Status NodeDoneCallback::ProfilingReport() { | |||||
return profiling_ret; | return profiling_ret; | ||||
} | } | ||||
ProfilingManager::Instance().ReportProfilingData(task_desc_info, compute_graph_info); | |||||
auto &profiling_manager = ProfilingManager::Instance(); | |||||
profiling_manager.ReportProfilingData(model->GetModelId(), task_desc_info, compute_graph_info, | |||||
!profiling_manager.IsAclApiMode()); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -249,7 +249,7 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s | |||||
} | } | ||||
// cond or branch need to be prepared before the execution of IF or CASE | // cond or branch need to be prepared before the execution of IF or CASE | ||||
if (node_item.node_type == IF || node_item.node_type == CASE) { | |||||
if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) { | |||||
const auto &in_anchor = ge_node->GetInDataAnchor(0); | const auto &in_anchor = ge_node->GetInDataAnchor(0); | ||||
GE_CHECK_NOTNULL(in_anchor); | GE_CHECK_NOTNULL(in_anchor); | ||||
const auto &peer_anchor = in_anchor->GetPeerOutAnchor(); | const auto &peer_anchor = in_anchor->GetPeerOutAnchor(); | ||||
@@ -653,6 +653,8 @@ Status HybridModelBuilder::LoadGraph() { | |||||
} else { | } else { | ||||
GE_CHK_STATUS_RET(IdentifyVariableOutputs(*parent_node_item), "[%s] Failed to identify ref outputs.", | GE_CHK_STATUS_RET(IdentifyVariableOutputs(*parent_node_item), "[%s] Failed to identify ref outputs.", | ||||
parent_node_item->NodeName().c_str()); | parent_node_item->NodeName().c_str()); | ||||
GE_CHK_STATUS_RET(IdentifySameInputs(*parent_node_item), "[%s] Failed to identify same outputs.", | |||||
parent_node_item->NodeName().c_str()); | |||||
// if parent is function control op. need add a virtual partitioned call | // if parent is function control op. need add a virtual partitioned call | ||||
if (parent_node_item->IsControlOp()) { | if (parent_node_item->IsControlOp()) { | ||||
@@ -858,7 +860,7 @@ Status HybridModelBuilder::LoadGeModel(ComputeGraph &sub_graph, const GeModelPtr | |||||
auto parent_node = sub_graph.GetParentNode(); | auto parent_node = sub_graph.GetParentNode(); | ||||
GE_CHECK_NOTNULL(parent_node); | GE_CHECK_NOTNULL(parent_node); | ||||
auto op_type = parent_node->GetType(); | auto op_type = parent_node->GetType(); | ||||
if (op_type == IF || op_type == CASE || op_type == WHILE) { | |||||
if (IsControlOp(op_type)) { | |||||
GELOGD("Set ge_model for control op subgraph: [%s], task_size = %d", sub_graph.GetName().c_str(), | GELOGD("Set ge_model for control op subgraph: [%s], task_size = %d", sub_graph.GetName().c_str(), | ||||
ge_model->GetModelTaskDefPtr()->task_size()); | ge_model->GetModelTaskDefPtr()->task_size()); | ||||
subgraph_models_.emplace(sub_graph.GetName(), ge_model); | subgraph_models_.emplace(sub_graph.GetName(), ge_model); | ||||
@@ -1087,6 +1089,43 @@ Status HybridModelBuilder::InitRuntimeParams() { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HybridModelBuilder::IdentifySameInputs(NodeItem &node_item) { | |||||
GELOGD("Start to parse same inputs on net output: %s", node_item.NodeName().c_str()); | |||||
auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); | |||||
GE_CHECK_NOTNULL(subgraph); | |||||
auto net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT); | |||||
if (net_output_node == nullptr) { | |||||
GELOGD("Subgraph [%s] does not have net output", subgraph->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
auto net_output_desc = net_output_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(net_output_desc); | |||||
std::map<std::string, int> connected_inputs; | |||||
for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { | |||||
auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
if (out_data_anchor == nullptr) { | |||||
continue; | |||||
} | |||||
auto src_node = out_data_anchor->GetOwnerNode(); | |||||
GE_CHECK_NOTNULL(src_node); | |||||
auto op_desc = src_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
std::string input_key = std::to_string(op_desc->GetId()) + "_" + std::to_string(out_data_anchor->GetIdx()); | |||||
auto it = connected_inputs.find(input_key); | |||||
if (it == connected_inputs.end()) { | |||||
connected_inputs.emplace(input_key, in_data_anchor->GetIdx()); | |||||
} else { | |||||
GELOGD("[%s] output [%d] reuse output [%d] input node = %s, idx = %d.", node_item.NodeName().c_str(), | |||||
in_data_anchor->GetIdx(), it->second, src_node->GetName().c_str(), out_data_anchor->GetIdx()); | |||||
node_item.reuse_outputs.emplace(in_data_anchor->GetIdx(), it->second); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { | Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { | ||||
GELOGD("Start to parse outputs of node: %s", node_item.NodeName().c_str()); | GELOGD("Start to parse outputs of node: %s", node_item.NodeName().c_str()); | ||||
auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); | auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); | ||||
@@ -57,6 +57,7 @@ class HybridModelBuilder { | |||||
Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | ||||
Status LoadTasks(); | Status LoadTasks(); | ||||
Status IdentifyVariableOutputs(NodeItem &node_item); | Status IdentifyVariableOutputs(NodeItem &node_item); | ||||
Status IdentifySameInputs(NodeItem &node_item); | |||||
Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | ||||
Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | ||||
Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies); | Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies); | ||||
@@ -28,6 +28,7 @@ namespace hybrid { | |||||
namespace { | namespace { | ||||
const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; | const char *const kAttrNameOriginalFusionGraph = "_original_fusion_graph"; | ||||
const char *const kNodeTypeRetVal = "_RetVal"; | const char *const kNodeTypeRetVal = "_RetVal"; | ||||
std::set<std::string> kControlOpTypes{IF, STATELESSIF, CASE, WHILE, STATELESSWHILE}; | |||||
Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { | Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) { | ||||
uint32_t parent_index = 0; | uint32_t parent_index = 0; | ||||
@@ -96,6 +97,9 @@ Status ParseFusedSubgraph(NodeItem &node_item) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} // namespace | } // namespace | ||||
bool IsControlOp(const std::string &op_type) { return kControlOpTypes.count(op_type) > 0; } | |||||
NodeItem::NodeItem(NodePtr node) : node(std::move(node)) { | NodeItem::NodeItem(NodePtr node) : node(std::move(node)) { | ||||
this->op_desc = this->node->GetOpDesc().get(); | this->op_desc = this->node->GetOpDesc().get(); | ||||
this->node_id = this->op_desc->GetId(); | this->node_id = this->op_desc->GetId(); | ||||
@@ -145,10 +149,7 @@ Status NodeItem::Init() { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
bool NodeItem::IsControlOp() const { | |||||
auto op_type = op_desc->GetType(); | |||||
return op_type == IF || op_type == CASE || op_type == WHILE || op_type == FOR; | |||||
} | |||||
bool NodeItem::IsControlOp() const { return ge::hybrid::IsControlOp(op_desc->GetType()); } | |||||
std::string NodeItem::DebugString() const { | std::string NodeItem::DebugString() const { | ||||
std::stringstream ss; | std::stringstream ss; | ||||
@@ -36,6 +36,8 @@ struct FusedSubgraph { | |||||
ComputeGraphPtr graph; | ComputeGraphPtr graph; | ||||
}; | }; | ||||
bool IsControlOp(const std::string &op_type); | |||||
// for caching static information across execution | // for caching static information across execution | ||||
struct NodeItem { | struct NodeItem { | ||||
explicit NodeItem(NodePtr node); | explicit NodeItem(NodePtr node); | ||||
@@ -79,6 +81,7 @@ struct NodeItem { | |||||
const NodeExecutor *node_executor = nullptr; | const NodeExecutor *node_executor = nullptr; | ||||
std::map<int, ge::NodePtr> ref_outputs; | std::map<int, ge::NodePtr> ref_outputs; | ||||
std::map<int, int> reuse_inputs; | std::map<int, int> reuse_inputs; | ||||
std::map<int, int> reuse_outputs; | |||||
std::vector<bool> is_input_shape_static; | std::vector<bool> is_input_shape_static; | ||||
bool is_output_shape_static = true; | bool is_output_shape_static = true; | ||||
@@ -17,8 +17,6 @@ | |||||
#include "aicore_node_executor.h" | #include "aicore_node_executor.h" | ||||
#include "cce/taskdown_common.hpp" | #include "cce/taskdown_common.hpp" | ||||
#include "hybrid/executor/hybrid_execution_context.h" | #include "hybrid/executor/hybrid_execution_context.h" | ||||
#include "init/gelib.h" | |||||
#include "hybrid/executor/hybrid_execution_context.h" | |||||
namespace ge { | namespace ge { | ||||
namespace hybrid { | namespace hybrid { | ||||
@@ -27,19 +25,10 @@ REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICORE, AiCore | |||||
AiCoreNodeTask::AiCoreNodeTask(std::vector<std::unique_ptr<AiCoreOpTask>> &&tasks) : tasks_(std::move(tasks)) {} | AiCoreNodeTask::AiCoreNodeTask(std::vector<std::unique_ptr<AiCoreOpTask>> &&tasks) : tasks_(std::move(tasks)) {} | ||||
Status AiCoreNodeExecutor::Initialize() { | Status AiCoreNodeExecutor::Initialize() { | ||||
auto ge_lib = GELib::GetInstance(); | |||||
GE_CHECK_NOTNULL(ge_lib); | |||||
if (!ge_lib->InitFlag()) { | |||||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge_lib is uninitialized, failed."); | |||||
return GE_CLI_GE_NOT_INITIALIZED; | |||||
compiler_ = TaskCompilerFactory::GetInstance().GetTaskCompiler(); | |||||
if (compiler_ != nullptr) { | |||||
GE_CHK_STATUS_RET(compiler_->Initialize(), "Failed to init aicore task compiler."); | |||||
} | } | ||||
auto &kernel_manager = ge_lib->OpsKernelManagerObj(); | |||||
auto aic_ops_store = kernel_manager.GetOpsKernelInfoStore("AIcoreEngine"); | |||||
GE_CHECK_NOTNULL(aic_ops_store); | |||||
compiler_.reset(new (std::nothrow) AiCoreTaskCompiler(aic_ops_store)); | |||||
GE_CHECK_NOTNULL(compiler_); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -119,6 +108,12 @@ Status AiCoreNodeExecutor::CompileTask(const HybridModel &model, const NodePtr & | |||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
GELOGI("AiCoreNodeExecutor(%s) CompileTask Start.", node->GetName().c_str()); | GELOGI("AiCoreNodeExecutor(%s) CompileTask Start.", node->GetName().c_str()); | ||||
auto ori_node_name = node->GetName(); | |||||
if (compiler_ == nullptr) { | |||||
GELOGE(FAILED, "[%s] Can not find any valid aicore task compiler.", ori_node_name.c_str()); | |||||
return FAILED; | |||||
} | |||||
AiCoreNodeTaskRegistry ®istry = AiCoreNodeTaskRegistry::GetInstance(); | AiCoreNodeTaskRegistry ®istry = AiCoreNodeTaskRegistry::GetInstance(); | ||||
std::string shape_key; | std::string shape_key; | ||||
GE_CHK_STATUS_RET(GenNodeKey(node, shape_key), "GenNodeKey failed, op name = %s.", node->GetName().c_str()); | GE_CHK_STATUS_RET(GenNodeKey(node, shape_key), "GenNodeKey failed, op name = %s.", node->GetName().c_str()); | ||||
@@ -132,7 +127,6 @@ Status AiCoreNodeExecutor::CompileTask(const HybridModel &model, const NodePtr & | |||||
} | } | ||||
std::vector<domi::TaskDef> task_defs; | std::vector<domi::TaskDef> task_defs; | ||||
auto ori_node_name = node->GetName(); | |||||
op_desc->SetName(ori_node_name + "_" + shape_key); | op_desc->SetName(ori_node_name + "_" + shape_key); | ||||
GE_CHK_STATUS_RET(compiler_->CompileOp(node, task_defs), "Compile op(%s) failed.", ori_node_name.c_str()); | GE_CHK_STATUS_RET(compiler_->CompileOp(node, task_defs), "Compile op(%s) failed.", ori_node_name.c_str()); | ||||
op_desc->SetName(ori_node_name); | op_desc->SetName(ori_node_name); | ||||
@@ -155,6 +149,13 @@ Status AiCoreNodeExecutor::CompileTask(const HybridModel &model, const NodePtr & | |||||
Status AiCoreNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | Status AiCoreNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | ||||
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeTaskExecuteAsync] Start"); | RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeTaskExecuteAsync] Start"); | ||||
if (IsNoOp(context)) { | |||||
GELOGD("[%s] Skipping execution for op with empty outputs", context.GetNodeName()); | |||||
auto ret = context.TryExecuteCallback(done_callback); | |||||
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeTaskExecuteAsync] End"); | |||||
return ret; | |||||
} | |||||
auto op_desc = context.GetNodeItem().op_desc; | auto op_desc = context.GetNodeItem().op_desc; | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
GELOGI("[%s] ExecuteAsync Start.", op_desc->GetName().c_str()); | GELOGI("[%s] ExecuteAsync Start.", op_desc->GetName().c_str()); | ||||
@@ -218,5 +219,32 @@ bool AiCoreNodeTask::IsSupportDynamicShape() { | |||||
return true; | return true; | ||||
} | } | ||||
bool AiCoreNodeTask::IsNoOp(TaskContext &task_context) { | |||||
for (int i = 0; i < task_context.NumOutputs(); ++i) { | |||||
const auto &tensor_desc = task_context.MutableOutputDesc(i); | |||||
GE_CHECK_NOTNULL(tensor_desc); | |||||
const auto &shape = tensor_desc->MutableShape(); | |||||
if (shape.IsScalar() || shape.GetShapeSize() > 0) { | |||||
return false; | |||||
} | |||||
} | |||||
return true; | |||||
} | |||||
TaskCompilerFactory &TaskCompilerFactory::GetInstance() { | |||||
static TaskCompilerFactory instance; | |||||
return instance; | |||||
} | |||||
void TaskCompilerFactory::Register(CreateFn fn) { compiler_func_ = fn; } | |||||
std::unique_ptr<TaskCompiler> TaskCompilerFactory::GetTaskCompiler() { | |||||
auto compiler_instance = std::unique_ptr<TaskCompiler>(compiler_func_()); | |||||
return compiler_instance; | |||||
} | |||||
CompilerFunctionRegistrar::CompilerFunctionRegistrar(CreateFn fn) { TaskCompilerFactory::GetInstance().Register(fn); } | |||||
} // namespace hybrid | } // namespace hybrid | ||||
} // namespace ge | } // namespace ge |
@@ -18,13 +18,21 @@ | |||||
#define GE_HYBRID_KERNEL_AICORE_NODE_EXECUTOR_H_ | #define GE_HYBRID_KERNEL_AICORE_NODE_EXECUTOR_H_ | ||||
#include "hybrid/node_executor/aicore/aicore_task_builder.h" | #include "hybrid/node_executor/aicore/aicore_task_builder.h" | ||||
#include "hybrid/node_executor/aicore/aicore_task_compiler.h" | |||||
#include "hybrid/node_executor/node_executor.h" | #include "hybrid/node_executor/node_executor.h" | ||||
#include <map> | #include <map> | ||||
#include <mutex> | #include <mutex> | ||||
namespace ge { | namespace ge { | ||||
namespace hybrid { | namespace hybrid { | ||||
class TaskCompiler { | |||||
public: | |||||
TaskCompiler() = default; | |||||
virtual ~TaskCompiler() = default; | |||||
virtual Status CompileOp(const NodePtr &node, std::vector<domi::TaskDef> &tasks) = 0; | |||||
virtual Status Initialize() = 0; | |||||
}; | |||||
class AiCoreNodeTaskRegistry { | class AiCoreNodeTaskRegistry { | ||||
public: | public: | ||||
~AiCoreNodeTaskRegistry() = default; | ~AiCoreNodeTaskRegistry() = default; | ||||
@@ -54,6 +62,7 @@ class AiCoreNodeTask : public NodeTask { | |||||
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | ||||
private: | private: | ||||
static bool IsNoOp(TaskContext &task_context); | |||||
std::vector<std::unique_ptr<AiCoreOpTask>> tasks_; | std::vector<std::unique_ptr<AiCoreOpTask>> tasks_; | ||||
}; | }; | ||||
@@ -65,8 +74,31 @@ class AiCoreNodeExecutor : public NodeExecutor { | |||||
private: | private: | ||||
static Status GenNodeKey(const NodePtr &node, std::string &node_key); | static Status GenNodeKey(const NodePtr &node, std::string &node_key); | ||||
std::unique_ptr<AiCoreTaskCompiler> compiler_; | |||||
std::unique_ptr<TaskCompiler> compiler_; | |||||
}; | |||||
using CreateFn = TaskCompiler *(*)(); | |||||
class TaskCompilerFactory { | |||||
public: | |||||
static TaskCompilerFactory &GetInstance(); | |||||
void Register(CreateFn fn); | |||||
std::unique_ptr<TaskCompiler> GetTaskCompiler(); | |||||
private: | |||||
CreateFn compiler_func_; | |||||
}; | |||||
class CompilerFunctionRegistrar { | |||||
public: | |||||
CompilerFunctionRegistrar(CreateFn fn); | |||||
~CompilerFunctionRegistrar() = default; | |||||
}; | }; | ||||
} // namespace hybrid | } // namespace hybrid | ||||
} // namespace ge | } // namespace ge | ||||
#define REGISTER_TASK_COMPILER(compiler) \ | |||||
static ::ge::hybrid::CompilerFunctionRegistrar register_compiler_function __attribute__((unused)) = \ | |||||
::ge::hybrid::CompilerFunctionRegistrar( \ | |||||
[]() -> ::ge::hybrid::TaskCompiler * { return new (std::nothrow) compiler(); }) | |||||
#endif // GE_HYBRID_KERNEL_AICORE_NODE_EXECUTOR_H_ | #endif // GE_HYBRID_KERNEL_AICORE_NODE_EXECUTOR_H_ |
@@ -18,6 +18,7 @@ | |||||
#include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "opskernel_manager/ops_kernel_builder_manager.h" | #include "opskernel_manager/ops_kernel_builder_manager.h" | ||||
#include "init/gelib.h" | |||||
namespace ge { | namespace ge { | ||||
namespace hybrid { | namespace hybrid { | ||||
@@ -25,11 +26,22 @@ namespace { | |||||
uintptr_t kWeightBase = 0x10000000; | uintptr_t kWeightBase = 0x10000000; | ||||
uintptr_t kMemBase = 0x20000000; | uintptr_t kMemBase = 0x20000000; | ||||
uint64_t kFakeSize = 0x10000000UL; | uint64_t kFakeSize = 0x10000000UL; | ||||
REGISTER_TASK_COMPILER(AiCoreTaskCompiler); | |||||
} // namespace | } // namespace | ||||
std::mutex AiCoreTaskCompiler::mu_; | std::mutex AiCoreTaskCompiler::mu_; | ||||
AiCoreTaskCompiler::AiCoreTaskCompiler(OpsKernelInfoStorePtr aic_kernel_store) | |||||
: aic_kernel_store_(std::move(aic_kernel_store)) {} | |||||
Status AiCoreTaskCompiler::Initialize() { | |||||
auto ge_lib = GELib::GetInstance(); | |||||
GE_CHECK_NOTNULL(ge_lib); | |||||
if (!ge_lib->InitFlag()) { | |||||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge_lib is uninitialized, failed."); | |||||
return GE_CLI_GE_NOT_INITIALIZED; | |||||
} | |||||
auto &kernel_manager = ge_lib->OpsKernelManagerObj(); | |||||
aic_kernel_store_ = kernel_manager.GetOpsKernelInfoStore("AIcoreEngine"); | |||||
GE_CHECK_NOTNULL(aic_kernel_store_); | |||||
return SUCCESS; | |||||
} | |||||
Status AiCoreTaskCompiler::DoCompileOp(const NodePtr &node) const { | Status AiCoreTaskCompiler::DoCompileOp(const NodePtr &node) const { | ||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
@@ -19,15 +19,17 @@ | |||||
#include <mutex> | #include <mutex> | ||||
#include "opskernel_manager/ops_kernel_manager.h" | #include "opskernel_manager/ops_kernel_manager.h" | ||||
#include "aicore_node_executor.h" | |||||
namespace ge { | namespace ge { | ||||
namespace hybrid { | namespace hybrid { | ||||
class AiCoreTaskCompiler { | |||||
class AiCoreTaskCompiler : public TaskCompiler { | |||||
public: | public: | ||||
explicit AiCoreTaskCompiler(OpsKernelInfoStorePtr aic_kernel_store); | |||||
AiCoreTaskCompiler() = default; | |||||
~AiCoreTaskCompiler() = default; | ~AiCoreTaskCompiler() = default; | ||||
Status CompileOp(const NodePtr &node, std::vector<domi::TaskDef> &tasks); | |||||
Status CompileOp(const NodePtr &node, std::vector<domi::TaskDef> &tasks) override; | |||||
Status Initialize() override; | |||||
private: | private: | ||||
Status DoCompileOp(const NodePtr &node) const; | Status DoCompileOp(const NodePtr &node) const; | ||||