Merge pull request !223 from yanghaoran/r1.0.1pull/223/MERGE
@@ -222,6 +222,18 @@ const char *const OPTION_GE_MAX_DUMP_OP_NUM = "ge.maxDumpOpNum"; | |||||
// Its value should be "0" or "1", default value is "1" | // Its value should be "0" or "1", default value is "1" | ||||
const char *const ENABLE_PRINT_OP_PASS = "ge.enablePrintOpPass"; | const char *const ENABLE_PRINT_OP_PASS = "ge.enablePrintOpPass"; | ||||
// Configure operator compilation path | |||||
// Its value should be file path, default value is "./" | |||||
const char *const DEBUG_DIR = "ge.debugDir"; | |||||
// Configure operator compiler cache path | |||||
// Its value should be file path, default value is "./" | |||||
const char *const OP_COMPILER_CACHE_DIR = "ge.op_compiler_cache_dir"; | |||||
// Configure operator compiler cache mode | |||||
// Its value should be "disable", "enable" or "force", default value is "disable" | |||||
const char *const OP_COMPILER_CACHE_MODE = "ge.op_compiler_cache_mode"; | |||||
// Configure whether to use single stream. | // Configure whether to use single stream. | ||||
// Its value should be "true" or "false", default value is "false" | // Its value should be "true" or "false", default value is "false" | ||||
const char *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream"; | const char *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream"; | ||||
@@ -295,7 +307,9 @@ static const char *const OUT_NODES = ge::OUTPUT_NODE_NAME.c_str(); | |||||
static const char *const INPUT_FP16_NODES = ge::INPUT_FP16_NODES.c_str(); | static const char *const INPUT_FP16_NODES = ge::INPUT_FP16_NODES.c_str(); | ||||
static const char *const LOG_LEVEL = "log"; | static const char *const LOG_LEVEL = "log"; | ||||
static const char *const OPTYPELIST_FOR_IMPLMODE = ge::OPTYPELIST_FOR_IMPLMODE.c_str(); | static const char *const OPTYPELIST_FOR_IMPLMODE = ge::OPTYPELIST_FOR_IMPLMODE.c_str(); | ||||
static const char *const DEBUG_DIR = ge::DEBUG_DIR; | |||||
static const char *const OP_COMPILER_CACHE_DIR = ge::OP_COMPILER_CACHE_DIR; | |||||
static const char *const OP_COMPILER_CACHE_MODE = ge::OP_COMPILER_CACHE_MODE; | |||||
// for interface: aclgrphBuildModel | // for interface: aclgrphBuildModel | ||||
const std::set<std::string> ir_builder_suppported_options = { | const std::set<std::string> ir_builder_suppported_options = { | ||||
INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, | INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, | ||||
@@ -317,7 +331,10 @@ const std::set<std::string> global_options = {CORE_TYPE, | |||||
FUSION_SWITCH_FILE, | FUSION_SWITCH_FILE, | ||||
ENABLE_SMALL_CHANNEL, | ENABLE_SMALL_CHANNEL, | ||||
OP_SELECT_IMPL_MODE, | OP_SELECT_IMPL_MODE, | ||||
OPTYPELIST_FOR_IMPLMODE}; | |||||
OPTYPELIST_FOR_IMPLMODE, | |||||
DEBUG_DIR, | |||||
OP_COMPILER_CACHE_DIR, | |||||
OP_COMPILER_CACHE_MODE}; | |||||
} // namespace ir_option | } // namespace ir_option | ||||
} // namespace ge | } // namespace ge | ||||
@@ -116,9 +116,9 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GNode { | |||||
bool HasAttr(const ge::AscendString &name); | bool HasAttr(const ge::AscendString &name); | ||||
graphStatus GetSubgraph(uint32_t index, GraphPtr graph) const; | |||||
graphStatus GetSubgraph(uint32_t index, GraphPtr &graph) const; | |||||
graphStatus GetALLSubgraphs(std::vector<GraphPtr> graph_list) const; | |||||
graphStatus GetALLSubgraphs(std::vector<GraphPtr> &graph_list) const; | |||||
private: | private: | ||||
std::shared_ptr<NodeImpl> impl_; | std::shared_ptr<NodeImpl> impl_; | ||||
@@ -1,101 +1,101 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/** | |||||
* @file hccl_types.h | |||||
* @brief HCCL data type definition | |||||
* | |||||
*/ | |||||
#ifndef HCCL_TYPES_H_ | |||||
#define HCCL_TYPES_H_ | |||||
#include <stdint.h> | |||||
#ifdef __cplusplus | |||||
extern "C" { | |||||
#endif // __cplusplus | |||||
/** | |||||
* @brief HCCL functions return value definition | |||||
*/ | |||||
typedef enum { | |||||
HCCL_SUCCESS = 0, /**< success */ | |||||
HCCL_E_PARA = 1, /**< parameter error */ | |||||
HCCL_E_PTR = 2, /**< empty pointer */ | |||||
HCCL_E_MEMORY = 3, /**< memory error */ | |||||
HCCL_E_INTERNAL = 4, /**< internal error */ | |||||
HCCL_E_NOT_SUPPORT = 5, /**< not support feature */ | |||||
HCCL_E_NOT_FOUND = 6, /**< not found specific resource */ | |||||
HCCL_E_UNAVAIL = 7, /**< resource unavailable */ | |||||
HCCL_E_SYSCALL = 8, /**< call system interface error */ | |||||
HCCL_E_TIMEOUT = 9, /**< timeout */ | |||||
HCCL_E_OPEN_FILE_FAILURE = 10, /**< open file fail */ | |||||
HCCL_E_TCP_CONNECT = 11, /**< tcp connect fail */ | |||||
HCCL_E_ROCE_CONNECT = 12, /**< roce connect fail */ | |||||
HCCL_E_TCP_TRANSFER = 13, /**< tcp transfer fail */ | |||||
HCCL_E_ROCE_TRANSFER = 14, /**< roce transfer fail */ | |||||
HCCL_E_RUNTIME = 15, /**< call runtime api fail */ | |||||
HCCL_E_DRV = 16, /**< call driver api fail */ | |||||
HCCL_E_PROFILING = 17, /**< call profiling api fail */ | |||||
HCCL_E_CCE = 18, /**< call cce api fail */ | |||||
HCCL_E_NETWORK = 19, /**< call network api fail */ | |||||
HCCL_E_RESERVED /**< reserved */ | |||||
} HcclResult; | |||||
/** | |||||
* @brief handle to HCCL communicator | |||||
*/ | |||||
typedef void *HcclComm; | |||||
/** | |||||
* @brief HCCL Reduction opperation | |||||
*/ | |||||
typedef enum { | |||||
HCCL_REDUCE_SUM = 0, /**< sum */ | |||||
HCCL_REDUCE_PROD = 1, /**< prod */ | |||||
HCCL_REDUCE_MAX = 2, /**< max */ | |||||
HCCL_REDUCE_MIN = 3, /**< min */ | |||||
HCCL_REDUCE_RESERVED /**< reserved */ | |||||
} HcclReduceOp; | |||||
/** | |||||
* @brief HCCL data type | |||||
*/ | |||||
typedef enum { | |||||
HCCL_DATA_TYPE_INT8 = 0, /**< int8 */ | |||||
HCCL_DATA_TYPE_INT16 = 1, /**< int16 */ | |||||
HCCL_DATA_TYPE_INT32 = 2, /**< int32 */ | |||||
HCCL_DATA_TYPE_FP16 = 3, /**< fp16 */ | |||||
HCCL_DATA_TYPE_FP32 = 4, /**< fp32 */ | |||||
HCCL_DATA_TYPE_INT64 = 5, /**< int64 */ | |||||
HCCL_DATA_TYPE_UINT64 = 6, /**< uint64 */ | |||||
HCCL_DATA_TYPE_RESERVED /**< reserved */ | |||||
} HcclDataType; | |||||
const uint32_t HCCL_ROOT_INFO_BYTES = 4108; // 4108: root info length | |||||
/** | |||||
* @brief HCCL root info | |||||
*/ | |||||
typedef struct HcclRootInfoDef { | |||||
char internal[HCCL_ROOT_INFO_BYTES]; | |||||
} HcclRootInfo; | |||||
#ifdef __cplusplus | |||||
} | |||||
#endif // __cplusplus | |||||
#endif // HCCL_TYPES_H_ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/** | |||||
* @file hccl_types.h | |||||
* @brief HCCL data type definition | |||||
* | |||||
*/ | |||||
#ifndef HCCL_TYPES_H_ | |||||
#define HCCL_TYPES_H_ | |||||
#include <stdint.h> | |||||
#ifdef __cplusplus | |||||
extern "C" { | |||||
#endif // __cplusplus | |||||
/** | |||||
* @brief HCCL functions return value definition | |||||
*/ | |||||
typedef enum { | |||||
HCCL_SUCCESS = 0, /**< success */ | |||||
HCCL_E_PARA = 1, /**< parameter error */ | |||||
HCCL_E_PTR = 2, /**< empty pointer */ | |||||
HCCL_E_MEMORY = 3, /**< memory error */ | |||||
HCCL_E_INTERNAL = 4, /**< internal error */ | |||||
HCCL_E_NOT_SUPPORT = 5, /**< not support feature */ | |||||
HCCL_E_NOT_FOUND = 6, /**< not found specific resource */ | |||||
HCCL_E_UNAVAIL = 7, /**< resource unavailable */ | |||||
HCCL_E_SYSCALL = 8, /**< call system interface error */ | |||||
HCCL_E_TIMEOUT = 9, /**< timeout */ | |||||
HCCL_E_OPEN_FILE_FAILURE = 10, /**< open file fail */ | |||||
HCCL_E_TCP_CONNECT = 11, /**< tcp connect fail */ | |||||
HCCL_E_ROCE_CONNECT = 12, /**< roce connect fail */ | |||||
HCCL_E_TCP_TRANSFER = 13, /**< tcp transfer fail */ | |||||
HCCL_E_ROCE_TRANSFER = 14, /**< roce transfer fail */ | |||||
HCCL_E_RUNTIME = 15, /**< call runtime api fail */ | |||||
HCCL_E_DRV = 16, /**< call driver api fail */ | |||||
HCCL_E_PROFILING = 17, /**< call profiling api fail */ | |||||
HCCL_E_CCE = 18, /**< call cce api fail */ | |||||
HCCL_E_NETWORK = 19, /**< call network api fail */ | |||||
HCCL_E_RESERVED /**< reserved */ | |||||
} HcclResult; | |||||
/** | |||||
* @brief handle to HCCL communicator | |||||
*/ | |||||
typedef void *HcclComm; | |||||
/** | |||||
* @brief HCCL Reduction opperation | |||||
*/ | |||||
typedef enum { | |||||
HCCL_REDUCE_SUM = 0, /**< sum */ | |||||
HCCL_REDUCE_PROD = 1, /**< prod */ | |||||
HCCL_REDUCE_MAX = 2, /**< max */ | |||||
HCCL_REDUCE_MIN = 3, /**< min */ | |||||
HCCL_REDUCE_RESERVED /**< reserved */ | |||||
} HcclReduceOp; | |||||
/** | |||||
* @brief HCCL data type | |||||
*/ | |||||
typedef enum { | |||||
HCCL_DATA_TYPE_INT8 = 0, /**< int8 */ | |||||
HCCL_DATA_TYPE_INT16 = 1, /**< int16 */ | |||||
HCCL_DATA_TYPE_INT32 = 2, /**< int32 */ | |||||
HCCL_DATA_TYPE_FP16 = 3, /**< fp16 */ | |||||
HCCL_DATA_TYPE_FP32 = 4, /**< fp32 */ | |||||
HCCL_DATA_TYPE_INT64 = 5, /**< int64 */ | |||||
HCCL_DATA_TYPE_UINT64 = 6, /**< uint64 */ | |||||
HCCL_DATA_TYPE_RESERVED /**< reserved */ | |||||
} HcclDataType; | |||||
const uint32_t HCCL_ROOT_INFO_BYTES = 4108; // 4108: root info length | |||||
/** | |||||
* @brief HCCL root info | |||||
*/ | |||||
typedef struct HcclRootInfoDef { | |||||
char internal[HCCL_ROOT_INFO_BYTES]; | |||||
} HcclRootInfo; | |||||
#ifdef __cplusplus | |||||
} | |||||
#endif // __cplusplus | |||||
#endif // HCCL_TYPES_H_ |
@@ -449,6 +449,7 @@ REGISTER_OPTYPE_DECLARE(MEMCPYASYNC, "MemcpyAsync"); | |||||
REGISTER_OPTYPE_DECLARE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); | REGISTER_OPTYPE_DECLARE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); | ||||
REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); | REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); | ||||
REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | ||||
REGISTER_OPTYPE_DECLARE(MODELEXIT, "ModelExit"); | |||||
REGISTER_OPTYPE_DECLARE(SEND, "Send"); | REGISTER_OPTYPE_DECLARE(SEND, "Send"); | ||||
REGISTER_OPTYPE_DECLARE(RECV, "Recv"); | REGISTER_OPTYPE_DECLARE(RECV, "Recv"); | ||||
REGISTER_OPTYPE_DECLARE(ENDOFSEQUENCE, "EndOfSequence"); | REGISTER_OPTYPE_DECLARE(ENDOFSEQUENCE, "EndOfSequence"); | ||||
@@ -100,6 +100,8 @@ struct OmgContext { | |||||
std::vector<std::string> net_out_nodes; | std::vector<std::string> net_out_nodes; | ||||
// net out nodes top names(only caffe has top) | // net out nodes top names(only caffe has top) | ||||
std::vector<std::string> out_top_names; | std::vector<std::string> out_top_names; | ||||
// net data nodes top names(only caffe has top) | |||||
std::vector<std::string> data_top_names; | |||||
// preferential format used by the entire network | // preferential format used by the entire network | ||||
domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; | domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; | ||||
domi::FrameworkType type = domi::FRAMEWORK_RESERVED; | domi::FrameworkType type = domi::FRAMEWORK_RESERVED; | ||||
@@ -187,6 +187,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RTS_LABEL_NODE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_OUTPUT_DIMS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_OUTPUT_DIMS; | ||||
@@ -778,8 +779,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; | ||||
@@ -95,6 +95,7 @@ class Node : public std::enable_shared_from_this<Node> { | |||||
ComputeGraphPtr GetOwnerComputeGraph() const; | ComputeGraphPtr GetOwnerComputeGraph() const; | ||||
graphStatus SetOwnerComputeGraph(const ComputeGraphPtr &graph); | graphStatus SetOwnerComputeGraph(const ComputeGraphPtr &graph); | ||||
graphStatus SetAnyOwnerComputeGraph(const ComputeGraphPtr &graph); | |||||
Vistor<InDataAnchorPtr> GetAllInDataAnchors() const; | Vistor<InDataAnchorPtr> GetAllInDataAnchors() const; | ||||
Vistor<OutDataAnchorPtr> GetAllOutDataAnchors() const; | Vistor<OutDataAnchorPtr> GetAllOutDataAnchors() const; | ||||
@@ -141,6 +141,8 @@ class GraphUtils { | |||||
static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph); | static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph); | ||||
static GraphPtr CreateGraphPtrFromComputeGraph(const ComputeGraphPtr compute_graph); | |||||
static graphStatus RecoverGraphOperators(const Graph &graph); | static graphStatus RecoverGraphOperators(const Graph &graph); | ||||
static ComputeGraphPtr CreateGraphFromOperator(const string &name, const std::vector<Operator> &inputs); | static ComputeGraphPtr CreateGraphFromOperator(const string &name, const std::vector<Operator> &inputs); | ||||
@@ -157,6 +157,7 @@ const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; | |||||
const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; | const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; | ||||
const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; | const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; | ||||
const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; | const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; | ||||
const std::string ATTR_NAME_RTS_LABEL_NODE = "_rts_label_node"; | |||||
const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL = "_continuous_stream_label"; | const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL = "_continuous_stream_label"; | ||||
const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; | const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; | ||||
const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID = "rtswitch_event_id"; | const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID = "rtswitch_event_id"; | ||||
@@ -25,6 +25,7 @@ | |||||
#include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
#include <graph/utils/graph_utils.h> | #include <graph/utils/graph_utils.h> | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/debug/ge_op_types.h" | |||||
#include "utils/node_utils.h" | #include "utils/node_utils.h" | ||||
#include "utils/op_desc_utils.h" | #include "utils/op_desc_utils.h" | ||||
@@ -264,20 +265,34 @@ graphStatus GNode::GetInputConstData(const int32_t index, Tensor &data) const { | |||||
} | } | ||||
NodePtr input_data_node = NodeUtils::GetInDataNodeByIndex(*node_ptr, index); | NodePtr input_data_node = NodeUtils::GetInDataNodeByIndex(*node_ptr, index); | ||||
bool is_const = NodeUtils::IsConst(*input_data_node); | |||||
if (!is_const) { | |||||
GELOGE(GRAPH_NODE_WITHOUT_CONST_INPUT, "Node[%s] has no const input.", node_ptr->GetName().c_str()); | |||||
return GRAPH_NODE_WITHOUT_CONST_INPUT; | |||||
} | |||||
Operator const_op = OpDescUtils::CreateOperatorFromNode(input_data_node); | |||||
if (const_op.GetAttr(ATTR_NAME_WEIGHTS, data) != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Input data node[%s] of node[%s] get data failed.", input_data_node->GetName().c_str(), | |||||
node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
GE_CHECK_NOTNULL(input_data_node); | |||||
string op_type = input_data_node->GetType(); | |||||
if (op_type == CONSTANT || op_type == CONSTANTOP) { | |||||
Operator const_op = OpDescUtils::CreateOperatorFromNode(input_data_node); | |||||
if (const_op.GetAttr(ATTR_NAME_WEIGHTS, data) != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Input data node[%s] of node[%s] get data failed.", input_data_node->GetName().c_str(), | |||||
node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
return SUCCESS; | |||||
} else if (op_type == DATA) { | |||||
auto parent_node = NodeUtils::GetParentInput(input_data_node); | |||||
while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) { | |||||
parent_node = NodeUtils::GetParentInput(parent_node); | |||||
} | |||||
if ((parent_node != nullptr) && ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) { | |||||
Operator const_op = OpDescUtils::CreateOperatorFromNode(parent_node); | |||||
if (const_op.GetAttr(ATTR_NAME_WEIGHTS, data) != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Input data node[%s] of node[%s] get data failed.", parent_node->GetName().c_str(), | |||||
node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
} | } | ||||
return GRAPH_SUCCESS; | |||||
GELOGE(GRAPH_NODE_WITHOUT_CONST_INPUT, "Node[%s] has no const input.", node_ptr->GetName().c_str()); | |||||
return GRAPH_NODE_WITHOUT_CONST_INPUT; | |||||
} | } | ||||
graphStatus GNode::GetInputIndexByName(const ge::AscendString &name, int32_t &index) { | graphStatus GNode::GetInputIndexByName(const ge::AscendString &name, int32_t &index) { | ||||
@@ -793,7 +808,7 @@ bool GNode::HasAttr(const ge::AscendString &name) { | |||||
return true; | return true; | ||||
} | } | ||||
graphStatus GNode::GetSubgraph(uint32_t index, GraphPtr graph) const { | |||||
graphStatus GNode::GetSubgraph(uint32_t index, GraphPtr &graph) const { | |||||
if (impl_ == nullptr) { | if (impl_ == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "GetSubgraph: node impl is nullptr."); | GELOGE(GRAPH_FAILED, "GetSubgraph: node impl is nullptr."); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
@@ -807,20 +822,20 @@ graphStatus GNode::GetSubgraph(uint32_t index, GraphPtr graph) const { | |||||
ComputeGraphPtr compute_graph_ptr = NodeUtils::GetSubgraph(*node_ptr, index); | ComputeGraphPtr compute_graph_ptr = NodeUtils::GetSubgraph(*node_ptr, index); | ||||
if (compute_graph_ptr == nullptr) { | if (compute_graph_ptr == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "GetSubgraph: get subgraph[%u] failed form node[%s].", index, node_ptr->GetName().c_str()); | |||||
GELOGE(GRAPH_FAILED, "GetSubgraph: get subgraph[%u] failed from node[%s].", index, node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
Graph create_graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph_ptr); | |||||
graph = std::make_shared<Graph>(create_graph); | |||||
graph = GraphUtils::CreateGraphPtrFromComputeGraph(compute_graph_ptr); | |||||
if (graph == nullptr) { | if (graph == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "GetSubgraph: graph make shared failed form node[%s].", node_ptr->GetName().c_str()); | |||||
GELOGE(GRAPH_FAILED, "GetSubgraph: get subgraph[%u] failed from node[%s].", index, node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
graphStatus GNode::GetALLSubgraphs(std::vector<GraphPtr> graph_list) const { | |||||
graphStatus GNode::GetALLSubgraphs(std::vector<GraphPtr> &graph_list) const { | |||||
if (impl_ == nullptr) { | if (impl_ == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "GetALLSubgraphs: node impl is nullptr."); | GELOGE(GRAPH_FAILED, "GetALLSubgraphs: node impl is nullptr."); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
@@ -834,24 +849,27 @@ graphStatus GNode::GetALLSubgraphs(std::vector<GraphPtr> graph_list) const { | |||||
std::vector<ComputeGraphPtr> sub_graphs = NodeUtils::GetAllSubgraphs(*node_ptr); | std::vector<ComputeGraphPtr> sub_graphs = NodeUtils::GetAllSubgraphs(*node_ptr); | ||||
if (sub_graphs.empty()) { | if (sub_graphs.empty()) { | ||||
GELOGE(GRAPH_FAILED, "GetALLSubgraphs: get all subgraphs failed form node[%s].", node_ptr->GetName().c_str()); | |||||
GELOGE(GRAPH_FAILED, "GetALLSubgraphs: get all subgraphs failed from node[%s].", node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
for (auto &sub_graph : sub_graphs) { | for (auto &sub_graph : sub_graphs) { | ||||
if (sub_graph == nullptr) { | if (sub_graph == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "Get subgraph failed form node[%s].", node_ptr->GetName().c_str()); | |||||
GELOGE(GRAPH_FAILED, "Get subgraph failed from node[%s].", node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
Graph create_graph = GraphUtils::CreateGraphFromComputeGraph(sub_graph); | |||||
GraphPtr graph = std::make_shared<Graph>(create_graph); | |||||
GraphPtr graph = GraphUtils::CreateGraphPtrFromComputeGraph(sub_graph); | |||||
if (graph == nullptr) { | if (graph == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "Subgraph make shared failed form node[%s].", node_ptr->GetName().c_str()); | |||||
GELOGE(GRAPH_FAILED, "Subgraph create compute graph failed from node[%s].", node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
graph_list.emplace_back(graph); | graph_list.emplace_back(graph); | ||||
} | } | ||||
if (graph_list.empty()) { | |||||
GELOGW("Node[%s] has no subgraph.", node_ptr->GetName().c_str()); | |||||
} | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -24,6 +24,7 @@ | |||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
#include "graph/utils/node_adapter.h" | #include "graph/utils/node_adapter.h" | ||||
#include "graph/utils/node_utils.h" | |||||
using std::map; | using std::map; | ||||
using std::pair; | using std::pair; | ||||
@@ -246,6 +247,53 @@ class GraphImpl { | |||||
ComputeGraphPtr GetComputeGraph() const { return compute_graph_; } | ComputeGraphPtr GetComputeGraph() const { return compute_graph_; } | ||||
graphStatus RemoveEdge(NodePtr &src_node_ptr, const int32_t src_port_index, NodePtr &dst_node_ptr, | |||||
const int32_t dst_port_index) { | |||||
GE_CHECK_NOTNULL(src_node_ptr); | |||||
GE_CHECK_NOTNULL(dst_node_ptr); | |||||
graphStatus res = GRAPH_FAILED; | |||||
if ((src_port_index == -1) && (dst_port_index == -1)) { | |||||
if (src_node_ptr->GetOutControlAnchor() == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "RemoveEdge: src node[%s] out control anchor is null.", src_node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
res = GraphUtils::RemoveEdge(src_node_ptr->GetOutControlAnchor(), dst_node_ptr->GetInControlAnchor()); | |||||
if (res != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "RemoveEdge: remove control edge between [%s] and [%s]failed.", | |||||
src_node_ptr->GetName().c_str(), dst_node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
if (src_node_ptr->GetOutDataAnchor(src_port_index) == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "RemoveEdge: src node[%s] out data anchor[%d] is null.", src_node_ptr->GetName().c_str(), | |||||
src_port_index); | |||||
return GRAPH_FAILED; | |||||
} | |||||
if (src_port_index != -1 && dst_port_index == -1) { | |||||
res = GraphUtils::RemoveEdge(src_node_ptr->GetOutDataAnchor(src_port_index), dst_node_ptr->GetInControlAnchor()); | |||||
if (res != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "RemoveEdge: remove data-control edge between [%s] and [%s]failed.", | |||||
src_node_ptr->GetName().c_str(), dst_node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
res = GraphUtils::RemoveEdge(src_node_ptr->GetOutDataAnchor(src_port_index), | |||||
dst_node_ptr->GetInDataAnchor(dst_port_index)); | |||||
if (res != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "RemoveEdge: remove data edge between [%s] and [%s] failed.", | |||||
src_node_ptr->GetName().c_str(), dst_node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
private: | private: | ||||
std::string name_; | std::string name_; | ||||
std::string output_name_; | std::string output_name_; | ||||
@@ -392,17 +440,25 @@ graphStatus Graph::RemoveNode(GNode &node) { | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
if (node_ptr->GetOwnerComputeGraph() == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "RemoveNode: node[%s] is invalid.", node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph(); | ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph(); | ||||
if (compute_graph_ptr == nullptr) { | if (compute_graph_ptr == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "RemoveNde: compute graph ptr is nullptr."); | GELOGE(GRAPH_FAILED, "RemoveNde: compute graph ptr is nullptr."); | ||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
if (compute_graph_ptr->RemoveNode(node_ptr) != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "RemoveNde: remove node failed."); | |||||
ge::NodeUtils::UnlinkAll(*node_ptr); | |||||
if (GraphUtils::RemoveNodeWithoutRelink(compute_graph_ptr, node_ptr) != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "RemoveNode: remove node[%s] failed.", node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
node_ptr->SetAnyOwnerComputeGraph(nullptr); | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
@@ -430,31 +486,21 @@ graphStatus Graph::RemoveEdge(GNode &src_node, const int32_t src_port_index, GNo | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
graphStatus res = GRAPH_FAILED; | |||||
if ((src_port_index == -1) && (dst_port_index == -1)) { | |||||
res = GraphUtils::RemoveEdge(src_node_ptr->GetOutControlAnchor(), dst_node_ptr->GetInControlAnchor()); | |||||
if (res != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "RemoveEdge: remove control edge failed."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
if (src_node_ptr->GetOwnerComputeGraph() == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "RemoveEdge: src node[%s] is invalid.", src_node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | } | ||||
if (src_port_index != -1 && dst_port_index == -1) { | |||||
res = GraphUtils::RemoveEdge(src_node_ptr->GetOutDataAnchor(src_port_index), dst_node_ptr->GetInControlAnchor()); | |||||
if (res != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "RemoveEdge: remove data-control edge failed."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
if (dst_node_ptr->GetOwnerComputeGraph() == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "RemoveEdge: dst node[%s] is invalid.", dst_node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | } | ||||
res = GraphUtils::RemoveEdge(src_node_ptr->GetOutDataAnchor(src_port_index), | |||||
dst_node_ptr->GetInDataAnchor(dst_port_index)); | |||||
if (res != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "RemoveEdge: remove data edge failed."); | |||||
if (impl_->RemoveEdge(src_node_ptr, src_port_index, dst_node_ptr, dst_port_index) != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "RemoveEdge: remove edge failed."); | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
@@ -501,6 +547,16 @@ graphStatus Graph::AddDataEdge(GNode &src_node, const int32_t src_port_index, GN | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
if (src_node_ptr->GetOwnerComputeGraph() == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "AddDataEdge: src node[%s] is invalid.", src_node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
if (dst_node_ptr->GetOwnerComputeGraph() == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "AddDataEdge: dst node[%s] is invalid.", dst_node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
graphStatus res = | graphStatus res = | ||||
GraphUtils::AddEdge(src_node_ptr->GetOutDataAnchor(src_port_index), dst_node_ptr->GetInDataAnchor(dst_port_index)); | GraphUtils::AddEdge(src_node_ptr->GetOutDataAnchor(src_port_index), dst_node_ptr->GetInDataAnchor(dst_port_index)); | ||||
if (res != GRAPH_SUCCESS) { | if (res != GRAPH_SUCCESS) { | ||||
@@ -529,6 +585,16 @@ graphStatus Graph::AddControlEdge(GNode &src_node, GNode &dst_node) { | |||||
return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
} | } | ||||
if (src_node_ptr->GetOwnerComputeGraph() == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "AddControlEdge: src node[%s] is invalid.", src_node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
if (dst_node_ptr->GetOwnerComputeGraph() == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "AddControlEdge: dst node[%s] is invalid.", dst_node_ptr->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
graphStatus res = GraphUtils::AddEdge(src_node_ptr->GetOutControlAnchor(), dst_node_ptr->GetInControlAnchor()); | graphStatus res = GraphUtils::AddEdge(src_node_ptr->GetOutControlAnchor(), dst_node_ptr->GetInControlAnchor()); | ||||
if (res != GRAPH_SUCCESS) { | if (res != GRAPH_SUCCESS) { | ||||
GELOGE(GRAPH_FAILED, "AddControlEdge: Add control edge failed."); | GELOGE(GRAPH_FAILED, "AddControlEdge: Add control edge failed."); | ||||
@@ -558,10 +624,9 @@ GraphPtr Graph::ConstructFromInputs(const std::vector<Operator> &inputs, const g | |||||
} | } | ||||
compute_graph->SetInputSize(static_cast<uint32_t>(inputs.size())); | compute_graph->SetInputSize(static_cast<uint32_t>(inputs.size())); | ||||
Graph graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||||
GraphPtr graph_ptr = std::make_shared<Graph>(graph); | |||||
GraphPtr graph_ptr = GraphUtils::CreateGraphPtrFromComputeGraph(compute_graph); | |||||
if (graph_ptr == nullptr) { | if (graph_ptr == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "ConstructFromInputs: graph make shared failed."); | |||||
GELOGE(GRAPH_FAILED, "ConstructFromInputs: create graph from compute graph failed."); | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
@@ -604,6 +669,20 @@ GraphUtils::CreateGraphFromComputeGraph(const ge::ComputeGraphPtr compute_graph) | |||||
return graph; | return graph; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GraphPtr | |||||
GraphUtils::CreateGraphPtrFromComputeGraph(const ge::ComputeGraphPtr compute_graph) { | |||||
GE_CHK_BOOL_EXEC_NOLOG(compute_graph != nullptr, return nullptr); | |||||
auto name = compute_graph->GetName(); | |||||
auto graph = ComGraphMakeShared<Graph>(name); | |||||
GE_CHK_BOOL_EXEC_NOLOG(graph != nullptr, return nullptr); | |||||
GE_CHK_BOOL_EXEC_NOLOG(graph->impl_ != nullptr, return nullptr); | |||||
graph->impl_->compute_graph_ = compute_graph; | |||||
return graph; | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RecoverGraphOperators(const Graph &graph) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RecoverGraphOperators(const Graph &graph) { | ||||
GE_CHECK_NOTNULL(graph.impl_); | GE_CHECK_NOTNULL(graph.impl_); | ||||
GE_CHECK_NOTNULL(graph.impl_->compute_graph_); | GE_CHECK_NOTNULL(graph.impl_->compute_graph_); | ||||
@@ -393,6 +393,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::SetOwnerCompute | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::SetAnyOwnerComputeGraph(const ComputeGraphPtr &graph) { | |||||
owner_graph_ = graph; | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<InDataAnchorPtr> Node::GetAllInDataAnchors() const { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<InDataAnchorPtr> Node::GetAllInDataAnchors() const { | ||||
return Vistor<InDataAnchorPtr>(shared_from_this(), in_data_anchors_); | return Vistor<InDataAnchorPtr>(shared_from_this(), in_data_anchors_); | ||||
} | } | ||||
@@ -292,6 +292,8 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format | |||||
graph_status = CalcElementCntByDims(dims, element_cnt); | graph_status = CalcElementCntByDims(dims, element_cnt); | ||||
break; | break; | ||||
default: | default: | ||||
ErrorManager::GetInstance().ATCReportErrMessage( | |||||
"E19012", {"function", "reason"}, {"CalcTensorElementCnt", "format[" + format_str + "] is not support"}); | |||||
GELOGE(GRAPH_FAILED, "unsupported format, format=%d(%s).", format, format_str.c_str()); | GELOGE(GRAPH_FAILED, "unsupported format, format=%d(%s).", format, format_str.c_str()); | ||||
graph_status = GRAPH_FAILED; | graph_status = GRAPH_FAILED; | ||||
break; | break; | ||||
@@ -16,6 +16,7 @@ | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
#include "common/util/error_manager/error_manager.h" | |||||
using domi::domiTensorFormat_t; | using domi::domiTensorFormat_t; | ||||
@@ -431,6 +432,9 @@ bool TypeUtils::GetDataTypeLength(ge::DataType data_type, uint32_t &length) { | |||||
length = it->second; | length = it->second; | ||||
return true; | return true; | ||||
} else { | } else { | ||||
ErrorManager::GetInstance().ATCReportErrMessage( | |||||
"E19012", {"function", "reason"}, | |||||
{"GetDataTypeLength", "data_type[" + std::to_string(data_type) + "] is not support"}); | |||||
GELOGE(GRAPH_FAILED, "data_type not support %d", data_type); | GELOGE(GRAPH_FAILED, "data_type not support %d", data_type); | ||||
return false; | return false; | ||||
} | } | ||||
@@ -96,6 +96,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" | "graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" | "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/model_exit_task_info.cc" | |||||
"graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/stream_active_task_info.cc" | "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | "graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | ||||
@@ -277,6 +278,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" | "graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" | "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/model_exit_task_info.cc" | |||||
"graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/stream_active_task_info.cc" | "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | "graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | ||||
@@ -398,6 +398,7 @@ REGISTER_OPTYPE_DEFINE(MEMCPYASYNC, "MemcpyAsync"); | |||||
REGISTER_OPTYPE_DEFINE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); | REGISTER_OPTYPE_DEFINE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); | ||||
REGISTER_OPTYPE_DEFINE(STREAMMERGE, "StreamMerge"); | REGISTER_OPTYPE_DEFINE(STREAMMERGE, "StreamMerge"); | ||||
REGISTER_OPTYPE_DEFINE(ENDGRAPH, "EndGraph"); | REGISTER_OPTYPE_DEFINE(ENDGRAPH, "EndGraph"); | ||||
REGISTER_OPTYPE_DEFINE(MODELEXIT, "ModelExit"); | |||||
REGISTER_OPTYPE_DEFINE(SEND, "Send"); | REGISTER_OPTYPE_DEFINE(SEND, "Send"); | ||||
REGISTER_OPTYPE_DEFINE(RECV, "Recv"); | REGISTER_OPTYPE_DEFINE(RECV, "Recv"); | ||||
REGISTER_OPTYPE_DEFINE(ENDOFSEQUENCE, "EndOfSequence"); | REGISTER_OPTYPE_DEFINE(ENDOFSEQUENCE, "EndOfSequence"); | ||||
@@ -1056,6 +1056,7 @@ ge::Status GeExecutor::ExecuteAsync(DynamicSingleOp *executor, const vector<GeTe | |||||
} | } | ||||
Status GeExecutor::ReleaseSingleOpResource(void *stream) { | Status GeExecutor::ReleaseSingleOpResource(void *stream) { | ||||
ModelManager::GetInstance()->ClearAicpuSo(); | |||||
return SingleOpManager::GetInstance().ReleaseResource(stream); | return SingleOpManager::GetInstance().ReleaseResource(stream); | ||||
} | } | ||||
@@ -48,6 +48,7 @@ local_ge_executor_src_files := \ | |||||
../graph/load/new_model_manager/task_info/stream_switch_task_info.cc \ | ../graph/load/new_model_manager/task_info/stream_switch_task_info.cc \ | ||||
../graph/load/new_model_manager/task_info/stream_switchn_task_info.cc \ | ../graph/load/new_model_manager/task_info/stream_switchn_task_info.cc \ | ||||
../graph/load/new_model_manager/task_info/end_graph_task_info.cc \ | ../graph/load/new_model_manager/task_info/end_graph_task_info.cc \ | ||||
../graph/load/new_model_manager/task_info/model_exit_task_info.cc \ | |||||
../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc \ | ../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc \ | ||||
../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc \ | ../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc \ | ||||
../opskernel_manager/ops_kernel_builder_manager.cc \ | ../opskernel_manager/ops_kernel_builder_manager.cc \ | ||||
@@ -109,6 +109,7 @@ OMG_HOST_SRC_FILES := \ | |||||
graph/passes/atomic_addr_clean_pass.cc \ | graph/passes/atomic_addr_clean_pass.cc \ | ||||
graph/passes/mark_same_addr_pass.cc \ | graph/passes/mark_same_addr_pass.cc \ | ||||
graph/passes/mark_graph_unknown_status_pass.cc \ | graph/passes/mark_graph_unknown_status_pass.cc \ | ||||
graph/passes/mark_agnostic_pass.cc \ | |||||
graph/common/omg_util.cc \ | graph/common/omg_util.cc \ | ||||
graph/common/bcast.cc \ | graph/common/bcast.cc \ | ||||
graph/common/local_context.cc \ | graph/common/local_context.cc \ | ||||
@@ -176,6 +177,7 @@ OMG_HOST_SRC_FILES := \ | |||||
graph/passes/cast_translate_pass.cc \ | graph/passes/cast_translate_pass.cc \ | ||||
graph/passes/prune_pass.cc \ | graph/passes/prune_pass.cc \ | ||||
graph/passes/merge_to_stream_merge_pass.cc \ | graph/passes/merge_to_stream_merge_pass.cc \ | ||||
graph/passes/merge_input_memcpy_pass.cc \ | |||||
graph/passes/switch_to_stream_switch_pass.cc \ | graph/passes/switch_to_stream_switch_pass.cc \ | ||||
graph/passes/attach_stream_label_pass.cc \ | graph/passes/attach_stream_label_pass.cc \ | ||||
graph/passes/multi_batch_pass.cc \ | graph/passes/multi_batch_pass.cc \ | ||||
@@ -247,6 +249,7 @@ OME_HOST_SRC_FILES := \ | |||||
graph/load/new_model_manager/task_info/stream_switch_task_info.cc \ | graph/load/new_model_manager/task_info/stream_switch_task_info.cc \ | ||||
graph/load/new_model_manager/task_info/stream_switchn_task_info.cc \ | graph/load/new_model_manager/task_info/stream_switchn_task_info.cc \ | ||||
graph/load/new_model_manager/task_info/end_graph_task_info.cc \ | graph/load/new_model_manager/task_info/end_graph_task_info.cc \ | ||||
graph/load/new_model_manager/task_info/model_exit_task_info.cc \ | |||||
graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc \ | graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc \ | ||||
graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc \ | graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc \ | ||||
single_op/task/op_task.cc \ | single_op/task/op_task.cc \ | ||||
@@ -61,6 +61,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
graph/load/new_model_manager/model_utils.cc \ | graph/load/new_model_manager/model_utils.cc \ | ||||
graph/load/new_model_manager/aipp_utils.cc \ | graph/load/new_model_manager/aipp_utils.cc \ | ||||
graph/load/new_model_manager/task_info/end_graph_task_info.cc \ | graph/load/new_model_manager/task_info/end_graph_task_info.cc \ | ||||
graph/load/new_model_manager/task_info/model_exit_task_info.cc \ | |||||
graph/load/new_model_manager/task_info/event_record_task_info.cc \ | graph/load/new_model_manager/task_info/event_record_task_info.cc \ | ||||
graph/load/new_model_manager/task_info/event_wait_task_info.cc \ | graph/load/new_model_manager/task_info/event_wait_task_info.cc \ | ||||
graph/load/new_model_manager/task_info/fusion_start_task_info.cc \ | graph/load/new_model_manager/task_info/fusion_start_task_info.cc \ | ||||
@@ -110,6 +111,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
graph/passes/atomic_addr_clean_pass.cc \ | graph/passes/atomic_addr_clean_pass.cc \ | ||||
graph/passes/mark_same_addr_pass.cc \ | graph/passes/mark_same_addr_pass.cc \ | ||||
graph/passes/mark_graph_unknown_status_pass.cc \ | graph/passes/mark_graph_unknown_status_pass.cc \ | ||||
graph/passes/mark_agnostic_pass.cc \ | |||||
graph/partition/dynamic_shape_partition.cc \ | graph/partition/dynamic_shape_partition.cc \ | ||||
graph/partition/stage_partition.cc \ | graph/partition/stage_partition.cc \ | ||||
graph/passes/base_pass.cc \ | graph/passes/base_pass.cc \ | ||||
@@ -210,6 +212,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
graph/passes/switch_data_edges_bypass.cc \ | graph/passes/switch_data_edges_bypass.cc \ | ||||
graph/passes/switch_logic_remove_pass.cc \ | graph/passes/switch_logic_remove_pass.cc \ | ||||
graph/passes/merge_to_stream_merge_pass.cc \ | graph/passes/merge_to_stream_merge_pass.cc \ | ||||
graph/passes/merge_input_memcpy_pass.cc \ | |||||
graph/passes/switch_to_stream_switch_pass.cc \ | graph/passes/switch_to_stream_switch_pass.cc \ | ||||
graph/passes/attach_stream_label_pass.cc \ | graph/passes/attach_stream_label_pass.cc \ | ||||
graph/passes/switch_dead_branch_elimination.cc \ | graph/passes/switch_dead_branch_elimination.cc \ | ||||
@@ -462,8 +462,7 @@ Status AllReduceParallelPass::Run(ComputeGraphPtr graph, const vector<SubgraphPt | |||||
set<NodePtr> all_reduce_succs; | set<NodePtr> all_reduce_succs; | ||||
for (const NodePtr &node : graph->GetDirectNode()) { | for (const NodePtr &node : graph->GetDirectNode()) { | ||||
if ((node->GetType() != HCOMALLREDUCE && node->GetType() != HVDCALLBACKALLREDUCE) || | |||||
node->GetInDataNodes().size() <= 1) { | |||||
if (!IsHcomNode(node->GetType()) || node->GetInDataNodes().size() <= 1) { | |||||
continue; | continue; | ||||
} | } | ||||
@@ -507,14 +506,20 @@ Status AllReduceParallelPass::Run(ComputeGraphPtr graph, const vector<SubgraphPt | |||||
old_stream_to_new.emplace(old_stream, new_stream); | old_stream_to_new.emplace(old_stream, new_stream); | ||||
} | } | ||||
GELOGI("Stream of node %s has been updated from %ld to %ld.", node->GetName().c_str(), old_stream, new_stream); | |||||
node->GetOpDesc()->SetStreamId(new_stream); | |||||
if (!IsHcomNode(node->GetType())) { | |||||
GELOGI("Stream of node %s has been updated from %ld to %ld.", node->GetName().c_str(), old_stream, new_stream); | |||||
node->GetOpDesc()->SetStreamId(new_stream); | |||||
} | |||||
} | } | ||||
} | } | ||||
return !all_reduce_succs.empty() ? SUCCESS : NOT_CHANGED; | return !all_reduce_succs.empty() ? SUCCESS : NOT_CHANGED; | ||||
} | } | ||||
bool AllReduceParallelPass::IsHcomNode(const std::string &node_type) { | |||||
return (node_type == HCOMALLREDUCE || node_type == HVDCALLBACKALLREDUCE); | |||||
} | |||||
LogicalStreamAllocator::LogicalStreamAllocator(const map<string, SchedulerConf> &scheduler_confs, | LogicalStreamAllocator::LogicalStreamAllocator(const map<string, SchedulerConf> &scheduler_confs, | ||||
const map<string, int> &max_parallel_num) | const map<string, int> &max_parallel_num) | ||||
: scheduler_confs_(scheduler_confs), max_parallel_num_(max_parallel_num) {} | : scheduler_confs_(scheduler_confs), max_parallel_num_(max_parallel_num) {} | ||||
@@ -166,6 +166,9 @@ class AllReduceParallelPass : public LogicalStreamPass { | |||||
public: | public: | ||||
STREAM_PASS_DEFAULT_FUNC(AllReduceParallelPass); | STREAM_PASS_DEFAULT_FUNC(AllReduceParallelPass); | ||||
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | ||||
private: | |||||
bool IsHcomNode(const std::string &node_type); | |||||
}; | }; | ||||
// Assign logical streams which is not limited by the number of tasks. | // Assign logical streams which is not limited by the number of tasks. | ||||
@@ -870,9 +870,11 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, | |||||
string ge_disable_reuse_mem_env = "0"; | string ge_disable_reuse_mem_env = "0"; | ||||
(void)ge::GetContext().GetOption(OPTION_EXEC_DISABLE_REUSED_MEMORY, ge_disable_reuse_mem_env); | (void)ge::GetContext().GetOption(OPTION_EXEC_DISABLE_REUSED_MEMORY, ge_disable_reuse_mem_env); | ||||
if (ge_disable_reuse_mem_env != "1") { | if (ge_disable_reuse_mem_env != "1") { | ||||
bool reuse_mem_flag = !((workspace_reuse_flag.size() > out_index) && !workspace_reuse_flag[out_index]); | |||||
bool reuse_mem_flag = (mem_type == kOutput) | |||||
? IsPreReuse(n, out_index) | |||||
: !((workspace_reuse_flag.size() > out_index) && !workspace_reuse_flag[out_index]); | |||||
is_reuse_memory = !node_op_desc->HasAttr(kL2FusionDynamicConvergeOp) && !node_op_desc->HasAttr(kOpNoReuseMem) && | is_reuse_memory = !node_op_desc->HasAttr(kL2FusionDynamicConvergeOp) && !node_op_desc->HasAttr(kOpNoReuseMem) && | ||||
reuse_mem_flag && is_op_reuse_mem && (IsPreReuse(n, out_index)); | |||||
reuse_mem_flag && is_op_reuse_mem; | |||||
auto stream_id = node_op_desc->GetStreamId(); | auto stream_id = node_op_desc->GetStreamId(); | ||||
if (is_reuse_memory && !continuous && !reusable_blocks_[memory_type].empty()) { | if (is_reuse_memory && !continuous && !reusable_blocks_[memory_type].empty()) { | ||||
for (auto it = reusable_blocks_[memory_type][stream_id].begin(); | for (auto it = reusable_blocks_[memory_type][stream_id].begin(); | ||||
@@ -464,6 +464,8 @@ Status DavinciModel::DoTaskSink() { | |||||
GE_CHK_STATUS_RET(InitTaskInfo(*model_task_def.get()), "InitTaskInfo failed."); | GE_CHK_STATUS_RET(InitTaskInfo(*model_task_def.get()), "InitTaskInfo failed."); | ||||
GE_CHK_STATUS_RET(ModelManager::GetInstance()->LaunchCustAicpuSo(), "Launch cust aicpu so failed."); | |||||
GE_CHK_STATUS_RET(InitEntryTask(), "InitEntryTask failed."); | GE_CHK_STATUS_RET(InitEntryTask(), "InitEntryTask failed."); | ||||
GE_CHK_STATUS_RET(DistributeTask(), "Distribute failed."); | GE_CHK_STATUS_RET(DistributeTask(), "Distribute failed."); | ||||
@@ -2051,6 +2053,7 @@ Status DavinciModel::SinkModelProfile() { | |||||
std::set<uint32_t> task_id_set; | std::set<uint32_t> task_id_set; | ||||
for (int32_t i = 0; i < task_num; i++) { | for (int32_t i = 0; i < task_num; i++) { | ||||
auto task = task_list_[i]; | auto task = task_list_[i]; | ||||
GE_CHECK_NOTNULL(task); | |||||
auto fusion_op_info = task->GetFusionOpInfo(); | auto fusion_op_info = task->GetFusionOpInfo(); | ||||
// when type is RT_MODEL_TASK_KERNEL, ctx is not null | // when type is RT_MODEL_TASK_KERNEL, ctx is not null | ||||
if (fusion_op_info != nullptr) { | if (fusion_op_info != nullptr) { | ||||
@@ -2077,6 +2080,7 @@ Status DavinciModel::SinkModelProfile() { | |||||
using Range = std::pair<CIT, CIT>; | using Range = std::pair<CIT, CIT>; | ||||
for (int32_t i = 0; i < task_num; i++) { | for (int32_t i = 0; i < task_num; i++) { | ||||
auto task = task_list_[i]; | auto task = task_list_[i]; | ||||
GE_CHECK_NOTNULL(task); | |||||
auto fusion_op_info = task->GetFusionOpInfo(); | auto fusion_op_info = task->GetFusionOpInfo(); | ||||
if (fusion_op_info != nullptr && fusion_op_info->original_op_names.size() > 0) { | if (fusion_op_info != nullptr && fusion_op_info->original_op_names.size() > 0) { | ||||
uint32_t task_id = task->GetTaskID(); | uint32_t task_id = task->GetTaskID(); | ||||
@@ -43,13 +43,18 @@ const std::string kCmdTypeProfInit = "prof_init"; | |||||
const std::string kCmdTypeProfFinalize = "prof_finalize"; | const std::string kCmdTypeProfFinalize = "prof_finalize"; | ||||
const std::string kCmdTypeProfStart = "prof_start"; | const std::string kCmdTypeProfStart = "prof_start"; | ||||
const std::string kCmdTypeProfStop = "prof_stop"; | const std::string kCmdTypeProfStop = "prof_stop"; | ||||
const char *const kLoadOpFromBuf = "loadOpFromBuf"; | |||||
const char *const kBatchLoadBuf = "batchLoadsoFrombuf"; | |||||
const char *const kDeleteCustOp = "deleteCustOp"; | |||||
struct CustAicpuSoBuf { | struct CustAicpuSoBuf { | ||||
uint64_t kernelSoBuf; | uint64_t kernelSoBuf; | ||||
uint32_t kernelSoBufLen; | uint32_t kernelSoBufLen; | ||||
uint64_t kernelSoName; | uint64_t kernelSoName; | ||||
uint32_t kernelSoNameLen; | uint32_t kernelSoNameLen; | ||||
} __attribute__((packed)); | } __attribute__((packed)); | ||||
struct BatchLoadOpFromBufArgs { | |||||
uint32_t soNum; | |||||
uint64_t args; | |||||
} __attribute__((packed)); | |||||
} // namespace | } // namespace | ||||
DumpProperties ModelManager::dump_properties_; | DumpProperties ModelManager::dump_properties_; | ||||
@@ -236,6 +241,7 @@ ModelManager::~ModelManager() { | |||||
std::lock_guard<std::mutex> lock(map_mutex_); | std::lock_guard<std::mutex> lock(map_mutex_); | ||||
model_map_.clear(); | model_map_.clear(); | ||||
model_aicpu_kernel_.clear(); | model_aicpu_kernel_.clear(); | ||||
cust_aicpu_so_.clear(); | |||||
GE_IF_BOOL_EXEC(device_count > 0, GE_CHK_RT(rtDeviceReset(0))); | GE_IF_BOOL_EXEC(device_count > 0, GE_CHK_RT(rtDeviceReset(0))); | ||||
} | } | ||||
@@ -399,7 +405,6 @@ Status ModelManager::Unload(uint32_t model_id) { | |||||
} | } | ||||
std::lock_guard<std::mutex> lock(exeception_infos_mutex_); | std::lock_guard<std::mutex> lock(exeception_infos_mutex_); | ||||
exception_infos_.clear(); | exception_infos_.clear(); | ||||
cust_aicpu_so_.clear(); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -1096,64 +1101,149 @@ Status ModelManager::CreateAicpuSession(uint64_t session_id) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status ModelManager::LoadCustAicpuSo(const OpDescPtr op_desc, string so_name) { | |||||
Status ModelManager::LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_name) { | |||||
GELOGI("LoadCustAicpuSo in, op name %s, so name %s", op_desc->GetName().c_str(), so_name.c_str()); | |||||
std::lock_guard<std::mutex> lock(cust_aicpu_mutex_); | std::lock_guard<std::mutex> lock(cust_aicpu_mutex_); | ||||
auto it = cust_aicpu_so_.find(so_name); | |||||
CustAICPUKernelPtr aicpu_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_CUSTAICPU_KERNEL, CustAICPUKernelPtr()); | |||||
if (aicpu_kernel == nullptr) { | |||||
GELOGE(INTERNAL_ERROR, "cust aicpu op %s can't find kernel!", op_desc->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
// get current context | |||||
rtContext_t rt_cur_ctx = nullptr; | |||||
auto rt_error = rtCtxGetCurrent(&rt_cur_ctx); | |||||
if (rt_error != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "get current context failed, runtime result is %d", static_cast<int>(rt_error)); | |||||
return RT_FAILED; | |||||
} | |||||
// use current context as resource key | |||||
uintptr_t resource_id = reinterpret_cast<uintptr_t>(rt_cur_ctx); | |||||
auto it = cust_aicpu_so_.find(resource_id); | |||||
if (it == cust_aicpu_so_.end()) { | if (it == cust_aicpu_so_.end()) { | ||||
GE_CHK_STATUS_RET(LaunchCustAicpuSo(op_desc, so_name), "LaunchCustAicpuSo failed. op name %s, so_name %s", | |||||
op_desc->GetName().c_str(), so_name.c_str()); | |||||
(void)cust_aicpu_so_.insert(so_name); | |||||
GELOGI("LaunchCustAicpuSo op name %s, so_name %s.", op_desc->GetName().c_str(), so_name.c_str()); | |||||
std::map<string, CustAICPUKernelPtr> new_so_name; | |||||
new_so_name.insert({so_name, aicpu_kernel}); | |||||
cust_aicpu_so_[resource_id] = new_so_name; | |||||
GELOGI("LoadCustAicpuSo new aicpu so resource id %lu", resource_id); | |||||
return SUCCESS; | |||||
} | |||||
auto it_so_name = it->second.find(so_name); | |||||
if (it_so_name == it->second.end()) { | |||||
it->second.insert({so_name, aicpu_kernel}); | |||||
GELOGI("LoadCustAicpuSo add aicpu so resource id %lu", resource_id); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status ModelManager::LaunchCustAicpuSo(const OpDescPtr op_desc, string so_name) { | |||||
CustAICPUKernelPtr aicpu_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_CUSTAICPU_KERNEL, CustAICPUKernelPtr()); | |||||
if (aicpu_kernel == nullptr) { | |||||
GELOGE(INTERNAL_ERROR, "cust aicpu op %s can't find kernel!", op_desc->GetName().c_str()); | |||||
return INTERNAL_ERROR; | |||||
Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { | |||||
GELOGI("LaunchCustAucpuSo in, kernel name %s", kernel_name.c_str()); | |||||
std::lock_guard<std::mutex> lock(cust_aicpu_mutex_); | |||||
if (cust_aicpu_so_.size() == 0) return SUCCESS; | |||||
// get current context | |||||
rtContext_t rt_cur_ctx = nullptr; | |||||
auto rt_error = rtCtxGetCurrent(&rt_cur_ctx); | |||||
if (rt_error != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "get current context failed, runtime result is %d", static_cast<int>(rt_error)); | |||||
return RT_FAILED; | |||||
} | |||||
uintptr_t resource_id = reinterpret_cast<uintptr_t>(rt_cur_ctx); | |||||
auto it = cust_aicpu_so_.find(resource_id); | |||||
if (it == cust_aicpu_so_.end()) { | |||||
GELOGI("Cust aicpu so map is empty, context id %lu", resource_id); | |||||
return SUCCESS; | |||||
} | } | ||||
const void *aicpu_data = aicpu_kernel->GetBinData(); | |||||
uint32_t aicpu_data_length = aicpu_kernel->GetBinDataSize(); | |||||
void *d_aicpu_data = nullptr; | |||||
void *d_so_name = nullptr; | |||||
void *args = nullptr; | |||||
vector<void *> allocated_mem; | |||||
rtError_t status; | rtError_t status; | ||||
rtStream_t stream = nullptr; | rtStream_t stream = nullptr; | ||||
GE_CHK_RT(rtMalloc(&d_aicpu_data, aicpu_data_length, RT_MEMORY_HBM)); | |||||
GE_CHK_RT(rtMemcpy(d_aicpu_data, aicpu_data_length, aicpu_data, aicpu_data_length, RT_MEMCPY_HOST_TO_DEVICE)); | |||||
GE_CHK_RT(rtMalloc(&d_so_name, so_name.size(), RT_MEMORY_HBM)); | |||||
GE_CHK_RT(rtMemcpy(d_so_name, so_name.size(), reinterpret_cast<const void *>(so_name.c_str()), so_name.size(), | |||||
RT_MEMCPY_HOST_TO_DEVICE)); | |||||
vector<CustAicpuSoBuf> v_cust_so; | |||||
void *args = nullptr; | |||||
CustAicpuSoBuf cust_aicpu_so_buf; | |||||
cust_aicpu_so_buf.kernelSoBuf = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_aicpu_data)); | |||||
cust_aicpu_so_buf.kernelSoBufLen = aicpu_data_length; | |||||
cust_aicpu_so_buf.kernelSoName = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_so_name)); | |||||
cust_aicpu_so_buf.kernelSoNameLen = so_name.size(); | |||||
for (const auto &it_so : it->second) { | |||||
const void *aicpu_data = it_so.second->GetBinData(); | |||||
uint32_t aicpu_data_length = it_so.second->GetBinDataSize(); | |||||
string so_name = it_so.first; | |||||
void *d_aicpu_data = nullptr; | |||||
void *d_so_name = nullptr; | |||||
status = rtMalloc(&d_aicpu_data, aicpu_data_length, RT_MEMORY_HBM); | |||||
if (status != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt failed, status: 0x%x", status); | |||||
return RT_ERROR_TO_GE_STATUS(status); | |||||
} | |||||
allocated_mem.push_back(d_aicpu_data); | |||||
status = rtMalloc(&d_so_name, so_name.size(), RT_MEMORY_HBM); | |||||
if (status != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt failed, status: 0x%x", status); | |||||
return RT_ERROR_TO_GE_STATUS(status); | |||||
} | |||||
allocated_mem.push_back(d_so_name); | |||||
GE_CHK_RT(rtMemcpy(d_aicpu_data, aicpu_data_length, aicpu_data, aicpu_data_length, RT_MEMCPY_HOST_TO_DEVICE)); | |||||
GE_CHK_RT(rtMemcpy(d_so_name, so_name.size(), reinterpret_cast<const void *>(so_name.c_str()), so_name.size(), | |||||
RT_MEMCPY_HOST_TO_DEVICE)); | |||||
CustAicpuSoBuf cust_aicpu_so_buf; | |||||
cust_aicpu_so_buf.kernelSoBuf = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_aicpu_data)); | |||||
cust_aicpu_so_buf.kernelSoBufLen = aicpu_data_length; | |||||
cust_aicpu_so_buf.kernelSoName = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_so_name)); | |||||
cust_aicpu_so_buf.kernelSoNameLen = so_name.size(); | |||||
v_cust_so.push_back(cust_aicpu_so_buf); | |||||
} | |||||
if (kernel_name == kDeleteCustOp) { | |||||
(void)cust_aicpu_so_.erase(it); | |||||
} | |||||
uint32_t args_size = sizeof(CustAicpuSoBuf) * v_cust_so.size(); | |||||
status = rtMalloc(&args, args_size, RT_MEMORY_HBM); | |||||
if (status != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt failed, status: 0x%x", status); | |||||
return RT_ERROR_TO_GE_STATUS(status); | |||||
} | |||||
allocated_mem.push_back(args); | |||||
GE_CHK_RT(rtMemcpy(args, args_size, v_cust_so.data(), args_size, RT_MEMCPY_HOST_TO_DEVICE)); | |||||
BatchLoadOpFromBufArgs batch_cust_so; | |||||
batch_cust_so.soNum = v_cust_so.size(); | |||||
batch_cust_so.args = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(args)); | |||||
void *batch_args = nullptr; | |||||
uint32_t batch_args_size = sizeof(BatchLoadOpFromBufArgs); | |||||
status = rtMalloc(&batch_args, batch_args_size, RT_MEMORY_HBM); | |||||
if (status != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt failed, status: 0x%x", status); | |||||
return RT_ERROR_TO_GE_STATUS(status); | |||||
} | |||||
allocated_mem.push_back(batch_args); | |||||
GE_CHK_RT(rtMemcpy(batch_args, batch_args_size, static_cast<void *>(&batch_cust_so), batch_args_size, | |||||
RT_MEMCPY_HOST_TO_DEVICE)); | |||||
uint32_t args_size = sizeof(CustAicpuSoBuf); | |||||
GE_CHK_RT(rtMalloc(&args, args_size, RT_MEMORY_HBM)); | |||||
GE_CHK_RT(rtMemcpy(args, args_size, static_cast<void *>(&cust_aicpu_so_buf), args_size, RT_MEMCPY_HOST_TO_DEVICE)); | |||||
GE_CHK_RT(rtStreamCreate(&stream, 0)); | GE_CHK_RT(rtStreamCreate(&stream, 0)); | ||||
GE_CHK_RT(rtCpuKernelLaunch(nullptr, kLoadOpFromBuf, 1, args, args_size, nullptr, stream)); | |||||
GE_CHK_RT(rtCpuKernelLaunch(nullptr, kernel_name.c_str(), 1, batch_args, batch_args_size, nullptr, stream)); | |||||
status = rtStreamSynchronize(stream); | status = rtStreamSynchronize(stream); | ||||
if (status != RT_ERROR_NONE) { | if (status != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rt stream sync failed, status: 0x%x", status); | GELOGE(RT_FAILED, "Call rt stream sync failed, status: 0x%x", status); | ||||
GE_CHK_RT(rtStreamDestroy(stream)); | |||||
GE_CHK_RT(rtFree(args)); | |||||
GE_CHK_RT(rtFree(d_aicpu_data)); | |||||
GE_CHK_RT(rtFree(d_so_name)); | |||||
return RT_ERROR_TO_GE_STATUS(status); | return RT_ERROR_TO_GE_STATUS(status); | ||||
} | } | ||||
GE_CHK_RT(rtStreamDestroy(stream)); | |||||
GE_CHK_RT(rtFree(args)); | |||||
GE_CHK_RT(rtFree(d_aicpu_data)); | |||||
GE_CHK_RT(rtFree(d_so_name)); | |||||
GELOGI("Cpu kernel launch loadOpFromBuf task success."); | |||||
std::function<void()> callback = [&]() { | |||||
for (auto mem : allocated_mem) { | |||||
GE_CHK_RT(rtFree(mem)); | |||||
} | |||||
GE_CHK_RT(rtStreamDestroy(stream)); | |||||
}; | |||||
GE_MAKE_GUARD(release, callback); | |||||
GELOGI("Cpu kernel launch task success."); | |||||
return SUCCESS; | |||||
} | |||||
Status ModelManager::ClearAicpuSo() { | |||||
GE_CHK_STATUS_RET(LaunchKernelCustAicpuSo(kDeleteCustOp), "delete cust op so failed."); | |||||
return SUCCESS; | |||||
} | |||||
Status ModelManager::LaunchCustAicpuSo() { | |||||
GE_CHK_STATUS_RET(LaunchKernelCustAicpuSo(kBatchLoadBuf), "launch cust op so failed."); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -270,9 +270,13 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
ge::Status DestroyAicpuSessionForInfer(uint32_t model_id); | ge::Status DestroyAicpuSessionForInfer(uint32_t model_id); | ||||
ge::Status LoadCustAicpuSo(const OpDescPtr op_desc, string so_name); | |||||
ge::Status LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_name); | |||||
ge::Status LaunchCustAicpuSo(const OpDescPtr op_desc, string so_name); | |||||
ge::Status LaunchCustAicpuSo(); | |||||
ge::Status ClearAicpuSo(); | |||||
ge::Status LaunchKernelCustAicpuSo(const string &kernel_name); | |||||
ge::Status GetOrigInputInfo(uint32_t model_id, uint32_t index, OriginInputInfo &orig_input_info); | ge::Status GetOrigInputInfo(uint32_t model_id, uint32_t index, OriginInputInfo &orig_input_info); | ||||
@@ -340,7 +344,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
std::set<uint64_t> sess_ids_; | std::set<uint64_t> sess_ids_; | ||||
std::vector<rtExceptionInfo> exception_infos_; | std::vector<rtExceptionInfo> exception_infos_; | ||||
std::mutex cust_aicpu_mutex_; | std::mutex cust_aicpu_mutex_; | ||||
std::set<std::string> cust_aicpu_so_; | |||||
std::map<uintptr_t, std::map<std::string, CustAICPUKernelPtr>> cust_aicpu_so_; | |||||
static DumpProperties dump_properties_; | static DumpProperties dump_properties_; | ||||
}; | }; | ||||
@@ -479,13 +479,15 @@ vector<void *> ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param | |||||
ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_WORKSPACE_TYPE_LIST, workspace_memory_type); | ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_WORKSPACE_TYPE_LIST, workspace_memory_type); | ||||
for (size_t i = 0; i < v_workspace_bytes.size(); ++i) { | for (size_t i = 0; i < v_workspace_bytes.size(); ++i) { | ||||
// Temporary solution, the aicpu workspace of multiple images cannot be shared. | // Temporary solution, the aicpu workspace of multiple images cannot be shared. | ||||
if (has_workspace_reuse && i < workspace_reuse_flag.size() && !workspace_reuse_flag[i]) { | |||||
if (has_workspace_reuse && i < workspace_reuse_flag.size() && !workspace_reuse_flag[i] && | |||||
!model_param.is_single_op) { | |||||
void *mem_addr = model_param.aicpu_mem_mall->Acquire(v_workspace_offset[i], v_workspace_bytes[i]); | void *mem_addr = model_param.aicpu_mem_mall->Acquire(v_workspace_offset[i], v_workspace_bytes[i]); | ||||
v_workspace_data_addr.push_back(mem_addr); | v_workspace_data_addr.push_back(mem_addr); | ||||
GELOGI( | GELOGI( | ||||
"[IMAS]GetWorkspaceDataAddrs graph_%u type[F] name[%s] aicpu workspace[%zu] offset[%ld] bytes[%ld] " | "[IMAS]GetWorkspaceDataAddrs graph_%u type[F] name[%s] aicpu workspace[%zu] offset[%ld] bytes[%ld] " | ||||
"memaddr[%p]", | "memaddr[%p]", | ||||
model_param.graph_id, op_desc->GetName().c_str(), i, v_workspace_offset[i], v_workspace_bytes[i], mem_addr); | model_param.graph_id, op_desc->GetName().c_str(), i, v_workspace_offset[i], v_workspace_bytes[i], mem_addr); | ||||
continue; | |||||
} else if (has_mem_type_workspace && workspace_memory_type[i] == RT_MEMORY_P2P_DDR) { | } else if (has_mem_type_workspace && workspace_memory_type[i] == RT_MEMORY_P2P_DDR) { | ||||
int64_t p2p_workspace_offset = v_workspace_offset[i]; | int64_t p2p_workspace_offset = v_workspace_offset[i]; | ||||
int64_t p2p_workspace_bytes = v_workspace_bytes[i]; | int64_t p2p_workspace_bytes = v_workspace_bytes[i]; | ||||
@@ -0,0 +1,54 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/load/new_model_manager/task_info/model_exit_task_info.h" | |||||
#include "common/properties_manager.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "graph/load/new_model_manager/davinci_model.h" | |||||
namespace ge { | |||||
Status ModelExitTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | |||||
GELOGI("InitModelExitTaskInfo Init Start."); | |||||
if (davinci_model == nullptr) { | |||||
GELOGE(PARAM_INVALID, "davinci_model is null!"); | |||||
return PARAM_INVALID; | |||||
} | |||||
Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "SetStream fail, stream_id:%u", task_def.stream_id()); | |||||
return ret; | |||||
} | |||||
model_ = davinci_model->GetRtModelHandle(); | |||||
GELOGI("InitModelExitTaskInfo Init Success, model:%p, stream:%p", model_, stream_); | |||||
return SUCCESS; | |||||
} | |||||
Status ModelExitTaskInfo::Distribute() { | |||||
GELOGI("ModelExitTaskInfo Distribute Start."); | |||||
rtError_t rt_ret = rtModelExit(model_, stream_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rtModelExit failed, ret: 0x%x", rt_ret); | |||||
return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
} | |||||
GELOGI("ModelExitTaskInfo Distribute Success."); | |||||
return SUCCESS; | |||||
} | |||||
REGISTER_TASK_INFO(RT_MODEL_TASK_MODEL_EXIT, ModelExitTaskInfo); | |||||
} // namespace ge |
@@ -0,0 +1,37 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MODEL_EXIT_TASK_INFO_H_ | |||||
#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MODEL_EXIT_TASK_INFO_H_ | |||||
#include "graph/load/new_model_manager/task_info/task_info.h" | |||||
namespace ge { | |||||
class ModelExitTaskInfo : public TaskInfo { | |||||
public: | |||||
ModelExitTaskInfo() {} | |||||
~ModelExitTaskInfo() override { model_ = nullptr; } | |||||
Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; | |||||
Status Distribute() override; | |||||
private: | |||||
rtModel_t model_{nullptr}; | |||||
}; | |||||
} // namespace ge | |||||
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MODEL_EXIT_TASK_INFO_H_ |
@@ -56,6 +56,7 @@ struct RuntimeParam { | |||||
uint32_t label_num = 0; | uint32_t label_num = 0; | ||||
uint64_t session_id = 0; | uint64_t session_id = 0; | ||||
uint32_t graph_id = 0; | uint32_t graph_id = 0; | ||||
bool is_single_op = false; | |||||
std::unique_ptr<TsMemMall> ts_mem_mall; | std::unique_ptr<TsMemMall> ts_mem_mall; | ||||
std::unique_ptr<TsMemMall> aicpu_mem_mall; | std::unique_ptr<TsMemMall> aicpu_mem_mall; | ||||
@@ -69,6 +69,7 @@ | |||||
#include "graph/passes/link_gen_mask_nodes_pass.h" | #include "graph/passes/link_gen_mask_nodes_pass.h" | ||||
#include "graph/passes/mark_graph_unknown_status_pass.h" | #include "graph/passes/mark_graph_unknown_status_pass.h" | ||||
#include "graph/passes/merge_pass.h" | #include "graph/passes/merge_pass.h" | ||||
#include "graph/passes/merge_input_memcpy_pass.h" | |||||
#include "graph/passes/merge_to_stream_merge_pass.h" | #include "graph/passes/merge_to_stream_merge_pass.h" | ||||
#include "graph/passes/multi_batch_pass.h" | #include "graph/passes/multi_batch_pass.h" | ||||
#include "graph/passes/next_iteration_pass.h" | #include "graph/passes/next_iteration_pass.h" | ||||
@@ -1949,6 +1950,8 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||||
} | } | ||||
PassManager after_merge_passes; | PassManager after_merge_passes; | ||||
GE_CHK_STATUS_RET( | GE_CHK_STATUS_RET( | ||||
after_merge_passes.AddPass("OptimizeStage1_1::MergeInputMemcpyPass", new (std::nothrow) MergeInputMemcpyPass)); | |||||
GE_CHK_STATUS_RET( | |||||
after_merge_passes.AddPass("OptimizeStage1_1::SwitchDataEdgesBypass", new (std::nothrow) SwitchDataEdgesBypass)); | after_merge_passes.AddPass("OptimizeStage1_1::SwitchDataEdgesBypass", new (std::nothrow) SwitchDataEdgesBypass)); | ||||
GE_CHK_STATUS_RET( | GE_CHK_STATUS_RET( | ||||
after_merge_passes.AddPass("OptimizeStage1_1::ConstantFuseSamePass", new (std::nothrow) ConstantFuseSamePass)); | after_merge_passes.AddPass("OptimizeStage1_1::ConstantFuseSamePass", new (std::nothrow) ConstantFuseSamePass)); | ||||
@@ -26,7 +26,7 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
constexpr int kMaxRePassTimes = 1000; | |||||
constexpr int kMaxRePassTimes = 10000; | |||||
constexpr size_t kMaxOneInNodes = 1000; | constexpr size_t kMaxOneInNodes = 1000; | ||||
// Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later | // Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later | ||||
constexpr int kMaxRecursiveDepth = 20; | constexpr int kMaxRecursiveDepth = 20; | ||||
@@ -84,6 +84,22 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) { | |||||
return graph_change ? SUCCESS : NOT_CHANGED; | return graph_change ? SUCCESS : NOT_CHANGED; | ||||
} | } | ||||
bool FlowCtrlPass::CheckMultiDataSet(ComputeGraphPtr &compute_graph) { | |||||
int data_set_num = 0; | |||||
for (auto &node : compute_graph->GetDirectNode()) { | |||||
if (node == nullptr) { | |||||
continue; | |||||
} | |||||
string type; | |||||
bool is_found = AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); | |||||
if (is_found && type == "IteratorV2") { | |||||
data_set_num++; | |||||
} | |||||
} | |||||
GELOGI("The ComputeGraph contain %d dataSet.", data_set_num); | |||||
return (data_set_num > 1) ? true : false; | |||||
} | |||||
NodePtr FlowCtrlPass::InsertOp(ComputeGraphPtr &compute_graph, const string &node_type, const string &node_name, | NodePtr FlowCtrlPass::InsertOp(ComputeGraphPtr &compute_graph, const string &node_type, const string &node_name, | ||||
const std::vector<GeTensorDesc> &input_list, | const std::vector<GeTensorDesc> &input_list, | ||||
const std::vector<GeTensorDesc> &output_list) { | const std::vector<GeTensorDesc> &output_list) { | ||||
@@ -312,12 +328,12 @@ Status FlowCtrlPass::CreateIterCtrlFalseBranch(ComputeGraphPtr &compute_graph, c | |||||
* loopCond | * loopCond | ||||
* | | * | | ||||
* v | * v | ||||
* switch --> Assign | |||||
* switch --> Assign --> active --> ModelExit | |||||
* ^ | * ^ | ||||
* | | * | | ||||
* loopReset | * loopReset | ||||
*/ | */ | ||||
// Insert Assign node | |||||
// Insert Assign node and ctrl edge | |||||
NodePtr assign_node = | NodePtr assign_node = | ||||
InsertAssignOp(compute_graph, ASSIGN, NODE_NAME_FLOWCTRL_LOOP_ASSIGN, loop_cond_node, loop_reset_node); | InsertAssignOp(compute_graph, ASSIGN, NODE_NAME_FLOWCTRL_LOOP_ASSIGN, loop_cond_node, loop_reset_node); | ||||
if (assign_node == nullptr || switch_node == nullptr) { | if (assign_node == nullptr || switch_node == nullptr) { | ||||
@@ -327,13 +343,50 @@ Status FlowCtrlPass::CreateIterCtrlFalseBranch(ComputeGraphPtr &compute_graph, c | |||||
GE_CHK_STATUS_RET(SetStreamLabel(assign_node, switch_node->GetName()), "set stream label failed"); | GE_CHK_STATUS_RET(SetStreamLabel(assign_node, switch_node->GetName()), "set stream label failed"); | ||||
// 3. Insert ctrl edges | |||||
graphStatus add_ret = GraphUtils::AddEdge(switch_node->GetOutControlAnchor(), assign_node->GetInControlAnchor()); | graphStatus add_ret = GraphUtils::AddEdge(switch_node->GetOutControlAnchor(), assign_node->GetInControlAnchor()); | ||||
if (add_ret != GRAPH_SUCCESS) { | if (add_ret != GRAPH_SUCCESS) { | ||||
GELOGE(FAILED, "Add switch_node to assign_node ctrl edge failed, add_ret=%u.", add_ret); | GELOGE(FAILED, "Add switch_node to assign_node ctrl edge failed, add_ret=%u.", add_ret); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
if (CheckMultiDataSet(compute_graph)) { | |||||
GELOGI("Multi dataSae exist, model_exit node is need."); | |||||
// 2. Insert active node and add ctrl edge | |||||
string active_name = switch_node->GetName() + "_StreamExitActive"; | |||||
NodePtr active_node = InsertOp(compute_graph, STREAMACTIVE, active_name, {}, {}); | |||||
if (active_node == nullptr) { | |||||
GELOGE(FAILED, "Insert stream active node:%s for IterCtrlTrueStream failed.", active_name.c_str()); | |||||
return FAILED; | |||||
} | |||||
GE_CHK_STATUS_RET(SetStreamLabel(active_node, switch_node->GetName()), "set stream label failed"); | |||||
GE_IF_BOOL_EXEC(!AttrUtils::SetBool(active_node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, true), | |||||
DOMI_LOGE("set ATTR_NAME_IS_LOOP_ACTIVE failed"); | |||||
return FAILED); | |||||
string model_exit_name = switch_node->GetName() + "_ModelExit"; | |||||
GE_CHK_STATUS_RET(SetActiveLabelList(active_node, {model_exit_name}), "set active label list failed"); | |||||
add_ret = GraphUtils::AddEdge(assign_node->GetOutControlAnchor(), active_node->GetInControlAnchor()); | |||||
if (add_ret != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "Add assign_node to active_node ctrl edge failed, add_ret=%u.", add_ret); | |||||
return FAILED; | |||||
} | |||||
// 3. Insert model exit node and add ctrl edge | |||||
NodePtr model_exit_node = InsertOp(compute_graph, MODELEXIT, model_exit_name, {}, {}); | |||||
if (model_exit_node == nullptr) { | |||||
GELOGE(FAILED, "Insert model_exit node:%s for IterCtrlTrueStream failed.", model_exit_name.c_str()); | |||||
return FAILED; | |||||
} | |||||
GE_CHK_STATUS_RET(SetStreamLabel(model_exit_node, model_exit_name), "set stream label failed"); | |||||
add_ret = GraphUtils::AddEdge(active_node->GetOutControlAnchor(), model_exit_node->GetInControlAnchor()); | |||||
if (add_ret != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "Add active_node to model_exit_node ctrl edge failed, add_ret=%u.", add_ret); | |||||
return FAILED; | |||||
} | |||||
} | |||||
GELOGI("CreateIterCtrlFalseBranch success."); | GELOGI("CreateIterCtrlFalseBranch success."); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -134,6 +134,14 @@ class FlowCtrlPass : public GraphPass { | |||||
/// Other: failed | /// Other: failed | ||||
/// | /// | ||||
Status AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph, NodePtr &loop_after_node); | Status AddSpecialNodeIteratorCtrl(ComputeGraphPtr &compute_graph, NodePtr &loop_after_node); | ||||
/// | |||||
/// add special iterator ctrl nodes(small cycle). | |||||
/// @param compute_graph graph | |||||
/// @return true: two or more dataSet exist | |||||
/// false: only one dataSet exist | |||||
/// | |||||
bool CheckMultiDataSet(ComputeGraphPtr &compute_graph); | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -16,20 +16,40 @@ | |||||
#include "graph/passes/mark_agnostic_pass.h" | #include "graph/passes/mark_agnostic_pass.h" | ||||
#include "utils/node_utils.h" | |||||
#include "graph/utils/node_utils.h" | |||||
namespace ge { | namespace ge { | ||||
Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | ||||
for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
auto node_type = NodeUtils::GetNodeType(*node); | auto node_type = NodeUtils::GetNodeType(*node); | ||||
if (node_type == SWITCH || node_type == REFSWITCH || node_type == SWITCHN) { | if (node_type == SWITCH || node_type == REFSWITCH || node_type == SWITCHN) { | ||||
GELOGD("Mark format agnostic for switch ndoe %s", node->GetName().c_str()); | |||||
GELOGD("Mark format agnostic and continuous for switch node %s", node->GetName().c_str()); | |||||
const OpDescPtr op_desc = node->GetOpDesc(); | |||||
const GeTensorDescPtr op_tensor = op_desc->MutableInputDesc(0); | |||||
if (op_tensor == nullptr) { | |||||
GELOGD("Op: %s, Index:0,has no input", node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
AttrUtils::SetInt(op_tensor, "_format_continuous", 1); | |||||
AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | |||||
AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector<int64_t>({1})); | |||||
continue; | |||||
} | |||||
if (node_type == IDENTITY) { | |||||
GELOGD("Mark format agnostic for identity node %s", node->GetName().c_str()); | |||||
AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | ||||
AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector<int64_t>({1})); | AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector<int64_t>({1})); | ||||
continue; | continue; | ||||
} | } | ||||
if (node_type == MERGE || node_type == REFMERGE) { | if (node_type == MERGE || node_type == REFMERGE) { | ||||
GELOGD("Mark format agnostic for merge node %s", node->GetName().c_str()); | |||||
GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str()); | |||||
const OpDescPtr op_desc = node->GetOpDesc(); | |||||
const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0); | |||||
if (op_tensor == nullptr) { | |||||
GELOGD("Op: %s, Index:0,has no output", node->GetName().c_str()); | |||||
continue; | |||||
} | |||||
AttrUtils::SetInt(op_tensor, "_format_continuous", 1); | |||||
AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | ||||
AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_output", std::vector<int64_t>({1})); | AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_output", std::vector<int64_t>({1})); | ||||
continue; | continue; | ||||
@@ -0,0 +1,97 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/passes/merge_input_memcpy_pass.h" | |||||
#include "common/ge/ge_util.h" | |||||
#include "ge/ge_api_types.h" | |||||
#include "graph/common/omg_util.h" | |||||
namespace ge { | |||||
Status MergeInputMemcpyPass::Run(ComputeGraphPtr graph) { | |||||
GELOGD("MergeInputMemcpyPass Enter"); | |||||
for (const auto &node : graph->GetDirectNode()) { | |||||
if ((node->GetType() != MERGE) && (node->GetType() != REFMERGE)) { | |||||
continue; | |||||
} | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, node, node->GetOpDesc()->HasAttr(ATTR_INSERT_BY_MBATCH)), | |||||
"Merge add memcpy node failed."); | |||||
} | |||||
GELOGD("MergeInputMemcpyPass Leave"); | |||||
return SUCCESS; | |||||
} | |||||
/// | |||||
/// @brief Add MemcpyAsync Op as Merge in_node | |||||
/// @param [in] graph | |||||
/// @param [in] node | |||||
/// @param [in] multi_batch_flag | |||||
/// @return Status | |||||
/// | |||||
Status MergeInputMemcpyPass::AddMemcpyAsyncNodes(const ComputeGraphPtr &graph, const NodePtr &node, | |||||
bool multi_batch_flag) { | |||||
for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | |||||
NodePtr in_node = peer_out_anchor->GetOwnerNode(); | |||||
const std::string &type = in_node->GetType(); | |||||
// For WhileLoop no need memcpy for merge. | |||||
GE_IF_BOOL_EXEC((type == ENTER) || (type == REFENTER) || (type == NEXTITERATION) || (type == REFNEXTITERATION), | |||||
continue); | |||||
const std::string &memcpy_name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()); | |||||
NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, memcpy_name, peer_out_anchor, multi_batch_flag); | |||||
GE_CHK_BOOL_EXEC(memcpy_node != nullptr, return FAILED, "Create MemcpyAsync node failed."); | |||||
GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "MemcpyAsync node remove edge failed."); | |||||
GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, memcpy_node->GetInDataAnchor(0)), | |||||
"MemcpyAsync node add edge failed."); | |||||
GE_CHK_STATUS(GraphUtils::AddEdge(memcpy_node->GetOutDataAnchor(0), in_data_anchor), | |||||
"MemcpyAsync node add edge failed."); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
/// | |||||
/// @brief Add MemcpyAsync Node | |||||
/// @param [in] graph | |||||
/// @param [in] name | |||||
/// @param [in] out_data_anchor | |||||
/// @param [in] multi_batch_flag | |||||
/// @return ge::NodePtr | |||||
/// | |||||
NodePtr MergeInputMemcpyPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, | |||||
const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag) { | |||||
OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); | |||||
GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "OpDesc of pre node is invalid."); | |||||
const std::string &memcpy_type = multi_batch_flag ? MEMCPYADDRASYNC : MEMCPYASYNC; | |||||
const std::string &node_name = name + "_" + memcpy_type; | |||||
GELOGI("Create MemcpyAsync op:%s.", node_name.c_str()); | |||||
OpDescPtr op_desc = MakeShared<OpDesc>(node_name, memcpy_type); | |||||
if (op_desc == nullptr) { | |||||
GELOGE(FAILED, "Create op_desc failed, MemcpyAsync:%s.", node_name.c_str()); | |||||
return nullptr; | |||||
} | |||||
GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) == GRAPH_SUCCESS, | |||||
return nullptr, "Create MemcpyAsync op: add input desc failed."); | |||||
GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) == GRAPH_SUCCESS, | |||||
return nullptr, "Create MemcpyAsync op: add output desc failed."); | |||||
return graph->AddNode(op_desc); | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,49 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GRAPH_PASSES_MERGE_ADD_INPUT_MEMCPY_PASS_H_ | |||||
#define GE_GRAPH_PASSES_MERGE_ADD_INPUT_MEMCPY_PASS_H_ | |||||
#include "inc/graph_pass.h" | |||||
namespace ge { | |||||
class MergeInputMemcpyPass : public GraphPass { | |||||
public: | |||||
Status Run(ComputeGraphPtr graph); | |||||
private: | |||||
/// | |||||
/// @brief Add MemcpyAsync Op as Merge in_node | |||||
/// @param [in] graph | |||||
/// @param [in] node | |||||
/// @param [in] multi_batch_flag | |||||
/// @return Status | |||||
/// | |||||
Status AddMemcpyAsyncNodes(const ComputeGraphPtr &graph, const NodePtr &node, bool multi_batch_flag); | |||||
/// | |||||
/// @brief Add MemcpyAsync Node | |||||
/// @param [in] graph | |||||
/// @param [in] name | |||||
/// @param [in] out_data_anchor | |||||
/// @param [in] multi_batch_flag | |||||
/// @return ge::NodePtr | |||||
/// | |||||
NodePtr CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, | |||||
const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag); | |||||
}; | |||||
} // namespace ge | |||||
#endif // GE_GRAPH_PASSES_MERGE_ADD_INPUT_MEMCPY_PASS_H_ |
@@ -32,7 +32,7 @@ Status MergeToStreamMergePass::Run(ComputeGraphPtr graph) { | |||||
OpDescPtr merge_op_desc = node->GetOpDesc(); | OpDescPtr merge_op_desc = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(merge_op_desc); | GE_CHECK_NOTNULL(merge_op_desc); | ||||
if (merge_op_desc->HasAttr(ATTR_INSERT_BY_MBATCH)) { | if (merge_op_desc->HasAttr(ATTR_INSERT_BY_MBATCH)) { | ||||
GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, node, true), "Merge add memcpy node failed."); | |||||
GE_CHK_STATUS_RET(AddActiveNodes(graph, node), "Merge add active node failed."); | |||||
GE_CHK_STATUS_RET(SetStreamLabel(node, node->GetName()), "Set stream label failed"); | GE_CHK_STATUS_RET(SetStreamLabel(node, node->GetName()), "Set stream label failed"); | ||||
} else { | } else { | ||||
GE_CHK_STATUS_RET(ReplaceMergeNode(graph, node), "Add StreamMerge node failed."); | GE_CHK_STATUS_RET(ReplaceMergeNode(graph, node), "Add StreamMerge node failed."); | ||||
@@ -99,38 +99,26 @@ Status MergeToStreamMergePass::ReplaceMergeNode(const ComputeGraphPtr &graph, co | |||||
} | } | ||||
} | } | ||||
return AddMemcpyAsyncNodes(graph, stream_merge, false); | |||||
return AddActiveNodes(graph, stream_merge); | |||||
} | } | ||||
/// | /// | ||||
/// @brief Add MemcpyAsync Op as StreamMerge in_node | |||||
/// @brief Add StreamActive Op before StreamMerge/Merge | |||||
/// @param [in] graph | /// @param [in] graph | ||||
/// @param [in] node | /// @param [in] node | ||||
/// @param [in] multi_batch_flag | |||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status MergeToStreamMergePass::AddMemcpyAsyncNodes(const ComputeGraphPtr &graph, const NodePtr &node, | |||||
bool multi_batch_flag) { | |||||
Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, const NodePtr &node) { | |||||
GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Param of pre node is null."); | GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Param of pre node is null."); | ||||
for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | ||||
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | ||||
NodePtr in_node = peer_out_anchor->GetOwnerNode(); | NodePtr in_node = peer_out_anchor->GetOwnerNode(); | ||||
const std::string &type = in_node->GetType(); | const std::string &type = in_node->GetType(); | ||||
// For WhileLoop no need memcpy & active for merge. | |||||
// For WhileLoop, no need to add active nodes here, since which have been added in NextIterationPass. | |||||
GE_IF_BOOL_EXEC((type == ENTER) || (type == REFENTER) || (type == NEXTITERATION) || (type == REFNEXTITERATION), | GE_IF_BOOL_EXEC((type == ENTER) || (type == REFENTER) || (type == NEXTITERATION) || (type == REFNEXTITERATION), | ||||
continue); | continue); | ||||
const std::string &memcpy_name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()); | |||||
NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, memcpy_name, peer_out_anchor, multi_batch_flag); | |||||
GE_CHK_BOOL_EXEC(memcpy_node != nullptr, return FAILED, "Create MemcpyAsync node failed."); | |||||
GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "MemcpyAsync node remove edge failed."); | |||||
GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, memcpy_node->GetInDataAnchor(0)), | |||||
"MemcpyAsync node add edge failed."); | |||||
GE_CHK_STATUS(GraphUtils::AddEdge(memcpy_node->GetOutDataAnchor(0), in_data_anchor), | |||||
"MemcpyAsync node add edge failed."); | |||||
NodePtr active_node = CreateActiveNode(graph, memcpy_node); | |||||
NodePtr active_node = CreateActiveNode(graph, in_node); | |||||
GE_CHK_BOOL_EXEC(active_node != nullptr, return FAILED, "Create StreamActive node failed."); | GE_CHK_BOOL_EXEC(active_node != nullptr, return FAILED, "Create StreamActive node failed."); | ||||
GE_CHK_STATUS(GraphUtils::AddEdge(active_node->GetOutControlAnchor(), node->GetInControlAnchor()), | GE_CHK_STATUS(GraphUtils::AddEdge(active_node->GetOutControlAnchor(), node->GetInControlAnchor()), | ||||
"StreamActive add ctrl edge failed."); | "StreamActive add ctrl edge failed."); | ||||
@@ -144,37 +132,6 @@ Status MergeToStreamMergePass::AddMemcpyAsyncNodes(const ComputeGraphPtr &graph, | |||||
} | } | ||||
/// | /// | ||||
/// @brief Add MemcpyAsync Node | |||||
/// @param [in] graph | |||||
/// @param [in] name | |||||
/// @param [in] out_data_anchor | |||||
/// @param [in] multi_batch_flag | |||||
/// @return ge::NodePtr | |||||
/// | |||||
NodePtr MergeToStreamMergePass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, | |||||
const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag) { | |||||
GE_CHK_BOOL_EXEC(out_data_anchor != nullptr, return nullptr, "Param of input node is null."); | |||||
OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); | |||||
GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "OpDesc of pre node is invalid."); | |||||
const std::string &memcpy_type = multi_batch_flag ? MEMCPYADDRASYNC : MEMCPYASYNC; | |||||
const std::string &node_name = name + "_" + memcpy_type; | |||||
GELOGI("Create MemcpyAsync op:%s.", node_name.c_str()); | |||||
OpDescPtr op_desc = MakeShared<OpDesc>(node_name, memcpy_type); | |||||
if (op_desc == nullptr) { | |||||
GELOGE(FAILED, "Create op_desc failed, MemcpyAsync:%s.", node_name.c_str()); | |||||
return nullptr; | |||||
} | |||||
GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) == GRAPH_SUCCESS, | |||||
return nullptr, "Create MemcpyAsync op: add input desc failed."); | |||||
GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) == GRAPH_SUCCESS, | |||||
return nullptr, "Create MemcpyAsync op: add output desc failed."); | |||||
return graph->AddNode(op_desc); | |||||
} | |||||
/// | |||||
/// @brief Create Active Op | /// @brief Create Active Op | ||||
/// @param [in] graph | /// @param [in] graph | ||||
/// @param [in] node | /// @param [in] node | ||||
@@ -34,24 +34,12 @@ class MergeToStreamMergePass : public GraphPass { | |||||
Status ReplaceMergeNode(const ComputeGraphPtr &graph, const NodePtr &merge_node); | Status ReplaceMergeNode(const ComputeGraphPtr &graph, const NodePtr &merge_node); | ||||
/// | /// | ||||
/// @brief Add MemcpyAsync Op as StreamMerge in_node | |||||
/// @brief Add StreamActive Op as StreamMerge in_node | |||||
/// @param [in] graph | /// @param [in] graph | ||||
/// @param [in] node | /// @param [in] node | ||||
/// @param [in] multi_batch_flag | |||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status AddMemcpyAsyncNodes(const ComputeGraphPtr &graph, const NodePtr &node, bool multi_batch_flag); | |||||
/// | |||||
/// @brief Add MemcpyAsync Node | |||||
/// @param [in] graph | |||||
/// @param [in] name | |||||
/// @param [in] out_data_anchor | |||||
/// @param [in] multi_batch_flag | |||||
/// @return ge::NodePtr | |||||
/// | |||||
NodePtr CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, | |||||
const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag); | |||||
Status AddActiveNodes(const ComputeGraphPtr &graph, const NodePtr &node); | |||||
/// | /// | ||||
/// @brief Create Active Op | /// @brief Create Active Op | ||||
@@ -131,6 +131,14 @@ graphStatus TransOpWithoutReshapeFusionPass::GetSubGraphNodesInfo() { | |||||
sub_graph_has_reshape_node[i] = true; | sub_graph_has_reshape_node[i] = true; | ||||
break; | break; | ||||
} | } | ||||
if (in_node->GetType() == TRANSPOSE || in_node->GetType() == TRANSPOSED) { | |||||
auto input_format = in_node->GetOpDesc()->GetInputDescPtr(0)->GetFormat(); | |||||
auto output_format = in_node->GetOpDesc()->GetOutputDescPtr(0)->GetFormat(); | |||||
if (input_format == output_format) { | |||||
sub_graph_has_reshape_node[i] = true; | |||||
break; | |||||
} | |||||
} | |||||
auto out_anchor = iter->first; | auto out_anchor = iter->first; | ||||
GE_CHECK_NOTNULL(out_anchor); | GE_CHECK_NOTNULL(out_anchor); | ||||
@@ -46,6 +46,14 @@ Status TransposeTransDataPass::Run(NodePtr &node) { | |||||
if (op_desc->GetType() != TRANSPOSED) { | if (op_desc->GetType() != TRANSPOSED) { | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
auto input_format = op_desc->GetInputDescPtr(0)->GetFormat(); | |||||
auto output_format = op_desc->GetOutputDescPtr(0)->GetFormat(); | |||||
if (input_format == output_format) { | |||||
GELOGW("Node %s input format is %s, output format is %s, should not happend. Ignore pass.", | |||||
op_desc->GetName().c_str(), TypeUtils::FormatToSerialString(input_format).c_str(), | |||||
TypeUtils::FormatToSerialString(output_format).c_str()); | |||||
return SUCCESS; | |||||
} | |||||
if (CheckOneInAndOneOutDataAnchor(node) != SUCCESS) { | if (CheckOneInAndOneOutDataAnchor(node) != SUCCESS) { | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -184,6 +184,11 @@ Status AippOp::InsertAippToGraph(ComputeGraphPtr &graph, std::string &aippConfig | |||||
GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
NodePtr target_input = nullptr; | NodePtr target_input = nullptr; | ||||
std::vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>> target_edges; | std::vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>> target_edges; | ||||
if (this->ConvertRelatedInputNameToRank() != SUCCESS) { | |||||
GELOGE(FAILED, "AippOp: convert related input name to rank failed."); | |||||
return FAILED; | |||||
} | |||||
GE_CHK_STATUS_RET(this->GetTargetPosition(graph, target_input, target_edges), "Get data nodes position failed"); | GE_CHK_STATUS_RET(this->GetTargetPosition(graph, target_input, target_edges), "Get data nodes position failed"); | ||||
std::map<OutDataAnchorPtr, NodePtr> out_anchors_to_aipp; | std::map<OutDataAnchorPtr, NodePtr> out_anchors_to_aipp; | ||||
@@ -412,6 +417,38 @@ Status AippOp::GetStaticTargetNode(const ComputeGraphPtr &graph, NodePtr &data_n | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status AippOp::ConvertRelatedInputNameToRank() { | |||||
GE_CHECK_NOTNULL(aipp_params_); | |||||
string related_input_name = aipp_params_->related_input_name(); | |||||
if (related_input_name.empty()) { | |||||
return SUCCESS; | |||||
} | |||||
std::vector<std::string> data_top_names = domi::GetContext().data_top_names; | |||||
GELOGI("Convert name to rank start: data size[%zu]", data_top_names.size()); | |||||
uint32_t index = 0; | |||||
bool convert_flag = false; | |||||
for (const auto &data_top_name : data_top_names) { | |||||
if (related_input_name == data_top_name) { | |||||
aipp_params_->set_related_input_rank(index); | |||||
convert_flag = true; | |||||
GELOGI("AippOp: rank: %u, top name: %s.", index, data_top_name.c_str()); | |||||
break; | |||||
} | |||||
index++; | |||||
} | |||||
if (!convert_flag) { | |||||
string error_msg = "Top name " + related_input_name + | |||||
"convert rank failed, Please" | |||||
" ensure top name in aipp config is the top name of data node."; | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {error_msg}); | |||||
GELOGE(PARAM_INVALID, "Top name[%s] converts rank failed.", related_input_name.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status AippOp::GetTargetPosition(ComputeGraphPtr graph, NodePtr &target_input, | Status AippOp::GetTargetPosition(ComputeGraphPtr graph, NodePtr &target_input, | ||||
std::vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>> &target_edges) { | std::vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>> &target_edges) { | ||||
@@ -79,6 +79,7 @@ class AippOp : public InsertOpBase { | |||||
Status AddNodeToGraph(const NodePtr &aipp_node, int64_t max_dynamic_aipp_size); | Status AddNodeToGraph(const NodePtr &aipp_node, int64_t max_dynamic_aipp_size); | ||||
Status AddAippAttrbutes(const OpDescPtr &op_desc, const std::string &aipp_cfg_path, const uint32_t &index); | Status AddAippAttrbutes(const OpDescPtr &op_desc, const std::string &aipp_cfg_path, const uint32_t &index); | ||||
Status AddAttrToAippData(const OpDescPtr &aipp_data_op_desc); | Status AddAttrToAippData(const OpDescPtr &aipp_data_op_desc); | ||||
Status ConvertRelatedInputNameToRank(); | |||||
domi::AippOpParams *aipp_params_ = nullptr; | domi::AippOpParams *aipp_params_ = nullptr; | ||||
ge::NodePtr aipp_node_ = nullptr; | ge::NodePtr aipp_node_ = nullptr; | ||||
@@ -115,23 +115,97 @@ void InsertNewOpUtil::ClearNewOps() { | |||||
} | } | ||||
} | } | ||||
Status InsertNewOpUtil::CheckPositionNotRepeat() { | |||||
Status InsertNewOpUtil::CheckInputNamePositionNotRepeat() { | |||||
for (int i = 0; i < insert_op_conf_->aipp_op_size(); i++) { | |||||
const domi::AippOpParams *item = insert_op_conf_->mutable_aipp_op(i); | |||||
GE_CHECK_NOTNULL(item); | |||||
for (int j = i + 1; j < insert_op_conf_->aipp_op_size(); j++) { | |||||
const domi::AippOpParams *another_item = insert_op_conf_->mutable_aipp_op(j); | |||||
GE_CHECK_NOTNULL(another_item); | |||||
if (another_item->related_input_name().empty()) { | |||||
string error_msg = | |||||
"Can not both set related_input_name and related_input_rank!" | |||||
" Please ensure param is the same with the first aipp config(related_input_name)."; | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {error_msg}); | |||||
GELOGE(PARAM_INVALID, | |||||
"Can not both set related_input_rank and related_input_name!" | |||||
" Please ensure param is the same with the first aipp config(related_input_name)."); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (item->related_input_name() == another_item->related_input_name()) { | |||||
string error_msg = | |||||
"Can not insert aipp to the same postion! Please ensure related_input_name" | |||||
" param is different in different aipp config."; | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {error_msg}); | |||||
GELOGE(PARAM_INVALID, | |||||
"Can not insert aipp op to the same postion! Please ensure related_input_rank param " | |||||
"is different in different aipp config."); | |||||
return PARAM_INVALID; | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status InsertNewOpUtil::CheckInputRankPositionNoRepeat() { | |||||
for (int i = 0; i < insert_op_conf_->aipp_op_size(); i++) { | for (int i = 0; i < insert_op_conf_->aipp_op_size(); i++) { | ||||
const domi::AippOpParams *item = insert_op_conf_->mutable_aipp_op(i); | const domi::AippOpParams *item = insert_op_conf_->mutable_aipp_op(i); | ||||
GE_CHECK_NOTNULL(item); | |||||
for (int j = i + 1; j < insert_op_conf_->aipp_op_size(); j++) { | for (int j = i + 1; j < insert_op_conf_->aipp_op_size(); j++) { | ||||
const domi::AippOpParams *another_item = insert_op_conf_->mutable_aipp_op(j); | const domi::AippOpParams *another_item = insert_op_conf_->mutable_aipp_op(j); | ||||
GE_IF_BOOL_EXEC(item->related_input_rank() == another_item->related_input_rank(), | |||||
string errormsg = | |||||
"Can not insert aipp to the same postion! Please ensure related_input_rank" | |||||
" param is different in different aipp config."; | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {errormsg}); | |||||
GELOGE(PARAM_INVALID, | |||||
"Can not insert aipp op to the same postion! Please ensure related_input_rank param " | |||||
"is different in different aipp config."); | |||||
return PARAM_INVALID;); | |||||
GE_CHECK_NOTNULL(another_item); | |||||
if (!another_item->related_input_name().empty()) { | |||||
string error_msg = | |||||
"Can not both set related_input_rank and related_input_name!" | |||||
" Please ensure param is the same with the first aipp config(related_input_rank)."; | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {error_msg}); | |||||
GELOGE(PARAM_INVALID, | |||||
"Can not both set related_input_rank and related_input_name!" | |||||
" Please ensure param is the same with the first aipp config(related_input_rank)."); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (item->related_input_rank() == another_item->related_input_rank()) { | |||||
string error_msg = | |||||
"Can not insert aipp to the same postion! Please ensure related_input_rank" | |||||
" param is different in different aipp config."; | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {error_msg}); | |||||
GELOGE(PARAM_INVALID, | |||||
"Can not insert aipp op to the same postion! Please ensure related_input_rank param " | |||||
"is different in different aipp config."); | |||||
return PARAM_INVALID; | |||||
} | |||||
} | } | ||||
} | } | ||||
return SUCCESS; | |||||
} | |||||
Status InsertNewOpUtil::CheckPositionNotRepeat() { | |||||
GE_CHECK_NOTNULL(insert_op_conf_); | |||||
if (insert_op_conf_->aipp_op_size() <= 1) { | |||||
GELOGI("Aipp op size[%d] less than 2, no need to check position repeat.", insert_op_conf_->aipp_op_size()); | |||||
return SUCCESS; | |||||
} | |||||
const domi::AippOpParams *item = insert_op_conf_->mutable_aipp_op(0); | |||||
GE_CHECK_NOTNULL(item); | |||||
string related_input_name = item->related_input_name(); | |||||
Status ret = FAILED; | |||||
if (related_input_name.empty()) { | |||||
ret = CheckInputRankPositionNoRepeat(); | |||||
} else { | |||||
ret = CheckInputNamePositionNotRepeat(); | |||||
} | |||||
if (ret != SUCCESS) { | |||||
GELOGE(FAILED, "Check position not repeat failed."); | |||||
return FAILED; | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -51,6 +51,10 @@ class InsertNewOpUtil { | |||||
Status GetAippParams(const std::unique_ptr<domi::AippOpParams> &aippParams, const ge::NodePtr &aipp_node); | Status GetAippParams(const std::unique_ptr<domi::AippOpParams> &aippParams, const ge::NodePtr &aipp_node); | ||||
Status CheckInputNamePositionNotRepeat(); | |||||
Status CheckInputRankPositionNoRepeat(); | |||||
Status CheckGraph(const ge::ComputeGraphPtr &graph); | Status CheckGraph(const ge::ComputeGraphPtr &graph); | ||||
InsertNewOpUtil() = default; | InsertNewOpUtil() = default; | ||||
@@ -28,7 +28,6 @@ target_include_directories(host_cpu_engine PRIVATE | |||||
${GE_CODE_DIR}/inc | ${GE_CODE_DIR}/inc | ||||
${GE_CODE_DIR}/inc/external | ${GE_CODE_DIR}/inc/external | ||||
${GE_CODE_DIR}/inc/framework | ${GE_CODE_DIR}/inc/framework | ||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
${METADEF_DIR}/inc | ${METADEF_DIR}/inc | ||||
${METADEF_DIR}/inc/external | ${METADEF_DIR}/inc/external | ||||
${METADEF_DIR}/inc/external/graph | ${METADEF_DIR}/inc/external/graph | ||||
@@ -36,6 +35,8 @@ target_include_directories(host_cpu_engine PRIVATE | |||||
${CMAKE_BINARY_DIR}/proto/ge | ${CMAKE_BINARY_DIR}/proto/ge | ||||
#### yellow zone #### | #### yellow zone #### | ||||
${GE_CODE_DIR}/../inc | ${GE_CODE_DIR}/../inc | ||||
#### blue zone #### | |||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
) | ) | ||||
target_link_libraries(host_cpu_engine PRIVATE | target_link_libraries(host_cpu_engine PRIVATE | ||||
@@ -67,7 +68,6 @@ target_include_directories(atc_host_cpu_engine PRIVATE | |||||
${GE_CODE_DIR}/inc | ${GE_CODE_DIR}/inc | ||||
${GE_CODE_DIR}/inc/external | ${GE_CODE_DIR}/inc/external | ||||
${GE_CODE_DIR}/inc/framework | ${GE_CODE_DIR}/inc/framework | ||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
${METADEF_DIR}/inc | ${METADEF_DIR}/inc | ||||
${METADEF_DIR}/inc/external | ${METADEF_DIR}/inc/external | ||||
${METADEF_DIR}/inc/external/graph | ${METADEF_DIR}/inc/external/graph | ||||
@@ -75,6 +75,8 @@ target_include_directories(atc_host_cpu_engine PRIVATE | |||||
${CMAKE_BINARY_DIR}/proto/ge | ${CMAKE_BINARY_DIR}/proto/ge | ||||
#### yellow zone #### | #### yellow zone #### | ||||
${GE_CODE_DIR}/../inc | ${GE_CODE_DIR}/../inc | ||||
#### blue zone #### | |||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
) | ) | ||||
target_link_libraries(atc_host_cpu_engine PRIVATE | target_link_libraries(atc_host_cpu_engine PRIVATE | ||||
@@ -107,7 +109,6 @@ target_include_directories(host_cpu_opskernel_builder PRIVATE | |||||
${GE_CODE_DIR}/inc | ${GE_CODE_DIR}/inc | ||||
${GE_CODE_DIR}/inc/external | ${GE_CODE_DIR}/inc/external | ||||
${GE_CODE_DIR}/inc/framework | ${GE_CODE_DIR}/inc/framework | ||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
${METADEF_DIR}/inc | ${METADEF_DIR}/inc | ||||
${METADEF_DIR}/inc/external | ${METADEF_DIR}/inc/external | ||||
${METADEF_DIR}/inc/external/graph | ${METADEF_DIR}/inc/external/graph | ||||
@@ -115,6 +116,8 @@ target_include_directories(host_cpu_opskernel_builder PRIVATE | |||||
${CMAKE_BINARY_DIR}/proto/ge | ${CMAKE_BINARY_DIR}/proto/ge | ||||
#### yellow zone #### | #### yellow zone #### | ||||
${GE_CODE_DIR}/../inc | ${GE_CODE_DIR}/../inc | ||||
#### blue zone #### | |||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
) | ) | ||||
target_link_libraries(host_cpu_opskernel_builder PRIVATE | target_link_libraries(host_cpu_opskernel_builder PRIVATE | ||||
@@ -141,7 +144,6 @@ target_include_directories(atc_host_cpu_opskernel_builder PRIVATE | |||||
${GE_CODE_DIR}/inc | ${GE_CODE_DIR}/inc | ||||
${GE_CODE_DIR}/inc/external | ${GE_CODE_DIR}/inc/external | ||||
${GE_CODE_DIR}/inc/framework | ${GE_CODE_DIR}/inc/framework | ||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
${METADEF_DIR}/inc | ${METADEF_DIR}/inc | ||||
${METADEF_DIR}/inc/external | ${METADEF_DIR}/inc/external | ||||
${METADEF_DIR}/inc/external/graph | ${METADEF_DIR}/inc/external/graph | ||||
@@ -149,6 +151,8 @@ target_include_directories(atc_host_cpu_opskernel_builder PRIVATE | |||||
${CMAKE_BINARY_DIR}/proto/ge | ${CMAKE_BINARY_DIR}/proto/ge | ||||
#### yellow zone #### | #### yellow zone #### | ||||
${GE_CODE_DIR}/../inc | ${GE_CODE_DIR}/../inc | ||||
#### blue zone #### | |||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
) | ) | ||||
target_link_libraries(atc_host_cpu_opskernel_builder PRIVATE | target_link_libraries(atc_host_cpu_opskernel_builder PRIVATE | ||||
@@ -180,7 +184,6 @@ target_include_directories(host_cpu_opskernel_builder_static PRIVATE | |||||
${GE_CODE_DIR}/inc | ${GE_CODE_DIR}/inc | ||||
${GE_CODE_DIR}/inc/external | ${GE_CODE_DIR}/inc/external | ||||
${GE_CODE_DIR}/inc/framework | ${GE_CODE_DIR}/inc/framework | ||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
${METADEF_DIR}/inc | ${METADEF_DIR}/inc | ||||
${METADEF_DIR}/inc/external | ${METADEF_DIR}/inc/external | ||||
${METADEF_DIR}/inc/external/graph | ${METADEF_DIR}/inc/external/graph | ||||
@@ -188,6 +191,8 @@ target_include_directories(host_cpu_opskernel_builder_static PRIVATE | |||||
${CMAKE_BINARY_DIR}/proto/ge | ${CMAKE_BINARY_DIR}/proto/ge | ||||
#### yellow zone #### | #### yellow zone #### | ||||
${GE_CODE_DIR}/../inc | ${GE_CODE_DIR}/../inc | ||||
#### blue zone #### | |||||
${GE_CODE_DIR}/third_party/fwkacllib/inc | |||||
) | ) | ||||
target_link_libraries(host_cpu_opskernel_builder_static PRIVATE | target_link_libraries(host_cpu_opskernel_builder_static PRIVATE | ||||
@@ -15,6 +15,7 @@ | |||||
*/ | */ | ||||
#include "hybrid/node_executor/aicpu/aicpu_node_executor.h" | #include "hybrid/node_executor/aicpu/aicpu_node_executor.h" | ||||
#include "cce/taskdown_common.hpp" | |||||
#include "common/formats/formats.h" | #include "common/formats/formats.h" | ||||
#include "aicpu/common/aicpu_task_struct.h" | #include "aicpu/common/aicpu_task_struct.h" | ||||
#include "graph/load/new_model_manager/model_manager.h" | #include "graph/load/new_model_manager/model_manager.h" | ||||
@@ -593,6 +594,15 @@ Status AicpuNodeTask::Init(const HybridModel &model) { | |||||
auto &args = kernel_def.args(); | auto &args = kernel_def.args(); | ||||
args_size_ = kernel_def.args_size(); | args_size_ = kernel_def.args_size(); | ||||
const std::string &so_name = kernel_def.so_name(); | |||||
const OpDescPtr op_desc = MakeShared<OpDesc>(*(node_item_->op_desc)); | |||||
const auto &context = kernel_def.context(); | |||||
auto kernel_type = static_cast<cce::ccKernelType>(context.kernel_type()); | |||||
if (kernel_type == cce::ccKernelType::CUST_AI_CPU) { | |||||
GE_CHK_STATUS_RET(ModelManager::GetInstance()->LoadCustAicpuSo(op_desc, so_name), "load cust aicpu so failed."); | |||||
GE_CHK_STATUS_RET(ModelManager::GetInstance()->LaunchCustAicpuSo(), "Launch cust aicpu so failed."); | |||||
} | |||||
GE_CHK_BOOL_RET_STATUS(args.size() == args_size_, FAILED, "Node[%s] task def args.size=%zu, but args_size=%u.", | GE_CHK_BOOL_RET_STATUS(args.size() == args_size_, FAILED, "Node[%s] task def args.size=%zu, but args_size=%u.", | ||||
node_name.c_str(), args.size(), args_size_); | node_name.c_str(), args.size(), args_size_); | ||||
@@ -676,7 +686,12 @@ Status AicpuNodeTask::LaunchTask(TaskContext &context) { | |||||
GELOGI("Node[%s] launch task start. unknown_type=%d.", node_name_.c_str(), unknown_type_); | GELOGI("Node[%s] launch task start. unknown_type=%d.", node_name_.c_str(), unknown_type_); | ||||
const auto &so_name = task_def_.kernel().so_name(); | const auto &so_name = task_def_.kernel().so_name(); | ||||
const auto &kernel_name = task_def_.kernel().kernel_name(); | const auto &kernel_name = task_def_.kernel().kernel_name(); | ||||
const auto &kcontext = task_def_.kernel().context(); | |||||
auto kernel_type = static_cast<cce::ccKernelType>(kcontext.kernel_type()); | |||||
uint32_t flag = RT_KERNEL_DEFAULT; | uint32_t flag = RT_KERNEL_DEFAULT; | ||||
if (kernel_type == cce::ccKernelType::CUST_AI_CPU) { | |||||
flag |= RT_KERNEL_CUSTOM_AICPU; | |||||
} | |||||
auto rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(so_name.c_str()), | auto rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(so_name.c_str()), | ||||
reinterpret_cast<const void *>(kernel_name.c_str()), | reinterpret_cast<const void *>(kernel_name.c_str()), | ||||
1, // default core dim is 1 | 1, // default core dim is 1 | ||||
@@ -438,6 +438,12 @@ graphStatus aclgrphInferShapeAndType(ge::Graph &graph) { | |||||
auto compute_graph = GraphUtils::GetComputeGraph(graph); | auto compute_graph = GraphUtils::GetComputeGraph(graph); | ||||
GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
auto ret = compute_graph->InferOriginFormat(); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(ret, "Acl InferOriginFormat failed."); | |||||
return ret; | |||||
} | |||||
for (auto &node : compute_graph->GetAllNodes()) { | for (auto &node : compute_graph->GetAllNodes()) { | ||||
graphStatus ret = ShapeRefiner::InferShapeAndType(node); | graphStatus ret = ShapeRefiner::InferShapeAndType(node); | ||||
if (ret == GRAPH_PARAM_INVALID) { | if (ret == GRAPH_PARAM_INVALID) { | ||||
@@ -1 +1 @@ | |||||
optimizer:["aicpu_tf_optimizer","AIcoreEngine","VectorEngine","aicpu_ascend_optimizer","hccl_graph_optimizer", "hvd_graph_optimizer", "DNN_VM_RTS_GRAPH_OPTIMIZER_STORE"] | |||||
optimizer:["aicpu_tf_optimizer","aicpu_ascend_optimizer","AIcoreEngine","VectorEngine","hccl_graph_optimizer", "hvd_graph_optimizer", "DNN_VM_RTS_GRAPH_OPTIMIZER_STORE"] |
@@ -995,8 +995,10 @@ FMK_FUNC_HOST_VISIBILITY Status ConvertFwkModelToJson(const domi::FrameworkType | |||||
ErrorManager::GetInstance().ATCReportErrMessage( | ErrorManager::GetInstance().ATCReportErrMessage( | ||||
"E10001", {"parameter", "value", "reason"}, | "E10001", {"parameter", "value", "reason"}, | ||||
{"--framework", std::to_string(framework), "only support 0(Caffe) 3(TensorFlow)"}); | |||||
GELOGE(PARAM_INVALID, "Input parameter[--framework] is mandatory and it's value must be: 0(Caffe) 3(TensorFlow)."); | |||||
{"--framework", std::to_string(framework), "only support 0(Caffe) 3(TensorFlow) 5(Onnx)"}); | |||||
GELOGE(PARAM_INVALID, | |||||
"Input parameter[--framework] is mandatory and it's value must be: 0(Caffe) 3(TensorFlow) " | |||||
"or 5(Onnx)."); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -1039,6 +1041,7 @@ void UpdateOmgCtxWithParserCtx() { | |||||
domi::GetContext().out_top_names = GetParserContext().out_top_names; | domi::GetContext().out_top_names = GetParserContext().out_top_names; | ||||
domi::GetContext().user_out_nodes_top_vec = GetParserContext().user_out_nodes_top_vec; | domi::GetContext().user_out_nodes_top_vec = GetParserContext().user_out_nodes_top_vec; | ||||
domi::GetContext().default_out_nodes = GetParserContext().default_out_nodes; | domi::GetContext().default_out_nodes = GetParserContext().default_out_nodes; | ||||
domi::GetContext().data_top_names = GetParserContext().data_top_names; | |||||
} | } | ||||
void UpdateParserCtxWithOmgCtx() { | void UpdateParserCtxWithOmgCtx() { | ||||
@@ -1055,5 +1058,6 @@ void UpdateParserCtxWithOmgCtx() { | |||||
GetParserContext().input_nodes_format_map = domi::GetContext().input_nodes_format_map; | GetParserContext().input_nodes_format_map = domi::GetContext().input_nodes_format_map; | ||||
GetParserContext().out_top_names = domi::GetContext().out_top_names; | GetParserContext().out_top_names = domi::GetContext().out_top_names; | ||||
GetParserContext().user_out_nodes_top_vec = domi::GetContext().user_out_nodes_top_vec; | GetParserContext().user_out_nodes_top_vec = domi::GetContext().user_out_nodes_top_vec; | ||||
GetParserContext().data_top_names = domi::GetContext().data_top_names; | |||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -31,6 +31,7 @@ | |||||
#include "task/aicpu_task_builder.h" | #include "task/aicpu_task_builder.h" | ||||
#include "task/aicpu_kernel_task_builder.h" | #include "task/aicpu_kernel_task_builder.h" | ||||
#include "task/tbe_task_builder.h" | #include "task/tbe_task_builder.h" | ||||
#include "graph/load/new_model_manager/model_manager.h" | |||||
static std::atomic<std::uint64_t> aicpu_sessionid(0); | static std::atomic<std::uint64_t> aicpu_sessionid(0); | ||||
@@ -187,6 +188,7 @@ Status SingleOpModel::LoadAllNodes() { | |||||
} | } | ||||
ge_model->GetTBEKernelStore().LoadTBEKernelBinToOpDesc(op_desc); | ge_model->GetTBEKernelStore().LoadTBEKernelBinToOpDesc(op_desc); | ||||
ge_model->GetCustAICPUKernelStore().LoadCustAICPUKernelBinToOpDesc(op_desc); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -244,7 +246,7 @@ Status SingleOpModel::BuildTaskList(SingleOp &single_op) { | |||||
single_op.arg_table_.resize(single_op.input_sizes_.size() + single_op.output_sizes_.size()); | single_op.arg_table_.resize(single_op.input_sizes_.size() + single_op.output_sizes_.size()); | ||||
ParseArgTable(tbe_task, single_op); | ParseArgTable(tbe_task, single_op); | ||||
single_op.tasks_.emplace_back(tbe_task); | single_op.tasks_.emplace_back(tbe_task); | ||||
} else if (kernel_type == cce::ccKernelType::AI_CPU) { | |||||
} else if (kernel_type == cce::ccKernelType::AI_CPU || kernel_type == cce::ccKernelType::CUST_AI_CPU) { | |||||
GELOGD("Building AICPU_CC task"); | GELOGD("Building AICPU_CC task"); | ||||
OpTask *task = nullptr; | OpTask *task = nullptr; | ||||
auto ret = BuildCpuKernelTask(task_def.kernel(), &task); | auto ret = BuildCpuKernelTask(task_def.kernel(), &task); | ||||
@@ -253,7 +255,7 @@ Status SingleOpModel::BuildTaskList(SingleOp &single_op) { | |||||
} | } | ||||
single_op.tasks_.emplace_back(task); | single_op.tasks_.emplace_back(task); | ||||
} else { | } else { | ||||
GELOGE(UNSUPPORTED, "Only TBE kernel and AI_CPU kernel are supported, but got %u", context.kernel_type()); | |||||
GELOGE(UNSUPPORTED, "Only TBE, AI_CPU, CUST_AI_CPU kernel are supported, but got %u", context.kernel_type()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
} else if (task_type == RT_MODEL_TASK_KERNEL_EX) { | } else if (task_type == RT_MODEL_TASK_KERNEL_EX) { | ||||
@@ -273,6 +275,7 @@ Status SingleOpModel::BuildTaskList(SingleOp &single_op) { | |||||
GELOGD("Skip task type: %d", static_cast<int>(task_type)); | GELOGD("Skip task type: %d", static_cast<int>(task_type)); | ||||
} | } | ||||
} | } | ||||
GE_CHK_STATUS_RET(ModelManager::GetInstance()->LaunchCustAicpuSo(), "launch cust aicpu so failed."); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -388,13 +391,13 @@ Status SingleOpModel::BuildModelTaskKernel(const TaskDef &task_def, DynamicSingl | |||||
TbeOpTask *tbe_task = nullptr; | TbeOpTask *tbe_task = nullptr; | ||||
GE_CHK_STATUS_RET_NOLOG(BuildKernelTask(task_def.kernel(), &tbe_task)); | GE_CHK_STATUS_RET_NOLOG(BuildKernelTask(task_def.kernel(), &tbe_task)); | ||||
single_op.op_task_.reset(tbe_task); | single_op.op_task_.reset(tbe_task); | ||||
} else if (kernel_type == cce::ccKernelType::AI_CPU) { | |||||
} else if (kernel_type == cce::ccKernelType::AI_CPU || kernel_type == cce::ccKernelType::CUST_AI_CPU) { | |||||
GELOGD("Building AICPU_CC task"); | GELOGD("Building AICPU_CC task"); | ||||
OpTask *task = nullptr; | OpTask *task = nullptr; | ||||
GE_CHK_STATUS_RET_NOLOG(BuildCpuKernelTask(task_def.kernel(), &task)); | GE_CHK_STATUS_RET_NOLOG(BuildCpuKernelTask(task_def.kernel(), &task)); | ||||
single_op.op_task_.reset(task); | single_op.op_task_.reset(task); | ||||
} else { | } else { | ||||
GELOGE(UNSUPPORTED, "Only TBE kernel and AI_CPU kernel are supported, but got %u", context.kernel_type()); | |||||
GELOGE(UNSUPPORTED, "Only TBE, AI_CPU, CUST_AI_CPU kernel are supported, but got %u", context.kernel_type()); | |||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -444,6 +447,7 @@ Status SingleOpModel::BuildTaskListForDynamicOp(DynamicSingleOp &single_op) { | |||||
GELOGD("Skip task type: %d", static_cast<int>(task_type)); | GELOGD("Skip task type: %d", static_cast<int>(task_type)); | ||||
} | } | ||||
} | } | ||||
GE_CHK_STATUS_RET(ModelManager::GetInstance()->LaunchCustAicpuSo(), "launch cust aicpu so failed."); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -15,6 +15,8 @@ | |||||
*/ | */ | ||||
#include "single_op/task/aicpu_kernel_task_builder.h" | #include "single_op/task/aicpu_kernel_task_builder.h" | ||||
#include "cce/taskdown_common.hpp" | |||||
#include "graph/load/new_model_manager/model_manager.h" | |||||
namespace ge { | namespace ge { | ||||
AiCpuCCTaskBuilder::AiCpuCCTaskBuilder(const OpDescPtr &op_desc, const domi::KernelDef &kernel_def) | AiCpuCCTaskBuilder::AiCpuCCTaskBuilder(const OpDescPtr &op_desc, const domi::KernelDef &kernel_def) | ||||
@@ -55,6 +57,14 @@ Status AiCpuCCTaskBuilder::BuildTask(AiCpuCCTask &task) { | |||||
task.SetkernelName(kernel_name); | task.SetkernelName(kernel_name); | ||||
task.op_desc_ = op_desc_; | task.op_desc_ = op_desc_; | ||||
const auto &context = kernel_def_.context(); | |||||
auto kernel_type = static_cast<cce::ccKernelType>(context.kernel_type()); | |||||
if (kernel_type == cce::ccKernelType::CUST_AI_CPU) { | |||||
task.is_custom_ = true; | |||||
task.dump_flag_ |= RT_KERNEL_CUSTOM_AICPU; | |||||
GE_CHK_STATUS_RET(ModelManager::GetInstance()->LoadCustAicpuSo(op_desc_, so_name), "launch cust aicpu so failed"); | |||||
} | |||||
task.num_inputs_ = op_desc_->GetInputsSize(); | task.num_inputs_ = op_desc_->GetInputsSize(); | ||||
task.num_outputs_ = op_desc_->GetOutputsSize(); | task.num_outputs_ = op_desc_->GetOutputsSize(); | ||||
@@ -45,6 +45,7 @@ std::vector<std::vector<void *>> BuildTaskUtils::GetAddresses(const OpDescPtr &o | |||||
runtime_para.logic_var_base = kLogicVarBase; | runtime_para.logic_var_base = kLogicVarBase; | ||||
runtime_para.var_base = kVarBase; | runtime_para.var_base = kVarBase; | ||||
runtime_para.session_id = kSessionId; | runtime_para.session_id = kSessionId; | ||||
runtime_para.is_single_op = true; | |||||
ret.emplace_back(ModelUtils::GetInputDataAddrs(runtime_para, op_desc)); | ret.emplace_back(ModelUtils::GetInputDataAddrs(runtime_para, op_desc)); | ||||
ret.emplace_back(ModelUtils::GetOutputDataAddrs(runtime_para, op_desc)); | ret.emplace_back(ModelUtils::GetOutputDataAddrs(runtime_para, op_desc)); | ||||
@@ -260,8 +260,8 @@ Status AiCpuBaseTask::SetExtInfoAndType(const std::string &kernel_ext_info) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status AiCpuBaseTask::UpdateExtInfo(const std::vector<GeTensorDesc> &input_desc, | |||||
std::vector<GeTensorDesc> &output_desc) { | |||||
Status AiCpuBaseTask::UpdateExtInfo(const std::vector<GeTensorDesc> &input_desc, std::vector<GeTensorDesc> &output_desc, | |||||
rtStream_t stream) { | |||||
GELOGI("Update ext info begin, unknown_type=%d.", unknown_type_); | GELOGI("Update ext info begin, unknown_type=%d.", unknown_type_); | ||||
if (num_inputs_ == 0 && num_outputs_ == 0) { | if (num_inputs_ == 0 && num_outputs_ == 0) { | ||||
GELOGI("No input and output, no need update ext info."); | GELOGI("No input and output, no need update ext info."); | ||||
@@ -278,15 +278,13 @@ Status AiCpuBaseTask::UpdateExtInfo(const std::vector<GeTensorDesc> &input_desc, | |||||
for (size_t j = 0; j < num_outputs_; ++j) { | for (size_t j = 0; j < num_outputs_; ++j) { | ||||
GE_CHK_STATUS_RET(aicpu_ext_handle_->UpdateOutputShapeAndType(j, output_desc[j]), | GE_CHK_STATUS_RET(aicpu_ext_handle_->UpdateOutputShapeAndType(j, output_desc[j]), | ||||
"Output[%zu] UpdateOutputShapeAndType failed.", j); | "Output[%zu] UpdateOutputShapeAndType failed.", j); | ||||
// debug code | |||||
GELOGD("No input and output, no need update ext info."); | |||||
} | } | ||||
} | } | ||||
GE_CHK_RT_RET(rtMemcpy(ext_info_addr_dev_, | |||||
aicpu_ext_handle_->GetExtInfoLen(), // check size | |||||
aicpu_ext_handle_->GetExtInfo(), aicpu_ext_handle_->GetExtInfoLen(), | |||||
RT_MEMCPY_HOST_TO_DEVICE)); | |||||
GE_CHK_RT_RET(rtMemcpyAsync(ext_info_addr_dev_, | |||||
aicpu_ext_handle_->GetExtInfoLen(), // check size | |||||
aicpu_ext_handle_->GetExtInfo(), aicpu_ext_handle_->GetExtInfoLen(), | |||||
RT_MEMCPY_HOST_TO_DEVICE_EX, stream)); | |||||
GELOGI("Update ext info end."); | GELOGI("Update ext info end."); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -599,7 +597,7 @@ Status AiCpuTask::SetMemCopyTask(const domi::KernelExDef &kernel_def) { | |||||
Status AiCpuTask::LaunchKernel(const std::vector<GeTensorDesc> &input_desc, | Status AiCpuTask::LaunchKernel(const std::vector<GeTensorDesc> &input_desc, | ||||
const std::vector<DataBuffer> &input_buffers, std::vector<GeTensorDesc> &output_desc, | const std::vector<DataBuffer> &input_buffers, std::vector<GeTensorDesc> &output_desc, | ||||
std::vector<DataBuffer> &output_buffers, rtStream_t stream) { | std::vector<DataBuffer> &output_buffers, rtStream_t stream) { | ||||
GE_CHK_STATUS_RET_NOLOG(UpdateExtInfo(input_desc, output_desc)); | |||||
GE_CHK_STATUS_RET_NOLOG(UpdateExtInfo(input_desc, output_desc, stream)); | |||||
std::vector<void *> inputs; | std::vector<void *> inputs; | ||||
std::vector<void *> outputs; | std::vector<void *> outputs; | ||||
for (auto &buffer : input_buffers) { | for (auto &buffer : input_buffers) { | ||||
@@ -610,11 +608,12 @@ Status AiCpuTask::LaunchKernel(const std::vector<GeTensorDesc> &input_desc, | |||||
} | } | ||||
GE_CHK_STATUS_RET_NOLOG(SetIO(inputs, outputs)); | GE_CHK_STATUS_RET_NOLOG(SetIO(inputs, outputs)); | ||||
GE_CHK_STATUS_RET_NOLOG(LaunchKernel(stream)); | GE_CHK_STATUS_RET_NOLOG(LaunchKernel(stream)); | ||||
GE_CHK_RT_RET(rtStreamSynchronize(stream)); | |||||
if (unknown_type_ == DEPEND_SHAPE_RANGE) { | if (unknown_type_ == DEPEND_SHAPE_RANGE) { | ||||
GE_CHK_RT_RET(rtStreamSynchronize(stream)); | |||||
GE_CHK_STATUS_RET_NOLOG(UpdateOutputShape(output_desc)); | GE_CHK_STATUS_RET_NOLOG(UpdateOutputShape(output_desc)); | ||||
} else if (unknown_type_ == DEPEND_COMPUTE) { | } else if (unknown_type_ == DEPEND_COMPUTE) { | ||||
GE_CHK_RT_RET(rtStreamSynchronize(stream)); | |||||
GE_CHK_STATUS_RET_NOLOG(UpdateShapeAndDataByResultSummary(output_desc, output_buffers, stream)); | GE_CHK_STATUS_RET_NOLOG(UpdateShapeAndDataByResultSummary(output_desc, output_buffers, stream)); | ||||
} | } | ||||
@@ -647,9 +646,9 @@ Status AiCpuCCTask::LaunchKernel(rtStream_t stream) { | |||||
kernel_name_.data()); | kernel_name_.data()); | ||||
// sm_desc is nullptr, because l2 buffer does not support | // sm_desc is nullptr, because l2 buffer does not support | ||||
auto *sm_desc = reinterpret_cast<rtSmDesc_t *>(sm_desc_); | auto *sm_desc = reinterpret_cast<rtSmDesc_t *>(sm_desc_); | ||||
auto ret = | |||||
rtCpuKernelLaunch(static_cast<const void *>(so_name_.data()), static_cast<const void *>(kernel_name_.data()), | |||||
block_dim_, args_.get(), static_cast<uint32_t>(arg_size_), sm_desc, stream); | |||||
auto ret = rtCpuKernelLaunchWithFlag(static_cast<const void *>(so_name_.data()), | |||||
static_cast<const void *>(kernel_name_.data()), block_dim_, args_.get(), | |||||
static_cast<uint32_t>(arg_size_), sm_desc, stream, dump_flag_); | |||||
if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Invoke rtCpuKernelLaunch failed. ret = %d", ret); | GELOGE(RT_FAILED, "Invoke rtCpuKernelLaunch failed. ret = %d", ret); | ||||
return RT_FAILED; | return RT_FAILED; | ||||
@@ -665,7 +664,7 @@ Status AiCpuCCTask::LaunchKernel(const std::vector<GeTensorDesc> &input_desc, | |||||
GE_CHK_BOOL_RET_STATUS(unknown_type_ != DEPEND_COMPUTE, FAILED, | GE_CHK_BOOL_RET_STATUS(unknown_type_ != DEPEND_COMPUTE, FAILED, | ||||
"AiCpuCCTask unknown type[%d] is depend compute, it's not supported now.", unknown_type_); | "AiCpuCCTask unknown type[%d] is depend compute, it's not supported now.", unknown_type_); | ||||
GE_CHK_STATUS_RET_NOLOG(UpdateExtInfo(input_desc, output_desc)); | |||||
GE_CHK_STATUS_RET_NOLOG(UpdateExtInfo(input_desc, output_desc, stream)); | |||||
size_t arg_index = 0; | size_t arg_index = 0; | ||||
auto *task_io_addr = reinterpret_cast<uintptr_t *>(io_addr_); | auto *task_io_addr = reinterpret_cast<uintptr_t *>(io_addr_); | ||||
@@ -678,9 +677,9 @@ Status AiCpuCCTask::LaunchKernel(const std::vector<GeTensorDesc> &input_desc, | |||||
} | } | ||||
GE_CHK_STATUS_RET_NOLOG(LaunchKernel(stream)); | GE_CHK_STATUS_RET_NOLOG(LaunchKernel(stream)); | ||||
GE_CHK_RT_RET(rtStreamSynchronize(stream)); | |||||
if (unknown_type_ == DEPEND_SHAPE_RANGE) { | if (unknown_type_ == DEPEND_SHAPE_RANGE) { | ||||
GE_CHK_RT_RET(rtStreamSynchronize(stream)); | |||||
GE_CHK_STATUS_RET_NOLOG(UpdateOutputShape(output_desc)); | GE_CHK_STATUS_RET_NOLOG(UpdateOutputShape(output_desc)); | ||||
} | } | ||||
@@ -118,7 +118,8 @@ class AiCpuBaseTask : public OpTask { | |||||
protected: | protected: | ||||
Status SetExtInfoAndType(const std::string &kernel_ext_info); | Status SetExtInfoAndType(const std::string &kernel_ext_info); | ||||
Status UpdateExtInfo(const std::vector<GeTensorDesc> &input_desc, std::vector<GeTensorDesc> &output_desc); | |||||
Status UpdateExtInfo(const std::vector<GeTensorDesc> &input_desc, std::vector<GeTensorDesc> &output_desc, | |||||
rtStream_t stream); | |||||
Status UpdateOutputShape(vector<GeTensorDesc> &output_desc); | Status UpdateOutputShape(vector<GeTensorDesc> &output_desc); | ||||
Status UpdateShapeToOutputDesc(const GeShape &shape_new, GeTensorDesc &output_desc); | Status UpdateShapeToOutputDesc(const GeShape &shape_new, GeTensorDesc &output_desc); | ||||
@@ -214,6 +215,8 @@ class AiCpuCCTask : public AiCpuBaseTask { | |||||
uint32_t block_dim_ = 1; | uint32_t block_dim_ = 1; | ||||
void *sm_desc_ = nullptr; | void *sm_desc_ = nullptr; | ||||
void *io_addr_ = nullptr; | void *io_addr_ = nullptr; | ||||
bool is_custom_ = false; | |||||
uint32_t dump_flag_ = RT_KERNEL_DEFAULT; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -61,6 +61,9 @@ message AippOpParams { | |||||
// 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 | ||||
uint32 related_input_rank = 2; | uint32 related_input_rank = 2; | ||||
// related_input_name is optional and the top name of data node which inserts aipp | |||||
string related_input_name = 6; | |||||
// input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 | ||||
// 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 | ||||
// 配置值 <= Data算子输出边的个数。 | // 配置值 <= Data算子输出边的个数。 | ||||
@@ -68,8 +68,10 @@ struct MemRegisterAddr { | |||||
u64 addr; | u64 addr; | ||||
u64 length; | u64 length; | ||||
}; | }; | ||||
const u32 HCCL_MAX_MEM_REGISTER_NUM = 1024 * 1024; // The max number of memory register address is 1M (1024 * 1024). | |||||
/* | |||||
* @brief The max number of memory register address for remote access. | |||||
*/ | |||||
const u32 HCCL_MAX_MEM_REGISTER_NUM = 32; | |||||
enum GradSplitForceMode { | enum GradSplitForceMode { | ||||
FORCE_NONE, /**< no force */ | FORCE_NONE, /**< no force */ | ||||
@@ -2240,6 +2240,64 @@ REG_OP(OutfeedEnqueueOp) | |||||
.ATTR(channel_name, String, "") | .ATTR(channel_name, String, "") | ||||
.OP_END_FACTORY_REG(OutfeedEnqueueOp) | .OP_END_FACTORY_REG(OutfeedEnqueueOp) | ||||
/** | |||||
*@brief LruCache, create cache resource. | |||||
*@par Inputs: | |||||
*No input. | |||||
*@par Attributes: | |||||
*cache_size: cache size An optional "int64". Defaults to "100000". | |||||
*load_factor: rate which show if cache is full An optional "float", Defaults to "1". | |||||
*@par Outputs: | |||||
*cache: cache resource. | |||||
*@par Restrictions: | |||||
*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. | |||||
*/ | |||||
REG_OP(LruCache) | |||||
.OUTPUT(cache, TensorType({DT_RESOURCE})) | |||||
.ATTR(container, String, "") | |||||
.ATTR(shared_name, String, "LruCache") | |||||
.ATTR(cache_size, Int, 100000) | |||||
.ATTR(load_factor, Float, 1) | |||||
.OP_END_FACTORY_REG(LruCache) | |||||
/** | |||||
*@brief CacheAdd, get id new come in cache and id get out of cache. | |||||
*@par Inputs: | |||||
*cache: resource data | |||||
*ids: Tensor stored id need to insert cache | |||||
*@par Outputs: | |||||
*swap_in_id: id come in cache. | |||||
*swap_in_idx: id in cache which come in cache | |||||
*swap_out_id: id get out of cache | |||||
*swap_out_idx: id in cache which get out of cache | |||||
*@par Restrictions: | |||||
*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. | |||||
*/ | |||||
REG_OP(CacheAdd) | |||||
.INPUT(cache, TensorType({DT_RESOURCE})) | |||||
.INPUT(ids, TensorType({DT_INT64, DT_INT32, DT_UINT64, DT_UINT32})) | |||||
.OUTPUT(swap_in_id, TensorType({DT_INT64, DT_INT32, DT_UINT64, DT_UINT32})) | |||||
.OUTPUT(swap_in_idx, TensorType({DT_INT64})) | |||||
.OUTPUT(swap_out_id, TensorType({DT_INT64, DT_INT32, DT_UINT64, DT_UINT32})) | |||||
.OUTPUT(swap_out_idx, TensorType({DT_INT64})) | |||||
.OP_END_FACTORY_REG(CacheAdd) | |||||
/** | |||||
*@brief CacheRemoteToLocalIndex, get id in cache from id. | |||||
*@par Inputs: | |||||
*cache: resource data | |||||
*ids: Tensor stored id need to insert cache | |||||
*@par Outputs: | |||||
*local_idx: id in cache. | |||||
*@par Restrictions: | |||||
*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. | |||||
*/ | |||||
REG_OP(CacheRemoteIndexToLocal) | |||||
.INPUT(cache, TensorType({DT_RESOURCE})) | |||||
.INPUT(ids, TensorType({DT_INT64, DT_INT32, DT_UINT64, DT_UINT32})) | |||||
.OUTPUT(local_idx, TensorType({DT_INT64})) | |||||
.OP_END_FACTORY_REG(CacheRemoteIndexToLocal) | |||||
} // namespace ge | } // namespace ge | ||||
#endif // OPS_BUILT_IN_OP_PROTO_INC_DATA_FLOW_OPS_H_ | #endif // OPS_BUILT_IN_OP_PROTO_INC_DATA_FLOW_OPS_H_ |
@@ -2803,6 +2803,80 @@ REG_OP(AdamApplyOneAssign) | |||||
.OP_END_FACTORY_REG(AdamApplyOneAssign) | .OP_END_FACTORY_REG(AdamApplyOneAssign) | ||||
/** | /** | ||||
*@brief A fusion operator for bert lamb. \n | |||||
*@par Inputs: | |||||
*Ten inputs, including: | |||||
* @li input0: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li input1: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li input2: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li input3: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li input4: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li mul0_x: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li mul1_x: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li mul2_x: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li mul3_x: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li steps: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li do_use_weight: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li weight_decay_rate: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li add2_y: A Tensor. Must be one of the following types: float16, float32. \n | |||||
*@par Outputs: | |||||
*Three outputs, including: | |||||
* @li output0: A Tensor. Must be one of the following types: float16, float32. \n | |||||
*@par Restrictions: | |||||
*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. | |||||
*/ | |||||
REG_OP(LambApplyOptimizerAssign) | |||||
.INPUT(input0, TensorType({DT_FLOAT16,DT_FLOAT})) | |||||
.INPUT(input1, TensorType({DT_FLOAT16,DT_FLOAT})) | |||||
.INPUT(input2, TensorType({DT_FLOAT16,DT_FLOAT})) | |||||
.INPUT(input3, TensorType({DT_FLOAT16,DT_FLOAT})) | |||||
.INPUT(mul0_x, TensorType({DT_FLOAT16,DT_FLOAT})) | |||||
.INPUT(mul1_x, TensorType({DT_FLOAT16,DT_FLOAT})) | |||||
.INPUT(mul2_x, TensorType({DT_FLOAT16,DT_FLOAT})) | |||||
.INPUT(mul3_x, TensorType({DT_FLOAT16,DT_FLOAT})) | |||||
.INPUT(add2_y, TensorType({DT_FLOAT16,DT_FLOAT})) | |||||
.INPUT(steps, TensorType({DT_FLOAT16,DT_FLOAT})) | |||||
.INPUT(do_use_weight, TensorType({DT_FLOAT16,DT_FLOAT})) | |||||
.INPUT(weight_decay_rate, TensorType({DT_FLOAT16,DT_FLOAT})) | |||||
.OUTPUT(output0, TensorType({DT_FLOAT16,DT_FLOAT})) | |||||
.OP_END_FACTORY_REG(LambApplyOptimizerAssign) | |||||
/** | |||||
*@brief A fusion operator for bert lamb. \n | |||||
*@par Inputs: | |||||
*Ten inputs, including: | |||||
* @li input0: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li input1: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li input2: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li input3: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li input4: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li mul0_x: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li mul1_x: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li mul2_x: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li mul3_x: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li steps: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li do_use_weight: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li weight_decay_rate: A Tensor. Must be one of the following types: float16, float32. | |||||
* @li add2_y: A Tensor. Must be one of the following types: float16, float32. \n | |||||
*@par Outputs: | |||||
*No outputs | |||||
*@par Restrictions: | |||||
*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. | |||||
*/ | |||||
REG_OP(LambApplyWeightAssign) | |||||
.INPUT(input0, TensorType({DT_FLOAT16,DT_FLOAT})) | |||||
.INPUT(input1, TensorType({DT_FLOAT16,DT_FLOAT})) | |||||
.INPUT(input2, TensorType({DT_FLOAT16,DT_FLOAT})) | |||||
.INPUT(input3, TensorType({DT_FLOAT16,DT_FLOAT})) | |||||
.INPUT(input4, TensorType({DT_FLOAT16,DT_FLOAT})) | |||||
.OP_END_FACTORY_REG(LambApplyWeightAssign) | |||||
/** | |||||
*@brief Confuse select, maximum, greater and sqrt. \n | *@brief Confuse select, maximum, greater and sqrt. \n | ||||
*@par Inputs: | *@par Inputs: | ||||
@@ -495,51 +495,51 @@ REG_OP(NextAfter) | |||||
.OP_END_FACTORY_REG(NextAfter) | .OP_END_FACTORY_REG(NextAfter) | ||||
/** | /** | ||||
* *@brief Compute element-wise finiteness, return a boolean tensor. | |||||
* | |||||
* *@par Inputs: | |||||
* *x:A Tensor. | |||||
* | |||||
* *@par Outputs: | |||||
* *y:A Tensor. Has the same shape as x. | |||||
* | |||||
* *@par Third-party framework compatibility. | |||||
* *Compatible with tensorflow IsFinite operator. | |||||
* */ | |||||
*@brief Compute element-wise finiteness, return a boolean tensor. | |||||
*@par Inputs: | |||||
*x:A Tensor. | |||||
*@par Outputs: | |||||
*y:A Tensor. Has the same shape as x. | |||||
*@par Third-party framework compatibility. | |||||
*Compatible with tensorflow IsFinite operator. | |||||
*/ | |||||
REG_OP(IsFinite) | REG_OP(IsFinite) | ||||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) | .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) | ||||
.OUTPUT(y, TensorType({DT_BOOL})) | .OUTPUT(y, TensorType({DT_BOOL})) | ||||
.OP_END_FACTORY_REG(IsFinite) | .OP_END_FACTORY_REG(IsFinite) | ||||
/** | /** | ||||
* *@brief Compute element-wise infiniteness, return a boolean tensor. | |||||
* | |||||
* *@par Inputs: | |||||
* *x:A Tensor. | |||||
* | |||||
* *@par Outputs: | |||||
* *y:A Tensor. Has the same shape as x. | |||||
* | |||||
* *@par Third-party framework compatibility. | |||||
* *Compatible with tensorflow IsInf operator. | |||||
* */ | |||||
*@brief Compute element-wise infiniteness, return a boolean tensor. | |||||
*@par Inputs: | |||||
*x:A Tensor. | |||||
*@par Outputs: | |||||
*y:A Tensor. Has the same shape as x. | |||||
*@par Third-party framework compatibility. | |||||
*Compatible with tensorflow IsInf operator. | |||||
*/ | |||||
REG_OP(IsInf) | REG_OP(IsInf) | ||||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) | .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) | ||||
.OUTPUT(y, TensorType({DT_BOOL})) | .OUTPUT(y, TensorType({DT_BOOL})) | ||||
.OP_END_FACTORY_REG(IsInf) | .OP_END_FACTORY_REG(IsInf) | ||||
/** | /** | ||||
* *@brief Computes the complex absolute value of a tensor. | |||||
* | |||||
* *@par Inputs: | |||||
* *x:A Tensor. | |||||
* | |||||
* *@par Outputs: | |||||
* *y:A tensor of type `float` or `double` that is the absolute value of each element in `x`. | |||||
* | |||||
* *@par Third-party framework compatibility. | |||||
* *Compatible with tensorflow ComplexAbs operator. | |||||
* */ | |||||
*@brief Computes the complex absolute value of a tensor. | |||||
*@par Inputs: | |||||
*x:A Tensor. | |||||
*@par Outputs: | |||||
*y:A tensor of type `float` or `double` that is the absolute value of each element in `x`. | |||||
*@par Third-party framework compatibility. | |||||
*Compatible with tensorflow ComplexAbs operator. | |||||
*/ | |||||
REG_OP(ComplexAbs) | REG_OP(ComplexAbs) | ||||
.INPUT(x, TensorType({DT_COMPLEX64, DT_COMPLEX128})) | .INPUT(x, TensorType({DT_COMPLEX64, DT_COMPLEX128})) | ||||
.OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE})) | .OUTPUT(y, TensorType({DT_FLOAT, DT_DOUBLE})) | ||||
@@ -547,34 +547,34 @@ REG_OP(ComplexAbs) | |||||
.OP_END_FACTORY_REG(ComplexAbs) | .OP_END_FACTORY_REG(ComplexAbs) | ||||
/** | /** | ||||
* *@brief Returns which elements of x are NaN. | |||||
* | |||||
* *@par Inputs: | |||||
* *x:A Tensor. | |||||
* | |||||
* *@par Outputs: | |||||
* *y:A Tensor. Has the same shape as x. | |||||
* | |||||
* *@par Third-party framework compatibility. | |||||
* *Compatible with tensorflow IsNan operator. | |||||
* */ | |||||
*@brief Returns which elements of x are NaN. | |||||
*@par Inputs: | |||||
*x:A Tensor. | |||||
*@par Outputs: | |||||
*y:A Tensor. Has the same shape as x. | |||||
*@par Third-party framework compatibility. | |||||
*Compatible with tensorflow IsNan operator. | |||||
*/ | |||||
REG_OP(IsNan) | REG_OP(IsNan) | ||||
.INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) | .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) | ||||
.OUTPUT(y, TensorType({DT_BOOL})) | .OUTPUT(y, TensorType({DT_BOOL})) | ||||
.OP_END_FACTORY_REG(IsNan) | .OP_END_FACTORY_REG(IsNan) | ||||
/** | /** | ||||
* *@brief Returns the real part of a complex number. | |||||
* | |||||
* *@par Inputs: | |||||
* *input:A Tensor. | |||||
* | |||||
* *@par Outputs: | |||||
* *output:A Tensor. Has the same shape as input. | |||||
* | |||||
* *@par Third-party framework compatibility. | |||||
* *Compatible with tensorflow Real operator. | |||||
* */ | |||||
*@brief Returns the real part of a complex number. | |||||
*@par Inputs: | |||||
*input:A Tensor. | |||||
*@par Outputs: | |||||
*output:A Tensor. Has the same shape as input. | |||||
*@par Third-party framework compatibility. | |||||
*Compatible with tensorflow Real operator. | |||||
*/ | |||||
REG_OP(Real) | REG_OP(Real) | ||||
.INPUT(input, TensorType({DT_COMPLEX64, DT_COMPLEX128})) | .INPUT(input, TensorType({DT_COMPLEX64, DT_COMPLEX128})) | ||||
.OUTPUT(output, TensorType({DT_FLOAT, DT_DOUBLE})) | .OUTPUT(output, TensorType({DT_FLOAT, DT_DOUBLE})) | ||||
@@ -582,17 +582,17 @@ REG_OP(Real) | |||||
.OP_END_FACTORY_REG(Real) | .OP_END_FACTORY_REG(Real) | ||||
/** | /** | ||||
* *@brief Returns the complex conjugate of a complex number. | |||||
* | |||||
* *@par Inputs: | |||||
* *input:A Tensor. | |||||
* | |||||
* *@par Outputs: | |||||
* *output:A Tensor. Has the same shape as input. | |||||
* | |||||
* *@par Third-party framework compatibility. | |||||
* *Compatible with tensorflow output operator. | |||||
* */ | |||||
*@brief Returns the complex conjugate of a complex number. | |||||
*@par Inputs: | |||||
*input:A Tensor. | |||||
*@par Outputs: | |||||
*output:A Tensor. Has the same shape as input. | |||||
*@par Third-party framework compatibility. | |||||
*Compatible with tensorflow output operator. | |||||
*/ | |||||
REG_OP(Conj) | REG_OP(Conj) | ||||
.INPUT(input, TensorType({DT_COMPLEX64, DT_COMPLEX128})) | .INPUT(input, TensorType({DT_COMPLEX64, DT_COMPLEX128})) | ||||
.OUTPUT(output, TensorType({DT_COMPLEX64, DT_COMPLEX128})) | .OUTPUT(output, TensorType({DT_COMPLEX64, DT_COMPLEX128})) | ||||
@@ -698,15 +698,14 @@ REG_OP(IFMR) | |||||
*@par Inputs: | *@par Inputs: | ||||
*@li w:A Tensor of weights. \n | *@li w:A Tensor of weights. \n | ||||
*@li w_min:A Tensor of weights reduce_min. \n | |||||
*@li w_max:A Tensor of weights reduce_max. \n | |||||
*@par Attributes: | *@par Attributes: | ||||
*axes: specify channel. | |||||
*num_bits: the bits num used for quantize. | *num_bits: the bits num used for quantize. | ||||
*offset_flag: whether using offset. \n | *offset_flag: whether using offset. \n | ||||
*@par Outputs: | *@par Outputs: | ||||
*scale: quantization factor scale. | |||||
*offset: quantization factor offset. | |||||
*y: fake quantized weights. \n | *y: fake quantized weights. \n | ||||
*@par Third-party framework compatibility | *@par Third-party framework compatibility | ||||
@@ -715,10 +714,9 @@ REG_OP(IFMR) | |||||
REG_OP(WtsARQ) | REG_OP(WtsARQ) | ||||
.INPUT(w, TensorType({DT_FLOAT16, DT_FLOAT})) | .INPUT(w, TensorType({DT_FLOAT16, DT_FLOAT})) | ||||
.OUTPUT(scale, TensorType({DT_FLOAT16, DT_FLOAT})) | |||||
.OUTPUT(offset, TensorType({DT_FLOAT16, DT_FLOAT})) | |||||
.INPUT(w_min, TensorType({DT_FLOAT16, DT_FLOAT})) | |||||
.INPUT(w_max, TensorType({DT_FLOAT16, DT_FLOAT})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) | .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) | ||||
.ATTR(axes, ListInt, {0}) | |||||
.ATTR(num_bits, Int, 8) | .ATTR(num_bits, Int, 8) | ||||
.ATTR(offset_flag, Bool, false) | .ATTR(offset_flag, Bool, false) | ||||
.OP_END_FACTORY_REG(WtsARQ) | .OP_END_FACTORY_REG(WtsARQ) | ||||
@@ -582,103 +582,105 @@ REG_OP(Conv2DBackpropFilterD) | |||||
/** | /** | ||||
*@brief Computes a 2D convolution given 4D "x" and "filter" tensors. | *@brief Computes a 2D convolution given 4D "x" and "filter" tensors. | ||||
*@par Inputs: | *@par Inputs: | ||||
*@li x: A 4D tensor of input images. With "NHWC" format, the shape is | |||||
* [batch, in_height, in_width, in_channels]. | |||||
*@li filter: A 4D tensor of filters. Has the same type as "x". With "HWCN" | |||||
* format, the shape is [filter_height, filter_width, in_channels, | |||||
* out_channels]. | |||||
*@li bias: An optional 1D tensor. Shape is [out_channels]. | |||||
*@li offset_w: An optional 1D tensor for quantized convolution. Shape is | |||||
* [out_channels]. Not supported. | |||||
*@li x: A 4D tensor of input image. With the format "NHWC", the data is stored | |||||
* in the order of: [batch, in_height, in_width, in_channels]. | |||||
*@li filter: A 4D tensor of learnable filters. Must have the same type as "x". | |||||
* With the format "HWCN" , the data is stored in the order of: [filter_height, | |||||
* filter_width, in_channels / groups, out_channels]. | |||||
*@li bias: An optional 1D tensor of additive biases to the filter outputs. | |||||
* The data is stored in the order of: [out_channels]. | |||||
*@li offset_w: Reserved. | |||||
*\n | *\n | ||||
*\n | *\n | ||||
* Note that there is a strict data type mapping between the input and output | |||||
* tensors: | |||||
* The following are the supported data types and data formats: | |||||
*@verbatim | *@verbatim | ||||
|Tensor | x | filter | bias | offset_w | y | |||||
-----------|---------|---------|---------|----------|-------- | |||||
|Data Type | float16 | float16 | float16 | _ | float16 | |||||
| |---------|---------|---------|----------|-------- | |||||
| | float32 | float32 | float32 | _ | float32 | |||||
| |---------|---------|---------|----------|-------- | |||||
| | int8 | int8 | int32 | int8 | int32 | |||||
-----------|---------|---------|---------|----------|-------- | |||||
|Format | NCHW | NCHW | ND | ND | NCHW | |||||
| | NHWC | HWCN | | | NHWC | |||||
| Tensor | x | filter | bias | y | |||||
------------|---------|---------|---------|-------- | |||||
| Data Type | float16 | float16 | float16 | float16 | |||||
| |---------|---------|---------|-------- | |||||
| | float32 | float32 | float32 | float32 | |||||
| |---------|---------|---------|-------- | |||||
| | int8 | int8 | int32 | int32 | |||||
------------|---------|---------|---------|-------- | |||||
| Format | NCHW | NCHW | ND | NCHW | |||||
| | NHWC | HWCN | | NHWC | |||||
@endverbatim | @endverbatim | ||||
* Type float32 is allowed only in mixed precision (float32->float16) scenarios. | |||||
* Mixed precision is enabled by default. | |||||
* \n | |||||
* For float32 type, the actual calculation on the chip is based on | |||||
* float16. For int8, a dequant or requant operator must be followed. | |||||
*\n | |||||
* | * | ||||
*@par Attributes: | *@par Attributes: | ||||
*@li strides: Required. A list of 4 integers. Specifying the strides of the | |||||
* convolution along the height and width. The dimension order is determined | |||||
* by the data format of "x". By default the N and C dimensions are set to 1. | |||||
*@li pads: Required. A list of 4 integers. Specifying the top, bottom, left | |||||
* and right padding. | |||||
* @li dilations: Optional. A list of 4 integers. Specifying the dilation rate | |||||
* to use for dilated convolution. Has the same dimension order and value as | |||||
* "strides". Dilation > 1 is not supported for quantized convolution. Defaults | |||||
* to [1, 1, 1, 1]. | |||||
* @li groups: Optional. An integer of type int32, for the number of blocked | |||||
* connections from input channels to output channels. Input channels and output | |||||
* channels must both be divisible by "groups". "x" in_channels must be equal to | |||||
* "filter" in_channels * groups. Defaults to 1. | |||||
* @li offset_x: Optional. An integer of type int32, for quantized convolution. | |||||
* Defaults to 0. | |||||
* @li data_format: Reserved and optional. A string from: "NHWC" and "NCHW". | |||||
* Specifying the data format of the input and output images. Defaults to | |||||
* "NHWC". | |||||
*@li strides: Required. A list of 4 integers. The stride of the sliding window | |||||
* for each dimension of input. The dimension order is determined by the data | |||||
* format of "x". The N and C dimensions must be set to 1. | |||||
*@li pads: Required. A list of 4 integers. The number of pixels to add to each | |||||
* (top, bottom, left, right) side of the input. | |||||
*@li dilations: Optional. A list of 4 integers. The dilation factor for each | |||||
* dimension of input. The dimension order is determined by the data format of | |||||
* "x". The N and C dimensions must be set to 1. The H and W dimensions must be | |||||
* set to 1 for int8 type. Defaults to [1, 1, 1, 1]. | |||||
*@li groups: Optional. An integer of type int32. The number of blocked | |||||
* connections from input channels to output channels. In_channels and | |||||
* out_channels must both be divisible by "groups". Defaults to 1. | |||||
*@li offset_x: Optional. An integer of type int32. The negative offset added | |||||
* to the input image for int8 type. Ensure that the output is within the | |||||
* effective range. Defaults to 0. | |||||
*@li data_format: Reserved. | |||||
*\n | *\n | ||||
*\n | *\n | ||||
* The following value range restrictions must be met: | * The following value range restrictions must be met: | ||||
*@verbatim | *@verbatim | ||||
|Name | Field | Scope | |||||
------------------|----------|---------- | |||||
|Input Image Size | H | [1, 100000] | |||||
| | W | [1, 4096] | |||||
------------------|----------|---------- | |||||
|Filter Size | H | [1, 255] | |||||
| | W | [1, 255] | |||||
------------------|----------|---------- | |||||
|Stride | H | [1, 63] | |||||
| | W | [1, 63] | |||||
------------------|----------|---------- | |||||
|Padding | top | [0, 255] | |||||
| | bottom | [0, 255] | |||||
| | left | [0, 255] | |||||
| | right | [0, 255] | |||||
------------------|----------|---------- | |||||
|Dilation | H | [1, 255] | |||||
| | W | [1, 255] | |||||
| Name | Field | Scope | |||||
-------------------|----------|-------------- | |||||
| Input Image Size | H | [1, 100000] | |||||
| | W | [1, 4096] | |||||
-------------------|----------|-------------- | |||||
| Filter Size | H | [1, 255] | |||||
| | W | [1, 255] | |||||
-------------------|----------|-------------- | |||||
| Stride | H | [1, 63] | |||||
| | W | [1, 63] | |||||
-------------------|----------|-------------- | |||||
| Padding | Top | [0, 255] | |||||
| | Bottom | [0, 255] | |||||
| | Left | [0, 255] | |||||
| | Right | [0, 255] | |||||
-------------------|----------|-------------- | |||||
| Dilation | H | [1, 255] | |||||
| | W | [1, 255] | |||||
-------------------|----------|-------------- | |||||
| Offset_x | | [-128, 127] | |||||
@endverbatim | @endverbatim | ||||
*\n | |||||
* | * | ||||
*@par Outputs: | *@par Outputs: | ||||
*@li y: A 4D Tensor of output images. Has the same type and format as "x". With | |||||
* "NHWC" format, the shape is [batch, out_height, out_width, out_channels]. | |||||
*@li y: A 4D Tensor of output feature map. Has the same type as "x". With the | |||||
* format "NHWC", the data is stored in the order of: [batch, out_height, | |||||
* out_width, out_channels]. | |||||
*\n | *\n | ||||
* out_height = (in_height + top_pad + bottom_pad - | |||||
* dilation_h * (filter_height - 1) - 1) | |||||
* out_height = (in_height + pad_top + pad_bottom - | |||||
* (dilation_h * (filter_height - 1) + 1)) | |||||
* / stride_h + 1 | * / stride_h + 1 | ||||
*\n | *\n | ||||
* out_width = (in_width + left_pad + right_pad - | |||||
* dilation_w * (filter_width - 1) - 1) | |||||
* / stride_w + 1 | |||||
* out_width = (in_width + pad_left + pad_right - | |||||
* (dilation_w * (filter_width - 1) + 1)) | |||||
* / stride_w + 1 | |||||
* | * | ||||
*@attention Constraints: | *@attention Constraints: | ||||
*@li The following restrictions on the output must be met: | *@li The following restrictions on the output must be met: | ||||
*@verbatim | *@verbatim | ||||
| Output | Restrictions | |||||
-------------------|--------------------------- | |||||
| W dimension == 1 | H*W(input) == H*W(filter) | |||||
| H dimension == 1 | | |||||
-------------------|--------------------------- | |||||
| W dimension == 1 | Not supported | |||||
| H dimension != 1 | | |||||
| Output | Restrictions | |||||
----------|-------------------------------- | |||||
| H == 1 | H * W(input) == H * W(filter) | |||||
| W == 1 | | |||||
----------|-------------------------------- | |||||
| H != 1 | W(input) == W(filter) | |||||
| W == 1 | Only for Ascend310 Hi3796V300CS | |||||
@endverbatim | @endverbatim | ||||
* "H * W (input)" indicates the image size after padding and "H * W (filter)" | * "H * W (input)" indicates the image size after padding and "H * W (filter)" | ||||
* indicates the filter size after dilation. | |||||
* indicates the filter size after dilation."W(input)" and W(filter) indicate | |||||
* the same rule on the W dimension. | |||||
*\n | *\n | ||||
* | * | ||||
*@par Quantization supported or not | *@par Quantization supported or not | ||||
@@ -767,106 +769,112 @@ REG_OP(Conv2DCompress) | |||||
.OP_END_FACTORY_REG(Conv2DCompress) | .OP_END_FACTORY_REG(Conv2DCompress) | ||||
/** | /** | ||||
*@brief Computes a 2D convolution given 4D "x", "filter" and "offsets" | |||||
* tensors. | |||||
*@brief Computes a 2D deformable convolution given 4D "x", "filter" and | |||||
* "offsets" tensors. | |||||
*@par Inputs: | *@par Inputs: | ||||
* @li x: A 4D tensor of input images. With shape of | |||||
* [batch, in_height, in_width, in_channels] when format is "NHWC". | |||||
* @li filter: A 4D tensor of filters. Must have the same type as "x". With | |||||
* shape of [filter_height, filter_width, in_channels, out_channels] when format | |||||
* is "HWCN". | |||||
* @li offsets: A 4D tensor of offsets. With shape of | |||||
* [batch, deformable_groups * filter_height * filter_width * 3, in_height, | |||||
* in_width] when format is "NCHW". | |||||
* @li bias: An optional 1D tensor. Shape is [out_channels]. | |||||
* | |||||
* The input and output tensor attributes are listed as follows: | |||||
* @verbatim | |||||
|Tensor | x | filter | offsets | bias | y | |||||
-----------|---------|---------|---------|----------|-------- | |||||
|Data Type | float16 | float16 | float16 | float16 | float16 | |||||
-----------|---------|---------|---------|----------|-------- | |||||
|Format | NCHW | NCHW | NCHW | ND | NCHW | |||||
| | NHWC | HWCN | | | NHWC | |||||
*@li x: A 4D tensor of input image. With the format "NHWC", the data is stored | |||||
* in the order of: [batch, in_height, in_width, in_channels]. | |||||
*@li filter: A 4D tensor of learnable filters. Must have the same type as "x". | |||||
* With the format "HWCN" , the data is stored in the order of: [filter_height, | |||||
* filter_width, in_channels / groups, out_channels]. | |||||
*@li offsets: A 4D tensor of x-y coordinates offset and mask. With the format | |||||
* "NHWC", the data is stored in the order of: [batch, out_height, out_width, | |||||
* deformable_groups * filter_height * filter_width * 3]. | |||||
*@li bias: An optional 1D tensor of additive biases to the filter outputs. | |||||
* The data is stored in the order of: [out_channels]. | |||||
*\n | |||||
*\n | |||||
* The following are the supported data types and data formats: | |||||
*@verbatim | |||||
| Tensor | x | filter | offsets | bias | y | |||||
------------|---------|---------|---------|----------|-------- | |||||
| Data Type | float16 | float16 | float16 | float16 | float16 | |||||
------------|---------|---------|---------|----------|-------- | |||||
| Format | NCHW | NCHW | NCHW | ND | NCHW | |||||
| | NHWC | HWCN | NHWC | | NHWC | |||||
@endverbatim | @endverbatim | ||||
* It should be noted that the data types must correspond to each other, but | |||||
* the format does not need to. | |||||
*\n | |||||
* | |||||
*@par Attributes: | *@par Attributes: | ||||
* @li strides: Required. A list of 4 integers. Specifying the strides of the | |||||
* convolution along the height and width. The dimension order is determined | |||||
* by the data format of "x". By default the N and C dimensions are set to 1. | |||||
* @li pads: Required. A list of 4 integers. Specifying the top, bottom, left | |||||
* and right padding. | |||||
* @li dilations: Optional. A list of 4 integers. Specifying the dilation rate | |||||
* to use for dilated convolution. Has the same dimension order and value as | |||||
* "strides". | |||||
* @li groups: Optional. Number of blocked connections from input channels to | |||||
* output channels. Input channels and output channels must both be divisible | |||||
* by "groups".Type is int32. | |||||
* @li data_format: Optional. An optional string from: "NHWC", "NCHW". Specifying the | |||||
* data format of the input and output images. Type is string. Defaults to | |||||
* "NHWC". Reserved. | |||||
* @li deformable_groups: Optional. Cut the c chanel of input X into deformable_groups, | |||||
* each share a different offsets. Input channels must be divisible by | |||||
* "deformable_groups". Type is int32. | |||||
*@par Outputs: | |||||
* @li y: A 4D Tensor of output images. Must have the same type and format as | |||||
* "x". With shape of [batch, out_channels, out_height, out_width] when format | |||||
* is "NHWC". | |||||
* @li output_height = (in_height + top_pad + botton_pad - | |||||
* dilation_h * (filter_height - 1) -1) / stride_h + 1 | |||||
* @li output_width = (in_width + left_pad + right_pad - | |||||
* dilation_w * (filter_width - 1) -1) / stride_w + 1 | |||||
*@attention | |||||
* @li The parameter scope is listed as follows: | |||||
* @verbatim | |||||
|Name | Field | Scope | |||||
------------------|--------------|---------------------------------------- | |||||
|Input Image Size | H dimension | 1 <= in_height * filter_height <= 4096 | |||||
| | W dimension | 1 <= in_width * filter_width <=4096 | |||||
------------------|--------------|---------------------------------------- | |||||
|Filter Size | H dimension | [1, 255] | |||||
| | W dimension | [1, 255] | |||||
------------------|--------------|---------------------------------------- | |||||
|offsets Size | C dimension | offsets_c = deformable_groups * | |||||
| | | filter_width * filter_height * 3 | |||||
| | H dimension | the same as output H dimension | |||||
| | W dimension | the same as output W dimension | |||||
------------------|--------------|---------------------------------------- | |||||
|Stride Size | H dimension | [1, 63] | |||||
| | W dimension | [1, 63] | |||||
------------------|--------------|---------------------------------------- | |||||
|Padding Size | top side | [0, 255] | |||||
| | bottom side | [0, 255] | |||||
| | left side | [0, 255] | |||||
| | right side | [0, 255] | |||||
------------------|--------------|---------------------------------------- | |||||
|Dilation Size | H dimension | [1, 255] | |||||
| | W dimension | [1, 255] | |||||
*@li strides: Required. A list of 4 integers. The stride of the sliding window | |||||
* for each dimension of input. The dimension order is interpreted according to | |||||
* the value of data_format. The N and C dimensions must be set to 1. | |||||
*@li pads: Required. A list of 4 integers. The number of pixels to add to each | |||||
* (top, bottom, left, right) side of the input. | |||||
*@li dilations: Optional. A list of 4 integers. The dilation factor for each | |||||
* dimension of input. The dimension order is interpreted according to the value | |||||
* of data_format The N and C dimensions must be set to 1. Defaults to | |||||
* [1, 1, 1, 1]. | |||||
*@li groups: Optional. An integer of type int32. The number of blocked | |||||
* connections from input channels to output channels. In_channels and | |||||
* out_channels must both be divisible by "groups". Defaults to 1. | |||||
*@li data_format: Optional. An optional string from: "NHWC", "NCHW". Specify | |||||
* the data format of the input and output data. Defaults to "NHWC". | |||||
*@li deformable_groups: Optional. An integer of type int32. The number of | |||||
* deformable group partitions. In_channels must be divisible by | |||||
* "deformable_groups". Defaults to 1. | |||||
*\n | |||||
*\n | |||||
* The following value range restrictions must be met: | |||||
*@verbatim | |||||
| Name | Field | Scope | |||||
--------------------|--------|---------------------------- | |||||
| Input Image Size | H | [1, 100000 / H(filter)] | |||||
| | W | [1, 4096 / W(filter)] | |||||
--------------------|--------|---------------------------- | |||||
| Filter Size | H | [1, 255] | |||||
| | W | [1, 255] | |||||
--------------------|--------|---------------------------- | |||||
| Stride | H | [1, 63] | |||||
| | W | [1, 63] | |||||
--------------------|--------|---------------------------- | |||||
| Padding | Top | [0, 255] | |||||
| | Bottom | [0, 255] | |||||
| | Left | [0, 255] | |||||
| | Right | [0, 255] | |||||
------------ -------|--------|---------------------------- | |||||
| Dilation | H | [1, 255] | |||||
| | W | [1, 255] | |||||
@endverbatim | @endverbatim | ||||
* @li There are restrictions for certain scenarios: | |||||
* @verbatim | |||||
| Output | Restrictions | |||||
-------------------|--------------------------- | |||||
| W dimension == 1 | HxW(input) == HxW(filter) | |||||
| H dimension == 1 | | |||||
-------------------|--------------------------- | |||||
| W dimension == 1 | Not supported | |||||
| H dimension != 1 | | |||||
* "W(input)" indicate the image width after padding and W(filter) indicates the | |||||
* filter width after dilation. | |||||
*\n | |||||
* | |||||
*@par Outputs: | |||||
*@li y: A 4D Tensor of output feature map. Has the same type as "x". With the | |||||
* format "NHWC", the data is stored in the order of: [batch, out_height, | |||||
* out_width, out_channels]. | |||||
*\n | |||||
* out_height = (in_height + pad_top + pad_bottom - | |||||
* (dilation_h * (filter_height - 1) + 1)) | |||||
* / stride_h + 1 | |||||
*\n | |||||
* out_width = (in_width + pad_left + pad_right - | |||||
* (dilation_w * (filter_width - 1) + 1)) | |||||
* / stride_w + 1 | |||||
* | |||||
*@attention Constraints: | |||||
*@li The following restrictions on the output must be met: | |||||
*@verbatim | |||||
| Output | Restrictions | |||||
----------|-------------------------------- | |||||
| H == 1 | H * W(input) == H * W(filter) | |||||
| W == 1 | | |||||
----------|-------------------------------- | |||||
| H != 1 | W(input) == W(filter) | |||||
| W == 1 | Only for Ascend310 Hi3796V300CS | |||||
@endverbatim | @endverbatim | ||||
* As shown above, "HxW(input)" indicates the image size after padding and | |||||
* "HxW(filter)" indicates the filter size after dilation. | |||||
* "H * W(input)" indicates the image size after padding and "H * W(filter)" | |||||
* indicates the filter size after dilation. "W(input)" and W(filter) indicate | |||||
* the same rule on the W dimension. | |||||
* | |||||
*@par Quantization supported or not | *@par Quantization supported or not | ||||
* Yes | |||||
*@li No | |||||
* | |||||
*@par Third-party framework compatibility | *@par Third-party framework compatibility | ||||
*@li Compatible with the TensorFlow operator "conv2d". | |||||
*@li Compatible with the Caffe operator 2D "Convolution". | |||||
*@li Compatible with the Mxnet operator "DeformableConvolution". | |||||
*@li Compatible with the Paddlepaddle operator "deformable_conv". | |||||
*@li Compatible with the Mmcv operator "deform_conv". | |||||
*/ | */ | ||||
REG_OP(DeformableConv2D) | REG_OP(DeformableConv2D) | ||||
.INPUT(x, TensorType({DT_FLOAT16})) | .INPUT(x, TensorType({DT_FLOAT16})) | ||||
@@ -1194,8 +1194,8 @@ REG_OP(MaxPoolGradWithArgmaxV2) | |||||
* @par Inputs: | * @par Inputs: | ||||
* One input: | * One input: | ||||
* x: An NC1HWC0 Tensor. Supported type:float16, float32, double, int8, int16, | |||||
* int32, int64, uint8, uint16, qint8 | |||||
* x: An NC1HWC0 Tensor. Supported type:float16, float32, double, int32, int64, | |||||
* uint8, int16, int8, uint16, qint8 | |||||
* @par Attributes: | * @par Attributes: | ||||
* @li ksize: A required list of int8, int16, int32, or int64 values, | * @li ksize: A required list of int8, int16, int32, or int64 values, | ||||
@@ -1206,14 +1206,14 @@ REG_OP(MaxPoolGradWithArgmaxV2) | |||||
* the input tensor. No default value. | * the input tensor. No default value. | ||||
* @li padding_mode: A required string. Defaults to "CALCULATED". | * @li padding_mode: A required string. Defaults to "CALCULATED". | ||||
* @li pads:A required list of int8, int16, int32, or int64 values, | * @li pads:A required list of int8, int16, int32, or int64 values, | ||||
* a data to caculate when padding_mode is "SAME" and "CALCULATED". | |||||
* a data to caculate when padding_mode is "CALCULATED". | |||||
* @li data_format: An optional string. Defaults to "NHWC" . | * @li data_format: An optional string. Defaults to "NHWC" . | ||||
* @li global_pooling bool, Whether to use the global pooling. | * @li global_pooling bool, Whether to use the global pooling. | ||||
* If global_pooling = true, kernel size and paddings will be ignored. | * If global_pooling = true, kernel size and paddings will be ignored. | ||||
* Default False | * Default False | ||||
* @li ceil_mode:global_pooling (bool) – (bool) Whether to use the global pooling. | |||||
* If global_pooling = true, kernel size and paddings will be ignored. | |||||
* Default False \n | |||||
* @li ceil_mode: Whether to use the ceil function to calculate output | |||||
* height and width. False is the default. If it is set to False, | |||||
* the floor function will be used. Default False \n | |||||
* @par Outputs: | * @par Outputs: | ||||
* y: A Tensor. Has the same type and format as input "x" . \n | * y: A Tensor. Has the same type and format as input "x" . \n | ||||
@@ -1230,8 +1230,8 @@ REG_OP(MaxPoolGradWithArgmaxV2) | |||||
* Compatible with the TensorFlow operator MaxPool. | * Compatible with the TensorFlow operator MaxPool. | ||||
*/ | */ | ||||
REG_OP(MaxPoolV3) | REG_OP(MaxPoolV3) | ||||
.INPUT(x,TensorType({DT_FLOAT16, DT_FLOAT32})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT32})) | |||||
.INPUT(x,TensorType({DT_FLOAT16, DT_FLOAT32, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, DT_INT16, DT_INT8, DT_UINT16, DT_QINT8})) | |||||
.OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT32, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, DT_INT16, DT_INT8, DT_UINT16, DT_QINT8})) | |||||
.REQUIRED_ATTR(ksize, ListInt) | .REQUIRED_ATTR(ksize, ListInt) | ||||
.REQUIRED_ATTR(strides, ListInt) | .REQUIRED_ATTR(strides, ListInt) | ||||
.ATTR(padding_mode, String, "CALCULATED") | .ATTR(padding_mode, String, "CALCULATED") | ||||
@@ -1258,14 +1258,14 @@ REG_OP(MaxPoolV3) | |||||
* the input tensor. No default value. | * the input tensor. No default value. | ||||
* @li padding_mode: A required string. Defaults to "CALCULATED". | * @li padding_mode: A required string. Defaults to "CALCULATED". | ||||
* @li pads:A required list of int8, int16, int32, or int64 values, | * @li pads:A required list of int8, int16, int32, or int64 values, | ||||
* a data to caculate when padding_mode is "SAME" and "CALCULATED". | |||||
* a data to caculate when padding_mode is "CALCULATED". | |||||
* @li data_format: An optional string. Defaults to "NHWC" . | * @li data_format: An optional string. Defaults to "NHWC" . | ||||
* @li global_pooling bool, Whether to use the global pooling. | * @li global_pooling bool, Whether to use the global pooling. | ||||
* If global_pooling = true, kernel size and paddings will be ignored. | * If global_pooling = true, kernel size and paddings will be ignored. | ||||
* Default False | * Default False | ||||
* @li ceil_mode:global_pooling (bool) – (bool) Whether to use the global pooling. | |||||
* If global_pooling = true, kernel size and paddings will be ignored. | |||||
* Default False \n | |||||
* @li ceil_mode: Whether to use the ceil function to calculate output | |||||
* height and width. False is the default. If it is set to False, | |||||
* the floor function will be used. Default False \n | |||||
* @par Outputs: | * @par Outputs: | ||||
* y: A mutable tensor. Has the same shape and type as "x1" . \n | * y: A mutable tensor. Has the same shape and type as "x1" . \n | ||||
@@ -403,6 +403,5 @@ REG_OP(EmbeddingRankId) | |||||
.ATTR(mode, String, "mod") | .ATTR(mode, String, "mod") | ||||
.OP_END_FACTORY_REG(EmbeddingRankId) | .OP_END_FACTORY_REG(EmbeddingRankId) | ||||
} // namespace ge | } // namespace ge | ||||
#endif // OPS_BUILT_IN_OP_PROTO_INC_PAD_OPS_H_ | #endif // OPS_BUILT_IN_OP_PROTO_INC_PAD_OPS_H_ |
@@ -0,0 +1,59 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
/*! | |||||
* \file target_crop_and_resize.h | |||||
* \brief | |||||
*/ | |||||
#ifndef GE_OP_TARGET_CROP_AND_RESIZE_H | |||||
#define GE_OP_TARGET_CROP_AND_RESIZE_H | |||||
#include "graph/operator_reg.h" | |||||
namespace ge { | |||||
/** | |||||
*@brief Performs crop and resize on images. | |||||
*@par Inputs: | |||||
*@li x: An NCHW tensor of type uint8, specifying the input to the data layer. | |||||
*@li boxes: Crop parameters of type int32. \n | |||||
*@li box_index: Batch index parameters of type int32. The batch of the input x to be cropped and resize. \n | |||||
*@par Attributes: | |||||
*output_h: A required int, specifying the height of output. \n | |||||
*output_w: A required int, specifying the width of output. \n | |||||
*input_format: A required string, specifying the input format. \n | |||||
*@par Outputs: | |||||
*y: The output tensor of type uint8, format only support NC1HWC0_C04. | |||||
*@par Third-party framework compatibility | |||||
* It is a custom operator. It has no corresponding operator in Caffe. | |||||
* | |||||
*@par Restrictions: | |||||
*Warning: THIS FUNCTION IS EXPERIMENTAL. Please do not use. | |||||
*/ | |||||
REG_OP(TargetCropAndResize) | |||||
.INPUT(x, TensorType({DT_UINT8})) | |||||
.INPUT(boxes, TensorType({DT_INT32})) | |||||
.INPUT(box_index, TensorType({DT_INT32})) | |||||
.OUTPUT(y, TensorType({DT_UINT8})) | |||||
.ATTR(output_h, Int, 224) | |||||
.ATTR(output_w, Int, 224) | |||||
.ATTR(input_format, String, "YUV420SP_U8") | |||||
.OP_END_FACTORY_REG(TargetCropAndResize) | |||||
} | |||||
#endif //GE_OP_TARGET_CROP_AND_RESIZE_H |
@@ -193,6 +193,7 @@ enum { | |||||
TDT_HDC_SRV_TYPE_ERROR_CODE, | TDT_HDC_SRV_TYPE_ERROR_CODE, | ||||
TDT_TSD_CLT_OPEN_FAILED_CODE, | TDT_TSD_CLT_OPEN_FAILED_CODE, | ||||
TDT_TSD_CLT_CLOSE_FAILED_CODE, | TDT_TSD_CLT_CLOSE_FAILED_CODE, | ||||
TDT_TSD_CLT_UPDATE_PROFILING_FAILED_CODE, | |||||
TDT_TSD_CLT_INTERFACE_NOT_SUPPORT_CODE, | TDT_TSD_CLT_INTERFACE_NOT_SUPPORT_CODE, | ||||
TDT_SUPERVISOR_ILLEGAL_HEARTBEAT_TIME_CODE, | TDT_SUPERVISOR_ILLEGAL_HEARTBEAT_TIME_CODE, | ||||
TDT_SUPERVISOR_INOTIFY_READ_SIZE_ERROR_CODE, | TDT_SUPERVISOR_INOTIFY_READ_SIZE_ERROR_CODE, | ||||
@@ -697,6 +698,8 @@ TDT_DEF_ERROR_CODE(MODID_HDC_SERVER, TDT_ERROR, TDT_BIND_CPUCORE_FAILED, "thread | |||||
TDT_DEF_ERROR_CODE(MODID_HDC_SERVER, TDT_ERROR, TDT_HDC_SRV_CLOSED, "hdc server has been closed"); | TDT_DEF_ERROR_CODE(MODID_HDC_SERVER, TDT_ERROR, TDT_HDC_SRV_CLOSED, "hdc server has been closed"); | ||||
TDT_DEF_ERROR_CODE(MODID_TSD_CLIENT, TDT_ERROR, TDT_TSD_CLT_OPEN_FAILED, "tsd client open failed"); | TDT_DEF_ERROR_CODE(MODID_TSD_CLIENT, TDT_ERROR, TDT_TSD_CLT_OPEN_FAILED, "tsd client open failed"); | ||||
TDT_DEF_ERROR_CODE(MODID_TSD_CLIENT, TDT_ERROR, TDT_TSD_CLT_CLOSE_FAILED, "tsd client close failed"); | TDT_DEF_ERROR_CODE(MODID_TSD_CLIENT, TDT_ERROR, TDT_TSD_CLT_CLOSE_FAILED, "tsd client close failed"); | ||||
TDT_DEF_ERROR_CODE(MODID_TSD_CLIENT, TDT_ERROR, TDT_TSD_CLT_UPDATE_PROFILING_FAILED, | |||||
"tsd client update profiling failed"); | |||||
TDT_DEF_ERROR_CODE(MODID_TSD_CLIENT, TDT_ERROR, TDT_TSD_CLT_INTERFACE_NOT_SUPPORT, "tsd client func not support"); | TDT_DEF_ERROR_CODE(MODID_TSD_CLIENT, TDT_ERROR, TDT_TSD_CLT_INTERFACE_NOT_SUPPORT, "tsd client func not support"); | ||||
TDT_DEF_ERROR_CODE(MODID_TDT_PREFETCH, TDT_ERROR, TDT_PREFETCH_FILELIST_NOT_EXIST, "tdt filelist open failed"); | TDT_DEF_ERROR_CODE(MODID_TDT_PREFETCH, TDT_ERROR, TDT_PREFETCH_FILELIST_NOT_EXIST, "tdt filelist open failed"); | ||||
TDT_DEF_ERROR_CODE(MODID_TDT_PREFETCH, TDT_ERROR, TDT_PREFETCH_SAMPLE_FILE_NOT_FOUND, "tdt sample file is empty"); | TDT_DEF_ERROR_CODE(MODID_TDT_PREFETCH, TDT_ERROR, TDT_PREFETCH_SAMPLE_FILE_NOT_FOUND, "tdt sample file is empty"); | ||||
@@ -49,7 +49,7 @@ extern "C" { | |||||
* @li tsd_client.h: Header file where the interface declaration is located. | * @li tsd_client.h: Header file where the interface declaration is located. | ||||
* @li data_common.h: Header file where 'TDT_StatusT' defined | * @li data_common.h: Header file where 'TDT_StatusT' defined | ||||
*/ | */ | ||||
TDT_StatusT TsdOpen(const uint32_t phyDeviceId, const uint32_t rankSize); | |||||
TDT_LIB_EXPORT TDT_StatusT TsdOpen(const uint32_t phyDeviceId, const uint32_t rankSize); | |||||
/** | /** | ||||
* @ingroup Close | * @ingroup Close | ||||
@@ -67,7 +67,25 @@ TDT_StatusT TsdOpen(const uint32_t phyDeviceId, const uint32_t rankSize); | |||||
* @li tsd_client.h: Header file where the interface declaration is located. | * @li tsd_client.h: Header file where the interface declaration is located. | ||||
* @li data_common.h: Header file where 'TDT_StatusT' defined | * @li data_common.h: Header file where 'TDT_StatusT' defined | ||||
*/ | */ | ||||
TDT_StatusT TsdClose(const uint32_t phyDeviceId); | |||||
TDT_LIB_EXPORT TDT_StatusT TsdClose(const uint32_t phyDeviceId); | |||||
/** | |||||
* @ingroup UpdateProfilingMode | |||||
* @brief notify TSDClient update profiling mode | |||||
* | |||||
* @par Function | |||||
* notify TSDClient update profiling mode | |||||
* | |||||
* @param NA | |||||
* @retval TDT_OK Success | |||||
* @retval OtherValues Failure | |||||
* | |||||
* @par Dependency | |||||
* @li libtsdclient.so: Library to which the interface belongs. | |||||
* @li tsd_client.h: Header file where the interface declaration is located. | |||||
* @li data_common.h: Header file where 'TDT_StatusT' defined | |||||
*/ | |||||
TDT_LIB_EXPORT TDT_StatusT UpdateProfilingMode(const uint32_t phyDeviceId, const uint32_t flag); | |||||
/** | /** | ||||
* @ingroup CreateCmdParameterObj | * @ingroup CreateCmdParameterObj | ||||