Browse Source

!175 synchronize latest Ascend software suite 27 Oct 2020

Merge pull request !175 from yanghaoran/r1.0.1
pull/175/MERGE
mindspore-ci-bot Gitee 4 years ago
parent
commit
423c0228e8
100 changed files with 2830 additions and 1455 deletions
  1. +51
    -0
      inc/common/opskernel/ops_kernel_builder.h
  2. +2
    -15
      inc/common/opskernel/ops_kernel_info_store.h
  3. +3
    -5
      inc/common/opskernel/ops_kernel_info_types.h
  4. +0
    -2
      inc/common/optimizer/graph_optimizer.h
  5. +48
    -0
      inc/common/util/ai_core/aicore_manager/aicore_util_manager.h
  6. +9
    -1
      inc/common/util/ai_core/common/aicore_util_attr_define.h
  7. +54
    -0
      inc/common/util/ai_core/common/aicore_util_constants.h
  8. +46
    -20
      inc/common/util/ai_core/common/aicore_util_types.h
  9. +28
    -26
      inc/common/util/ai_core/common/graph_comm.h
  10. +54
    -0
      inc/common/util/ai_core/common/json_util.h
  11. +44
    -0
      inc/common/util/ai_core/common/l2_stream_info.h
  12. +4
    -4
      inc/common/util/ai_core/common/scope_allocator.h
  13. +5
    -5
      inc/common/util/ai_core/param_calculate/tensorsize_calculator.h
  14. +2
    -0
      inc/common/util/error_manager/error_manager.h
  15. +27
    -27
      inc/common/util/platform_info.h
  16. +87
    -85
      inc/common/util/platform_info_def.h
  17. +1
    -1
      inc/external/ge/ge_api_error_codes.h
  18. +21
    -0
      inc/external/ge/ge_ir_build.h
  19. +38
    -0
      inc/external/graph/ascend_string.h
  20. +0
    -2
      inc/external/graph/attr_value.h
  21. +1
    -0
      inc/external/graph/ge_error_codes.h
  22. +129
    -0
      inc/external/graph/gnode.h
  23. +24
    -5
      inc/external/graph/graph.h
  24. +10
    -13
      inc/external/graph/operator.h
  25. +0
    -1
      inc/external/graph/tensor.h
  26. +134
    -0
      inc/external/hccl/hccl.h
  27. +101
    -0
      inc/external/hccl/hccl_types.h
  28. +0
    -2
      inc/external/register/register.h
  29. +0
    -1
      inc/external/register/scope/scope_fusion_pass_register.h
  30. +1
    -1
      inc/framework/common/ge_inner_error_codes.h
  31. +0
    -1
      inc/framework/common/op/attr_value_util.h
  32. +2
    -1
      inc/framework/common/op/ge_op_utils.h
  33. +3
    -5
      inc/framework/common/string_util.h
  34. +1
    -0
      inc/framework/common/types.h
  35. +1
    -1
      inc/framework/common/util.h
  36. +1
    -0
      inc/framework/engine/dnnengine.h
  37. +1
    -0
      inc/framework/generator/ge_generator.h
  38. +14
    -0
      inc/framework/memory/memory_api.h
  39. +1
    -1
      inc/framework/memory/memory_assigner.h
  40. +0
    -3
      inc/framework/omg/omg.h
  41. +2
    -3
      inc/framework/omg/omg_inner_types.h
  42. +3
    -3
      inc/graph/buffer.h
  43. +0
    -2
      inc/graph/compute_graph.h
  44. +12
    -4
      inc/graph/debug/ge_attr_define.h
  45. +1
    -1
      inc/graph/detail/any_map.h
  46. +2
    -2
      inc/graph/detail/attributes_holder.h
  47. +3
    -3
      inc/graph/ge_attr_value.h
  48. +1
    -1
      inc/graph/ge_context.h
  49. +5
    -0
      inc/graph/ge_local_context.h
  50. +1
    -1
      inc/graph/node.h
  51. +0
    -4
      inc/graph/range_vistor.h
  52. +39
    -5
      inc/graph/utils/graph_utils.h
  53. +32
    -0
      inc/graph/utils/node_adapter.h
  54. +8
    -0
      inc/graph/utils/node_utils.h
  55. +14
    -14
      src/common/graph/ascend_string.cc
  56. +16
    -17
      src/common/graph/format_refiner.cc
  57. +13
    -2
      src/common/graph/ge_attr_define.cc
  58. +12
    -17
      src/common/graph/ge_attr_value.cc
  59. +857
    -0
      src/common/graph/gnode.cc
  60. +234
    -1
      src/common/graph/graph.cc
  61. +12
    -3
      src/common/graph/graph.mk
  62. +1
    -0
      src/common/graph/model.cc
  63. +2
    -2
      src/common/graph/model_serialize.cc
  64. +0
    -1
      src/common/graph/op_desc.cc
  65. +10
    -12
      src/common/graph/operator.cc
  66. +0
    -2
      src/common/graph/opsproto/opsproto_manager.cc
  67. +2
    -0
      src/common/graph/option/ge_context.cc
  68. +14
    -0
      src/common/graph/option/ge_local_context.cc
  69. +56
    -35
      src/common/graph/shape_refiner.cc
  70. +0
    -6
      src/common/graph/stub/Makefile
  71. +0
    -578
      src/common/graph/stub/gen_stubapi.py
  72. +10
    -14
      src/common/graph/tensor.cc
  73. +294
    -39
      src/common/graph/utils/graph_utils.cc
  74. +37
    -0
      src/common/graph/utils/node_utils.cc
  75. +4
    -6
      src/common/graph/utils/op_desc_utils.cc
  76. +20
    -1
      src/common/graph/utils/tuning_utils.cc
  77. +8
    -1
      src/ge/CMakeLists.txt
  78. +28
    -33
      src/ge/analyzer/analyzer.cc
  79. +8
    -1
      src/ge/analyzer/analyzer.h
  80. +18
    -5
      src/ge/client/ge_prof.cc
  81. +2
    -4
      src/ge/client/module.mk
  82. +20
    -2
      src/ge/common/auth/file_saver.cc
  83. +0
    -248
      src/ge/common/convert/pb2json.cc
  84. +0
    -68
      src/ge/common/convert/pb2json.h
  85. +1
    -1
      src/ge/common/dump/dump_properties.cc
  86. +1
    -1
      src/ge/common/dump/dump_properties.h
  87. +0
    -36
      src/ge/common/ge/tbe_plugin_manager.cc
  88. +0
    -1
      src/ge/common/ge/tbe_plugin_manager.h
  89. +4
    -1
      src/ge/common/ge_common.mk
  90. +0
    -1
      src/ge/common/helper/model_cache_helper.cc
  91. +1
    -0
      src/ge/common/op/attr_value_util.cc
  92. +1
    -0
      src/ge/common/op/ge_op_utils.cc
  93. +14
    -21
      src/ge/common/profiling/profiling_manager.cc
  94. +1
    -0
      src/ge/common/types.cc
  95. +3
    -4
      src/ge/common/util.cc
  96. +3
    -3
      src/ge/engine_manager/dnnengine_manager.cc
  97. +7
    -0
      src/ge/engine_manager/engine_conf.json
  98. +29
    -13
      src/ge/executor/ge_executor.cc
  99. +11
    -4
      src/ge/executor/module.mk
  100. +5
    -0
      src/ge/ge_inference.mk

+ 51
- 0
inc/common/opskernel/ops_kernel_builder.h View File

@@ -0,0 +1,51 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_COMMON_OPSKERNELUTILS_OPS_KERNEL_INFO_UTILS_H_
#define INC_COMMON_OPSKERNELUTILS_OPS_KERNEL_INFO_UTILS_H_

#include "external/ge/ge_api_error_codes.h"
#include "cce/aicpu_engine_struct.h"
#include "common/opskernel/ops_kernel_info_types.h"
#include "graph/node.h"
#include "proto/task.pb.h"

namespace ge {
class OpsKernelBuilder {
public:
OpsKernelBuilder() = default;
virtual ~OpsKernelBuilder() = default;

// initialize OpsKernelBuilder
virtual Status Initialize(const std::map<std::string, std::string> &options) = 0;

// finalize OpsKernelBuilder
virtual Status Finalize() = 0;

// memory allocation requirement
virtual Status CalcOpRunningParam(Node &node) = 0;

// generate task for op
virtual Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) = 0;

// only call aicpu interface to generate task struct
virtual Status GenSingleOpRunTask(const NodePtr &node, STR_FWK_OP_KERNEL &task, string &task_info) { return FAILED; }

// only call aicpu interface to generate task struct
virtual Status GenMemCopyTask(uint64_t count, STR_FWK_OP_KERNEL &task, string &task_info) { return FAILED; }
};
} // namespace ge
#endif // INC_COMMON_OPSKERNELUTILS_OPS_KERNEL_INFO_UTILS_H_

+ 2
- 15
inc/common/opskernel/ops_kernel_info_store.h View File

@@ -43,10 +43,10 @@ class OpsKernelInfoStore {
virtual ~OpsKernelInfoStore() {}

// initialize opsKernelInfoStore
virtual Status Initialize(const map<string, string> &options) = 0; /*lint -e148*/
virtual Status Initialize(const map<string, string> &options) = 0;

// close opsKernelInfoStore
virtual Status Finalize() = 0; /*lint -e148*/
virtual Status Finalize() = 0;

virtual Status CreateSession(const std::map<std::string, std::string> &session_options) { return SUCCESS; }

@@ -65,24 +65,11 @@ class OpsKernelInfoStore {
// opsFlag opsFlag[0] indicates constant folding is supported or not
virtual void opsFlagCheck(const ge::Node &node, std::string &opsFlag){};

// memory allocation requirement
virtual Status CalcOpRunningParam(Node &node) = 0; /*lint -e148*/

// generate task for op。
virtual Status GenerateTask(const Node &node, RunContext &context,
std::vector<domi::TaskDef> &tasks) = 0; /*lint -e148*/

// only call fe engine interface to compile single op
virtual Status CompileOp(vector<ge::NodePtr> &node_vec) { return SUCCESS; }
virtual Status CompileOpRun(vector<ge::NodePtr> &node_vec) { return SUCCESS; }
// load task for op
virtual Status LoadTask(GETaskInfo &task) { return SUCCESS; }

// only call aicpu interface to generate task struct
virtual Status GenSingleOpRunTask(const NodePtr &node, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; }

// only call aicpu interface to generate task struct
virtual Status GenMemCopyTask(uint64_t count, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; }
};
} // namespace ge
#endif // INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_STORE_H_

+ 3
- 5
inc/common/opskernel/ops_kernel_info_types.h View File

@@ -26,13 +26,14 @@
using std::string;

namespace ge {
/*lint -e148*/
struct RunContext {
rtModel_t model;
rtStream_t stream;
uint64_t sessionId;
uint64_t dataMemSize;
uint8_t *dataMemBase;
std::map<int64_t, uint64_t> mem_type_data_mem_size;
std::map<int64_t, uint8_t *> mem_type_data_mem_base;
uint64_t weightMemSize;
uint8_t *weightMemBase;
ge::Buffer weightsBuffer;
@@ -41,8 +42,6 @@ struct RunContext {
std::vector<rtLabel_t> graphLabelList; // all labels of graph, order by ge label id(0,1,...)
};

/*lint +e148*/

struct Task {
uint32_t id;
uint16_t type;
@@ -51,8 +50,7 @@ struct Task {
};

struct OpInfo {
string engine; // which engin
/*lint -e148*/
string engine; // which engin
string opKernelLib; // which opsKernelStore
int computeCost; // compute cost
bool flagPartial; // whether to support is related to shape


+ 0
- 2
inc/common/optimizer/graph_optimizer.h View File

@@ -27,7 +27,6 @@
using std::map;
using std::string;

/*lint -e148*/
namespace ge {
class GraphOptimizer {
public:
@@ -67,5 +66,4 @@ class GraphOptimizer {
virtual Status OptimizeFusedGraphAfterGraphSlice(ComputeGraph &graph) { return SUCCESS; }
};
} // namespace ge
/*lint +e148*/
#endif // INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_

+ 48
- 0
inc/common/util/ai_core/aicore_manager/aicore_util_manager.h View File

@@ -0,0 +1,48 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef AICORE_UTIL_MANAGER_H_
#define AICORE_UTIL_MANAGER_H_

#include <string>
#include "register/graph_optimizer/graph_optimize_register_error_codes.h"

namespace fe {
class AICoreUtilManager {
public:
static AICoreUtilManager &Instance();
/*
* to initialize the aicore configuration
* param[in] the options of init
* param[in] engine Name
* param[in] socVersion soc version from ge
* return Status(SUCCESS/FAILED)
*/
Status Initialize(const std::map<std::string, std::string> &options, std::string &soc_version);

/*
* to release the source of fusion manager
* return Status(SUCCESS/FAILED)
*/
Status Finalize();

private:
AICoreUtilManager();
~AICoreUtilManager();
bool is_init_;
};
} // namespace fe
#endif // AICORE_UTIL_MANAGER_H

+ 9
- 1
inc/common/util/ai_core/common/aicore_util_attr_define.h View File

@@ -36,6 +36,14 @@ static const std::string L1_OPTIMIZED = "l1_optimized";

static const std::string L2_OPTIMIZED = "l2_optimized";

static const std::string OP_SLICE_INFO = "_op_slice_info";
static const std::string ATTR_NAME_UNKNOWN_SHAPE = "_unknown_shape";

static const std::string ATTR_NAME_IS_UNKNOWN_GRAPH = "_fe_is_unknown_graph";

static const std::string ATTR_NAME_IS_UNKNOWN_SHAPE_OP = "_fe_is_unknown_shape_op";

static const std::string ATTR_NAME_TVM_CACHE_READ_MODE = "tvm_cache_read_mode";

static const std::string ATTR_NAME_TBE_KERNEL_SIZE = "_tbeKernelSize";
} // namespace fe
#endif

+ 54
- 0
inc/common/util/ai_core/common/aicore_util_constants.h View File

@@ -0,0 +1,54 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_COMMON_UTILS_AI_CORE_COMMON_CONSTANTS_H_
#define INC_COMMON_UTILS_AI_CORE_COMMON_CONSTANTS_H_

#include <string>

namespace fe {
static const std::string CORE_TYPE = "_coretype";
/* engine name of AI core and vector core */
static const std::string AI_CORE_NAME = "AIcoreEngine";
static const std::string VECTOR_CORE_NAME = "VectorEngine";

static const int64_t IS_UNKNOWN_SHAPE_VALUE = 1;

static const int64_t SHAPE_UNKNOWN_DIM = -1;

static const int64_t SHAPE_UNKNOWN_DIM_NUM = -2;

static const std::string SOC_VERSION_ASCEND310 = "Ascend310";
static const std::string SOC_VERSION_ASCEND610 = "Ascend610";
static const std::string SOC_VERSION_ASCEND615 = "Ascend615";
static const std::string SOC_VERSION_ASCEND710 = "Ascend710";
static const std::string SOC_VERSION_ASCEND710P = "Ascend710Pro";
static const std::string SOC_VERSION_ASCEND910A = "Ascend910A";
static const std::string SOC_VERSION_ASCEND910B = "Ascend910B";
static const std::string SOC_VERSION_ASCEND910PROA = "Ascend910ProA";
static const std::string SOC_VERSION_ASCEND910PROB = "Ascend910ProB";
static const std::string SOC_VERSION_ASCEND910PREMIUMA = "Ascend910PremiumA";
static const std::string SOC_VERSION_HI3796CV300ES = "Hi3796CV300ES";
static const std::string SOC_VERSION_HI3796CV300CS = "Hi3796CV300CS";

static const std::vector<std::string> SOC_VERSION_CLOUD_LIST = {SOC_VERSION_ASCEND910A, SOC_VERSION_ASCEND910B,
SOC_VERSION_ASCEND910PROA, SOC_VERSION_ASCEND910PROB,
SOC_VERSION_ASCEND910PREMIUMA};

static const std::vector<std::string> SOC_VERSION_DC_LIST = {SOC_VERSION_ASCEND610, SOC_VERSION_ASCEND615,
SOC_VERSION_ASCEND710, SOC_VERSION_ASCEND710P};
} // namespace fe
#endif

+ 46
- 20
inc/common/util/ai_core/common/aicore_util_types.h View File

@@ -42,47 +42,61 @@ struct FusionDataFlow {
std::pair<std::string, ge::AnchorPtr> node_dataindex_pair;
};

typedef struct tagL2FusionData {
typedef struct tag_l2_fusion_data {
uint32_t l2Index;
uint64_t l2Addr;
uint64_t l2PageNum;
} L2FusionData_t;
typedef std::map<uint64_t, L2FusionData_t> L2FusionDataMap_t;

typedef struct tagFeSmDesc {
typedef struct tag_fe_sm_desc {
rtL2Ctrl_t l2ctrl;
std::string nodeName[8];
uint8_t outputIndex[8];
} feSmDesc_t;
std::string node_name[8];
uint8_t output_index[8];
} fe_sm_desc_t;

typedef struct TagTaskL2FusionInfo {
std::string nodeName;
feSmDesc_t l2Info;
std::string node_name;
fe_sm_desc_t l2_info;
L2FusionDataMap_t input;
L2FusionDataMap_t output;
uint32_t isUsed;
uint32_t is_used;
} TaskL2FusionInfo_t;

using L2FusionInfoPtr = std::shared_ptr<TaskL2FusionInfo_t>;

typedef struct ToOpStruct {
int64_t opL1Space = 0;
std::vector<int64_t> opL1FusionType;
int64_t opL1WorkspaceFlag = 0; // for workspace flag
int64_t opL1WorkspaceSize = 0;
std::vector<std::vector<int64_t>> validInputShape;
std::vector<std::vector<int64_t>> validOutputShape;
std::vector<std::vector<int64_t>> sliceInputOffset; // conv & pooling & ReadSelect
std::vector<std::vector<int64_t>> sliceOutputOffset; // WriteSelect
std::vector<uint32_t> totalShape;
uint32_t splitIndex = 0;
int64_t op_l1_space = 0;
std::vector<int64_t> op_l1_fusion_type;
int64_t op_l1_workspace_flag = 0; // for workspace flag
int64_t op_l1_workspace_size = 0;
std::vector<std::vector<int64_t>> valid_input_shape;
std::vector<std::vector<int64_t>> valid_output_shape;
std::vector<std::vector<int64_t>> slice_input_offset; // conv & pooling & ReadSelect
std::vector<std::vector<int64_t>> slice_output_offset; // WriteSelect
std::vector<uint32_t> total_shape;
uint32_t split_index = 0;
ToOpStruct() {
// set invalid value for essential variable
opL1Space = -1;
opL1WorkspaceSize = -1;
op_l1_space = -1;
op_l1_workspace_size = -1;
}
} ToOpStruct_t;

enum SlicePattern {
ELEMENT_WISE = 0,
ELEMENT_WISE_BROADCAST,
BROADCAST,
SLIDING_WINDOW,
SLIDING_WINDOW_DECONV,
CUBE_MATMUL,
SLICE_PATTERN_REDUCE,
SLICE_PATTERN_RESIZE,
SLICE_PATTERN_SCATTER,
SLICE_PATTERN_SEGMENT,
PATTERN_RESERVED
};

enum OpImplType {
EN_IMPL_CUSTOM_CONSTANT_CCE = 0, // custom constant op
EN_IMPL_CUSTOM_TIK, // custom tik op
@@ -99,6 +113,10 @@ enum OpImplType {
EN_RESERVED // reserved value
};

// Dont change the order, only add new mode in the end
enum L2Mode { EN_L2_CLOSE = 0, EN_L2_BUFFER_OPTIMIZE, EN_L2_CACHE_NORMAL, EN_L2_CACHE_RC };
enum BufferFusionMode { EN_OPTIMIZE_DISABLE = 0, EN_L2_BUFFER, EN_L2_FUSION };

static const std::map<ge::DataType, uint32_t> DATATYPE_SIZE_MAP{{ge::DT_FLOAT, sizeof(float)},
{ge::DT_FLOAT16, sizeof(int16_t)},
{ge::DT_INT8, sizeof(int8_t)},
@@ -114,5 +132,13 @@ static const std::map<ge::DataType, uint32_t> DATATYPE_SIZE_MAP{{ge::DT_FLOAT, s
{ge::DT_DUAL, sizeof(float) + sizeof(int8_t)},
{ge::DT_DUAL_SUB_UINT8, sizeof(int8_t)},
{ge::DT_DUAL_SUB_INT8, sizeof(int8_t)}};

enum OpReduceType {
REDUCE_MEAN = 0,
REDUCE_ADD,
REDUCE_MAX,
REDUCE_MIN,
};

} // namespace fe
#endif

+ 28
- 26
inc/common/util/ai_core/common/graph_comm.h View File

@@ -28,33 +28,34 @@

namespace fe {

using kScopeNodeMap_t = std::map<int64_t, std::vector<ge::NodePtr>>;
using kScopeNodePair_t = std::pair<int64_t, std::vector<ge::NodePtr>>;
using k_scope_node_map_t = std::map<int64_t, std::vector<ge::NodePtr>>;
using k_scope_node_pair_t = std::pair<int64_t, std::vector<ge::NodePtr>>;

class GraphCommImpl;
using GraphCommImplPtr = std::unique_ptr<GraphCommImpl>;

class GraphComm {
public:
GraphComm(const string &engineName);
GraphComm(const string &engine_name);
virtual ~GraphComm();
GraphComm(const GraphComm &in) = delete;
GraphComm &operator=(const GraphComm &in) = delete;

Status GetscopeNodeMap(ge::ComputeGraph &graph, kScopeNodeMap_t &fusionMap);
Status GetscopeNodeMap(ge::ComputeGraph &graph, k_scope_node_map_t &fusion_map);

Status CopyFusionOpNodes(vector<FusionDataFlow> &fusInputEdgeList, vector<FusionDataFlow> &fusOutputEdgeList,
vector<ge::NodePtr> &fusNodelist, ge::OpDescPtr fusionOpDesc,
ge::ComputeGraphPtr fusionGraph);
Status CopyFusionOpNodes(vector<FusionDataFlow> &fus_input_edge_list, vector<FusionDataFlow> &fus_output_edge_list,
vector<ge::NodePtr> &fus_nodelist, ge::OpDescPtr fusion_op_desc,
ge::ComputeGraphPtr fusion_graph);

Status CopyFusionOpEdges(ge::OpDescPtr fusionOpDesc, ge::ComputeGraph &origGraph, ge::ComputeGraphPtr fusionGraph);
Status CopyFusionOpEdges(ge::OpDescPtr fusion_op_desc, ge::ComputeGraph &orig_graph,
ge::ComputeGraphPtr fusion_graph);

Status GetNodeDataFlowMap(const ge::NodePtr &fusNode,
std::map<ge::NodePtr, std::map<ge::AnchorPtr, ge::AnchorPtr>> &fusionOpAnchorsMap,
ge::kFusionDataFlowVec_t &fusDataflowList, const int &mapType);
Status GetNodeDataFlowMap(const ge::NodePtr &fus_node,
std::map<ge::NodePtr, std::map<ge::AnchorPtr, ge::AnchorPtr>> &fusion_op_anchors_map,
ge::kFusionDataFlowVec_t &fus_dataflow_list, const int &map_type);

Status GetFusionNodeEdgeList(std::vector<ge::NodePtr> &fusNodelist, std::vector<FusionDataFlow> &fusInputEdgeList,
std::vector<FusionDataFlow> &fusOutputEdgeList);
Status GetFusionNodeEdgeList(std::vector<ge::NodePtr> &fus_nodelist, std::vector<FusionDataFlow> &fus_input_edge_list,
std::vector<FusionDataFlow> &fus_output_edge_list);
void ClearFusionSrc();

void ClearFusionDst();
@@ -72,25 +73,26 @@ class GraphComm {
bool GetFusionSrc(const uint32_t &src_op_id, const ge::AnchorPtr &src_anchor, int32_t &fusion_src_index,
int32_t &fusion_dst_index);

Status GetFusionNodeCtrlEdgeList(vector<ge::NodePtr> &fusNodelist, vector<FusionDataFlow> &fusInputCtrlEdgeList,
vector<FusionDataFlow> &fusOutputCtrlEdgeList);
Status GetFusionNodeCtrlEdgeList(vector<ge::NodePtr> &fus_nodelist, vector<FusionDataFlow> &fus_input_ctrl_edge_list,
vector<FusionDataFlow> &fus_output_ctrl_edge_list);

Status MergeFusionNodeEdgeList(ge::NodePtr &fusNode, vector<ge::NodePtr> &fusNodelist,
vector<FusionDataFlow> &fusInputEdgeList, vector<FusionDataFlow> &fusOutputEdgeList);
Status MergeFusionNodeEdgeList(ge::NodePtr &fus_node, vector<ge::NodePtr> &fus_nodelist,
vector<FusionDataFlow> &fus_input_edge_list,
vector<FusionDataFlow> &fus_output_edge_list);

Status MergeFusionNodeCtrlEdgeList(ge::NodePtr &fusNode, vector<ge::NodePtr> &fusNodelist,
vector<FusionDataFlow> &fusInputEdgeList,
vector<FusionDataFlow> &fusOutputEdgeList);
Status MergeFusionNodeCtrlEdgeList(ge::NodePtr &fus_node, vector<ge::NodePtr> &fus_nodelist,
vector<FusionDataFlow> &fus_input_edge_list,
vector<FusionDataFlow> &fus_output_edge_list);

string GetEngineName();

private:
Status MergeFusionNodeInputEdgeList(ge::NodePtr fusNode, std::vector<ge::NodePtr> &fusNodelist,
std::vector<FusionDataFlow> &fusInputEdgeList);
Status MergeFusionNodeOutputEdgeList(ge::NodePtr fusNode, std::vector<ge::NodePtr> &fusNodelist,
std::vector<FusionDataFlow> &fusOutputEdgeList);
Status MergeFusionNodeInputEdgeList(ge::NodePtr fus_node, std::vector<ge::NodePtr> &fus_nodelist,
std::vector<FusionDataFlow> &fus_input_edge_list);
Status MergeFusionNodeOutputEdgeList(ge::NodePtr fus_node, std::vector<ge::NodePtr> &fus_nodelist,
std::vector<FusionDataFlow> &fus_output_edge_list);

string engineName_;
string engine_name_;

std::vector<FusionOpSrc> exist_fusion_src_list_;
std::vector<FusionOpDst> exist_fusion_dst_list_;
@@ -101,7 +103,7 @@ class GraphComm {
// std::vector<std::multimap<std::string, ge::AnchorPtr>>
ge::kFusionDataFlowVec_t fusion_output_dataflow_list_;

GraphCommImplPtr graphCommImplPtr_;
GraphCommImplPtr graph_comm_impl_ptr_;
};
} // namespace fe
#endif

+ 54
- 0
inc/common/util/ai_core/common/json_util.h View File

@@ -0,0 +1,54 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PROJECT_JSON_UTIL_H
#define PROJECT_JSON_UTIL_H

#include "graph/compute_graph.h"

#include "common/aicore_util_types.h"
#include "fusion_engine/graph_tuner/graph_tuner_errorcode.h"

const std::string L1_FUSION_EXTEND_CONTENT = "_l1_fusion_extend_content";
const std::string L2_FUSION_EXTEND_CONTENT = "l2_fusion_extend_content";
const std::string TASK_L2_FUSION_INFO_EXTEND_CONTENT = "task_l2_fusion_info_extend_content";
const std::string L1_FUSION_TO_OP_STRUCT = "_l1fusion_ToOpStruct";
const std::string L2_FUSION_TO_OP_STRUCT = "_l2fusion_ToOpStruct";
const std::string TASK_L2_FUSION_INFO = "_task_L2FusionInfo";

namespace tune {
using ToOpStructPtr = std::shared_ptr<fe::ToOpStruct_t>;
using L2FusionInfoPtr = std::shared_ptr<fe::TaskL2FusionInfo_t>;

Status GetL1InfoFromJson(ge::OpDescPtr opDescPtr);

Status GetL2InfoFromJson(ge::OpDescPtr opDescPtr);

Status GetTaskL2FusionInfoFromJson(ge::OpDescPtr opDescPtr);

Status ReadGraphInfoFromJson(ge::ComputeGraph &graph);

Status WriteGraphInfoToJson(ge::ComputeGraph &graph);

void GetL2ToOpStructFromJson(ge::OpDescPtr &opDescPtr, ToOpStructPtr &l2InfoPtr);

void GetL1ToOpStructFromJson(ge::OpDescPtr &opDescPtr, ToOpStructPtr &l1InfoPtr);

L2FusionInfoPtr GetL2FusionInfoFromJson(ge::OpDescPtr &opDescPtr);

void SetL2FusionInfoToNode(ge::OpDescPtr &opDescPtr, L2FusionInfoPtr &l2FusionInfoPtr);
} // namespace tune
#endif // PROJECT_JSON_UTIL_H

+ 44
- 0
inc/common/util/ai_core/common/l2_stream_info.h View File

@@ -0,0 +1,44 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef L2_STREAM_INFO_H_
#define L2_STREAM_INFO_H_

#include <map>
#include <string>
#include <mutex>
#include "register/graph_optimizer/graph_optimize_register_error_codes.h"
#include "runtime/base.h"
#include "cce/l2fusion_struct.hpp"

namespace fe {
class StreamL2Info {
public:
StreamL2Info(const StreamL2Info &) = delete;
StreamL2Info &operator=(const StreamL2Info &) = delete;
static StreamL2Info &Instance();
Status GetStreamL2Info(rtStream_t stream_id, string node_name, fusion::TaskL2Info_t *&l2_data);
Status SetStreamL2Info(const rtStream_t &stream_id, fusion::TaskL2InfoFEMap_t &l2_alloc_res);

private:
StreamL2Info();
~StreamL2Info();
mutable std::mutex stream_l2_mutex_;
std::map<rtStream_t, fusion::TaskL2InfoFEMap_t> stream_l2_map_;
};
} // namespace fe

#endif // L2_STREAM_INFO_H_

+ 4
- 4
inc/common/util/ai_core/common/scope_allocator.h View File

@@ -32,12 +32,12 @@ class ScopeAllocator {
int64_t GetCurrentScopeId();
int64_t AllocateScopeId(void);
bool HasScopeAttr(ge::ConstOpDescPtr opdef);
bool GetScopeAttr(ge::ConstOpDescPtr opdef, int64_t& scopeId);
bool SetScopeAttr(ge::OpDescPtr opdef, int64_t scopeId);
bool ResetScopeId(int64_t scopeId);
bool GetScopeAttr(ge::ConstOpDescPtr opdef, int64_t& scope_id);
bool SetScopeAttr(ge::OpDescPtr opdef, int64_t scope_id);
bool ResetScopeId(int64_t scope_id);

private:
int64_t scopeId;
int64_t scope_id;
};
} // namespace fe
#endif

+ 5
- 5
inc/common/util/ai_core/param_calculate/tensorsize_calculator.h View File

@@ -29,16 +29,16 @@ class TensorSizeCalculator {
public:
/**
* Calculate the tensor size of input and output of each opdesc
* @param opDesc opdesc object
* @param opImplType op impl type
* @param op_desc opdesc object
* @param op_impl_type op impl type
* @return status SUCCESS or FAILED
*/
static Status CalculateOpTensorSize(ge::OpDesc &opDesc);
static Status CalculateOpTensorSize(ge::OpDesc &op_desc);

private:
static Status CalcInputOpTensorSize(ge::OpDesc &opDesc, int32_t &outputRealCalcFlag);
static Status CalcInputOpTensorSize(ge::OpDesc &op_desc, int32_t &output_real_calc_flag);

static Status CalcOutputOpTensorSize(ge::OpDesc &opDesc, int32_t &outputRealCalcFlag);
static Status CalcOutputOpTensorSize(ge::OpDesc &op_desc, int32_t &output_real_calc_flag);
};
} // namespace fe



+ 2
- 0
inc/common/util/error_manager/error_manager.h View File

@@ -20,6 +20,7 @@
#include <map>
#include <string>
#include <vector>
#include <mutex>

class ErrorManager {
public:
@@ -86,6 +87,7 @@ class ErrorManager {
int ReadJsonFile(const std::string &file_path, void *handle);

bool is_init_ = false;
std::mutex mutex_;
std::map<std::string, ErrorInfo> error_map_;
std::vector<std::string> error_messages_;
std::vector<std::string> warning_messages_;


+ 27
- 27
inc/common/util/platform_info.h View File

@@ -36,66 +36,66 @@ class PlatformInfoManager {
uint32_t InitializePlatformInfo();
uint32_t Finalize();

uint32_t GetPlatformInfo(const string SoCVersion, PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo);
uint32_t GetPlatformInfo(const string SoCVersion, PlatformInfo &platform_info, OptionalInfo &opti_compilation_info);

uint32_t GetPlatformInfoWithOutSocVersion(PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo);
uint32_t GetPlatformInfoWithOutSocVersion(PlatformInfo &platform_info, OptionalInfo &opti_compilation_info);

void SetOptionalCompilationInfo(OptionalInfo &optiCompilationInfo);
void SetOptionalCompilationInfo(OptionalInfo &opti_compilation_info);

private:
PlatformInfoManager();
~PlatformInfoManager();

uint32_t LoadIniFile(string iniFileRealPath);
uint32_t LoadIniFile(string ini_file_real_path);

void Trim(string &str);

uint32_t LoadConfigFile(string realPath);
uint32_t LoadConfigFile(string real_path);

string RealPath(const std::string &path);

string GetSoFilePath();

void ParseVersion(map<string, string> &versionMap, string &socVersion, PlatformInfo &platformInfoTemp);
void ParseVersion(map<string, string> &version_map, string &soc_version, PlatformInfo &platform_info_temp);

void ParseSocInfo(map<string, string> &socInfoMap, PlatformInfo &platformInfoTemp);
void ParseSocInfo(map<string, string> &soc_info_map, PlatformInfo &platform_info_temp);

void ParseCubeOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp);
void ParseCubeOfAICoreSpec(map<string, string> &ai_core_spec_map, PlatformInfo &platform_info_temp);

void ParseBufferOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp);
void ParseBufferOfAICoreSpec(map<string, string> &ai_core_spec_map, PlatformInfo &platform_info_temp);

void ParseUBOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp);
void ParseUBOfAICoreSpec(map<string, string> &ai_core_spec_map, PlatformInfo &platform_info_temp);

void ParseUnzipOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp);
void ParseUnzipOfAICoreSpec(map<string, string> &ai_core_spec_map, PlatformInfo &platform_info_temp);

void ParseAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp);
void ParseAICoreSpec(map<string, string> &ai_core_spec_map, PlatformInfo &platform_info_temp);

void ParseBufferOfAICoreMemoryRates(map<string, string> &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp);
void ParseBufferOfAICoreMemoryRates(map<string, string> &ai_core_memory_rates_map, PlatformInfo &platform_info_temp);

void ParseAICoreMemoryRates(map<string, string> &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp);
void ParseAICoreMemoryRates(map<string, string> &ai_core_memory_rates_map, PlatformInfo &platform_info_temp);

void ParseUBOfAICoreMemoryRates(map<string, string> &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp);
void ParseUBOfAICoreMemoryRates(map<string, string> &ai_core_memory_rates_map, PlatformInfo &platform_info_temp);

void ParseAICoreintrinsicDtypeMap(map<string, string> &aiCoreintrinsicDtypeMap, PlatformInfo &platformInfoTemp);
void ParseAICoreintrinsicDtypeMap(map<string, string> &ai_coreintrinsic_dtype_map, PlatformInfo &platform_info_temp);

void ParseVectorCoreSpec(map<string, string> &vectorCoreSpecMap, PlatformInfo &platformInfoTemp);
void ParseVectorCoreSpec(map<string, string> &vector_core_spec_map, PlatformInfo &platform_info_temp);

void ParseVectorCoreMemoryRates(map<string, string> &vectorCoreMemoryRatesMap, PlatformInfo &platformInfoTemp);
void ParseVectorCoreMemoryRates(map<string, string> &vector_core_memory_rates_map, PlatformInfo &platform_info_temp);

void ParseCPUCache(map<string, string> &CPUCacheMap, PlatformInfo &platformInfoTemp);
void ParseCPUCache(map<string, string> &CPUCacheMap, PlatformInfo &platform_info_temp);

void ParseVectorCoreintrinsicDtypeMap(map<string, string> &vectorCoreintrinsicDtypeMap,
PlatformInfo &platformInfoTemp);
void ParseVectorCoreintrinsicDtypeMap(map<string, string> &vector_coreintrinsic_dtype_map,
PlatformInfo &platform_info_temp);

uint32_t ParsePlatformInfoFromStrToStruct(map<string, map<string, string>> &contentInfoMap, string &socVersion,
PlatformInfo &platformInfoTemp);
uint32_t ParsePlatformInfoFromStrToStruct(map<string, map<string, string>> &content_info_map, string &soc_version,
PlatformInfo &platform_info_temp);

uint32_t AssemblePlatformInfoVector(map<string, map<string, string>> &contentInfoMap);
uint32_t AssemblePlatformInfoVector(map<string, map<string, string>> &content_info_map);

private:
bool initFlag_;
map<string, PlatformInfo> platformInfoMap_;
OptionalInfo optiCompilationInfo_;
bool init_flag_;
map<string, PlatformInfo> platform_info_map_;
OptionalInfo opti_compilation_info_;
};
} // namespace fe
#endif

+ 87
- 85
inc/common/util/platform_info_def.h View File

@@ -30,111 +30,113 @@ enum MemoryType { DDR = 0, HBM };

enum L2Type { Cache = 0, Buff };

typedef struct tagStrInfo {
string aicVersion;
string ccecAICVersion;
string ccecAIVVersion;
string isSupportAIcpuCompiler;
typedef struct tag_str_info {
string aic_version;
string ccec_aic_version;
string ccec_aiv_version;
string is_support_ai_cpu_compiler;
} StrInfo;

typedef struct tagSoCInfo {
uint32_t aiCoreCnt;
uint32_t vectorCoreCnt;
uint32_t aiCpuCnt;
MemoryType memoryType;
uint64_t memorySize;
L2Type l2Type;
uint64_t l2Size;
typedef struct tag_so_c_info {
uint32_t ai_core_cnt;
uint32_t vector_core_cnt;
uint32_t ai_cpu_cnt;
MemoryType memory_type;
uint64_t memory_size;
L2Type l2_type;
uint64_t l2_size;
uint32_t l2PageNum;
} SoCInfo;

typedef struct tagAiCoreSpec {
double cubeFreq;
uint64_t cubeMSize;
uint64_t cubeNSize;
uint64_t cubeKSize;
uint64_t vecCalcSize;
uint64_t l0ASize;
uint64_t l0BSize;
uint64_t l0CSize;
uint64_t l1Size;
uint64_t smaskBuffer;
uint64_t ubSize;
uint64_t ubblockSize;
uint64_t ubbankSize;
uint64_t ubbankNum;
uint64_t ubburstInOneBlock;
uint64_t ubbankGroupNum;
uint32_t unzipEngines;
uint32_t unzipMaxRatios;
uint32_t unzipChannels;
uint8_t unzipIsTight;
typedef struct tag_ai_core_spec {
double cube_freq;
uint64_t cube_m_size;
uint64_t cube_n_size;
uint64_t cube_k_size;
uint64_t vec_calc_size;
uint64_t l0_a_size;
uint64_t l0_b_size;
uint64_t l0_c_size;
uint64_t l1_size;
uint64_t smask_buffer;
uint64_t ub_size;
uint64_t ubblock_size;
uint64_t ubbank_size;
uint64_t ubbank_num;
uint64_t ubburst_in_one_block;
uint64_t ubbank_group_num;
uint32_t unzip_engines;
uint32_t unzip_max_ratios;
uint32_t unzip_channels;
uint8_t unzip_is_tight;
uint8_t cube_vector_split;
} AiCoreSpec;

typedef struct tagAiCoreMemoryRates {
double ddrRate;
double ddrReadRate;
double ddrWriteRate;
double l2Rate;
double l2ReadRate;
double l2WriteRate;
double l1ToL0ARate;
double l1ToL0BRate;
double l1ToUBRate;
double l0CToUBRate;
double ubToL2Rate;
double ubToDdrRate;
double ubToL1Rate;
typedef struct tag_ai_core_memory_rates {
double ddr_rate;
double ddr_read_rate;
double ddr_write_rate;
double l2_rate;
double l2_read_rate;
double l2_write_rate;
double l1_to_l0_a_rate;
double l1_to_l0_b_rate;
double l1_to_ub_rate;
double l0_c_to_ub_rate;
double ub_to_l2_rate;
double ub_to_ddr_rate;
double ub_to_l1_rate;
} AiCoreMemoryRates;

typedef struct tagVectorCoreSpec {
double vecFreq;
uint64_t vecCalcSize;
uint64_t smaskBuffer;
uint64_t ubSize;
uint64_t ubblockSize;
uint64_t ubbankSize;
uint64_t ubbankNum;
uint64_t ubburstInOneBlock;
uint64_t ubbankGroupNum;
uint64_t vectorRegSize;
uint64_t predicateRegSize;
uint64_t addressRegSize;
typedef struct tag_vector_core_spec {
double vec_freq;
uint64_t vec_calc_size;
uint64_t smask_buffer;
uint64_t ub_size;
uint64_t ubblock_size;
uint64_t ubbank_size;
uint64_t ubbank_num;
uint64_t ubburst_in_one_block;
uint64_t ubbank_group_num;
uint64_t vector_reg_size;
uint64_t predicate_reg_size;
uint64_t address_reg_size;
uint64_t alignment_reg_size;
} VectorCoreSpec;

typedef struct tagVectorCoreMemoryRates {
double ddrRate;
double ddrReadRate;
double ddrWriteRate;
double l2Rate;
double l2ReadRate;
double l2WriteRate;
double ubToL2Rate;
double ubToDdrRate;
typedef struct tag_vector_core_memory_rates {
double ddr_rate;
double ddr_read_rate;
double ddr_write_rate;
double l2_rate;
double l2_read_rate;
double l2_write_rate;
double ub_to_l2_rate;
double ub_to_ddr_rate;
} VectorCoreMemoryRates;

typedef struct tagCPUCache {
typedef struct tag_cpu_cache {
uint32_t AICPUSyncBySW;
uint32_t TSCPUSyncBySW;
} CPUCache;

typedef struct tagPlatformInfo {
StrInfo strInfo;
SoCInfo socInfo;
AiCoreSpec aiCoreSpec;
AiCoreMemoryRates aiCoreMemoryRates;
map<string, vector<string>> aiCoreIntrinsicDtypeMap;
VectorCoreSpec vectorCoreSpec;
VectorCoreMemoryRates vectorCoreMemoryRates;
typedef struct tag_platform_info {
StrInfo str_info;
SoCInfo soc_info;
AiCoreSpec ai_core_spec;
AiCoreMemoryRates ai_core_memory_rates;
map<string, vector<string>> ai_core_intrinsic_dtype_map;
VectorCoreSpec vector_core_spec;
VectorCoreMemoryRates vector_core_memory_rates;
CPUCache cpucache;
map<string, vector<string>> vectorCoreIntrinsicDtypeMap;
map<string, vector<string>> vector_core_intrinsic_dtype_map;
} PlatformInfo;

typedef struct tagOptionalInfo {
string socVersion;
string coreType;
uint32_t aiCoreNum;
string l1FusionFlag;
typedef struct tag_optional_info {
string soc_version;
string core_type;
uint32_t ai_core_num;
string l1_fusion_flag;
} OptionalInfo;
} // namespace fe
#endif

+ 1
- 1
inc/external/ge/ge_api_error_codes.h View File

@@ -70,7 +70,7 @@ using Status = uint32_t;

// General error code
GE_ERRORNO(0, 0, 0, 0, 0, SUCCESS, 0, "success");
GE_ERRORNO(0b11, 0b11, 0b111, 0xFF, 0b11111, FAILED, 0xFFF, "failed"); /*lint !e401*/
GE_ERRORNO(0b11, 0b11, 0b111, 0xFF, 0b11111, FAILED, 0xFFF, "failed");
} // namespace ge

#endif // INC_EXTERNAL_GE_GE_API_ERROR_CODES_H_

+ 21
- 0
inc/external/ge/ge_ir_build.h View File

@@ -89,5 +89,26 @@ graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &m
*/
graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *patch_version);

/**
* @ingroup AscendCL
* @brief infer shape and data type
*
* @param graph[IN] the graph ready to build
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
graphStatus aclgrphInferShapeAndType(ge::Graph &graph);

/**
* @ingroup AscendCL
* @brief dump graph
*
* @param graph[IN] the graph ready to build
* @param file[IN] file path
* @param file[IN] file path string len
* @retval GRAPH_SUCCESS The function is successfully executed.
* @retval OtherValues Failure
*/
graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char *file, const size_t len);
}; // namespace ge
#endif

+ 38
- 0
inc/external/graph/ascend_string.h View File

@@ -0,0 +1,38 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_GRAPH_ASCEND_STRING_H_
#define INC_EXTERNAL_GRAPH_ASCEND_STRING_H_

#include <string>
#include <memory>

namespace ge {
class AscendString {
public:
AscendString() = default;

~AscendString() = default;

explicit AscendString(const char* name);

const char* GetString() const;

private:
std::shared_ptr<std::string> name_;
};
} // namespace ge
#endif // INC_EXTERNAL_GRAPH_ASCEND_STRING_H_

+ 0
- 2
inc/external/graph/attr_value.h View File

@@ -34,7 +34,6 @@ using std::vector;

namespace ge {
class AttrValueImpl;
/*lint -e148*/
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue {
public:
using INT = int64_t;
@@ -70,6 +69,5 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue {
VALUE_SET_GET_DEC(AttrValue::FLOAT)
#undef VALUE_SET_GET_DEC
};
/*lint +e148*/
} // namespace ge
#endif // INC_EXTERNAL_GRAPH_ATTR_VALUE_H_

+ 1
- 0
inc/external/graph/ge_error_codes.h View File

@@ -33,6 +33,7 @@ using graphStatus = uint32_t;
const graphStatus GRAPH_FAILED = 0xFFFFFFFF;
const graphStatus GRAPH_SUCCESS = 0;
const graphStatus GRAPH_PARAM_INVALID = 50331649;
const graphStatus GRAPH_NODE_WITHOUT_CONST_INPUT = 50331648;
} // namespace ge

#endif // INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_

+ 129
- 0
inc/external/graph/gnode.h View File

@@ -0,0 +1,129 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_EXTERNAL_GRAPH_NODE_H_
#define INC_EXTERNAL_GRAPH_NODE_H_

#include <vector>
#include <cstdint>

#include "./ge_error_codes.h"
#include "./types.h"
#include "./tensor.h"
#include "./ascend_string.h"

namespace ge {
class AttrValue;
class GNode;
class OpDesc;
class Graph;
class ComputeGraph;
using GNodePtr = std::shared_ptr<GNode>;
using GraphPtr = std::shared_ptr<Graph>;
using OpBytes = std::vector<uint8_t>;
using OpDescPtr = std::shared_ptr<OpDesc>;
using ComputeGraphPtr = std::shared_ptr<ComputeGraph>;

class NodeImpl;
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GNode {
public:
GNode();

~GNode() = default;

graphStatus GetType(ge::AscendString &type) const;

graphStatus GetName(ge::AscendString &name) const;

std::pair<GNodePtr, int32_t> GetInDataNodesAndPortIndexs(const int32_t index) const;

std::vector<GNodePtr> GetInControlNodes() const;

std::vector<std::pair<GNodePtr, int32_t>> GetOutDataNodesAndPortIndexs(const int32_t index) const;

std::vector<GNodePtr> GetOutControlNodes() const;

graphStatus GetInputConstData(const int32_t index, Tensor &data) const;

graphStatus GetInputIndexByName(const ge::AscendString &name, int32_t &index);

graphStatus GetOutputIndexByName(const ge::AscendString &name, int32_t &index);

size_t GetInputsSize() const;

size_t GetOutputsSize() const;

graphStatus GetInputDesc(const int32_t index, TensorDesc &tensor_desc) const;

graphStatus UpdateInputDesc(const int32_t index, const TensorDesc &tensor_desc);

graphStatus GetOutputDesc(const int32_t index, TensorDesc &tensor_desc) const;

graphStatus UpdateOutputDesc(const int32_t index, const TensorDesc &tensor_desc);

graphStatus GetAttr(const ge::AscendString &name, int64_t &attr_value) const;
graphStatus GetAttr(const ge::AscendString &name, int32_t &attr_value) const;
graphStatus GetAttr(const ge::AscendString &name, uint32_t &attr_value) const;
graphStatus GetAttr(const ge::AscendString &name, float &attr_value) const;
graphStatus GetAttr(const ge::AscendString &name, ge::AscendString &attr_value) const;
graphStatus GetAttr(const ge::AscendString &name, bool &attr_value) const;
graphStatus GetAttr(const ge::AscendString &name, Tensor &attr_value) const;
graphStatus GetAttr(const ge::AscendString &name, std::vector<int64_t> &attr_value) const;
graphStatus GetAttr(const ge::AscendString &name, std::vector<int32_t> &attr_value) const;
graphStatus GetAttr(const ge::AscendString &name, std::vector<uint32_t> &attr_value) const;
graphStatus GetAttr(const ge::AscendString &name, std::vector<float> &attr_value) const;
graphStatus GetAttr(const ge::AscendString &name, std::vector<ge::AscendString> &attr_values) const;
graphStatus GetAttr(const ge::AscendString &name, std::vector<bool> &attr_value) const;
graphStatus GetAttr(const ge::AscendString &name, std::vector<Tensor> &attr_value) const;
graphStatus GetAttr(const ge::AscendString &name, OpBytes &attr_value) const;
graphStatus GetAttr(const ge::AscendString &name, std::vector<std::vector<int64_t>> &attr_value) const;
graphStatus GetAttr(const ge::AscendString &name, std::vector<ge::DataType> &attr_value) const;
graphStatus GetAttr(const ge::AscendString &name, ge::DataType &attr_value) const;
graphStatus GetAttr(const ge::AscendString &name, AttrValue &attr_value) const;

graphStatus SetAttr(const ge::AscendString &name, int64_t &attr_value) const;
graphStatus SetAttr(const ge::AscendString &name, int32_t &attr_value) const;
graphStatus SetAttr(const ge::AscendString &name, uint32_t &attr_value) const;
graphStatus SetAttr(const ge::AscendString &name, float &attr_value) const;
graphStatus SetAttr(const ge::AscendString &name, ge::AscendString &attr_value) const;
graphStatus SetAttr(const ge::AscendString &name, bool &attr_value) const;
graphStatus SetAttr(const ge::AscendString &name, Tensor &attr_value) const;
graphStatus SetAttr(const ge::AscendString &name, std::vector<int64_t> &attr_value) const;
graphStatus SetAttr(const ge::AscendString &name, std::vector<int32_t> &attr_value) const;
graphStatus SetAttr(const ge::AscendString &name, std::vector<uint32_t> &attr_value) const;
graphStatus SetAttr(const ge::AscendString &name, std::vector<float> &attr_value) const;
graphStatus SetAttr(const ge::AscendString &name, std::vector<ge::AscendString> &attr_values) const;
graphStatus SetAttr(const ge::AscendString &name, std::vector<bool> &attr_value) const;
graphStatus SetAttr(const ge::AscendString &name, std::vector<Tensor> &attr_value) const;
graphStatus SetAttr(const ge::AscendString &name, OpBytes &attr_value) const;
graphStatus SetAttr(const ge::AscendString &name, std::vector<std::vector<int64_t>> &attr_value) const;
graphStatus SetAttr(const ge::AscendString &name, std::vector<ge::DataType> &attr_value) const;
graphStatus SetAttr(const ge::AscendString &name, ge::DataType &attr_value) const;
graphStatus SetAttr(const ge::AscendString &name, AttrValue &attr_value) const;

bool HasAttr(const ge::AscendString &name);

graphStatus GetSubgraph(uint32_t index, GraphPtr graph) const;

graphStatus GetALLSubgraphs(std::vector<GraphPtr> graph_list) const;

private:
std::shared_ptr<NodeImpl> impl_;
friend class NodeAdapter;
};
} // namespace ge

#endif // INC_EXTERNAL_GRAPH_NODE_H_

+ 24
- 5
inc/external/graph/graph.h View File

@@ -23,11 +23,14 @@
#include <vector>

#include "./operator.h"
#include "./gnode.h"

namespace ge {
class Graph;
class GraphImpl;

using GraphImplPtr = std::shared_ptr<GraphImpl>;
using GraphPtr = std::shared_ptr<Graph>;

class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph {
friend class GraphUtils;
@@ -53,15 +56,15 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph {

graphStatus AddOp(const ge::Operator &op);

graphStatus FindOpByName(const string &name, ge::Operator &op) const;
graphStatus FindOpByName(const std::string &name, ge::Operator &op) const;

graphStatus FindOpByType(const string &type, std::vector<ge::Operator> &ops) const;
graphStatus FindOpByType(const std::string &type, std::vector<ge::Operator> &ops) const;

graphStatus GetAllOpName(std::vector<string> &op_name) const;
graphStatus GetAllOpName(std::vector<std::string> &op_name) const;

graphStatus SaveToFile(const string &file_name) const;
graphStatus SaveToFile(const std::string &file_name) const;

graphStatus LoadFromFile(const string &file_name);
graphStatus LoadFromFile(const std::string &file_name);

const std::string &GetName() const;

@@ -73,6 +76,22 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph {
///
void SetNeedIteration(bool need_iteration);

std::vector<GNode> GetAllNodes() const;

std::vector<GNode> GetDirectNode() const;

graphStatus RemoveNode(GNode &node);

graphStatus RemoveEdge(GNode &src_node, const int32_t src_port_index, GNode &dst_node, const int32_t dst_port_index);

GNode AddNodeByOp(const Operator &op);

graphStatus AddDataEdge(GNode &src_node, const int32_t src_port_index, GNode &dst_node, const int32_t dst_port_index);

graphStatus AddControlEdge(GNode &src_node, GNode &dst_node);

static GraphPtr ConstructFromInputs(const std::vector<Operator> &inputs, const ge::AscendString &name);

private:
GraphImplPtr impl_{nullptr};
};


+ 10
- 13
inc/external/graph/operator.h View File

@@ -63,7 +63,6 @@ using std::function;
using std::shared_ptr;
using std::string;

/*lint -e148*/
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator {
public:
friend class OperatorImpl;
@@ -91,7 +90,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator {

explicit Operator(const string &type);

Operator(const string &name, const string &type); // lint !e148
Operator(const string &name, const string &type);

virtual ~Operator() = default;

@@ -104,7 +103,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator {
// Only has one output index = 0
Operator &SetInput(const string &dst_name, const Operator &src_oprt);

Operator &SetInput(const string &dst_name, const Operator &src_oprt, const string &name); // lint !e148
Operator &SetInput(const string &dst_name, const Operator &src_oprt, const string &name);

Operator &SetInput(const string &dst_name, const Operator &src_oprt, uint32_t index);

@@ -128,22 +127,22 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator {

TensorDesc GetOutputDesc(uint32_t index) const;

graphStatus UpdateOutputDesc(const string &name, const TensorDesc &tensor_desc); // lint !e148
graphStatus UpdateOutputDesc(const string &name, const TensorDesc &tensor_desc);

TensorDesc GetDynamicInputDesc(const string &name, uint32_t index) const;

graphStatus UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148
graphStatus UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc);

TensorDesc GetDynamicOutputDesc(const string &name, uint32_t index) const;

graphStatus UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148
graphStatus UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc);

graphStatus InferShapeAndType(); // lint !e148
graphStatus InferShapeAndType();

void SetInferenceContext(const InferenceContextPtr &inference_context);
InferenceContextPtr GetInferenceContext() const;

graphStatus VerifyAllAttr(bool disable_common_verifier = false); // lint !e148
graphStatus VerifyAllAttr(bool disable_common_verifier = false);

size_t GetInputsSize() const;

@@ -256,20 +255,19 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator {

void RequiredAttrRegister(const string &name);

graphStatus VerifyAll(); // lint !e148
graphStatus VerifyAll();

// Only has one output index = 0
Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt);

Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt,
const string &name); // lint !e148
Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, const string &name);

void SubgraphRegister(const string &ir_name, bool dynamic);
void SubgraphCountRegister(const string &ir_name, uint32_t count);
void SetSubgraphBuilder(const string &ir_name, uint32_t index, const SubgraphBuilder &builder);

private:
Operator &SetInput(const string &dst_name, const OutHandler &out_handler); // lint !e148
Operator &SetInput(const string &dst_name, const OutHandler &out_handler);

OutHandler GetOutput(const string &name) const;

@@ -283,7 +281,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator {

std::shared_ptr<const Node> GetNode() const;
};
/*lint +e148*/
} // namespace ge

#endif // INC_EXTERNAL_GRAPH_OPERATOR_H_

+ 0
- 1
inc/external/graph/tensor.h View File

@@ -126,6 +126,5 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Tensor {
friend class TensorAdapter;
};
} // namespace ge
/*lint +e148*/

#endif // INC_EXTERNAL_GRAPH_TENSOR_H_

+ 134
- 0
inc/external/hccl/hccl.h View File

@@ -0,0 +1,134 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/**
* @file hccl.h
* @brief HCCL API
*/

#ifndef HCCL_H_
#define HCCL_H_

#include <hccl/hccl_types.h>
#include <acl/acl.h>

#ifdef __cplusplus
extern "C" {
#endif // __cplusplus

/**
* @brief Initialize HCCL.
*
* @param clusterInfo A string identifying the cluster info file path, include file name.
* @param rank A integer identifying the identify for the rank.
* @param comm A pointer identifying the initialized communication resource.
* @return HcclResult
* @see HcclCommDestroy()
*/
extern HcclResult HcclCommInitClusterInfo(const char *clusterInfo, uint32_t rank, HcclComm *comm);

/**
* @brief Get hccl root info.
*
* @param rootInfo A pointer identifying the hccl root info.
* @return HcclResult
*/
extern HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo);

/**
* @brief Initialize HCCL with root info.
*
* @param nRanks A integer identifying the rank size of the cluster.
* @param rootInfo A struct identifying the hccl root info.
* @param rank A integer identifying the identify for the rank.
* @param comm A pointer identifying the initialized communication resource.
* @return HcclResult
* @see HcclCommDestroy()
*/
extern HcclResult HcclCommInitRootInfo(uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm);

/**
* @brief AllReduce operator.
*
* @param sendBuf A pointer identifying the input data address of the operator.
* @param recvBuf A pointer identifying the output data address of the operator.
* @param count An integer(u64) identifying the number of the output data.
* @param dataType The data type of the operator, must be one of the following types: int8, int16, int32, float16,
* float32.
* @param op The reduction type of the operator, must be one of the following types: sum, min, max, prod.
* @param comm A pointer identifying the communication resource based on.
* @param stream A pointer identifying the stream information.
* @return HcclResult
*/
extern HcclResult HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op,
HcclComm comm, aclrtStream stream);

/**
* @brief Broadcast operator.
*
* @param buf A pointer identifying the data address of the operator.
* @param count An integer(u64) identifying the number of the data.
* @param dataType The data type of the operator, must be one of the following types: int8, int32, float16, float32.
* @param root An integer(u32) identifying the the root rank in the operator.
* @param comm A pointer identifying the communication resource based on
* @param stream A pointer identifying the stream information.
* @return HcclResult
*/
extern HcclResult HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, HcclComm comm,
aclrtStream stream);

/**
* @brief ReduceScatter operator.
*
* @param sendBuf A pointer identifying the input data address of the operator.
* @param recvBuf A pointer identifying the output data address of the operator.
* @param recvCount An integer(u64) identifying the number of the output data.
* @param dataType The data type of the operator, must be one of the following types: int8, int32, float16, float32.
* @param op The reduction type of the operator, must be one of the following types: sum, min, max, prod.
* @param comm A pointer identifying the communication resource based on.
* @param stream A pointer identifying the stream information.
* @return HcclResult
*/
extern HcclResult HcclReduceScatter(void *sendBuf, void *recvBuf, uint64_t recvCount, HcclDataType dataType,
HcclReduceOp op, HcclComm comm, aclrtStream stream);

/**
* @brief AllGather operator.
*
* @param sendBuf A pointer identifying the input data address of the operator.
* @param recvBuf A pointer identifying the output data address of the operator.
* @param sendCount An integer(u64) identifying the number of the input data.
* @param dataType The data type of the operator, must be one of the following types: int8, int32, float16, float32.
* @param comm A pointer identifying the communication resource based on.
* @param stream A pointer identifying the stream information.
* @return HcclResult
*/
extern HcclResult HcclAllGather(void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType, HcclComm comm,
aclrtStream stream);

/**
* @brief Destroy HCCL comm
*
* @param comm A pointer identifying the communication resource targetting
* @return HcclResult
* @see HcclCommInitClusterInfo()
*/
extern HcclResult HcclCommDestroy(HcclComm comm);

#ifdef __cplusplus
}
#endif // __cplusplus
#endif // HCCL_H_

+ 101
- 0
inc/external/hccl/hccl_types.h View File

@@ -0,0 +1,101 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* @file hccl_types.h
* @brief HCCL data type definition
*
*/
#ifndef HCCL_TYPES_H_
#define HCCL_TYPES_H_
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
/**
* @brief HCCL functions return value definition
*/
typedef enum {
HCCL_SUCCESS = 0, /**< success */
HCCL_E_PARA = 1, /**< parameter error */
HCCL_E_PTR = 2, /**< empty pointer */
HCCL_E_MEMORY = 3, /**< memory error */
HCCL_E_INTERNAL = 4, /**< internal error */
HCCL_E_NOT_SUPPORT = 5, /**< not support feature */
HCCL_E_NOT_FOUND = 6, /**< not found specific resource */
HCCL_E_UNAVAIL = 7, /**< resource unavailable */
HCCL_E_SYSCALL = 8, /**< call system interface error */
HCCL_E_TIMEOUT = 9, /**< timeout */
HCCL_E_OPEN_FILE_FAILURE = 10, /**< open file fail */
HCCL_E_TCP_CONNECT = 11, /**< tcp connect fail */
HCCL_E_ROCE_CONNECT = 12, /**< roce connect fail */
HCCL_E_TCP_TRANSFER = 13, /**< tcp transfer fail */
HCCL_E_ROCE_TRANSFER = 14, /**< roce transfer fail */
HCCL_E_RUNTIME = 15, /**< call runtime api fail */
HCCL_E_DRV = 16, /**< call driver api fail */
HCCL_E_PROFILING = 17, /**< call profiling api fail */
HCCL_E_CCE = 18, /**< call cce api fail */
HCCL_E_NETWORK = 19, /**< call network api fail */
HCCL_E_RESERVED /**< reserved */
} HcclResult;
/**
* @brief handle to HCCL communicator
*/
typedef void *HcclComm;
/**
* @brief HCCL Reduction opperation
*/
typedef enum {
HCCL_REDUCE_SUM = 0, /**< sum */
HCCL_REDUCE_PROD = 1, /**< prod */
HCCL_REDUCE_MAX = 2, /**< max */
HCCL_REDUCE_MIN = 3, /**< min */
HCCL_REDUCE_RESERVED /**< reserved */
} HcclReduceOp;
/**
* @brief HCCL data type
*/
typedef enum {
HCCL_DATA_TYPE_INT8 = 0, /**< int8 */
HCCL_DATA_TYPE_INT16 = 1, /**< int16 */
HCCL_DATA_TYPE_INT32 = 2, /**< int32 */
HCCL_DATA_TYPE_FP16 = 3, /**< fp16 */
HCCL_DATA_TYPE_FP32 = 4, /**< fp32 */
HCCL_DATA_TYPE_INT64 = 5, /**< int64 */
HCCL_DATA_TYPE_UINT64 = 6, /**< uint64 */
HCCL_DATA_TYPE_RESERVED /**< reserved */
} HcclDataType;
const uint32_t HCCL_ROOT_INFO_BYTES = 4108; // 4108: root info length
/**
* @brief HCCL root info
*/
typedef struct HcclRootInfoDef {
char internal[HCCL_ROOT_INFO_BYTES];
} HcclRootInfo;
#ifdef __cplusplus
}
#endif // __cplusplus
#endif // HCCL_TYPES_H_

+ 0
- 2
inc/external/register/register.h View File

@@ -40,7 +40,6 @@ using std::to_string;
using std::unique_ptr;
using std::vector;

/*lint -e148*/
namespace ge {
class Operator;
class TensorDesc;
@@ -159,5 +158,4 @@ namespace ge {
using OpRegistrationData = domi::OpRegistrationData;
using OpReceiver = domi::OpReceiver;
} // namespace ge
/*lint +e148*/
#endif // INC_EXTERNAL_REGISTER_REGISTER_H_

+ 0
- 1
inc/external/register/scope/scope_fusion_pass_register.h View File

@@ -301,7 +301,6 @@ class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistry {
private:
ScopeFusionPassRegistry();
class ScopeFusionPassRegistryImpl;
/*lint -e148*/
std::unique_ptr<ScopeFusionPassRegistryImpl> impl_;
friend class TensorFlowModelParser;
};


+ 1
- 1
inc/framework/common/ge_inner_error_codes.h View File

@@ -14,7 +14,6 @@
* limitations under the License.
*/

/*lint -e* */
#ifndef INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_
#define INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_

@@ -304,6 +303,7 @@ GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_WEIGHT_MEM_FAILED, 16, "Failed to allocate wei
GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_VAR_MEM_FAILED, 17, "Failed to allocate variable memory.");
GE_ERRORNO_EXECUTOR(GE_AIPP_NOT_EXIST, 18, "GE AIPP is not exist.");
GE_ERRORNO_EXECUTOR(GE_DYNAMIC_AIPP_NOT_SUPPORT_QUERY, 19, "GE Dynamic AIPP is not support to query temporarily.");
GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_P2P_MEM_FAILED, 20, "Failed to allocate P2P memory");

// Generator module error code definition
GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, 1, "Graph manager initialize failed.");


+ 0
- 1
inc/framework/common/op/attr_value_util.h View File

@@ -21,7 +21,6 @@
#include <unordered_map>
#include <string>

#include "common/types.h"
#include "graph/debug/ge_attr_define.h"
#include "proto/om.pb.h"



+ 2
- 1
inc/framework/common/op/ge_op_utils.h View File

@@ -22,7 +22,8 @@
#include <vector>

#include "common/op/attr_value_util.h"
#include "common/types.h"
#include "register/register_types.h"
#include "register/register_error_codes.h"
#include "common/util.h"
#include "graph/attr_value.h"
#include "graph/ge_tensor.h"


+ 3
- 5
inc/framework/common/string_util.h View File

@@ -36,8 +36,8 @@ class StringUtils {
#endif
return s;
}
// lint -esym(551,*)
static std::string &Rtrim(std::string &s) { /*lint !e618*/
static std::string &Rtrim(std::string &s) {
#if __cplusplus >= 201103L
(void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); }));
#else
@@ -45,7 +45,7 @@ class StringUtils {
#endif
return s;
}
// lint -esym(551,*)
///
/// @ingroup domi_common
/// @brief delete spaces at the beginning and end of a string
@@ -61,10 +61,8 @@ class StringUtils {
/// @param [in] delim separator
/// @return string array after segmentation
///
/*lint -e1077*/
static std::vector<std::string> Split(const std::string &str, char delim) {
std::vector<std::string> elems;
/*lint +e1077*/

if (str.empty()) {
elems.emplace_back("");


+ 1
- 0
inc/framework/common/types.h View File

@@ -434,6 +434,7 @@ REGISTER_OPTYPE_DECLARE(HCOMREDUCESCATTER, "HcomReduceScatter");
REGISTER_OPTYPE_DECLARE(HCOMSEND, "HcomSend");
REGISTER_OPTYPE_DECLARE(HCOMRECEIVE, "HcomReceive");
REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead");
REGISTER_OPTYPE_DECLARE(HCOMREMOTEREFREAD, "HcomRemoteRefRead");
REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite");

REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign");


+ 1
- 1
inc/framework/common/util.h View File

@@ -345,7 +345,7 @@ std::string ToString(const google::protobuf::RepeatedField<T> &rpd_field) {
/// @return Timestamp, in microseconds (US)
///
///
uint64_t GetCurrentTimestap();
uint64_t GetCurrentTimestamp();

///
/// @ingroup domi_common


+ 1
- 0
inc/framework/engine/dnnengine.h View File

@@ -30,6 +30,7 @@ enum PriorityEnum {
COST_0 = 0,
COST_1,
COST_2,
COST_3,
COST_9 = 9,
COST_10 = 10,
};


+ 1
- 0
inc/framework/generator/ge_generator.h View File

@@ -86,6 +86,7 @@ class GeGenerator {
Status BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
bool is_offline = true);
Status CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs);

class Impl;



+ 14
- 0
inc/framework/memory/memory_api.h View File

@@ -21,6 +21,7 @@
#include <vector>

#include "ge/ge_api_error_codes.h"
#include "graph//types.h"
#include "runtime/mem.h"

namespace ge {
@@ -35,6 +36,12 @@ struct HostVarInfo {
uint64_t var_size;
};

struct TensorInfo {
std::string var_name;
std::vector<int64_t> dims;
DataType data_type;
};

///
/// \param size [in] rdma pool memory size to be allocated.
/// \param mem_type [in] memory type for rdma pool.
@@ -48,6 +55,13 @@ Status InitRdmaPool(size_t size, rtMemType_t mem_type = RT_MEMORY_HBM);
Status RdmaRemoteRegister(const std::vector<HostVarInfo> &var_info, rtMemType_t mem_type = RT_MEMORY_HBM);

///
/// \param tensor_info [in] description for tensor stored shared memory.
/// \param dev_addr [out] malloced shared memory addr.
/// \param memory_size [out] malloced shared memory size.
/// \return Status result of function
Status MallocSharedMemory(const TensorInfo &tensor_info, uint64_t &dev_addr, uint64_t &memory_size);

///
/// \param var_name [in] var_name name of host variable.
/// \param base_addr [out] base_addr vase addr of host variable.
/// \param var_size [out] var_size memory_size of host variable.


+ 1
- 1
inc/framework/memory/memory_assigner.h View File

@@ -33,7 +33,7 @@ class MemoryAssigner {

MemoryAssigner &operator=(const MemoryAssigner &) = delete;

Status AssignMemory(bool is_loop_graph, size_t &mem_offset, size_t &zero_copy_mem_size);
Status AssignMemory(bool is_loop_graph, map<int64_t, size_t> &mem_offset, size_t &zero_copy_mem_size);

private:
ge::ComputeGraphPtr compute_graph_;


+ 0
- 3
inc/framework/omg/omg.h View File

@@ -21,7 +21,6 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "framework/common/types.h"
#include "framework/omg/omg_inner_types.h"
#include "framework/omg/parser/parser_inner_ctx.h"
#include "proto/ge_ir.pb.h"
@@ -92,8 +91,6 @@ void GetGroupName(ge::proto::ModelDef &model);

void FindParserSo(const string &path, vector<string> &fileList, string &caffe_parser_path);

Status CheckCustomAiCpuOpLib();

Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file);

Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format);


+ 2
- 3
inc/framework/omg/omg_inner_types.h View File

@@ -25,7 +25,6 @@
#include <utility>
#include <vector>
#include "framework/common/fmk_error_codes.h"
#include "framework/common/types.h"
#include "register/register_fmk_types.h"

using domi::DOMI_TENSOR_ND;
@@ -92,6 +91,8 @@ struct OmgContext {
std::map<std::string, std::vector<int32_t>> out_nodes_map;
// user-designate out nodes (this is used for determing the orders)
std::vector<std::pair<std::string, int32_t>> user_out_nodes;
// default out nodes (this is used for determing the orders)
std::vector<std::pair<std::string, int32_t>> default_out_nodes;
// save the output node of the network, value = topName,
// topName indicates the output name of the operator.
std::vector<std::string> user_out_nodes_top_vec;
@@ -99,8 +100,6 @@ struct OmgContext {
std::vector<std::string> net_out_nodes;
// net out nodes top names(only caffe has top)
std::vector<std::string> out_top_names;
// path for the aicpu custom operator so_file
std::vector<std::string> aicpu_op_run_paths;
// preferential format used by the entire network
domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED;
domi::FrameworkType type = domi::FRAMEWORK_RESERVED;


+ 3
- 3
inc/graph/buffer.h View File

@@ -57,11 +57,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer {

// For compatibility
inline const std::uint8_t *data() const { return GetData(); }
inline std::uint8_t *data() { return GetData(); } // lint !e659
inline std::uint8_t *data() { return GetData(); }
inline std::size_t size() const { return GetSize(); }
inline void clear() { return ClearBuffer(); }
uint8_t operator[](size_t index) const { // lint !e1022 !e1042
if (buffer_ != nullptr && index < buffer_->size()) { // lint !e574
uint8_t operator[](size_t index) const {
if (buffer_ != nullptr && index < buffer_->size()) {
return (uint8_t)(*buffer_)[index];
}
return 0xff;


+ 0
- 2
inc/graph/compute_graph.h View File

@@ -84,7 +84,6 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A

NodePtr FindNode(const std::string &name) const;
NodePtr FindFirstNodeMatchType(const std::string &name) const;
/*lint -e504*/
// AddNode with NodePtr
NodePtr AddNode(NodePtr node);
NodePtr AddNode(OpDescPtr op);
@@ -152,7 +151,6 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A
graphStatus InsertEventNodes();
bool operator==(const ComputeGraph &r_compute_graph) const;

/*lint +e504*/
const std::map<std::vector<std::string>, std::vector<std::string>> &GetShareParamLayer() const {
return params_share_map_;
}


+ 12
- 4
inc/graph/debug/ge_attr_define.h View File

@@ -14,7 +14,6 @@
* limitations under the License.
*/

/*lint -e618*/
#ifndef INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_
#define INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_

@@ -33,6 +32,8 @@ namespace ge {
#define GE_FUNC_DEV_VISIBILITY
#endif
// Public attribute
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FORCE_UNKNOWN_SHAPE;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_UNKNOWN_SHAPE;

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED;
@@ -1021,8 +1022,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_KEY;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_VIRTUAL_OP;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_TYPE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION;
@@ -1044,6 +1043,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TBE_KERNEL_NAME;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TBE_KERNEL_BUFFER;

// used for memory allocate
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WORKSPACE_TYPE_LIST;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TENSOR_MEM_TYPE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_P2P_MEMORY_SIZE;

// for unregistered op
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_OPPATH;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_ATTRLIST;
@@ -1121,10 +1127,12 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_VAR
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INPUT_MEMORY_TYPE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OUTPUT_MEMORY_TYPE;

// stage
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_STAGE_LEVEL;

// input_output_offset
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_BASIC_OFFSET;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET;
} // namespace ge

#endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_
/*lint +e618*/

+ 1
- 1
inc/graph/detail/any_map.h View File

@@ -38,7 +38,7 @@ class TypeID {
bool operator==(const TypeID &__arg) const { return type_ == __arg.type_; }

private:
explicit TypeID(string type) : type_(std::move(type)) {} // lint !e30 !e32
explicit TypeID(string type) : type_(std::move(type)) {}

string type_;
};


+ 2
- 2
inc/graph/detail/attributes_holder.h View File

@@ -50,7 +50,7 @@ class OpDef;
class GraphDef;
} // namespace proto

using ProtoAttrMap = ::google::protobuf::Map<::std::string, ::ge::proto::AttrDef>; // lint !e1073
using ProtoAttrMap = ::google::protobuf::Map<::std::string, ::ge::proto::AttrDef>;
using ProtoMsgOwner = std::shared_ptr<::google::protobuf::Message>;

template <class ProtoType>
@@ -147,7 +147,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder {
protected:
graphStatus AddRequiredAttr(const std::string &name);
const std::unordered_set<string> GetAllAttrNames() const;
const std::map<string, GeAttrValue> GetAllAttrs() const; // lint !e1073
const std::map<string, GeAttrValue> GetAllAttrs() const;

virtual ProtoAttrMapHelper MutableAttrMap() = 0;
virtual ConstProtoAttrMapHelper GetAttrMap() const = 0;


+ 3
- 3
inc/graph/ge_attr_value.h View File

@@ -310,7 +310,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue {
VALUE_SET_GET_DEC(GeAttrValue::GRAPH)
VALUE_SET_GET_DEC(BYTES)
VALUE_SET_GET_DEC(NamedAttrs)
VALUE_SET_GET_DEC(ge::DataType) // lint !e665
VALUE_SET_GET_DEC(ge::DataType)
VALUE_SET_GET_DEC(vector<GeAttrValue::STR>)
VALUE_SET_GET_DEC(vector<GeAttrValue::INT>)
VALUE_SET_GET_DEC(vector<GeAttrValue::FLOAT>)
@@ -320,8 +320,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue {
VALUE_SET_GET_DEC(vector<GeAttrValue::GRAPH>)
VALUE_SET_GET_DEC(vector<GeAttrValue::BYTES>)
VALUE_SET_GET_DEC(vector<NamedAttrs>)
VALUE_SET_GET_DEC(vector<vector<int64_t>>) // lint !e665
VALUE_SET_GET_DEC(vector<ge::DataType>) // lint !e665
VALUE_SET_GET_DEC(vector<vector<int64_t>>)
VALUE_SET_GET_DEC(vector<ge::DataType>)
#undef VALUE_SET_GET_DEC

GeIrProtoHelper<proto::AttrDef> value_;


+ 1
- 1
inc/graph/ge_context.h View File

@@ -33,7 +33,7 @@ class GEContext {
void SetCtxDeviceId(uint32_t device_id);

private:
uint64_t session_id_ = 0;
thread_local static uint64_t session_id_;
uint32_t device_id_ = 0;
uint64_t trace_id_ = 0;
}; // class GEContext


+ 5
- 0
inc/graph/ge_local_context.h View File

@@ -33,6 +33,11 @@ class GEThreadLocalContext {
void SetSessionOption(map<std::string, string> options_map);
void SetGlobalOption(map<std::string, string> options_map);

map<string, string> GetAllGraphOptions() const;
map<string, string> GetAllSessionOptions() const;
map<string, string> GetAllGlobalOptions() const;
map<string, string> GetAllOptions() const;

private:
map<string, string> graph_options_;
map<string, string> session_options_;


+ 1
- 1
inc/graph/node.h View File

@@ -193,7 +193,7 @@ class Node : public std::enable_shared_from_this<Node> {
vector<OutDataAnchorPtr> out_data_anchors_;
InControlAnchorPtr in_control_anchor_;
OutControlAnchorPtr out_control_anchor_;
map<string, GeAttrValue> attrs_; // lint !e1073
map<string, GeAttrValue> attrs_;
bool has_init_{false};
bool host_node_{false};
bool anchor_status_updated_{false};


+ 0
- 4
inc/graph/range_vistor.h View File

@@ -22,10 +22,8 @@
template <class E, class O>
class RangeVistor {
public:
/*lint -e151*/
using Iterator = typename std::vector<E>::iterator;
using ConstIterator = typename std::vector<E>::const_iterator;
/*lint +e151*/

RangeVistor(O owner, const std::vector<E> &vs) : owner_(owner), elements_(vs) {}

@@ -43,9 +41,7 @@ class RangeVistor {

bool empty() const { return elements_.empty(); }

/*lint -e659*/
E &at(std::size_t index) { return elements_.at(index); }
/*lint +e659*/

const E &at(std::size_t index) const { return elements_.at(index); }



+ 39
- 5
inc/graph/utils/graph_utils.h View File

@@ -19,18 +19,18 @@

#include <fstream>
#include <iostream>
#include <list>
#include <map>
#include <string>
#include <vector>
#include <list>
#include <unordered_map>
#include <vector>

#include "graph/anchor.h"
#include "graph/node.h"
#include "graph/compute_graph.h"
#include "graph/utils/anchor_utils.h"
#include "graph/graph.h"
#include "graph/model.h"
#include "graph/node.h"
#include "graph/utils/anchor_utils.h"

#define GE_DUMP(compute_graph, name) \
do { \
@@ -206,6 +206,8 @@ class GraphUtils {
static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false,
const std::string &user_graph_name = "");

static void DumpGEGrph(const ge::ComputeGraphPtr &graph, const std::string &path, const std::string &suffix);

static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph);

static bool LoadGEGraph(const char *file, ge::ComputeGraphPtr &compute_graph);
@@ -214,6 +216,8 @@ class GraphUtils {

static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix);

static void DumpGrphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &path, const std::string &suffix);

static bool LoadGEGraphFromOnnx(const char *file, ge::ComputeGraph &compute_graph);

static bool ReadProtoFromTextFile(const char *file, google::protobuf::Message *message);
@@ -559,7 +563,8 @@ class ComputeGraphBuilder {

class CompleteGraphBuilder : public ComputeGraphBuilder {
public:
explicit CompleteGraphBuilder(std::string name) : name_(std::move(name)), parent_node_(nullptr) {}
explicit CompleteGraphBuilder(std::string name, bool retval_flag = true)
: name_(std::move(name)), parent_node_(nullptr), retval_flag_(retval_flag) {}
CompleteGraphBuilder(const CompleteGraphBuilder &) = delete;
CompleteGraphBuilder &operator=(const CompleteGraphBuilder &) = delete;
CompleteGraphBuilder(const CompleteGraphBuilder &&) = delete;
@@ -687,8 +692,37 @@ class CompleteGraphBuilder : public ComputeGraphBuilder {
///
void BuildGraphTargets(graphStatus &error_code, std::string &error_msg);

///
/// @brief Add NetOutput node
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void AddNetOutputNode(graphStatus &error_code, std::string &error_msg);

///
/// @brief Build NetOutput nodes with data & ctrl edges
/// @param [in] net_output_desc
/// @param [in] peer_out_anchors
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void BuildNetOutputNodeWithLink(const OpDescPtr &net_output_desc,
const std::vector<OutDataAnchorPtr> &peer_out_anchors, graphStatus &error_code,
std::string &error_msg);

///
/// @brief process after build
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void PostProcess(graphStatus &error_code, std::string &error_msg);

std::string name_;
NodePtr parent_node_;
bool retval_flag_;
std::map<uint32_t, std::pair<std::vector<std::string>, std::vector<uint32_t>>> graph_inputs_;
std::vector<std::pair<std::string, uint32_t>> graph_outputs_;
std::vector<std::string> graph_targets_;


+ 32
- 0
inc/graph/utils/node_adapter.h View File

@@ -0,0 +1,32 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef INC_GRAPH_UTILS_NODE_ADAPTER_H_
#define INC_GRAPH_UTILS_NODE_ADAPTER_H_

#include "graph/gnode.h"
#include "graph/node.h"

namespace ge {
using NodePtr = std::shared_ptr<Node>;
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodeAdapter {
public:
static GNode Node2GNode(const NodePtr &node);
static NodePtr GNode2Node(const GNode &node);
static GNodePtr Node2GNodePtr(const NodePtr &node);
};
} // namespace ge
#endif // INC_GRAPH_UTILS_NODE_ADAPTER_H_

+ 8
- 0
inc/graph/utils/node_utils.h View File

@@ -83,6 +83,7 @@ class NodeUtils {
static std::string GetNodeType(const Node &node);
static std::string GetNodeType(const NodePtr &node);

static std::vector<ComputeGraphPtr> GetAllSubgraphs(const Node &node);
static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index);
static graphStatus SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph);

@@ -162,6 +163,13 @@ class NodeUtils {

static graphStatus GetInputConstData(const Node &node, const string &dst_name, GeTensorPtr &ge_tensor);

///
/// @brief Get node type in cross subgragh.
/// @param [in] node
/// @return type
///
static std::string GetInConstNodeTypeCrossSubgraph(const ge::NodePtr &node);

private:
static std::map<NodePtr, std::vector<uint32_t>> map_send_info_;
static std::map<NodePtr, std::vector<uint32_t>> map_recv_info_;


inc/common/util/ai_core/param_calculate/aicore_param_calculator.h → src/common/graph/ascend_string.cc View File

@@ -14,20 +14,20 @@
* limitations under the License.
*/

#ifndef AICORE_PARAM_CALCULATOR
#define AICORE_PARAM_CALCULATOR
#include "external/graph/ascend_string.h"

#include "graph/node.h"
#include "graph_optimizer/graph_optimize_register_error_codes.h"
namespace ge {
AscendString::AscendString(const char* name) {
if (name != nullptr) {
name_ = std::shared_ptr<std::string>(new (std::nothrow) std::string(name));
}
}

namespace fe {
class AICoreParamCalculator {
public:
AICoreParamCalculator();
const char* AscendString::GetString() const {
if (name_ == nullptr) {
return nullptr;
}

~AICoreParamCalculator();

Status CalcOpRunningParam(ge::Node &node);
};
} // namespace fe
#endif // AICORE_PARAM_CALCULATOR
return (*name_).c_str();
}
} // namespace ge

+ 16
- 17
src/common/graph/format_refiner.cc View File

@@ -41,6 +41,7 @@ using namespace ge;
using namespace std;
namespace ge {
namespace {
const size_t kDimSize4d = 4;
const std::unordered_set<string> kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE};
const string kIsGraphInferred = "_is_graph_inferred";
thread_local RefRelations reflection_builder;
@@ -410,28 +411,26 @@ graphStatus FormatRefiner::DataNodeFormatProcess(const ComputeGraphPtr &graph, s
GE_CHECK_NOTNULL(data_node);
auto op_desc = data_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(0));
auto curr_format = op_desc->GetOutputDescPtr(0)->GetOriginFormat();

auto input_desc = op_desc->MutableInputDesc(0);
auto output_desc = op_desc->MutableOutputDesc(0);
GE_CHECK_NOTNULL(input_desc);
GE_CHECK_NOTNULL(output_desc);

auto curr_format = output_desc->GetOriginFormat();
if (curr_format != FORMAT_ND) {
// Data format has been infered , continue
continue;
}
// Set format for un-infered data node
auto input_descs = op_desc->GetAllInputsDescPtr();
auto output_descs = op_desc->GetAllOutputsDescPtr();

for (const auto &input_desc : input_descs) {
if (input_desc != nullptr) {
input_desc->SetOriginFormat(data_format);
input_desc->SetFormat(data_format);
}
}
for (const auto &output_desc : output_descs) {
if (output_desc != nullptr) {
output_desc->SetOriginFormat(data_format);
output_desc->SetFormat(data_format);
}
// keep data format be ND because lacking of defination when input shape num is smaller than 4
if (input_desc->MutableShape().GetDimNum() < kDimSize4d) {
continue;
}
// Set format for un-infered data node
input_desc->SetOriginFormat(data_format);
input_desc->SetFormat(data_format);
output_desc->SetOriginFormat(data_format);
output_desc->SetFormat(data_format);
uninfered_data_nodes.push_back(data_node);
}
// Reinfer format from uninfered data nodes


+ 13
- 2
src/common/graph/ge_attr_define.cc View File

@@ -18,6 +18,8 @@

namespace ge {
// Public attribute
const std::string ATTR_NAME_FORCE_UNKNOWN_SHAPE = "_force_unknown_shape";

const std::string ATTR_NAME_IS_UNKNOWN_SHAPE = "_is_unknown_shape";

const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED = "_dynamic_shape_partitioned";
@@ -718,6 +720,8 @@ const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size";

const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE = "zero_copy_memory_size";

const std::string ATTR_MODEL_P2P_MEMORY_SIZE = "p2p_memory_size";

const std::string ATTR_MODEL_OUT_NODES_NAME = "attr_model_out_nodes_name";

const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size";
@@ -957,8 +961,6 @@ const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key";
const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key";
const std::string ATTR_NAME_FUSION_VIRTUAL_OP = "_fusion_virtual_op";
const std::string ATTR_NAME_FUSION_GROUP_TYPE = "_fusion_group_type";
const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST = "_input_memory_type";
const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST = "_output_memory_type";
const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR = "_l1_fusion_extend_content";
const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE = "_tensor_actual_size";
const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1_fuison";
@@ -980,6 +982,12 @@ const std::string ATTR_NAME_OP_COMPILE_STRATEGY = "_op_compile_strategy";
const std::string ATTR_NAME_TBE_KERNEL_NAME = "_tbe_kernel_name";
const std::string ATTR_NAME_TBE_KERNEL_BUFFER = "_tbe_kernel_buffer";

// used for memory allocate
const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST = "_input_memory_type";
const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST = "_output_memory_type";
const std::string ATTR_NAME_WORKSPACE_TYPE_LIST = "_workspace_type";
const std::string ATTR_NAME_TENSOR_MEM_TYPE = "_tensor_memory_type";

// Op debug attrs
const std::string ATTR_OP_DEBUG_FLAG = "_op_debug_flag";
const std::string ATTR_OP_DEBUG_MODE = "_op_debug_mode";
@@ -1080,6 +1088,9 @@ const std::string ATTR_VARIABLE_PLACEMENT = "_variable_placement";
const std::string ATTR_INPUT_MEMORY_TYPE = "_input_memory_type";
const std::string ATTR_OUTPUT_MEMORY_TYPE = "_output_memory_type";

// stage
const std::string ATTR_STAGE_LEVEL = "_stage_level";

// input_output_offset
const std::string ATTR_ZERO_COPY_BASIC_OFFSET = "_zero_copy_basic_offset";
const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET = "_zero_copy_relative_offset";


+ 12
- 17
src/common/graph/ge_attr_value.cc View File

@@ -33,8 +33,7 @@ using std::vector;
namespace ge {
NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); }

NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg)
: named_attrs_(owner, proto_msg) {} // lint !e1744
NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) : named_attrs_(owner, proto_msg) {}

void NamedAttrs::SetName(const std::string &name) {
auto proto_msg = named_attrs_.GetProtoMsg();
@@ -239,7 +238,7 @@ ATTR_VALUE_SET_GET_IMP(GeAttrValue::STR)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::STR>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::INT)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::INT>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) // lint !e524
ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::FLOAT>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::BOOL)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BOOL>)
@@ -253,11 +252,9 @@ ATTR_VALUE_SET_GET_IMP(GeAttrValue::BYTES)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BYTES>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::NAMED_ATTRS)
ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::NAMED_ATTRS>)
/*lint -e665*/
ATTR_VALUE_SET_GET_IMP(vector<vector<int64_t>>)
/*lint +e665*/
ATTR_VALUE_SET_GET_IMP(vector<DataType>) // lint !e665
ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) // lint !e665
ATTR_VALUE_SET_GET_IMP(vector<DataType>)
ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE)

#undef ATTR_VALUE_SET_GET_IMP

@@ -785,14 +782,14 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM
if (graph_def == nullptr) {
GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed");
graph_def = nullptr;
return false; // lint !e665
return false;
} else {
ModelSerializeImp imp;
imp.SetProtobufOwner(graph_def);
if (!imp.UnserializeGraph(graph, *graph_def)) {
GELOGE(GRAPH_FAILED, "UnserializeGraph Failed");
return false;
} // lint !e514
}
value = graph;
}
return true;
@@ -812,7 +809,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM
if (graph_def == nullptr) {
GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed");
graph_def = nullptr;
return false; // lint !e665
return false;
} else {
ComputeGraphPtr graph = nullptr;
ModelSerializeImp imp;
@@ -820,7 +817,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM
if (!imp.UnserializeGraph(graph, *graph_def)) {
GELOGE(GRAPH_FAILED, "UnserializeGraph Failed");
return false;
} // lint !e514
}
value.push_back(graph);
}
}
@@ -972,9 +969,7 @@ ATTR_UTILS_SET_IMP(Tensor, GeTensor)
ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS)
ATTR_UTILS_SET_GET_IMP(Bytes, Buffer)
ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr)
/*lint -e665*/
ATTR_UTILS_SET_GET_IMP(ListListInt, vector<vector<int64_t>>)
/*lint +e665*/

ATTR_UTILS_SET_GET_IMP(ListInt, vector<int64_t>)
ATTR_UTILS_SET_IMP(ListInt, vector<int32_t>)
@@ -989,8 +984,8 @@ ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensor>)
ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector<GeAttrValue::NAMED_ATTRS>)
ATTR_UTILS_SET_GET_IMP(ListBytes, vector<Buffer>)
ATTR_UTILS_SET_GET_IMP(ListGraph, vector<ComputeGraphPtr>)
ATTR_UTILS_SET_GET_IMP(ListDataType, vector<ge::DataType>) // lint !e665
ATTR_UTILS_SET_GET_IMP(DataType, ge::DataType) // lint !e665
ATTR_UTILS_SET_GET_IMP(ListDataType, vector<ge::DataType>)
ATTR_UTILS_SET_GET_IMP(DataType, ge::DataType)

bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, const string &name,
std::initializer_list<ConstGeTensorPtr> &&value) {
@@ -1159,7 +1154,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListOpDesc(Con
}
for (const auto &item : bytes_vals) {
ModelSerialize serialize;
auto op_desc = serialize.UnserializeOpDesc(item.GetData(), item.GetSize()); // lint !e732
auto op_desc = serialize.UnserializeOpDesc(item.GetData(), item.GetSize());
value.push_back(op_desc);
}
return true;
@@ -1211,7 +1206,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc(
op_def = ComGraphMakeShared<proto::OpDef>();
if (op_def == nullptr) {
GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed");
return nullptr; // lint !e665
return nullptr;
}
ModelSerializeImp imp;
(void)imp.SerializeOpDesc(org_op_desc, op_def.get());


+ 857
- 0
src/common/graph/gnode.cc View File

@@ -0,0 +1,857 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "graph/gnode.h"

#include <utility>
#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/anchor.h"
#include "graph/node.h"
#include "graph/utils/node_adapter.h"
#include "graph/utils/tensor_adapter.h"
#include <graph/utils/graph_utils.h>
#include "graph/debug/ge_attr_define.h"
#include "utils/node_utils.h"
#include "utils/op_desc_utils.h"

namespace ge {
class NodeImpl {
public:
NodeImpl() = default;
~NodeImpl() = default;

NodeImpl(NodeImpl &) = delete;
NodeImpl &operator=(const NodeImpl &) = delete;

std::weak_ptr<Node> node_ptr_;
};

NodePtr NodeAdapter::GNode2Node(const ge::GNode &graph_node) {
if (graph_node.impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "GNode2Node: gnode impl is nullptr.");
return nullptr;
}

return graph_node.impl_->node_ptr_.lock();
}

GNode NodeAdapter::Node2GNode(const ge::NodePtr &node) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "Node2GNode: node is nullptr");
return GNode();
}

GNode graph_node;
if (graph_node.impl_ == nullptr) {
GELOGW("Node2GNode: gnode impl is nullptr, node[%s].", node->GetName().c_str());
return graph_node;
}
graph_node.impl_->node_ptr_ = node;

return graph_node;
}

GNodePtr NodeAdapter::Node2GNodePtr(const ge::NodePtr &node) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "Node2GNodePtr: node is nullptr");
return nullptr;
}

GNodePtr gnode = std::shared_ptr<GNode>(new (std::nothrow) GNode());
if (gnode == nullptr) {
GELOGE(GRAPH_FAILED, "Node2GNodePtr: gnode is nullptr, node[%s].", node->GetName().c_str());
return nullptr;
}

if (gnode->impl_ == nullptr) {
GELOGW("Node2GNode: gnode impl is nullptr, node[%s].", node->GetName().c_str());
return nullptr;
}
gnode->impl_->node_ptr_ = node;

return gnode;
}

GNode::GNode() { impl_ = ComGraphMakeShared<NodeImpl>(); }

graphStatus GNode::GetType(ge::AscendString &type) const {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "GetType: node impl is nullptr.");
return GRAPH_FAILED;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "GetType: the shared ptr is not valid.");
return GRAPH_FAILED;
}
std::string node_type = node_ptr->GetType();
AscendString ascend_type(node_type.c_str());
type = ascend_type;

return GRAPH_SUCCESS;
}

graphStatus GNode::GetName(ge::AscendString &name) const {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "GetName: node impl is nullptr.");
return GRAPH_FAILED;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "GetName: the shared ptr is not valid.");
return GRAPH_FAILED;
}
std::string node_name = node_ptr->GetName();
AscendString ascend_name(node_name.c_str());
name = ascend_name;

return GRAPH_SUCCESS;
}

std::pair<GNodePtr, int32_t> GNode::GetInDataNodesAndPortIndexs(const int32_t index) const {
pair<GNodePtr, int32_t> gnode_idx = {nullptr, 0xFF};
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "Gnode: node impl is nullptr.");
return gnode_idx;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "Gnode: the shared ptr is not valid.");
return gnode_idx;
}

auto in_anchor = node_ptr->GetInDataAnchor(index);
if (in_anchor == nullptr) {
GELOGE(GRAPH_FAILED, "Failed to get in data node of index[%d] from node[%s], the anchor does not exist", index,
node_ptr->GetName().c_str());
return gnode_idx;
}

auto out_anchor = in_anchor->GetPeerOutAnchor();
if (out_anchor == nullptr) {
GELOGE(GRAPH_FAILED, "Failed to get in data node of index[%d] from node [%s], the data input does not exist", index,
node_ptr->GetName().c_str());
return gnode_idx;
}

NodePtr peer_node_ptr = out_anchor->GetOwnerNode();
GNodePtr gnode = NodeAdapter::Node2GNodePtr(peer_node_ptr);
if (gnode == nullptr) {
GELOGE(GRAPH_FAILED, "Peer node of node[%s] to gnode faild.", node_ptr->GetName().c_str());
return gnode_idx;
}

return {gnode, out_anchor->GetIdx()};
}

std::vector<GNodePtr> GNode::GetInControlNodes() const {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "Gnode: node impl is nullptr.");
return {};
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "Gnode: the shared ptr is not valid.");
return {};
}

std::vector<GNodePtr> gnodes;
auto in_control_nodes = node_ptr->GetInControlNodes();
for (auto &in_control_node : in_control_nodes) {
GNodePtr gnode = NodeAdapter::Node2GNodePtr(in_control_node);
if (gnode == nullptr) {
GELOGE(GRAPH_FAILED, "In control_node of node[%s] to gnode faild.", node_ptr->GetName().c_str());
return {};
}
gnodes.emplace_back(gnode);
}

return gnodes;
}

std::vector<std::pair<GNodePtr, int32_t>> GNode::GetOutDataNodesAndPortIndexs(const int32_t index) const {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "Gnode: node impl is nullptr.");
return {};
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "Gnode: the shared ptr is not valid.");
return {};
}

auto out_anchor = node_ptr->GetOutDataAnchor(index);
if (out_anchor == nullptr) {
GELOGE(GRAPH_FAILED, "Failed to get out data node of index %d from node %s, the anchor does not exists", index,
node_ptr->GetName().c_str());
return {};
}

vector<std::pair<GNodePtr, int32_t>> gnode_index;
auto in_data_anchors = out_anchor->GetPeerInDataAnchors();
for (auto &in_data_anchor : in_data_anchors) {
if (in_data_anchor == nullptr) {
GELOGE(GRAPH_FAILED, "In data anchor of node[%s] is nullptr.", node_ptr->GetName().c_str());
return {};
}
NodePtr peer_node_ptr = in_data_anchor->GetOwnerNode();
GNodePtr gnode = NodeAdapter::Node2GNodePtr(peer_node_ptr);
if (gnode == nullptr) {
GELOGE(GRAPH_FAILED, "Peer node of node[%s] to gnode faild.", node_ptr->GetName().c_str());
return {};
}
gnode_index.emplace_back(std::pair<GNodePtr, int32_t>(gnode, in_data_anchor->GetIdx()));
}

return gnode_index;
}

std::vector<GNodePtr> GNode::GetOutControlNodes() const {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "GetOutControlNodes: node impl is nullptr.");
return {};
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "GetOutControlNodes: the node shared ptr is not valid.");
return {};
}

std::vector<GNodePtr> gnodes;
auto out_control_nodes = node_ptr->GetOutControlNodes();
for (auto &out_control_node : out_control_nodes) {
GNodePtr gnode = NodeAdapter::Node2GNodePtr(out_control_node);
if (gnode == nullptr) {
GELOGE(GRAPH_FAILED, "In control_node of node[%s] to gnode faild.", node_ptr->GetName().c_str());
return {};
}
gnodes.emplace_back(gnode);
}

return gnodes;
}

graphStatus GNode::GetInputConstData(const int32_t index, Tensor &data) const {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "GetInputConstData: node impl is nullptr.");
return GRAPH_FAILED;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "GetInputConstData: the node shared ptr is not valid.");
return GRAPH_FAILED;
}

NodePtr input_data_node = NodeUtils::GetInDataNodeByIndex(*node_ptr, index);
bool is_const = NodeUtils::IsConst(*input_data_node);
if (!is_const) {
GELOGE(GRAPH_NODE_WITHOUT_CONST_INPUT, "Node[%s] has no const input.", node_ptr->GetName().c_str());
return GRAPH_NODE_WITHOUT_CONST_INPUT;
}

Operator const_op = OpDescUtils::CreateOperatorFromNode(input_data_node);
if (const_op.GetAttr(ATTR_NAME_WEIGHTS, data) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Input data node[%s] of node[%s] get data failed.", input_data_node->GetName().c_str(),
node_ptr->GetName().c_str());
return GRAPH_FAILED;
}

return GRAPH_SUCCESS;
}

graphStatus GNode::GetInputIndexByName(const ge::AscendString &name, int32_t &index) {
const char *ascend_name = name.GetString();
if (ascend_name == nullptr) {
GELOGE(GRAPH_PARAM_INVALID, "GetInputIndexByName: ascend string error.");
return GRAPH_PARAM_INVALID;
}

if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "GetInputIndexByName: node impl is nullptr.");
return GRAPH_FAILED;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "GetInputIndexByName: the node shared ptr is not valid.");
return GRAPH_FAILED;
}

OpDescPtr op_desc = node_ptr->GetOpDesc();
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}

std::string node_name = ascend_name;
index = op_desc->GetInputIndexByName(node_name);

return GRAPH_SUCCESS;
}

graphStatus GNode::GetOutputIndexByName(const ge::AscendString &name, int32_t &index) {
const char *ascend_name = name.GetString();
if (ascend_name == nullptr) {
GELOGE(GRAPH_PARAM_INVALID, "GetOutputIndexByName: ascend string error.");
return GRAPH_PARAM_INVALID;
}

if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "GetOutputIndexByName: node impl is nullptr.");
return GRAPH_FAILED;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "GetOutputIndexByName: the node shared ptr is not valid.");
return GRAPH_FAILED;
}

OpDescPtr op_desc = node_ptr->GetOpDesc();
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}

std::string node_name = ascend_name;
index = op_desc->GetOutputIndexByName(node_name);

return GRAPH_SUCCESS;
}

size_t GNode::GetInputsSize() const {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "GetInputsSize: node impl is nullptr.");
return GRAPH_FAILED;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "GetInputsSize: the node shared ptr is not valid.");
return GRAPH_FAILED;
}

OpDescPtr op_desc = node_ptr->GetOpDesc();
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}

return op_desc->GetInputsSize();
}

size_t GNode::GetOutputsSize() const {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "GetOutputsSize: node impl is nullptr.");
return GRAPH_FAILED;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "GetOutputsSize: the shared ptr is not valid.");
return GRAPH_FAILED;
}

OpDescPtr op_desc = node_ptr->GetOpDesc();
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}

return op_desc->GetOutputsSize();
}

graphStatus GNode::GetInputDesc(const int32_t index, TensorDesc &tensor_desc) const {
if (index < 0) {
GELOGE(GRAPH_PARAM_INVALID, "GetInputDesc: index[%d] cannot be less than zero.", index);
return GRAPH_PARAM_INVALID;
}

if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "GetInputDesc: node impl is nullptr.");
return GRAPH_FAILED;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "GetInputDesc: the node shared ptr is not valid.");
return GRAPH_FAILED;
}

OpDescPtr op_desc = node_ptr->GetOpDesc();
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}

ConstGeTensorDescPtr ge_tensor_desc = op_desc->GetInputDescPtr(static_cast<uint32_t>(index));
if (ge_tensor_desc == nullptr) {
GELOGE(GRAPH_FAILED, "Get tensor desc of node[%s] failed.", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}
tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(*ge_tensor_desc);

return GRAPH_SUCCESS;
}

graphStatus GNode::UpdateInputDesc(const int32_t index, const TensorDesc &tensor_desc) {
if (index < 0) {
GELOGE(GRAPH_PARAM_INVALID, "UpdateInputDesc: index[%d] cannot be less than zero.", index);
return GRAPH_PARAM_INVALID;
}

if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "UpdateInputDesc: node impl is nullptr.");
return GRAPH_FAILED;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "UpdateInputDesc: the node shared ptr is not valid.");
return GRAPH_FAILED;
}

OpDescPtr op_desc = node_ptr->GetOpDesc();
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}

GeTensorDesc ge_tensor_desc = TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc);
if (op_desc->UpdateInputDesc(static_cast<uint32_t>(index), ge_tensor_desc) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Update input desc of node[%s] failed.", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}

return GRAPH_SUCCESS;
}

graphStatus GNode::GetOutputDesc(const int32_t index, TensorDesc &tensor_desc) const {
if (index < 0) {
GELOGE(GRAPH_PARAM_INVALID, "GetOutputDesc: index[%d] cannot be less than zero.", index);
return GRAPH_PARAM_INVALID;
}

if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "GetOutputDesc: node impl is nullptr.");
return GRAPH_FAILED;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "GetOutputDesc: the node shared ptr is not valid.");
return GRAPH_FAILED;
}

OpDescPtr op_desc = node_ptr->GetOpDesc();
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}

ConstGeTensorDescPtr ge_tensor_desc = op_desc->GetOutputDescPtr(static_cast<uint32_t>(index));
if (ge_tensor_desc == nullptr) {
GELOGE(GRAPH_FAILED, "Get tensor desc of node[%s] failed.", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}
tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(*ge_tensor_desc);

return GRAPH_SUCCESS;
}

graphStatus GNode::UpdateOutputDesc(const int32_t index, const TensorDesc &tensor_desc) {
if (index < 0) {
GELOGE(GRAPH_PARAM_INVALID, "Gnode: index[%d] cannot be less than zero.", index);
return GRAPH_PARAM_INVALID;
}

if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "UpdateOutputDesc: node impl is nullptr.");
return GRAPH_FAILED;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "UpdateOutputDesc: the shared ptr is not valid.");
return GRAPH_FAILED;
}

OpDescPtr op_desc = node_ptr->GetOpDesc();
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}

GeTensorDesc ge_tensor_desc = TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc);
if (op_desc->UpdateOutputDesc(static_cast<uint32_t>(index), ge_tensor_desc) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Update input desc of node[%s] failed.", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}

return GRAPH_SUCCESS;
}

#define NODE_ATTR_GET_IMP(ArgType) \
graphStatus GNode::GetAttr(const ge::AscendString &name, ArgType &attr_value) const { \
const char *ascend_name = name.GetString(); \
if (ascend_name == nullptr) { \
GELOGE(GRAPH_PARAM_INVALID, "GetAttr: ascend string error."); \
return GRAPH_PARAM_INVALID; \
} \
\
if (impl_ == nullptr) { \
GELOGE(GRAPH_FAILED, "GetAttr: node impl is nullptr."); \
return GRAPH_FAILED; \
} \
\
std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); \
if (node_ptr == nullptr) { \
GELOGE(GRAPH_FAILED, "GetAttr: the shared ptr is not valid."); \
return GRAPH_FAILED; \
} \
\
std::string node_name = ascend_name; \
Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); \
if (op.GetAttr(node_name, attr_value) != GRAPH_SUCCESS) { \
GELOGE(GRAPH_FAILED, "Get attr of node[%s] failed.", node_ptr->GetName().c_str()); \
return GRAPH_FAILED; \
} \
\
return GRAPH_SUCCESS; \
}

#define NODE_ATTR_SET_IMP(ArgType) \
graphStatus GNode::SetAttr(const ge::AscendString &name, ArgType &attr_value) const { \
const char *ascend_name = name.GetString(); \
if (ascend_name == nullptr) { \
GELOGE(GRAPH_PARAM_INVALID, "SetAttr: ascend string error."); \
return GRAPH_PARAM_INVALID; \
} \
\
if (impl_ == nullptr) { \
GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr."); \
return GRAPH_FAILED; \
} \
\
std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); \
if (node_ptr == nullptr) { \
GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid."); \
return GRAPH_FAILED; \
} \
\
std::string node_name = ascend_name; \
Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); \
(void)op.SetAttr(node_name, attr_value); \
return GRAPH_SUCCESS; \
}

NODE_ATTR_GET_IMP(int64_t)
NODE_ATTR_GET_IMP(int32_t)
NODE_ATTR_GET_IMP(uint32_t)
NODE_ATTR_GET_IMP(float)
NODE_ATTR_GET_IMP(bool)
NODE_ATTR_GET_IMP(Tensor)
NODE_ATTR_GET_IMP(std::vector<int64_t>)
NODE_ATTR_GET_IMP(std::vector<int32_t>)
NODE_ATTR_GET_IMP(std::vector<uint32_t>)
NODE_ATTR_GET_IMP(std::vector<float>)
NODE_ATTR_GET_IMP(std::vector<bool>)
NODE_ATTR_GET_IMP(std::vector<Tensor>)
NODE_ATTR_GET_IMP(OpBytes)
NODE_ATTR_GET_IMP(std::vector<std::vector<int64_t>>)
NODE_ATTR_GET_IMP(std::vector<ge::DataType>)
NODE_ATTR_GET_IMP(ge::DataType)
NODE_ATTR_GET_IMP(AttrValue)

NODE_ATTR_SET_IMP(int64_t)
NODE_ATTR_SET_IMP(int32_t)
NODE_ATTR_SET_IMP(uint32_t)
NODE_ATTR_SET_IMP(float)
NODE_ATTR_SET_IMP(bool)
NODE_ATTR_SET_IMP(Tensor)
NODE_ATTR_SET_IMP(std::vector<int64_t>)
NODE_ATTR_SET_IMP(std::vector<int32_t>)
NODE_ATTR_SET_IMP(std::vector<uint32_t>)
NODE_ATTR_SET_IMP(std::vector<float>)
NODE_ATTR_SET_IMP(std::vector<bool>)
NODE_ATTR_SET_IMP(std::vector<Tensor>)
NODE_ATTR_SET_IMP(OpBytes)
NODE_ATTR_SET_IMP(std::vector<std::vector<int64_t>>)
NODE_ATTR_SET_IMP(std::vector<ge::DataType>)
NODE_ATTR_SET_IMP(ge::DataType)

graphStatus GNode::SetAttr(const ge::AscendString &name, AttrValue &attr_value) const {
const char *ascend_name = name.GetString();
if (ascend_name == nullptr) {
GELOGE(GRAPH_PARAM_INVALID, "SetAttr: ascend string error.");
return GRAPH_PARAM_INVALID;
}

if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr.");
return GRAPH_FAILED;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid.");
return GRAPH_FAILED;
}

std::string node_name = ascend_name;
Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr);
(void)op.SetAttr(node_name, std::move(attr_value));
return GRAPH_SUCCESS;
}

graphStatus GNode::SetAttr(const ge::AscendString &name, ge::AscendString &attr_value) const {
const char *ascend_name = name.GetString();
if (ascend_name == nullptr) {
GELOGE(GRAPH_PARAM_INVALID, "SetAttr: name ascend string error.");
return GRAPH_PARAM_INVALID;
}

const char *ascend_attr_value = attr_value.GetString();
if (ascend_attr_value == nullptr) {
GELOGE(GRAPH_PARAM_INVALID, "SetAttr: attr value ascend string error.");
return GRAPH_PARAM_INVALID;
}

if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr.");
return GRAPH_FAILED;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid.");
return GRAPH_FAILED;
}
std::string node_name = ascend_name;
std::string node_attr_value = ascend_attr_value;
Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr);
(void)op.SetAttr(node_name, node_attr_value);

return GRAPH_SUCCESS;
}

graphStatus GNode::SetAttr(const ge::AscendString &name, std::vector<ge::AscendString> &attr_values) const {
const char *ascend_name = name.GetString();
if (ascend_name == nullptr) {
GELOGE(GRAPH_PARAM_INVALID, "SetAttr: name ascend string error.");
return GRAPH_PARAM_INVALID;
}

for (auto &attr_val : attr_values) {
const char *ascend_attr_value = attr_val.GetString();
if (ascend_attr_value == nullptr) {
GELOGE(GRAPH_PARAM_INVALID, "SetAttr: attr val error.");
return GRAPH_PARAM_INVALID;
}
}

if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr.");
return GRAPH_FAILED;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid.");
return GRAPH_FAILED;
}
vector<std::string> node_attr_vals;
for (auto attr_val : attr_values) {
if (attr_val.GetString() != nullptr) {
std::string node_attr_val = attr_val.GetString();
node_attr_vals.emplace_back(node_attr_val);
}
}
std::string node_name = ascend_name;
Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr);
(void)op.SetAttr(node_name, node_attr_vals);

return GRAPH_SUCCESS;
}

graphStatus GNode::GetAttr(const ge::AscendString &name, ge::AscendString &attr_value) const {
const char *ascend_name = name.GetString();
if (ascend_name == nullptr) {
GELOGE(GRAPH_PARAM_INVALID, "GetAttr: name ascend string error.");
return GRAPH_PARAM_INVALID;
}

if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "GetAttr: node impl is nullptr.");
return GRAPH_FAILED;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "GetAttr: the node shared ptr is not valid.");
return GRAPH_FAILED;
}

std::string node_name = ascend_name;
Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr);
std::string op_name;
if (op.GetAttr(node_name, op_name) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Get attr of node[%s] failed.", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}

ge::AscendString attr_value_get(op_name.c_str());
attr_value = attr_value_get;

return GRAPH_SUCCESS;
}

graphStatus GNode::GetAttr(const ge::AscendString &name, std::vector<ge::AscendString> &attr_values) const {
const char *ascend_name = name.GetString();
if (ascend_name == nullptr) {
GELOGE(GRAPH_PARAM_INVALID, "GetAttr: name ascend string error.");
return GRAPH_PARAM_INVALID;
}

if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "GetAttr: node impl is nullptr.");
return GRAPH_FAILED;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "GetAttr: the node shared ptr is not valid.");
return GRAPH_FAILED;
}

std::string node_name = ascend_name;
Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr);
vector<std::string> attr_names;
if (op.GetAttr(node_name, attr_names) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Get attr of node[%s] failed.", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}

for (auto &attr_name : attr_names) {
AscendString ascend_attr_name(attr_name.c_str());
attr_values.push_back(ascend_attr_name);
}

return GRAPH_SUCCESS;
}

bool GNode::HasAttr(const ge::AscendString &name) {
const char *ascend_name = name.GetString();
if (ascend_name == nullptr) {
GELOGE(GRAPH_PARAM_INVALID, "HasAttr: ascend string error.");
return false;
}

if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "HasAttr: node impl is nullptr.");
return false;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "HasAttr: the node shared ptr is not valid.");
return false;
}

OpDescPtr op_desc = node_ptr->GetOpDesc();
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str());
return false;
}
std::string attr_name = ascend_name;
if (!op_desc->HasAttr(attr_name)) {
GELOGE(GRAPH_FAILED, "Node[%s] has no attr name[%s]", node_ptr->GetName().c_str(), attr_name.c_str());
return false;
}

return true;
}

graphStatus GNode::GetSubgraph(uint32_t index, GraphPtr graph) const {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "GetSubgraph: node impl is nullptr.");
return GRAPH_FAILED;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "GetSubgraph: the node shared ptr is not valid.");
return GRAPH_FAILED;
}

ComputeGraphPtr compute_graph_ptr = NodeUtils::GetSubgraph(*node_ptr, index);
if (compute_graph_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "GetSubgraph: get subgraph[%u] failed form node[%s].", index, node_ptr->GetName().c_str());
return GRAPH_FAILED;
}
Graph create_graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph_ptr);
graph = std::make_shared<Graph>(create_graph);
if (graph == nullptr) {
GELOGE(GRAPH_FAILED, "GetSubgraph: graph make shared failed form node[%s].", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}

return GRAPH_SUCCESS;
}

graphStatus GNode::GetALLSubgraphs(std::vector<GraphPtr> graph_list) const {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "GetALLSubgraphs: node impl is nullptr.");
return GRAPH_FAILED;
}

std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock();
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "GetALLSubgraphs: the node shared ptr is not valid.");
return GRAPH_FAILED;
}

std::vector<ComputeGraphPtr> sub_graphs = NodeUtils::GetAllSubgraphs(*node_ptr);
if (sub_graphs.empty()) {
GELOGE(GRAPH_FAILED, "GetALLSubgraphs: get all subgraphs failed form node[%s].", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}

for (auto &sub_graph : sub_graphs) {
if (sub_graph == nullptr) {
GELOGE(GRAPH_FAILED, "Get subgraph failed form node[%s].", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}
Graph create_graph = GraphUtils::CreateGraphFromComputeGraph(sub_graph);
GraphPtr graph = std::make_shared<Graph>(create_graph);
if (graph == nullptr) {
GELOGE(GRAPH_FAILED, "Subgraph make shared failed form node[%s].", node_ptr->GetName().c_str());
return GRAPH_FAILED;
}
graph_list.emplace_back(graph);
}

return GRAPH_SUCCESS;
}
} // namespace ge

+ 234
- 1
src/common/graph/graph.cc View File

@@ -15,6 +15,7 @@
*/

#include "external/graph/graph.h"
#include <cstring>
#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/debug/ge_attr_define.h"
@@ -22,6 +23,7 @@
#include "graph/model.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/node_adapter.h"

using std::map;
using std::pair;
@@ -242,6 +244,8 @@ class GraphImpl {

const std::string &GetName() const { return name_; }

ComputeGraphPtr GetComputeGraph() const { return compute_graph_; }

private:
std::string name_;
std::string output_name_;
@@ -261,7 +265,7 @@ graphStatus Graph::AddOp(const ge::Operator &op) {
return impl_->AddOp(op);
}

graphStatus Graph::GetAllOpName(std::vector<string> &op_name) const {
graphStatus Graph::GetAllOpName(std::vector<std::string> &op_name) const {
GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED,
"GetAllOpName failed: graph can not be used, impl is nullptr.");
return impl_->GetAllOpName(op_name);
@@ -335,6 +339,235 @@ void Graph::SetNeedIteration(bool need_iteration) {
impl_->SetNeedIteration(need_iteration);
}

std::vector<GNode> Graph::GetAllNodes() const {
std::vector<GNode> graph_nodes;
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "GetAllNodes: graph can not be used, impl is nullptr.");
return graph_nodes;
}

ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph();
if (compute_graph_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "GetAllNodes: compute graph ptr is nullptr.");
return graph_nodes;
}

for (auto &node : compute_graph_ptr->GetAllNodes()) {
GNode gnode = NodeAdapter::Node2GNode(node);
graph_nodes.emplace_back(gnode);
}

return graph_nodes;
}

std::vector<GNode> Graph::GetDirectNode() const {
std::vector<GNode> graph_nodes;
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "GetDirectNode: graph can not be used, impl is nullptr.");
return graph_nodes;
}
ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph();
if (compute_graph_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "GetDirectNode: compute graph ptr is nullptr.");
return graph_nodes;
}

for (auto &node : compute_graph_ptr->GetDirectNode()) {
GNode gnode = NodeAdapter::Node2GNode(node);
graph_nodes.emplace_back(gnode);
}

return graph_nodes;
}

graphStatus Graph::RemoveNode(GNode &node) {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "RemoveNode: graph can not be used, impl is nullptr.");
return GRAPH_FAILED;
}

NodePtr node_ptr = NodeAdapter::GNode2Node(node);
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "RemoveNode: gnode to node failed.");
return GRAPH_FAILED;
}

ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph();
if (compute_graph_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "RemoveNde: compute graph ptr is nullptr.");
return GRAPH_FAILED;
}

if (compute_graph_ptr->RemoveNode(node_ptr) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "RemoveNde: remove node failed.");
return GRAPH_FAILED;
}

return GRAPH_SUCCESS;
}

graphStatus Graph::RemoveEdge(GNode &src_node, const int32_t src_port_index, GNode &dst_node,
const int32_t dst_port_index) {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "RemoveEdge: graph can not be used, impl is nullptr.");
return GRAPH_FAILED;
}

if ((src_port_index == -1) && (dst_port_index != -1)) {
GELOGE(GRAPH_FAILED, "RemoveEdge:src control anchor link to dst data anchor not exists.");
return GRAPH_FAILED;
}

NodePtr src_node_ptr = NodeAdapter::GNode2Node(src_node);
if (src_node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "RemoveEdge: src gnode to node failed.");
return GRAPH_FAILED;
}

NodePtr dst_node_ptr = NodeAdapter::GNode2Node(dst_node);
if (dst_node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "RemoveEdge: dst gnode to node failed.");
return GRAPH_FAILED;
}

graphStatus res = GRAPH_FAILED;
if ((src_port_index == -1) && (dst_port_index == -1)) {
res = GraphUtils::RemoveEdge(src_node_ptr->GetOutControlAnchor(), dst_node_ptr->GetInControlAnchor());
if (res != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "RemoveEdge: remove control edge failed.");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}

if (src_port_index != -1 && dst_port_index == -1) {
res = GraphUtils::RemoveEdge(src_node_ptr->GetOutDataAnchor(src_port_index), dst_node_ptr->GetInControlAnchor());
if (res != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "RemoveEdge: remove data-control edge failed.");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}

res = GraphUtils::RemoveEdge(src_node_ptr->GetOutDataAnchor(src_port_index),
dst_node_ptr->GetInDataAnchor(dst_port_index));
if (res != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "RemoveEdge: remove data edge failed.");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}

GNode Graph::AddNodeByOp(const Operator &op) {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "AddNodeByOp: graph can not be used, impl is nullptr.");
return GNode();
}

std::shared_ptr<ge::OpDesc> op_desc = ge::OpDescUtils::GetOpDescFromOperator(op);
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "AddNodeByOp: get op desc from op[%s] failed.", op.GetName().c_str());
return GNode();
}

ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph();
if (compute_graph_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "AddNodeByOp: compute graph ptr is nullptr.");
return GNode();
}

NodePtr node_ptr = compute_graph_ptr->AddNode(op_desc);
GNode gnode = NodeAdapter::Node2GNode(node_ptr);

return gnode;
}

graphStatus Graph::AddDataEdge(GNode &src_node, const int32_t src_port_index, GNode &dst_node,
const int32_t dst_port_index) {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "AddDataEdge: graph can not be used, impl is nullptr.");
return GRAPH_FAILED;
}

NodePtr src_node_ptr = NodeAdapter::GNode2Node(src_node);
if (src_node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "AddDataEdge: src gnode to node failed.");
return GRAPH_FAILED;
}

NodePtr dst_node_ptr = NodeAdapter::GNode2Node(dst_node);
if (dst_node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "AddDataEdge: dst gnode to node failed.");
return GRAPH_FAILED;
}

graphStatus res =
GraphUtils::AddEdge(src_node_ptr->GetOutDataAnchor(src_port_index), dst_node_ptr->GetInDataAnchor(dst_port_index));
if (res != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "AddDataEdge: Add data edge failed.");
return GRAPH_FAILED;
}

return GRAPH_SUCCESS;
}

graphStatus Graph::AddControlEdge(GNode &src_node, GNode &dst_node) {
if (impl_ == nullptr) {
GELOGE(GRAPH_FAILED, "AddControlEdge: graph can not be used, impl is nullptr.");
return GRAPH_FAILED;
}

NodePtr src_node_ptr = NodeAdapter::GNode2Node(src_node);
if (src_node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "AddControlEdge: src gnode to node failed.");
return GRAPH_FAILED;
}

NodePtr dst_node_ptr = NodeAdapter::GNode2Node(dst_node);
if (dst_node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "AddControlEdge: dst gnode to node failed.");
return GRAPH_FAILED;
}

graphStatus res = GraphUtils::AddEdge(src_node_ptr->GetOutControlAnchor(), dst_node_ptr->GetInControlAnchor());
if (res != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "AddControlEdge: Add control edge failed.");
return GRAPH_FAILED;
}

return SUCCESS;
}

GraphPtr Graph::ConstructFromInputs(const std::vector<Operator> &inputs, const ge::AscendString &name) {
const char *ascend_name = name.GetString();
if (ascend_name == nullptr) {
GELOGE(GRAPH_PARAM_INVALID, "ConstructFromInputs: ascend string error.");
return nullptr;
}

if (inputs.empty()) {
GELOGE(GRAPH_FAILED, "ConstructFromInputs: inputs size can not be 0.");
return nullptr;
}

std::string graph_name = ascend_name;
ComputeGraphPtr compute_graph = GraphUtils::CreateGraphFromOperator(graph_name, inputs);
if (compute_graph == nullptr) {
GELOGE(GRAPH_FAILED, "ConstructFromInputs: create compute graph failed.");
return nullptr;
}

compute_graph->SetInputSize(static_cast<uint32_t>(inputs.size()));
Graph graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
GraphPtr graph_ptr = std::make_shared<Graph>(graph);
if (graph_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "ConstructFromInputs: graph make shared failed.");
return nullptr;
}

return graph_ptr;
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::GetComputeGraph(const ge::Graph &graph) {
GE_CHK_BOOL_EXEC_NOLOG(graph.IsValid(), return nullptr);
return graph.impl_->compute_graph_;


+ 12
- 3
src/common/graph/graph.mk View File

@@ -14,6 +14,8 @@ COMMON_LOCAL_SRC_FILES := \
./attr_value.cc \
./buffer.cc \
./compute_graph.cc \
./ascend_string.cc \
./gnode.cc \
./graph.cc \
./inference_context.cc \
./shape_refiner.cc \
@@ -98,11 +100,13 @@ LOCAL_CPPFLAGS += -fexceptions

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := \
../../out/graph/lib64/stub/attr_value.cc \
../../out/graph/lib64/stub/graph.cc \
../../out/graph/lib64/stub/operator.cc \
../../out/graph/lib64/stub/tensor.cc \
../../out/graph/lib64/stub/operator_factory.cc \

../../out/graph/lib64/stub/ascend_string.cc \
../../out/graph/lib64/stub/gnode.cc \

LOCAL_SHARED_LIBRARIES :=

@@ -128,7 +132,8 @@ LOCAL_SRC_FILES := \
../../out/graph/lib64/stub/operator_factory.cc \
../../out/graph/lib64/stub/tensor.cc \
../../out/graph/lib64/stub/inference_context.cc \

../../out/graph/lib64/stub/ascend_string.cc \
../../out/graph/lib64/stub/gnode.cc \

LOCAL_SHARED_LIBRARIES :=

@@ -173,11 +178,13 @@ LOCAL_CFLAGS += -O2

LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES)
LOCAL_SRC_FILES := \
../../out/graph/lib64/stub/attr_value.cc \
../../out/graph/lib64/stub/graph.cc \
../../out/graph/lib64/stub/operator.cc \
../../out/graph/lib64/stub/tensor.cc \
../../out/graph/lib64/stub/operator_factory.cc \

../../out/graph/lib64/stub/ascend_string.cc \
../../out/graph/lib64/stub/gnode.cc \

LOCAL_SHARED_LIBRARIES :=

@@ -206,6 +213,8 @@ LOCAL_SRC_FILES := \
../../out/graph/lib64/stub/operator_factory.cc \
../../out/graph/lib64/stub/tensor.cc \
../../out/graph/lib64/stub/inference_context.cc \
../../out/graph/lib64/stub/ascend_string.cc \
../../out/graph/lib64/stub/gnode.cc \


LOCAL_SHARED_LIBRARIES :=


+ 1
- 0
src/common/graph/model.cc View File

@@ -47,6 +47,7 @@ const int ACCESS_PERMISSION_BITS = 0400;
namespace ge {
void Model::Init() {
(void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0);
(void)AttrUtils::SetInt(this, ATTR_MODEL_P2P_MEMORY_SIZE, 0);
(void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0);
(void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0);
(void)AttrUtils::SetInt(this, ATTR_MODEL_LABEL_NUM, 0);


+ 2
- 2
src/common/graph/model_serialize.cc View File

@@ -409,13 +409,13 @@ bool ModelSerializeImp::HandleNodeNameRef() {
item.dst_node_name.c_str(), item.dst_in_index);
return false;
}
GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737
GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed.");
} else {
// Control edge
auto src_anchor = src_node_it->second->GetOutControlAnchor();
auto dst_anchor = item.dst_node->GetInControlAnchor();
if (src_anchor != nullptr && dst_anchor != nullptr) {
GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737
GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed.");
}
}
}


+ 0
- 1
src/common/graph/op_desc.cc View File

@@ -33,7 +33,6 @@ using std::shared_ptr;
using std::string;
using std::vector;

/*lint -save -e521 -e681 -e732 -e737*/
namespace ge {
const std::string ATTR_NAME_ID = "id";



+ 10
- 12
src/common/graph/operator.cc View File

@@ -56,9 +56,6 @@ using std::string;
using std::to_string;
using std::vector;

/*lint -save -e529 -e728*/
/*lint -e446 -e732*/
/*lint -e665*/
namespace ge {
class OpIO {
public:
@@ -768,6 +765,8 @@ const std::map<GeAttrValue::ValueType, std::string> kAttrTypesMap = {
{GeAttrValue::VT_BYTES, "VT_BYTES"},
{GeAttrValue::VT_GRAPH, "VT_GRAPH"},
{GeAttrValue::VT_NAMED_ATTRS, "VT_NAMED_ATTRS"},
{GeAttrValue::VT_LIST_LIST_INT, "VT_LIST_LIST_INT"},
{GeAttrValue::VT_DATA_TYPE, "VT_DATA_TYPE"},
{GeAttrValue::VT_LIST_BASE, "VT_LIST_BASE"},
{GeAttrValue::VT_LIST_STRING, "VT_LIST_STRING"},
{GeAttrValue::VT_LIST_FLOAT, "VT_LIST_FLOAT"},
@@ -778,6 +777,7 @@ const std::map<GeAttrValue::ValueType, std::string> kAttrTypesMap = {
{GeAttrValue::VT_LIST_BYTES, "VT_LIST_BYTES"},
{GeAttrValue::VT_GRAPH, "VT_GRAPH"},
{GeAttrValue::VT_LIST_NAMED_ATTRS, "VT_LIST_NAMED_ATTRS"},
{GeAttrValue::VT_LIST_DATA_TYPE, "VT_LIST_DATA_TYPE"},
};
} // namespace
const std::map<std::string, std::string> Operator::GetAllAttrNamesAndTypes() const {
@@ -943,7 +943,7 @@ OperatorImplPtr Operator::GetOperatorImplPtr() const { return operator_impl_; }
GELOGW("set attr name %s failed.", name.c_str()); \
} \
return *this; \
} // lint !e665
}

#define OP_ATTR_GET_IMP(ArgType, AttrUtilsFun) \
graphStatus Operator::GetAttr(const string &name, ArgType attr_value) const { \
@@ -956,7 +956,7 @@ OperatorImplPtr Operator::GetOperatorImplPtr() const { return operator_impl_; }
return GRAPH_FAILED; \
} \
return GRAPH_SUCCESS; \
} // lint !e665
}

void Operator::BreakConnect() const {
if (operator_impl_ == nullptr) {
@@ -977,7 +977,7 @@ void Operator::BreakConnect() const {
if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \
GELOGW("reg attr name %s failed.", name.c_str()); \
} \
} // lint !e665
}

OP_ATTR_SET_IMP(int64_t, Int)
OP_ATTR_SET_IMP(int32_t, Int)
@@ -998,22 +998,22 @@ OP_ATTR_SET_IMP(const vector<vector<int64_t>> &, ListListInt)
OP_ATTR_SET_IMP(float, Float)
OP_ATTR_GET_IMP(float &, Float)
OP_ATTR_SET_IMP(const vector<float> &, ListFloat)
OP_ATTR_GET_IMP(vector<float> &, ListFloat) // lint !e665
OP_ATTR_GET_IMP(vector<float> &, ListFloat)

OP_ATTR_SET_IMP(bool, Bool)
OP_ATTR_GET_IMP(bool &, Bool)
OP_ATTR_SET_IMP(const vector<bool> &, ListBool)
OP_ATTR_GET_IMP(vector<bool> &, ListBool) // lint !e665
OP_ATTR_GET_IMP(vector<bool> &, ListBool)

OP_ATTR_SET_IMP(const string &, Str)
OP_ATTR_GET_IMP(string &, Str)
OP_ATTR_SET_IMP(const vector<string> &, ListStr)
OP_ATTR_GET_IMP(vector<string> &, ListStr) // lint !e665
OP_ATTR_GET_IMP(vector<string> &, ListStr)

OP_ATTR_SET_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs)
OP_ATTR_GET_IMP(GeAttrValue::NAMED_ATTRS &, NamedAttrs)
OP_ATTR_SET_IMP(const vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs)
OP_ATTR_GET_IMP(vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) // lint !e665
OP_ATTR_GET_IMP(vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs)

OP_ATTR_REG_IMP(int64_t, Int)
OP_ATTR_REG_IMP(const vector<int64_t> &, ListInt)
@@ -1583,5 +1583,3 @@ void GraphUtils::BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_node
}
}
} // namespace ge
/*lint +e446 +e732*/
/*lint +e665*/

+ 0
- 2
src/common/graph/opsproto/opsproto_manager.cc View File

@@ -38,9 +38,7 @@ bool OpsProtoManager::Initialize(const std::map<std::string, std::string> &optio
return true;
}

/*lint -e1561*/
auto proto_iter = options.find("ge.opsProtoLibPath");
/*lint +e1561*/
if (proto_iter == options.end()) {
GELOGW("ge.opsProtoLibPath option not set, return.");
return false;


+ 2
- 0
src/common/graph/option/ge_context.cc View File

@@ -31,6 +31,8 @@ GEContext &GetContext() {
return ge_context;
}

thread_local uint64_t GEContext::session_id_;

graphStatus GEContext::GetOption(const std::string &key, std::string &option) {
return GetThreadLocalContext().GetOption(key, option);
}


+ 14
- 0
src/common/graph/option/ge_local_context.cc View File

@@ -57,4 +57,18 @@ void GEThreadLocalContext::SetGraphOption(map<std::string, string> options_map)
graph_options_.clear();
graph_options_ = std::move(options_map);
}

map<string, string> GEThreadLocalContext::GetAllGraphOptions() const { return graph_options_; }

map<string, string> GEThreadLocalContext::GetAllSessionOptions() const { return session_options_; }

map<string, string> GEThreadLocalContext::GetAllGlobalOptions() const { return global_options_; }

map<string, string> GEThreadLocalContext::GetAllOptions() const {
map<string, string> options_all;
options_all.insert(graph_options_.begin(), graph_options_.end());
options_all.insert(session_options_.begin(), session_options_.end());
options_all.insert(global_options_.begin(), global_options_.end());
return options_all;
}
} // namespace ge

+ 56
- 35
src/common/graph/shape_refiner.cc View File

@@ -365,6 +365,37 @@ string Serial(const vector<int64_t> &dims) {
return serial_string;
}

void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) {
desc_str += "[";
std::vector<std::pair<int64_t, int64_t>> shape_range;
(void)desc->GetShapeRange(shape_range);
for (const auto &pair : shape_range) {
desc_str += "{";
desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second);
desc_str += "},";
}
desc_str += "] ";
}

void SerialShapeAndDtype(const GeTensorDescPtr &desc, bool is_origin_info, std::string &desc_str) {
desc_str += "[";
if (!is_origin_info) {
for (int64_t dim : desc->GetShape().GetDims()) {
desc_str += std::to_string(dim) + " ";
}
desc_str += "]";
desc_str += ":" + TypeUtils::DataTypeToSerialString(desc->GetDataType()) + ":" +
TypeUtils::FormatToSerialString(desc->GetFormat()) + " ";
} else {
for (int64_t dim : desc->GetOriginShape().GetDims()) {
desc_str += std::to_string(dim) + " ";
}
desc_str += "]";
desc_str += ":" + TypeUtils::DataTypeToSerialString(desc->GetOriginDataType()) + ":" +
TypeUtils::FormatToSerialString(desc->GetOriginFormat()) + " ";
}
}

graphStatus UpdateOpInputDesc(const ConstNodePtr &node_ptr) {
GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED);
GE_IF_BOOL_EXEC(node_ptr->GetOpDesc() == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED);
@@ -386,9 +417,9 @@ graphStatus UpdateOpInputDesc(const ConstNodePtr &node_ptr) {
if (in_desc == nullptr) {
continue;
}
auto in_shape = in_desc->GetShape().GetDims();
auto in_shape = in_desc->MutableShape().GetDims();
auto in_dtype = in_desc->GetDataType();
auto peer_out_shape = peer_out_desc->GetShape().GetDims();
auto peer_out_shape = peer_out_desc->MutableShape().GetDims();
auto peer_out_dtype = peer_out_desc->GetDataType();
if (peer_out_dtype != in_dtype) {
GELOGW(
@@ -407,13 +438,15 @@ graphStatus UpdateOpInputDesc(const ConstNodePtr &node_ptr) {
}
// refresh current node input desc
in_desc->SetOriginShape(peer_out_desc->GetOriginShape());
in_desc->SetShape(peer_out_desc->GetShape());
in_desc->SetShape(peer_out_desc->MutableShape());
in_desc->SetDataType(peer_out_desc->GetDataType());
in_desc->SetOriginDataType(peer_out_desc->GetOriginDataType());
std::vector<std::pair<int64_t, int64_t>> shape_range;
(void)peer_out_desc->GetShapeRange(shape_range);
in_desc->SetShapeRange(shape_range);
ge::TensorUtils::SetRealDimCnt(*in_desc, static_cast<uint32_t>(peer_out_desc->GetShape().GetDims().size()));
if (peer_out_desc->MutableShape().GetDims() != UNKNOWN_RANK) {
std::vector<std::pair<int64_t, int64_t>> shape_range;
(void)peer_out_desc->GetShapeRange(shape_range);
in_desc->SetShapeRange(shape_range);
}
ge::TensorUtils::SetRealDimCnt(*in_desc, static_cast<uint32_t>(peer_out_desc->MutableShape().GetDims().size()));
}
return GRAPH_SUCCESS;
}
@@ -432,25 +465,19 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str
if (op_desc->GetInputsSize() != 0) {
std::string input_desc_str = "input shape: ";
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
input_desc_str += "[";
for (int64_t dim : input_desc->GetShape().GetDims()) {
input_desc_str += std::to_string(dim) + " ";
}
input_desc_str += "]";
input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) + ":" +
TypeUtils::FormatToSerialString(input_desc->GetFormat()) + " ";
SerialShapeAndDtype(input_desc, false, input_desc_str);
}
str += input_desc_str;

input_desc_str = "input origin shape: ";
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
input_desc_str += "[";
for (int64_t dim : input_desc->GetOriginShape().GetDims()) {
input_desc_str += std::to_string(dim) + " ";
}
input_desc_str += "]";
input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) + ":" +
TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) + " ";
SerialShapeAndDtype(input_desc, true, input_desc_str);
}
str += input_desc_str;
input_desc_str = "input shape range: ";
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) {
SerialShapeRange(input_desc, input_desc_str);
}
str += input_desc_str;
}
@@ -461,13 +488,7 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str
if (output_desc == nullptr) {
continue;
}
output_desc_str += "[";
for (int64_t dim : output_desc->GetShape().GetDims()) {
output_desc_str += std::to_string(dim) + " ";
}
output_desc_str += "]";
output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) + ":" +
TypeUtils::FormatToSerialString(output_desc->GetFormat()) + " ";
SerialShapeAndDtype(output_desc, false, output_desc_str);
}
str += output_desc_str;

@@ -476,13 +497,13 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str
if (output_desc == nullptr) {
continue;
}
output_desc_str += "[";
for (int64_t dim : output_desc->GetOriginShape().GetDims()) {
output_desc_str += std::to_string(dim) + " ";
}
output_desc_str += "]";
output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) + ":" +
TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) + " ";
SerialShapeAndDtype(output_desc, true, output_desc_str);
}
str += output_desc_str;
output_desc_str = "output shape range: ";
for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) {
SerialShapeRange(output_desc, output_desc_str);
}
str += output_desc_str;
}


+ 0
- 6
src/common/graph/stub/Makefile View File

@@ -1,6 +0,0 @@
inc_path := $(shell pwd)/metadef/inc/external/
out_path := $(shell pwd)/out/graph/lib64/stub/
stub_path := $(shell pwd)/metadef/graph/stub/

mkdir_stub := $(shell mkdir -p $(out_path))
graph_local_stub := $(shell $(HI_PYTHON) $(stub_path)/gen_stubapi.py $(inc_path) $(out_path))

+ 0
- 578
src/common/graph/stub/gen_stubapi.py View File

@@ -1,578 +0,0 @@
import os
import re
import sys
import logging

logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s',
level=logging.INFO)

"""
this attr is used for symbol table visible
"""
GE_ATTR = 'GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY'

"""
generate stub func body by return type
"""
RETURN_STATEMENTS = {
'graphStatus': ' std::cout << "[ERROR]: stub library libgraph or libge_compiler cannot be used for execution, please check your "\n '
' << "environment variables and compilation options to make sure you use the correct library."\n'
' << std::endl;\n'
' return ACL_ERROR_COMPILING_STUB_MODE;',
'Status': ' return SUCCESS;',
'Graph': ' return Graph();',
'Graph&': ' return *this;',
'Format': ' return Format();',
'Format&': ' return *this;',
'Shape': ' return Shape();',
'Shape&': ' return *this;',
'TensorDesc': ' return TensorDesc();',
'TensorDesc&': ' return *this;',
'Tensor': ' return Tensor();',
'Tensor&': ' return *this;',
'Operator': ' return Operator();',
'Operator&': ' return *this;',
'Ptr': ' return nullptr;',
'std::string': ' return "";',
'std::string&': ' return "";',
'string': ' return "";',
'int': ' return 0;',
'DataType': ' return DT_FLOAT;',
'InferenceContextPtr': ' return nullptr;',
'SubgraphBuilder': ' return nullptr;',
'OperatorImplPtr': ' return nullptr;',
'OutHandler': ' return nullptr;',
'std::vector<std::string>': ' return {};',
'std::vector<int64_t>': ' return {};',
'std::map': ' return {};',
'uint32_t': ' return 0;',
'int64_t': ' return 0;',
'uint64_t': ' return 0;',
'size_t': ' return 0;',
'float': ' return 0.0f;',
'bool': ' return false;',
}

"""
max code len per line in hua_wei software programming specifications
"""
max_code_len_per_line = 100

"""
white_list_for_debug, include_dir_key_words is to
determines which header files to generate cc files from
when DEBUG on
"""
white_list_for_debug = ["attr_value.h", "operator.h", "tensor.h", "graph.h", "operator_factory.h", "inference_context.h",
"ge_ir_build.h", "ge_api.h", "ascend_string.h", "gnode.h"]
include_dir_key_words = ["ge", "graph"]
DEBUG = True


def need_generate_func(func_line):
"""
:param func_line:
:return:
"""
if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \
or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"):
return False
return True


def file_endswith_white_list_suffix(file):
"""
:param file:
:return:
"""
if DEBUG:
for suffix in white_list_for_debug:
if file.endswith(suffix):
return True
return False
else:
return True


"""
belows are patterns used for analyse .h file
"""
# pattern function
pattern_func = re.compile(r"""(^[\s]*) #leading with space,we will find and delete after
([a-zA-Z~_] # void int likely
.*
[)] #we find )
(?!.*{) # we do not want the case int abc() const
.*)
(;.*) #we want to find ; and after for we will replace these later
\n$
""", re.VERBOSE | re.MULTILINE | re.DOTALL)

# pattern comment
pattern_comment = re.compile(r'^\s*//')
pattern_comment_2_start = re.compile(r'^\s*/[*]')
pattern_comment_2_end = re.compile(r'[*]/\s*$')
# pattern define
pattern_define = re.compile(r'^\s*#define')
pattern_define_return = re.compile(r'\\\s*$')
# blank line
pattern_blank_line = re.compile(r'^\s*$')
# virtual,explicit,friend,static
pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)')
# lead space
pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]')
# functions will have patterns such as func ( or func(
# but operator is an exception; the class name is preceded by an operator, and the above mode does not exist
# format like :"operator = ()"
pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]')
# template
pattern_template = re.compile(r'^\s*template')
pattern_template_end = re.compile(r'>\s*$')
# namespace
pattern_namespace = re.compile(r'namespace.*{')
# class : which can handle classA a and {not on the same line, but if found ';' after class,then don't deal with
pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+<?)(?!.*;)' % GE_ATTR)
# {}
pattern_start = re.compile('{')
pattern_end = re.compile('}')

line_index = 0


class H2CC(object):
def __init__(self, input_file, output_file, shared_includes_content):
"""
:param input_file:
:param output_file:
:param shared_includes_content:
"""
self.input_file = input_file
self.output_file = output_file
self.shared_includes_content = shared_includes_content
self.line_index = 0
self.input_fd = open(self.input_file, 'r')
self.input_content = self.input_fd.readlines()
self.output_fd = open(self.output_file, 'w')

# The state may be normal_now(in the middle of {}),class_now,namespace_now
self.stack = []
self.stack_class = []
self.stack_template = []
# record funcs generated by h2cc func
self.func_list_exist = []

def __del__(self):
self.input_fd.close()
self.output_fd.close()
del self.stack
del self.stack_class
del self.stack_template
del self.func_list_exist

def just_skip(self):
# skip blank line or comment
if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search(
self.input_content[self.line_index]): # /n or comment using //
self.line_index += 1
if pattern_comment_2_start.search(self.input_content[self.line_index]): # comment using /*
while not pattern_comment_2_end.search(self.input_content[self.line_index]): # */
self.line_index += 1
self.line_index += 1
# skip define
if pattern_define.search(self.input_content[self.line_index]):
while pattern_blank_line.search(self.input_content[self.line_index]) or pattern_define_return.search(
self.input_content[self.line_index]):
self.line_index += 1
self.line_index += 1

def write_inc_content(self):
for shared_include_content in self.shared_includes_content:
self.output_fd.write(shared_include_content)

def h2cc(self):
"""
:return:
"""
logging.info("start generate cc_file[%s] from h_file[%s]", self.output_file, self.input_file)
global pattern_comment
global pattern_comment_2_start
global pattern_comment_2_end
global pattern_blank_line
global pattern_func
global pattern_keyword
global pattern_leading_space
global pattern_func_name
global pattern_template
global pattern_template_end
global pattern_namespace
global pattern_class
global pattern_start
global pattern_end
global line_index
# write inc content
self.write_inc_content()
# core processing cycle, process the input .h file by line
while self.line_index < len(self.input_content):
# handle comment and blank line
self.just_skip()

# match namespace
self.handle_namespace()

# match template
template_string = self.handle_template()
# match class
line = self.input_content[self.line_index]
match_class = pattern_class.search(line)
match_start = pattern_start.search(line)
handle_class_result = self.handle_class(template_string, line, match_start, match_class)
if handle_class_result == "continue":
continue

# match "}"
handle_stack_result = self.handle_stack(match_start)
if handle_stack_result == "continue":
continue
# handle func
handle_func1_result, line, start_i = self.handle_func1(line)
if handle_func1_result == "continue":
continue

# here means func is found
# delete key word
line = pattern_keyword.sub('', line)
logging.info("line[%s]", line)

# Class member function
# if friend we will not add class name
friend_match = re.search('friend ', line)
if len(self.stack_class) > 0 and not friend_match:
line, func_name = self.handle_class_member_func(line, template_string)
# Normal functions
else:
line, func_name = self.handle_normal_func(line, template_string)

need_generate = need_generate_func(line)
# func body
line += self.implement_function(line)
# comment
line = self.gen_comment(start_i) + line
# write to out file
self.write_func_content(line, func_name, need_generate)
# next loop
self.line_index += 1

logging.info('Added %s functions', len(self.func_list_exist))
logging.info('Successfully converted,please see ' + self.output_file)

def handle_func1(self, line):
"""
:param line:
:return:
"""
find1 = re.search('[(]', line)
if not find1:
self.line_index += 1
return "continue", line, None
find2 = re.search('[)]', line)
start_i = self.line_index
space_match = pattern_leading_space.search(line)
# deal with
# int abc(int a,
# int b)
if find1 and (not find2):
self.line_index += 1
line2 = self.input_content[self.line_index]
if space_match:
line2 = re.sub('^' + space_match.group(1), '', line2)
line += line2
while self.line_index < len(self.input_content) and (not re.search('[)]', line2)):
self.line_index += 1
line2 = self.input_content[self.line_index]
line2 = re.sub('^' + space_match.group(1), '', line2)
line += line2

match_start = pattern_start.search(self.input_content[self.line_index])
match_end = pattern_end.search(self.input_content[self.line_index])
if match_start: # like ) { or ) {} int the last line
if not match_end:
self.stack.append('normal_now')
ii = start_i
while ii <= self.line_index:
ii += 1
self.line_index += 1
return "continue", line, start_i
logging.info("line[%s]", line)
# ' int abc();'->'int abc()'
(line, match) = pattern_func.subn(r'\2\n', line)
logging.info("line[%s]", line)
# deal with case:
# 'int \n abc(int a, int b)'
if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]):
line = self.input_content[start_i - 1] + line
line = line.lstrip()
if not match:
self.line_index += 1
return "continue", line, start_i
return "pass", line, start_i

def handle_stack(self, match_start):
"""
:param match_start:
:return:
"""
line = self.input_content[self.line_index]
match_end = pattern_end.search(line)
if match_start:
self.stack.append('normal_now')
if match_end:
top_status = self.stack.pop()
if top_status == 'namespace_now':
self.output_fd.write(line + '\n')
elif top_status == 'class_now':
self.stack_class.pop()
self.stack_template.pop()
if match_start or match_end:
self.line_index += 1
return "continue"

if len(self.stack) > 0 and self.stack[-1] == 'normal_now':
self.line_index += 1
return "continue"
return "pass"

def handle_class(self, template_string, line, match_start, match_class):
"""
:param template_string:
:param line:
:param match_start:
:param match_class:
:return:
"""
if match_class: # we face a class
self.stack_template.append(template_string)
self.stack.append('class_now')
class_name = match_class.group(3)

# class template specializations: class A<u,Node<u> >
if '<' in class_name:
k = line.index('<')
fit = 1
for ii in range(k + 1, len(line)):
if line[ii] == '<':
fit += 1
if line[ii] == '>':
fit -= 1
if fit == 0:
break
class_name += line[k + 1:ii + 1]
logging.info('class_name[%s]', class_name)
self.stack_class.append(class_name)
while not match_start:
self.line_index += 1
line = self.input_content[self.line_index]
match_start = pattern_start.search(line)
self.line_index += 1
return "continue"
return "pass"

def handle_template(self):
line = self.input_content[self.line_index]
match_template = pattern_template.search(line)
template_string = ''
if match_template:
match_template_end = pattern_template_end.search(line)
template_string = line
while not match_template_end:
self.line_index += 1
line = self.input_content[self.line_index]
template_string += line
match_template_end = pattern_template_end.search(line)
self.line_index += 1
return template_string

def handle_namespace(self):
line = self.input_content[self.line_index]
match_namespace = pattern_namespace.search(line)
if match_namespace: # we face namespace
self.output_fd.write(line + '\n')
self.stack.append('namespace_now')
self.line_index += 1

def handle_normal_func(self, line, template_string):
template_line = ''
self.stack_template.append(template_string)
if self.stack_template[-1] != '':
template_line = re.sub(r'\s*template', 'template', self.stack_template[-1])
# change '< class T = a, class U = A(3)>' to '<class T, class U>'
template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line)
template_line = re.sub(r'\s*=.*,', ',', template_line)
template_line = re.sub(r'\s*=.*', '', template_line)
line = re.sub(r'\s*=.*,', ',', line)
line = re.sub(r'\s*=.*\)', ')', line)
line = template_line + line
self.stack_template.pop()
func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group()
logging.info("line[%s]", line)
logging.info("func_name[%s]", func_name)
return line, func_name

def handle_class_member_func(self, line, template_string):
template_line = ''
x = ''
if template_string != '':
template_string = re.sub(r'\s*template', 'template', template_string)
template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string)
template_string = re.sub(r'\s*=.*,', ',', template_string)
template_string = re.sub(r'\s*=.*', '', template_string)
if self.stack_template[-1] != '':
if not (re.search(r'<\s*>', stack_template[-1])):
template_line = re.sub(r'^\s*template', 'template', stack_template[-1])
if not (re.search(r'<.*>', self.stack_class[-1])):
# for x we get like template<class T, typename U> -> <T,U>
x = re.sub(r'template\s*<', '<', template_line) # remove template -> <class T, typename U>
x = re.sub(r'\n', '', x)
x = re.sub(r'\s*=.*,', ',', x)
x = re.sub(r'\s*=.*\>', '>', x)
x = x.rstrip() # remove \n
x = re.sub(r'(class|typename)\s+|(<class>|<typename>\s*class)', '',
x) # remove class,typename -> <T, U>
x = re.sub(r'<\s+', '<', x)
x = re.sub(r'\s+>', '>', x)
x = re.sub(r'\s+,', ',', x)
x = re.sub(r',\s+', ', ', x)
line = re.sub(r'\s*=\s+0', '', line)
line = re.sub(r'\s*=\s+.*,', ',', line)
line = re.sub(r'\s*=\s+.*\)', ')', line)
logging.info("x[%s]\nline[%s]", x, line)
# if the function is long, void ABC::foo()
# breaks into two lines void ABC::\n foo()
temp_line = pattern_func_name.sub(self.stack_class[-1] + x + '::' + r'\1(', line, count=1)
if len(temp_line) > max_code_len_per_line:
line = pattern_func_name.sub(self.stack_class[-1] + x + '::\n' + r'\1(', line, count=1)
else:
line = temp_line
logging.info("line[%s]", line)
# add template as the above if there is one
template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line)
template_line = re.sub(r'\s*=.*,', ',', template_line)
template_line = re.sub(r'\s*=.*', '', template_line)
line = template_line + template_string + line
func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group()
logging.info("line[%s]", line)
logging.info("func_name[%s]", func_name)
return line, func_name

def write_func_content(self, content, func_name, need_generate):
if not (func_name in self.func_list_exist) and need_generate:
self.output_fd.write(content)
self.func_list_exist.append(func_name)
logging.info('add func:[%s]', func_name)

def gen_comment(self, start_i):
comment_line = ''
# Function comments are on top of function declarations, copy them over
k = start_i - 1 # one line before this func start
if pattern_template.search(self.input_content[k]):
k -= 1
if pattern_comment_2_end.search(self.input_content[k]):
comment_line = self.input_content[k].lstrip()
while not pattern_comment_2_start.search(self.input_content[k]):
k -= 1
comment_line = self.input_content[k].lstrip() + comment_line
else:
for j in range(k, 0, -1):
c_line = self.input_content[j]
if pattern_comment.search(c_line):
c_line = re.sub(r'\s*//', '//', c_line)
comment_line = c_line + comment_line
else:
break
return comment_line

@staticmethod
def implement_function(func):
function_def = ''
function_def += '{\n'

all_items = func.split()
start = 0
return_type = all_items[start]
if return_type == "const":
start += 1
return_type = all_items[start]
if return_type.startswith(('std::map', 'std::set', 'std::vector')):
return_type = "std::map"
if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')):
return_type = "Ptr"
if len(all_items) > start + 1 and all_items[start + 1].startswith('&'):
return_type += "&"
if RETURN_STATEMENTS.__contains__(return_type):
function_def += RETURN_STATEMENTS[return_type]
else:
logging.warning("Unhandled return type[%s]", return_type)

function_def += '\n'
function_def += '}\n'
function_def += '\n'
return function_def


def collect_header_files(path):
"""
:param path:
:return:
"""
header_files = []
shared_includes_content = []
for root, dirs, files in os.walk(path):
files.sort()
for file in files:
if file.find("git") >= 0:
continue
if not file.endswith('.h'):
continue
file_path = os.path.join(root, file)
file_path = file_path.replace('\\', '/')
header_files.append(file_path)
include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:])
shared_includes_content.append(include_str)
# for acl error code
shared_includes_content.append('#include <iostream>\n')
shared_includes_content.append('const int ACL_ERROR_COMPILING_STUB_MODE = 100039;\n')
return header_files, shared_includes_content


def generate_stub_file(inc_dir, out_cc_dir):
"""
:param inc_dir:
:param out_cc_dir:
:return:
"""
target_header_files, shared_includes_content = collect_header_files(inc_dir)
for header_file in target_header_files:
if not file_endswith_white_list_suffix(header_file):
continue
cc_file = re.sub('.h*$', '.cc', header_file)
h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content)
h_2_cc.h2cc()


def gen_code(inc_dir, out_cc_dir):
"""
:param inc_dir:
:param out_cc_dir:
:return:
"""
if not inc_dir.endswith('/'):
inc_dir += '/'
if not out_cc_dir.endswith('/'):
out_cc_dir += '/'
for include_dir_key_word in include_dir_key_words:
generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir)


if __name__ == '__main__':
inc_dir = sys.argv[1]
out_cc_dir = sys.argv[2]
gen_code(inc_dir, out_cc_dir)

+ 10
- 14
src/common/graph/tensor.cc View File

@@ -178,18 +178,16 @@ int64_t Shape::GetShapeSize() const {
return 0;
}

TensorDesc::TensorDesc() {
impl = ComGraphMakeShared<TensorDescImpl>(); // lint !e665
}
TensorDesc::TensorDesc() { impl = ComGraphMakeShared<TensorDescImpl>(); }

TensorDesc::TensorDesc(Shape shape, Format format, DataType dt) {
impl = ComGraphMakeShared<TensorDescImpl>(shape, format, dt); // lint !e665
impl = ComGraphMakeShared<TensorDescImpl>(shape, format, dt);
SetRealDimCnt(shape.GetDimNum());
}

TensorDesc::TensorDesc(const TensorDesc &desc) {
// Copy
impl = ComGraphMakeShared<TensorDescImpl>(); // lint !e665
impl = ComGraphMakeShared<TensorDescImpl>();
if (desc.impl != nullptr && impl != nullptr) {
*impl = *desc.impl;
}
@@ -360,9 +358,7 @@ void TensorDesc::SetName(const std::string &name) {

Tensor::Tensor() { impl = ComGraphMakeShared<TensorImpl>(); }

Tensor::Tensor(const TensorDesc &tensor_desc) {
impl = ComGraphMakeShared<TensorImpl>(tensor_desc); // lint !e665
}
Tensor::Tensor(const TensorDesc &tensor_desc) { impl = ComGraphMakeShared<TensorImpl>(tensor_desc); }

Tensor::Tensor(const TensorDesc &tensor_desc, const std::vector<uint8_t> &data) {
uint64_t shape_size = tensor_desc.GetShape().GetShapeSize();
@@ -384,7 +380,7 @@ Tensor::Tensor(const TensorDesc &tensor_desc, const std::vector<uint8_t> &data)
}
}
}
impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data); // lint !e665
impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data);
}

Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) {
@@ -406,7 +402,7 @@ Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size)
}
}

impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data, size); // lint !e665
impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data, size);
}

Tensor::Tensor(TensorDesc &&tensor_desc, std::vector<uint8_t> &&data) {
@@ -429,7 +425,7 @@ Tensor::Tensor(TensorDesc &&tensor_desc, std::vector<uint8_t> &&data) {
}
}
}
impl = ComGraphMakeShared<TensorImpl>(std::move(tensor_desc), std::move(data)); // lint !e665
impl = ComGraphMakeShared<TensorImpl>(std::move(tensor_desc), std::move(data));
}

TensorDesc Tensor::GetTensorDesc() const {
@@ -643,7 +639,7 @@ TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_
GeTensorPtr TensorAdapter::Tensor2GeTensor(const Tensor &tensor) {
GeTensorPtr ge_tensor;
if (tensor.impl != nullptr) {
ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor.Clone()); // lint !e665
ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor.Clone());
}
return ge_tensor;
}
@@ -659,7 +655,7 @@ Tensor TensorAdapter::GeTensor2Tensor(const ConstGeTensorPtr &ge_tensor) {
ConstGeTensorPtr TensorAdapter::AsGeTensorPtr(const Tensor &tensor) {
GeTensorPtr ge_tensor;
if (tensor.impl != nullptr) {
ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor); // lint !e665
ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor);
}
return ge_tensor;
}
@@ -667,7 +663,7 @@ ConstGeTensorPtr TensorAdapter::AsGeTensorPtr(const Tensor &tensor) {
GeTensorPtr TensorAdapter::AsGeTensorPtr(Tensor &tensor) {
GeTensorPtr ge_tensor;
if (tensor.impl != nullptr) {
ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor); // lint !e665
ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor);
}
return ge_tensor;
}


+ 294
- 39
src/common/graph/utils/graph_utils.cc View File

@@ -58,8 +58,10 @@ namespace {
const int32_t kBaseOfIntegerValue = 10;
#ifdef FMK_SUPPORT_DUMP
const char *const kDumpGeGraph = "DUMP_GE_GRAPH";
const int kDumpGraphIndexWidth = 5;
const int kDumpGraphIndexWidth = 8;
#endif

const char *const kDumpGraphPath = "DUMP_GRAPH_PATH";
const char *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL";
const char *const kDumpStrBuild = "Build";
const char *const kDumpStrPartition = "partition";
@@ -588,6 +590,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(cons
}

std::stringstream stream_file_name;
char *dump_graph_path = std::getenv(kDumpGraphPath);
if (dump_graph_path != nullptr) {
std::string dump_graph_path_str(dump_graph_path);
stream_file_name << (dump_graph_path_str.empty() ? "" : dump_graph_path_str + "/");
}
stream_file_name << "ge_proto_" << std::setw(kDumpGraphIndexWidth) << std::setfill('0') << file_index;
stream_file_name << "_" << suffix << ".txt";
std::string proto_file = user_graph_name.empty() ? stream_file_name.str() : user_graph_name;
@@ -598,7 +605,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(cons
Buffer buffer;
const int64_t kDumpLevel =
(dump_ge_graph != nullptr) ? std::strtol(dump_ge_graph, nullptr, kBaseOfIntegerValue) : ge::OnnxUtils::NO_DUMP;
model.Save(buffer, kDumpLevel != ge::OnnxUtils::DUMP_ALL);
model.Save(buffer, kDumpLevel != ge::OnnxUtils::DUMP_ALL && !is_always_dump);

// Write file
ge::proto::ModelDef ge_proto;
@@ -620,6 +627,54 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(cons
#endif
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGrph(const ge::ComputeGraphPtr &graph,
const std::string &path,
const std::string &suffix) {
// file name
static std::atomic_long atomic_file_index(0);
auto file_index = atomic_file_index.fetch_add(1);
GELOGD("Start to dump om txt: %ld", file_index);

thread_local long max_dump_file_num = 0;
if (max_dump_file_num == 0) {
string opt = "0";
(void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt);
max_dump_file_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue);
}
if (max_dump_file_num != 0 && file_index > max_dump_file_num) {
GELOGW("Dump graph file cnt > maxDumpFileNum, maxDumpFileCnt=%ld.", max_dump_file_num);
return;
}

std::stringstream stream_file_name;
stream_file_name << path.c_str() << "/ge_proto_" << std::setw(5) << std::setfill('0') << file_index;
stream_file_name << "_" << suffix << ".txt";
std::string proto_file = stream_file_name.str();

// Create buffer
ge::Model model("", "");
model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(std::const_pointer_cast<ComputeGraph>(graph)));
Buffer buffer;
const int64_t kDumpLevel = ge::OnnxUtils::NO_DUMP;
model.Save(buffer, kDumpLevel != ge::OnnxUtils::DUMP_ALL);

// Write file
ge::proto::ModelDef ge_proto;
if (buffer.GetData() != nullptr) {
std::string str(reinterpret_cast<const char *>(buffer.GetData()), buffer.GetSize());
if (!ge_proto.ParseFromString(str)) {
GELOGE(GRAPH_FAILED, "parse from string failed.");
return;
}
char real_path[PATH_MAX] = {0x00};
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(proto_file.c_str()) >= PATH_MAX, return, "file path is too longer!");
GE_IF_BOOL_EXEC(realpath(proto_file.c_str(), real_path) == nullptr,
GELOGI("file %s does not exist, it will be created.", proto_file.c_str()));

GraphUtils::WriteProtoToTextFile(ge_proto, real_path);
}
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraph(const char *file,
ge::ComputeGraph &compute_graph) {
ge::proto::ModelDef model_def;
@@ -722,7 +777,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToText
}
GE_CHK_BOOL_EXEC(fclose(file) == 0, return, "Fclose fileoutputstream failed");
#else
GELOGW("need to define FMK_SUPPORT_DUMP for dump graph.");
GELOGW("Need to define FMK_SUPPORT_DUMP for dump graph.");
#endif
}

@@ -789,6 +844,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnn
}

std::stringstream stream_file_name;
char *dump_graph_path = std::getenv(kDumpGraphPath);
if (dump_graph_path != nullptr) {
std::string dump_graph_path_str(dump_graph_path);
stream_file_name << (dump_graph_path_str.empty() ? "" : dump_graph_path_str + "/");
}
stream_file_name << "ge_onnx_" << std::setw(kDumpGraphIndexWidth) << std::setfill('0') << file_index;
stream_file_name << "_graph_" << compute_graph.GetGraphID();
stream_file_name << "_" << suffix << ".pbtxt";
@@ -822,6 +882,66 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnn
#endif
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGrphToOnnx(const ge::ComputeGraph &compute_graph,
const std::string &path,
const std::string &suffix) {
// 1.Get ge::onnx::ModelProto from ge::Model
ge::Model model("GE", "");
std::shared_ptr<ge::ComputeGraph> compute_graph_ptr = ComGraphMakeShared<ge::ComputeGraph>(compute_graph);
model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(std::const_pointer_cast<ComputeGraph>(compute_graph_ptr)));
onnx::ModelProto model_proto;
if (!OnnxUtils::ConvertGeModelToModelProto(model, model_proto)) {
GELOGE(GRAPH_FAILED, "DumpGEGraphToOnnx failed.");
return;
}

// 2.Set file name
static std::atomic_long atomic_file_index(0);
auto file_index = atomic_file_index.fetch_add(1);
GELOGD("Start to dump ge onnx file: %ld", file_index);

thread_local long max_dump_file_num = 0;
if (max_dump_file_num == 0) {
string opt = "0";
(void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt);
max_dump_file_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue);
}
if (max_dump_file_num != 0 && file_index > max_dump_file_num) {
GELOGW("Dump graph file cnt > maxDumpFileNum, maxDumpFileNum=%ld.", max_dump_file_num);
return;
}

std::stringstream stream_file_name;
stream_file_name << path.c_str() << "/ge_onnx_" << std::setw(5) << std::setfill('0') << file_index;
stream_file_name << "_graph_" << compute_graph.GetGraphID();
stream_file_name << "_" << suffix << ".pbtxt";
std::string proto_file = stream_file_name.str();
if ((proto_file.length()) >= NAME_MAX) {
GELOGE(GRAPH_FAILED, "File name is too longer!");
return;
}
std::unique_ptr<char[]> real_path(new (std::nothrow) char[PATH_MAX]{0});
if (real_path == nullptr) {
GELOGE(GRAPH_FAILED, "New real_path failed.");
return;
}
/// Returning nullptr means 3 case as follows:
/// a.path is PATH_MAX chars or more
/// b.the file does not exist
/// c.the path has no permissions
/// Distinguish between last the two cases in the function WriteProtoToTextFile call open()
if (realpath(proto_file.c_str(), real_path.get()) == nullptr) {
// For case a
if (errno == ENAMETOOLONG) {
GELOGE(GRAPH_FAILED, "Call realpath failed: path is PATH_MAX chars or more.");
return;
}
}

// 3. Serialize to file in current path
GraphUtils::WriteProtoToTextFile(model_proto, real_path.get());
}

GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraphFromOnnx(const char *file,
ge::ComputeGraph &compute_graph) {
if (file == nullptr) {
@@ -1419,7 +1539,7 @@ GraphUtils::CloneGraph(const ComputeGraphPtr &graph, const std::string &prefix,
return nullptr;
}

op_desc->SetName(prefix + n->GetName());
op_desc->SetName(n->GetName() + prefix);
NodePtr node = new_graph->AddNode(op_desc);
GE_CHK_BOOL_EXEC(node != nullptr, return nullptr, "Add node[%s] to graph failed", op_desc->GetName().c_str());
all_new_nodes[node->GetName()] = node;
@@ -1445,6 +1565,17 @@ GraphUtils::CloneGraph(const ComputeGraphPtr &graph, const std::string &prefix,
return nullptr;
}
}

// copy info of output nodes from old graph to new graph.
std::vector<std::pair<NodePtr, int32_t>> out_nodes_info = graph->GetGraphOutNodesInfo();
std::vector<std::pair<NodePtr, int32_t>> new_out_nodes_info;
for (const auto &info : out_nodes_info) {
auto it = all_new_nodes.find(info.first->GetName());
if (it != all_new_nodes.end()) {
new_out_nodes_info.emplace_back(it->second, info.second);
}
}
new_graph->SetGraphOutNodesInfo(new_out_nodes_info);
return new_graph;
}

@@ -1501,7 +1632,7 @@ graphStatus GraphUtils::RelinkGraphEdges(const NodePtr &node, const string &pref
return GRAPH_FAILED;
}

auto it = all_nodes.find(prefix + node->GetName());
auto it = all_nodes.find(node->GetName() + prefix);
if (it == all_nodes.end()) {
GELOGE(GRAPH_FAILED, "node[%s] not found", node->GetName().c_str());
return GRAPH_FAILED;
@@ -1517,7 +1648,7 @@ graphStatus GraphUtils::RelinkGraphEdges(const NodePtr &node, const string &pref
}
GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null");

it = all_nodes.find(prefix + out_anchor->GetOwnerNode()->GetName());
it = all_nodes.find(out_anchor->GetOwnerNode()->GetName() + prefix);
if (it == all_nodes.end()) {
GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str());
return GRAPH_FAILED;
@@ -1535,7 +1666,7 @@ graphStatus GraphUtils::RelinkGraphEdges(const NodePtr &node, const string &pref
GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "Peer out anchor is null: %s", node->GetName().c_str());
GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null");

it = all_nodes.find(prefix + out_anchor->GetOwnerNode()->GetName());
it = all_nodes.find(out_anchor->GetOwnerNode()->GetName() + prefix);
if (it == all_nodes.end()) {
GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str());
return GRAPH_FAILED;
@@ -1736,7 +1867,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node,
if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, next_name) && !next_name.empty()) {
ComputeGraphPtr graph = node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(graph);
ge::NodePtr next_node = graph->FindNode(next_name);
ge::NodePtr next_node = FindNodeFromAllNodes(graph, next_name);
GE_CHECK_NOTNULL(next_node);
// NextIteration has and only has one output
peer_out_anchor = next_node->GetOutDataAnchor(0);
@@ -2332,15 +2463,12 @@ CompleteGraphBuilder &CompleteGraphBuilder::SetOutputMapping(const std::map<uint
///
ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string &error_msg) {
owner_graph_ = shared_ptr<ComputeGraph>(new (std::nothrow) ComputeGraph(name_));
if ((owner_graph_ == nullptr) || (parent_node_ == nullptr)) {
if (owner_graph_ == nullptr) {
error_code = GRAPH_FAILED;
error_msg = "graph / parent_node is NULL.";
error_msg = "graph is NULL.";
return nullptr;
}

owner_graph_->SetParentNode(parent_node_);
owner_graph_->SetParentGraph(parent_node_->GetOwnerComputeGraph());

BuildNodes(error_code, error_msg);
if (error_code != GRAPH_SUCCESS) {
return nullptr;
@@ -2361,37 +2489,27 @@ ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string
return nullptr;
}

AddRetValNodes(error_code, error_msg);
if (error_code != GRAPH_SUCCESS) {
return nullptr;
if (retval_flag_) {
AddRetValNodes(error_code, error_msg);
if (error_code != GRAPH_SUCCESS) {
return nullptr;
}
BuildGraphTargets(error_code, error_msg);
if (error_code != GRAPH_SUCCESS) {
return nullptr;
}
} else {
AddNetOutputNode(error_code, error_msg);
if (error_code != GRAPH_SUCCESS) {
return nullptr;
}
}

BuildGraphTargets(error_code, error_msg);
PostProcess(error_code, error_msg);
if (error_code != GRAPH_SUCCESS) {
return nullptr;
}

// ATTR_NAME_SESSION_GRAPH_ID
std::string graph_id;
if (!AttrUtils::GetStr(parent_node_->GetOwnerComputeGraph(), ATTR_NAME_SESSION_GRAPH_ID, graph_id)) {
error_code = GRAPH_FAILED;
error_msg = "Get attr session_graph_id failed.";
return nullptr;
}
if (!AttrUtils::SetStr(owner_graph_, ATTR_NAME_SESSION_GRAPH_ID, graph_id)) {
error_code = GRAPH_FAILED;
error_msg = "Set attr session_graph_id failed.";
return nullptr;
}

// refresh node name
for (const NodePtr &node : owner_graph_->GetDirectNode()) {
if ((node->GetOpDesc() == nullptr) || (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2)) {
continue;
}
node->GetOpDesc()->SetName(owner_graph_->GetName() + "/" + node->GetName());
}

return owner_graph_;
}

@@ -2586,7 +2704,144 @@ void CompleteGraphBuilder::BuildGraphTargets(graphStatus &error_code, std::strin
target_nodes.emplace_back(target_iter->second);
}
owner_graph_->SetGraphTargetNodesInfo(target_nodes);
return;
}

///
/// @brief Add NetOutput node
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void CompleteGraphBuilder::AddNetOutputNode(graphStatus &error_code, std::string &error_msg) {
std::string log_msg = "AddNetOutputNode name:" + std::string(NODE_NAME_NET_OUTPUT) + ", type:" + NETOUTPUT;
OpDescPtr net_output_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(NODE_NAME_NET_OUTPUT, NETOUTPUT));
if (net_output_desc == nullptr) {
error_code = GRAPH_FAILED;
error_msg = log_msg + " failed: op_desc is NULL.";
return;
}

size_t output_num = graph_outputs_.size();
std::vector<OutDataAnchorPtr> peer_out_anchors(output_num);
for (size_t i = 0; i < output_num; i++) {
int32_t index = graph_outputs_[i].second;
auto out_iter = node_names_.find(graph_outputs_[i].first);
if (out_iter == node_names_.end()) {
error_code = GRAPH_FAILED;
error_msg = "AddNetOutputNode failed: node " + graph_outputs_[i].first + " not exist in graph.";
return;
}
NodePtr node = out_iter->second;
if ((node == nullptr) || (node->GetOpDesc() == nullptr)) {
error_code = GRAPH_FAILED;
error_msg = "AddNetOutputNode failed: node is NULL.";
return;
}

ge::GeTensorDesc tensor = node->GetOpDesc()->GetOutputDesc(index);
uint32_t update_index = i;
auto iter = output_mapping_.find(i);
if (iter != output_mapping_.end()) {
update_index = iter->second;
}
if (!ge::AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, update_index)) {
error_code = GRAPH_FAILED;
error_msg = "AddNetOutputNode failed: set attr PARENT_NODE_INDEX failed.";
return;
}
if (net_output_desc->AddInputDesc(tensor) != GRAPH_SUCCESS) {
error_code = GRAPH_FAILED;
error_msg = "AddNetOutputNode failed: add input_desc ailed.";
return;
}
peer_out_anchors[i] = node->GetOutDataAnchor(index);
}

BuildNetOutputNodeWithLink(net_output_desc, peer_out_anchors, error_code, error_msg);
if (error_code != GRAPH_SUCCESS) {
return;
}

GELOGD("%s succ.", log_msg.c_str());
}

///
/// @brief Build NetOutput nodes with data & ctrl edges
/// @param [in] net_output_desc
/// @param [in] peer_out_anchors
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void CompleteGraphBuilder::BuildNetOutputNodeWithLink(const OpDescPtr &net_output_desc,
const std::vector<OutDataAnchorPtr> &peer_out_anchors,
graphStatus &error_code, std::string &error_msg) {
std::string log_msg = "AddNetOutputNode name:" + std::string(NODE_NAME_NET_OUTPUT) + ", type:" + NETOUTPUT;
NodePtr net_output = owner_graph_->AddNode(net_output_desc);
if (net_output == nullptr) {
error_code = GRAPH_FAILED;
error_msg = log_msg + " failed: add NetOutput node failed.";
return;
}

size_t output_num = graph_outputs_.size();
for (size_t i = 0; i < output_num; i++) {
if (GraphUtils::AddEdge(peer_out_anchors[i], net_output->GetInDataAnchor(i)) != GRAPH_SUCCESS) {
error_code = GRAPH_FAILED;
error_msg = "AddNetOutputNode failed: add data-edge " + peer_out_anchors[i]->GetOwnerNode()->GetName() + ":" +
std::to_string(peer_out_anchors[i]->GetIdx()) + "->" + NODE_NAME_NET_OUTPUT + ":" +
std::to_string(i) + " failed.";
return;
}
}
for (const std::string &target_name : graph_targets_) {
auto target_iter = node_names_.find(target_name);
if ((target_iter == node_names_.end()) || (target_iter->second == nullptr)) {
error_code = GRAPH_FAILED;
error_msg = "BuildGraphTargets failed: target_node " + target_name + " not exist in graph.";
return;
}
const auto &target_node = target_iter->second;
if (GraphUtils::AddEdge(target_node->GetOutControlAnchor(), net_output->GetInControlAnchor()) != GRAPH_SUCCESS) {
error_code = GRAPH_FAILED;
error_msg =
"AddNetOutputNode failed: add ctrl-edge " + target_node->GetName() + "->" + NODE_NAME_NET_OUTPUT + " failed.";
return;
}
}
}

///
/// @brief process after build
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void CompleteGraphBuilder::PostProcess(graphStatus &error_code, std::string &error_msg) {
if (parent_node_ != nullptr) {
owner_graph_->SetParentNode(parent_node_);
owner_graph_->SetParentGraph(parent_node_->GetOwnerComputeGraph());
// ATTR_NAME_SESSION_GRAPH_ID
std::string graph_id;
if (!AttrUtils::GetStr(parent_node_->GetOwnerComputeGraph(), ATTR_NAME_SESSION_GRAPH_ID, graph_id)) {
error_code = GRAPH_FAILED;
error_msg = "Get attr session_graph_id failed.";
return;
}
if (!AttrUtils::SetStr(owner_graph_, ATTR_NAME_SESSION_GRAPH_ID, graph_id)) {
error_code = GRAPH_FAILED;
error_msg = "Set attr session_graph_id failed.";
return;
}
}

// refresh node name
for (const NodePtr &node : owner_graph_->GetDirectNode()) {
if ((node->GetOpDesc() == nullptr) || (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2)) {
continue;
}
node->GetOpDesc()->SetName(owner_graph_->GetName() + "/" + node->GetName());
}
}

///


+ 37
- 0
src/common/graph/utils/node_utils.cc View File

@@ -391,7 +391,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInpu
GELOGE(GRAPH_FAILED, "Add input desc failed");
return GRAPH_FAILED;
}
}

for (size_t i = node->in_data_anchors_.size(); i < num; ++i) {
auto anchor = ComGraphMakeShared<InDataAnchor>(node, i);
if (anchor == nullptr) {
GELOGE(OUT_OF_MEMORY, "Current in data anchor is null, make shared_ptr failed.");
@@ -444,7 +446,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendOutp
GELOGE(GRAPH_FAILED, "Add output desc failed");
return GRAPH_FAILED;
}
}

for (size_t i = node->out_data_anchors_.size(); i < num; ++i) {
auto anchor = ComGraphMakeShared<OutDataAnchor>(node, i);
if (anchor == nullptr) {
GELOGE(OUT_OF_MEMORY, "Current out data anchor is null, make shared_ptr failed.");
@@ -644,6 +648,20 @@ std::string NodeUtils::GetNodeType(const Node &node) {

std::string NodeUtils::GetNodeType(const NodePtr &node) { return node == nullptr ? "" : GetNodeType(*node); }

std::vector<ComputeGraphPtr> NodeUtils::GetAllSubgraphs(const Node &node) {
auto op_desc = node.GetOpDesc();
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "Failed to get op desc from node %s ", node.GetName().c_str());
return {};
}
auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
if (root_graph == nullptr) {
GELOGE(GRAPH_FAILED, "Failed to find root graph from node %s ", node.GetName().c_str());
return {};
}
return root_graph->GetAllSubgraphs();
}

ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) {
auto op_desc = node.GetOpDesc();
if (op_desc == nullptr) {
@@ -1002,4 +1020,23 @@ vector<pair<InDataAnchorPtr, NodePtr>> NodeUtils::GetOutDataNodesWithAnchorByInd
}

ConstNodePtr NodeUtils::GetNodeFromOperator(const Operator &oprt) { return oprt.GetNode(); }

std::string NodeUtils::GetInConstNodeTypeCrossSubgraph(const NodePtr &node) {
NodePtr input_node = node;
while (input_node != nullptr) {
if (input_node->GetType() != DATA) {
return input_node->GetType();
}

auto owner_graph = input_node->GetOwnerComputeGraph();
auto parent_node = owner_graph->GetParentNode();
if ((parent_node == nullptr) || (kWhileOpTypes.count(parent_node->GetType()) > 0)) {
return node->GetType(); // not in subgraph or while subgraph.
}

input_node = GetParentInput(input_node);
}

return "";
}
} // namespace ge

+ 4
- 6
src/common/graph/utils/op_desc_utils.cc View File

@@ -28,7 +28,6 @@

using std::vector;

/*lint -e512 -e737 -e752*/
namespace ge {
const char OP_DESC_QUANT_PARAMS[] = "quantize_factor";
static const int CONST_OP_NORMAL_WEIGHT_SIZE = 1;
@@ -133,11 +132,11 @@ graphStatus OpDescUtils::GetQuantizeFactorParams(const OpDesc &op_desc, Quantize
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
OpDescUtils::SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant) {
GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr");
return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732
return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant));
}

graphStatus OpDescUtils::SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant) {
return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732
return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant));
}

GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) {
@@ -255,7 +254,7 @@ size_t OpDescUtils::GetNonConstInputsSize(const ge::Node &node) {
continue;
}
}
return input_num; // lint !e712
return input_num;
} else {
GE_IF_BOOL_EXEC(
node.GetInDataNodes().size() < GetConstInputs(node).size(),
@@ -360,7 +359,7 @@ bool OpDescUtils::IsNonConstInput(const ge::Node &node, const size_t index) {
bool ret = false;
if (index < node.GetAllInDataAnchors().size()) {
if (NodeUtils::IsAnchorStatusSet(node)) {
ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast<int>(index))) == ANCHOR_DATA); // lint !e712
ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast<int>(index))) == ANCHOR_DATA);
} else {
for (const auto &anchor : node.GetAllInDataAnchors()) {
if (anchor->GetIdx() != static_cast<int>(index)) {
@@ -822,4 +821,3 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetSubgr
return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name);
}
} // namespace ge
/*lint +e512 +e737 +e752*/

+ 20
- 1
src/common/graph/utils/tuning_utils.cc View File

@@ -17,8 +17,10 @@
#include "graph/tuning_utils.h"
#include "../debug/ge_util.h"
#include "../debug/ge_op_types.h"
#include "framework/common/scope_guard.h"

namespace ge {
namespace {
const std::string peer_node_name_attr = "_peerNodeName";
const std::string parent_node_name_attr = "_parentNodeName";
const std::string alias_name_attr = "_aliasName";
@@ -28,6 +30,7 @@ const std::string tuning_subgraph_prefix = "/aicore_subgraph_";
const std::string non_tuning_subgraph_prefix = "/subgraph_";
const std::set<std::string> kPartitionOpTypes = {PLACEHOLDER, END};
const std::set<std::string> kExeTypes = {DATA, NETOUTPUT};
} // namespace
NodeNametoNodeNameMap TuningUtils::data_2_netoutput_;
NodetoNodeNameMap TuningUtils::data_node_2_netoutput_;
NodetoNodeMap TuningUtils::data_node_2_netoutput_node_;
@@ -116,6 +119,10 @@ graphStatus TuningUtils::ConvertGraphToFile(std::vector<ComputeGraphPtr> tuning_
// +---------------+
graphStatus TuningUtils::MakeExeGraph(ComputeGraphPtr &exe_graph, const HelpInfo &help_info) {
GE_CHECK_NOTNULL(exe_graph);

// clear graph id
GELOGI("TUU:clear [%s] session_graph_id %s", exe_graph->GetName().c_str(),
(AttrUtils::SetStr(*exe_graph, ATTR_NAME_SESSION_GRAPH_ID, "") ? "success" : "not success"));
// if not make exe, just dump and return
if (!help_info.exe_flag) {
DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path);
@@ -346,7 +353,9 @@ graphStatus TuningUtils::LinkEnd2NetOutput(NodePtr &end_node, NodePtr &out_node)
AnchorPtr end_in_anchor = (end_node->GetInDataAnchor(0)->GetFirstPeerAnchor() == nullptr)
? Anchor::DynamicAnchorCast<Anchor>(end_node->GetInControlAnchor())
: Anchor::DynamicAnchorCast<Anchor>(end_node->GetInDataAnchor(0));
GE_CHECK_NOTNULL(end_in_anchor);
auto src_anchor = end_in_anchor->GetFirstPeerAnchor(); // src_anchor should be only 1
GE_CHECK_NOTNULL(src_anchor);
if (GraphUtils::RemoveEdge(src_anchor, end_in_anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "TUU:remove end input edge from from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s",
GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(),
@@ -447,6 +456,14 @@ graphStatus TuningUtils::HandleEnd(NodePtr &node) {

// part 2
graphStatus TuningUtils::ConvertFileToGraph(const map<int64_t, string> &options, ge::Graph &graph) {
std::function<void()> callback = [&]() {
data_2_netoutput_.clear();
data_node_2_netoutput_.clear();
data_node_2_netoutput_node_.clear();
netoutput_nodes_.clear();
merged_graph_nodes_.clear();
};
GE_MAKE_GUARD(release, callback);
// 1. get all subgraph object
std::vector<ComputeGraphPtr> graphs;
// options format like {index:"subgraph_path"}
@@ -666,7 +683,9 @@ graphStatus TuningUtils::GetInAndOutAnchorPair(NodePtr &data_node, NodePtr &out_
GE_CHECK_NOTNULL(src_anchor);
auto src_node = src_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(src_node);
if (src_node->GetName() == netoutput_input_name && src_anchor->GetIdx() == parent_node_anchor_index) {
std::string src_node_name = src_node->GetName();
if (src_node_name.find(netoutput_input_name) != src_node_name.npos &&
src_anchor->GetIdx() == parent_node_anchor_index) {
dest_in_anchor = in_anchor;
src_out_anchor = src_anchor;
GELOGD("TUU:get out node:%s 's in anchor(%d) src_node:%s 's out anchor(%d) related with data node:%s",


+ 8
- 1
src/ge/CMakeLists.txt View File

@@ -39,7 +39,7 @@ ge_protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}
# include directories
include_directories(${CMAKE_CURRENT_LIST_DIR})
include_directories(${GE_SOURCE_DIR})
include_directories(${GE_SOURCE_DIR}/src)
include_directories(${GE_SOURCE_DIR}/src/ge)
include_directories(${GE_SOURCE_DIR}/src/ge/analyzer)
include_directories(${GE_SOURCE_DIR}/inc)
include_directories(${GE_SOURCE_DIR}/inc/common/util)
@@ -109,6 +109,8 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/manager/graph_mem_allocator.cc"
"graph/manager/graph_caching_allocator.cc"
"graph/manager/graph_var_manager.cc"
"graph/manager/host_mem_manager.cc"
"graph/manager/memory_api.cc"
"graph/manager/model_manager/event_manager.cc"
"graph/manager/rdma_pool_allocator.cc"
"graph/manager/trans_var_data_utils.cc"
@@ -127,6 +129,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/partition/dynamic_shape_partition.cc"
"graph/partition/engine_place.cc"
"graph/partition/graph_partition.cc"
"graph/partition/stage_partition.cc"
"graph/passes/*.cc"
"graph/preprocess/graph_preprocess.cc"
"graph/preprocess/insert_op/ge_aipp_op.cc"
@@ -200,6 +203,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"model/ge_root_model.cc"
"omm/csa_interact.cc"
"opskernel_manager/ops_kernel_manager.cc"
"opskernel_manager/ops_kernel_builder_manager.cc"
"session/inner_session.cc"
"session/session_manager.cc"
"single_op/*.cc"
@@ -283,6 +287,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/manager/graph_manager.cc"
"graph/manager/graph_manager_utils.cc"
"graph/manager/graph_mem_allocator.cc"
"graph/manager/host_mem_manager.cc"
"graph/manager/trans_var_data_utils.cc"
"graph/manager/graph_var_manager.cc"
"graph/manager/model_manager/event_manager.cc"
@@ -296,6 +301,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"graph/partition/dynamic_shape_partition.cc"
"graph/partition/engine_place.cc"
"graph/partition/graph_partition.cc"
"graph/partition/stage_partition.cc"
"graph/passes/*.cc"
"graph/preprocess/graph_preprocess.cc"
"graph/preprocess/insert_op/ge_aipp_op.cc"
@@ -349,6 +355,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"model/ge_root_model.cc"
"omm/csa_interact.cc"
"opskernel_manager/ops_kernel_manager.cc"
"opskernel_manager/ops_kernel_builder_manager.cc"
"session/inner_session.cc"
"session/session_manager.cc"
"single_op/*.cc"


+ 28
- 33
src/ge/analyzer/analyzer.cc View File

@@ -75,9 +75,8 @@ Status Analyzer::BuildJsonObject(uint64_t session_id, uint64_t graph_id) {
std::lock_guard<std::recursive_mutex> lg(mutex_);
auto iter = graph_infos_.find(session_id);
if (iter == graph_infos_.end()) {
auto p = new (std::nothrow) GraphInfo();
GE_CHECK_NOTNULL(p);
std::shared_ptr<GraphInfo> graph_info(p);
std::shared_ptr<GraphInfo> graph_info(new (std::nothrow) GraphInfo());
GE_CHECK_NOTNULL(graph_info);
std::map<uint64_t, std::shared_ptr<GraphInfo>> graph_map;
graph_map[graph_id] = graph_info;
graph_info->session_id = session_id;
@@ -86,9 +85,8 @@ Status Analyzer::BuildJsonObject(uint64_t session_id, uint64_t graph_id) {
} else {
auto iter1 = (iter->second).find(graph_id);
if (iter1 == (iter->second).end()) {
auto p = new (std::nothrow) GraphInfo();
GE_CHECK_NOTNULL(p);
std::shared_ptr<GraphInfo> graph_info(p);
std::shared_ptr<GraphInfo> graph_info(new (std::nothrow) GraphInfo());
GE_CHECK_NOTNULL(graph_info);
graph_info->session_id = session_id;
graph_info->graph_id = graph_id;
(iter->second).insert({graph_id, graph_info});
@@ -100,7 +98,14 @@ Status Analyzer::BuildJsonObject(uint64_t session_id, uint64_t graph_id) {
}

ge::Status Analyzer::Initialize() {
ClearHistoryFile();
// Initialize file
string real_path = RealPath(kFilePath.c_str());
if (real_path.empty()) {
GELOGE(FAILED, "File path is invalid.");
return FAILED;
}
json_file_name_ = real_path + "/" + kAnalyzeFile;

return SUCCESS;
}

@@ -138,6 +143,7 @@ void Analyzer::DestroyGraphJsonObject(uint64_t session_id, uint64_t graph_id) {
if (iter1 == (iter->second).end()) {
GELOGW("Can not find the graph json object by session_id[%lu] and graph_id[%lu]. Do nothing.", session_id,
graph_id);
return;
}
(iter->second).erase(iter1);
}
@@ -174,15 +180,8 @@ ge::Status Analyzer::CreateAnalyzerFile() {
return SUCCESS;
}
GELOGD("start to create analyzer file!");
// Check whether the manifest exists, if not, create it.
string real_path = RealPath(kFilePath.c_str());
if (real_path.empty()) {
GELOGE(FAILED, "File path is invalid.");
return FAILED;
}

std::lock_guard<std::mutex> lg(file_mutex_);
json_file_name_ = real_path + "/" + kAnalyzeFile;
GELOGD("Created analyzer file:[%s]", json_file_name_.c_str());
int fd = open(json_file_name_.c_str(), O_WRONLY | O_CREAT | O_TRUNC, kFileAuthority);
if (fd < 0) {
GELOGE(INTERNAL_ERROR, "Fail to open the file: %s.", json_file_name_.c_str());
@@ -198,25 +197,27 @@ ge::Status Analyzer::CreateAnalyzerFile() {
return SUCCESS;
}

ge::Status Analyzer::SaveAnalyzerDataToFile() {
ge::Status Analyzer::SaveAnalyzerDataToFile(uint64_t session_id, uint64_t graph_id) {
GELOGD("start to save analyze file!");

auto graph_info = GetJsonObject(session_id, graph_id);
GE_CHECK_NOTNULL(graph_info);
if (graph_info->op_info.size() == 0) {
GELOGD("session_id:%lu graph_id:%lu does not owner op info, break it!", session_id, graph_id);
return SUCCESS;
}
std::lock_guard<std::mutex> lg(file_mutex_);
json_file_.open(json_file_name_, std::ios::out);
json_file_.open(json_file_name_, std::ios::app);
if (!json_file_.is_open()) {
GELOGE(FAILED, "analyzer file does not exist[%s]", json_file_name_.c_str());
return PARAM_INVALID;
}

std::lock_guard<std::recursive_mutex> lk(mutex_);
for (auto &ele : graph_infos_) {
for (auto &ele2 : ele.second) {
json jsn;
GraphInfoToJson(jsn, *(ele2.second));
json_file_ << jsn.dump(kJsonDumpLevel) << std::endl;
}
}

json jsn;
GraphInfoToJson(jsn, *graph_info);
json_file_ << jsn.dump(kJsonDumpLevel) << std::endl;
json_file_.close();

return SUCCESS;
}

@@ -237,13 +238,7 @@ ge::Status Analyzer::DoAnalyze(DataInfo &data_info) {
return FAILED;
}
// create json file
status = CreateAnalyzerFile();
if (status != SUCCESS) {
GELOGE(status, "create analyzer file failed!");
return status;
}
// save data to file
return SaveAnalyzerDataToFile();
return CreateAnalyzerFile();
}

ge::Status Analyzer::SaveOpInfo(ge::OpDescPtr desc, DataInfo &data_info,


+ 8
- 1
src/ge/analyzer/analyzer.h View File

@@ -156,6 +156,14 @@ class Analyzer {
*/
ge::Status DoAnalyze(analyzer::DataInfo &data_info);

/**
* @ingroup ge
* @brief: Buff analyzed data and output to json file
* @param [in]: session id , graph id
* @return: 0: SUCCESS other: FAILED
*/
ge::Status SaveAnalyzerDataToFile(uint64_t session_id, uint64_t graph_id);

Analyzer(const Analyzer &) = delete;
Analyzer &operator=(const Analyzer &) = delete;
Analyzer(Analyzer &&) = delete;
@@ -166,7 +174,6 @@ class Analyzer {
void OpInfoToJson(nlohmann::json &j, const analyzer::OpInfo &op_info);
void GraphInfoToJson(nlohmann::json &j, const analyzer::GraphInfo &graph_info);

ge::Status SaveAnalyzerDataToFile();
ge::Status SaveOpInfo(ge::OpDescPtr desc, analyzer::DataInfo &data_info,
std::shared_ptr<analyzer::GraphInfo> graph_info);



+ 18
- 5
src/ge/client/ge_prof.cc View File

@@ -324,10 +324,17 @@ Status aclgrphProfStop(aclgrphProfConfig *profiler_config) {
return GE_PROF_NOT_INIT;
}

Status ret = ProfStopProfiling(&profiler_config->config);
if (ret != SUCCESS) {
GELOGE(ret, "Stop profiling failed, prof result = %d", ret);
return ret;
for (uint32_t i = 0; i < profiler_config->config.devNums; i++) {
uint64_t data_type_config;
Status status = ProfGetDataTypeConfig(profiler_config->config.devIdList[i], data_type_config);
if (status != SUCCESS) {
GELOGE(status, "Prof get data type config failed, prof result = %d", status);
return status;
}
if (data_type_config != profiler_config->config.dataTypeConfig) {
GELOGE(FAILED, "data type config verify failed");
return FAILED;
}
}

std::vector<string> prof_params;
@@ -344,12 +351,18 @@ Status aclgrphProfStop(aclgrphProfConfig *profiler_config) {
command.module_index = profiler_config->config.dataTypeConfig;
GELOGI("Profiling will stop, device nums:%s , deviceID:[%s], data type config: 0x%llx", prof_params[0].c_str(),
prof_params[kDeviceListIndex].c_str(), command.module_index);
ret = graph_loader.CommandHandle(command);
Status ret = graph_loader.CommandHandle(command);
if (ret != SUCCESS) {
GELOGE(ret, "Handle profiling command failed");
return FAILED;
}

ret = ProfStopProfiling(&profiler_config->config);
if (ret != SUCCESS) {
GELOGE(ret, "Stop profiling failed, prof result = %d", ret);
return ret;
}

GELOGI("Successfully execute GraphProfStopProfiling.");
return SUCCESS;
}


+ 2
- 4
src/ge/client/module.mk View File

@@ -70,10 +70,9 @@ LOCAL_SHARED_LIBRARIES := \
libregister \
libge_compiler \
libge_common \
libmsprof \
stub/libascend_hal
libmsprof

LOCAL_STATIC_LIBRARIES := libmsprofiler

LOCAL_LDFLAGS := -lrt -ldl

@@ -108,7 +107,6 @@ LOCAL_SHARED_LIBRARIES := \
libge_common \
libmsprof

LOCAL_STATIC_LIBRARIES := libmsprofiler

LOCAL_LDFLAGS := -lrt -ldl
LOCAL_CFLAGS += \


+ 20
- 2
src/ge/common/auth/file_saver.cc View File

@@ -55,9 +55,26 @@ Status FileSaver::OpenFile(int32_t &fd, const std::string &file_path) {

Status FileSaver::WriteData(const void *data, uint32_t size, int32_t fd) {
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size == 0 || data == nullptr, return PARAM_INVALID);

mmSsize_t write_count;
uint32_t size_2g = ((uint32_t)0x1 << 31);
uint32_t size_1g = ((uint32_t)0x1 << 30);
// Write data
int32_t write_count = mmWrite(fd, const_cast<void *>(data), size);
if (size > size_2g) {
auto seek = reinterpret_cast<uint8_t *>(const_cast<void *>(data));
while (size > size_1g) {
write_count = mmWrite(fd, reinterpret_cast<void *>(seek), size_1g);
if (write_count == EN_INVALID_PARAM || write_count == EN_ERROR) {
GELOGE(FAILED, "Write data failed. mmpa_errorno = %d, %s", write_count, strerror(errno));
return FAILED;
}
size -= size_1g;
seek += size_1g;
}
write_count = mmWrite(fd, reinterpret_cast<void *>(seek), size);
} else {
write_count = mmWrite(fd, const_cast<void *>(data), size);
}

// -1: Failed to write to file; - 2: Illegal parameter
if (write_count == EN_INVALID_PARAM || write_count == EN_ERROR) {
GELOGE(FAILED, "Write data failed. mmpa_errorno = %d, %s", write_count, strerror(errno));
@@ -117,6 +134,7 @@ Status FileSaver::SaveWithFileHeader(const std::string &file_path, const ModelFi
WriteData(static_cast<const void *>(&model_partition_table), table_size, fd) != SUCCESS, ret = FAILED; break);
// Write partition data
for (const auto &partitionData : partition_datas) {
GELOGI("GC:size[%zu]", partitionData.size);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
WriteData(static_cast<const void *>(partitionData.data), partitionData.size, fd) != SUCCESS, ret = FAILED;
break);


+ 0
- 248
src/ge/common/convert/pb2json.cc View File

@@ -1,248 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// File: pb2json.h
// Description: This imply file for protobuf message and json interconversion

#include "common/convert/pb2json.h"
#include <set>
#include <string>
#include "securec.h"
#include "framework/common/fmk_types.h"
#include "framework/common/debug/ge_log.h"

using std::set;
using std::string;

namespace ge {
namespace {
const int kSignificantDigits = 10;
}
// JSON parses non utf8 character throwing exceptions, so some fields need to be shielded through black fields
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(const ProtobufMsg &message,
const set<string> &black_fields, Json &json,
bool enum2str) {
auto descriptor = message.GetDescriptor();
auto reflection = message.GetReflection();
if (descriptor == nullptr || reflection == nullptr) {
return;
}

auto count = descriptor->field_count();

for (auto i = 0; i < count; ++i) {
const auto field = descriptor->field(i);
if (field == nullptr) {
return;
}

// Do not display weight data
if (black_fields.find(field->name()) != black_fields.end()) {
continue;
}

if (field->is_repeated()) {
if (reflection->FieldSize(message, field) > 0) {
RepeatedMessage2Json(message, field, reflection, black_fields, json[field->name()], enum2str);
}
continue;
}

if (!reflection->HasField(message, field)) {
continue;
}

OneField2Json(message, field, reflection, black_fields, json, enum2str);
}
}

void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field,
const ProtobufReflection *reflection, const set<string> &black_fields, Json &json,
bool enum2str) {
switch (field->type()) {
case ProtobufFieldDescriptor::TYPE_MESSAGE: {
const ProtobufMsg &tmp_message = reflection->GetMessage(message, field);
if (0 != tmp_message.ByteSize()) {
Message2Json(tmp_message, black_fields, json[field->name()], enum2str);
}
break;
}

case ProtobufFieldDescriptor::TYPE_BOOL:
json[field->name()] = reflection->GetBool(message, field);
break;

case ProtobufFieldDescriptor::TYPE_ENUM: {
auto *enum_value_desc = reflection->GetEnum(message, field);
Enum2Json(enum_value_desc, field, enum2str, json);
break;
}

case ProtobufFieldDescriptor::TYPE_INT32:
case ProtobufFieldDescriptor::TYPE_SINT32:
case ProtobufFieldDescriptor::TYPE_SFIXED32:
json[field->name()] = reflection->GetInt32(message, field);
break;

case ProtobufFieldDescriptor::TYPE_UINT32:
case ProtobufFieldDescriptor::TYPE_FIXED32:
json[field->name()] = reflection->GetUInt32(message, field);
break;

case ProtobufFieldDescriptor::TYPE_INT64:
case ProtobufFieldDescriptor::TYPE_SINT64:
case ProtobufFieldDescriptor::TYPE_SFIXED64:
json[field->name()] = reflection->GetInt64(message, field);
break;

case ProtobufFieldDescriptor::TYPE_UINT64:
case ProtobufFieldDescriptor::TYPE_FIXED64:
json[field->name()] = reflection->GetUInt64(message, field);
break;

case ProtobufFieldDescriptor::TYPE_FLOAT:
char str[kSignificantDigits];
if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1) {
json[field->name()] = str;
} else {
json[field->name()] = reflection->GetFloat(message, field);
}

break;

case ProtobufFieldDescriptor::TYPE_STRING:
json[field->name()] = reflection->GetString(message, field);
break;

case ProtobufFieldDescriptor::TYPE_BYTES: {
string field_name = field->name();
string type_bytes = reflection->GetString(message, field);
json[field_name] = TypeBytes2String(field_name, type_bytes);
break;
}

default:
break;
}
}

string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) {
if (field_name != "offset") {
return type_bytes;
}
string result = "";
for (char temp_value : type_bytes) {
uint8_t *value = 0;
value = reinterpret_cast<uint8_t *>(&temp_value);
char str[kSignificantDigits];
if (sprintf_s(str, kSignificantDigits, "%d", *value) == -1) {
GELOGW("Convert bytes to string fail, filed name:%s", field_name.c_str());
continue;
}
result += str;
}
return result;
}

void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field,
const ProtobufReflection *reflection, const set<string> &black_fields, Json &json,
bool enum2str) {
if ((field == nullptr) || (reflection == nullptr)) {
Message2Json(message, black_fields, json, enum2str);
return;
}

for (auto i = 0; i < reflection->FieldSize(message, field); ++i) {
Json tmp_json;
switch (field->type()) {
case ProtobufFieldDescriptor::TYPE_MESSAGE: {
const ProtobufMsg &tmp_message = reflection->GetRepeatedMessage(message, field, i);
if (0 != tmp_message.ByteSize()) {
Message2Json(tmp_message, black_fields, tmp_json, enum2str);
}
} break;

case ProtobufFieldDescriptor::TYPE_BOOL:
tmp_json = reflection->GetRepeatedBool(message, field, i);
break;

case ProtobufFieldDescriptor::TYPE_ENUM: {
auto *enum_value_desc = reflection->GetRepeatedEnum(message, field, i);
RepeatedEnum2Json(enum_value_desc, enum2str, tmp_json);
} break;

case ProtobufFieldDescriptor::TYPE_INT32:
case ProtobufFieldDescriptor::TYPE_SINT32:
case ProtobufFieldDescriptor::TYPE_SFIXED32:
tmp_json = reflection->GetRepeatedInt32(message, field, i);
break;

case ProtobufFieldDescriptor::TYPE_UINT32:
case ProtobufFieldDescriptor::TYPE_FIXED32:
tmp_json = reflection->GetRepeatedUInt32(message, field, i);
break;

case ProtobufFieldDescriptor::TYPE_INT64:
case ProtobufFieldDescriptor::TYPE_SINT64:
case ProtobufFieldDescriptor::TYPE_SFIXED64:
tmp_json = reflection->GetRepeatedInt64(message, field, i);
break;

case ProtobufFieldDescriptor::TYPE_UINT64:
case ProtobufFieldDescriptor::TYPE_FIXED64:
tmp_json = reflection->GetRepeatedUInt64(message, field, i);
break;

case ProtobufFieldDescriptor::TYPE_FLOAT:
tmp_json = reflection->GetRepeatedFloat(message, field, i);
break;

case ProtobufFieldDescriptor::TYPE_STRING:
case ProtobufFieldDescriptor::TYPE_BYTES:
tmp_json = reflection->GetRepeatedString(message, field, i);
break;

default:
break;
}
json += tmp_json;
}
}

void Pb2Json::Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field,
bool enum2str, Json &json) {
if (enum_value_desc != nullptr) {
if (field == nullptr) {
return;
}
if (enum2str) {
json[field->name()] = enum_value_desc->name();
} else {
json[field->name()] = enum_value_desc->number();
}
}
}

void Pb2Json::RepeatedEnum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, bool enum2str, Json &json) {
if (enum_value_desc != nullptr) {
if (enum2str) {
json = enum_value_desc->name();
} else {
json = enum_value_desc->number();
}
}
}
} // namespace ge

+ 0
- 68
src/ge/common/convert/pb2json.h View File

@@ -1,68 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// File: pb2json.h
// Description: This header file for protobuf message and json interconversion

#ifndef GE_COMMON_CONVERT_PB2JSON_H_
#define GE_COMMON_CONVERT_PB2JSON_H_
#include <functional>
#include <memory>
#include <set>
#include <string>
#include "google/protobuf/descriptor.h"
#include "google/protobuf/message.h"
#include "nlohmann/json.hpp"

namespace ge {
using Json = nlohmann::json;
using ProtobufMsg = ::google::protobuf::Message;
using ProtobufReflection = ::google::protobuf::Reflection;
using ProtobufFieldDescriptor = ::google::protobuf::FieldDescriptor;
using ProtobufDescriptor = ::google::protobuf::Descriptor;
using ProtobufEnumValueDescriptor = ::google::protobuf::EnumValueDescriptor;

class Pb2Json {
public:
/**
* @ingroup domi_omg
* @brief Transfer protobuf object to JSON object
* @param [out] json Converted JSON object
* @return void success
* @author
*/
static void Message2Json(const ProtobufMsg &message, const std::set<std::string> &black_fields, Json &json,
bool enum2str = false);

protected:
static void RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field,
const ProtobufReflection *reflection, const std::set<std::string> &black_fields,
Json &json, bool enum2str);

static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field,
bool enum2str, Json &json);

static void RepeatedEnum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, bool enum2str, Json &json);

static void OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field,
const ProtobufReflection *reflection, const std::set<std::string> &black_fields, Json &json,
bool enum2str);

static std::string TypeBytes2String(std::string &field_name, std::string &type_bytes);
};
} // namespace ge

#endif // GE_COMMON_CONVERT_PB2JSON_H_

+ 1
- 1
src/ge/common/dump/dump_properties.cc View File

@@ -201,7 +201,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperti
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpOpSwitch(
const std::string &dump_op_switch) {
const std::string dump_op_switch) {
dump_op_switch_ = dump_op_switch;
}



+ 1
- 1
src/ge/common/dump/dump_properties.h View File

@@ -65,7 +65,7 @@ class DumpProperties {

const std::string &GetDumpStatus() const;

void SetDumpOpSwitch(const std::string &dump_op_switch);
void SetDumpOpSwitch(const std::string dump_op_switch);

const std::string &GetDumpOpSwitch() const;



+ 0
- 36
src/ge/common/ge/tbe_plugin_manager.cc View File

@@ -94,13 +94,6 @@ void TBEPluginManager::ProcessSoFullName(vector<string> &file_list, string &caff
full_name.compare(full_name.size() - caffe_parser_so_suff.size(), caffe_parser_so_suff.size(),
caffe_parser_so_suff) == 0) {
caffe_parser_path = full_name;
} else if ((full_name.size() >= aicpu_so_suff.size() &&
full_name.compare(full_name.size() - aicpu_so_suff.size(), aicpu_so_suff.size(), aicpu_so_suff) == 0) ||
(full_name.size() >= aicpu_host_so_suff.size() &&
full_name.compare(full_name.size() - aicpu_host_so_suff.size(), aicpu_host_so_suff.size(),
aicpu_host_so_suff) == 0)) {
// aicpu so, Put the file path into the omgcontext and save into the model in the builder stage.
domi::GetContext().aicpu_op_run_paths.push_back(full_name);
} else {
// Save parser so path into file_list vector
file_list.push_back(full_name);
@@ -230,39 +223,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginManager::LoadPlug
}
}

Status TBEPluginManager::CheckCustomAiCpuOpLib() {
std::vector<std::string> vec_op_type;

domi::OpRegistry::Instance()->GetOpTypeByImplyType(vec_op_type, domi::ImplyType::CUSTOM);
for (size_t i = 0; i < vec_op_type.size(); i++) {
bool aicpu_so_exist = false;
std::string ai_cpu_so_name = "lib" + vec_op_type[i] + "_aicpu.so";
for (size_t j = 0; j < domi::GetContext().aicpu_op_run_paths.size(); j++) {
string bin_file_path = domi::GetContext().aicpu_op_run_paths[j];
if (bin_file_path.size() >= ai_cpu_so_name.size() &&
bin_file_path.compare(bin_file_path.size() - ai_cpu_so_name.size(), ai_cpu_so_name.size(), ai_cpu_so_name) ==
0) {
aicpu_so_exist = true;
break;
}
}
if (!aicpu_so_exist) {
GELOGE(FAILED, "Can't find aicpu run so(%s), please check the plugin path!", ai_cpu_so_name.c_str());
return FAILED;
}
}
return SUCCESS;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginManager::InitPreparation(
const std::map<string, string> &options) {
options_.insert(options.begin(), options.end());
// Load TBE plugin
TBEPluginManager::Instance().LoadCustomOpLib();
Status ret = CheckCustomAiCpuOpLib();
if (ret != SUCCESS) {
GELOGE(ret, "Check custom aicpu run so failed!");
return;
}
}
} // namespace ge

+ 0
- 1
src/ge/common/ge/tbe_plugin_manager.h View File

@@ -62,7 +62,6 @@ class TBEPluginManager {
static void GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path);
static void GetCustomOpPath(std::string &customop_path);
void LoadCustomOpLib();
static Status CheckCustomAiCpuOpLib();

SoHandlesVec handles_vec_;
static std::map<string, string> options_;


+ 4
- 1
src/ge/common/ge_common.mk View File

@@ -71,7 +71,10 @@ GE_COMMON_LOCAL_C_INCLUDES := \
$(TOPDIR)third_party/openssl/include/x86/include \
$(TOPDIR)framework/domi \
$(TOPDIR)framework/domi/common \
$(TOPDIR)framework/domi/common/op
$(TOPDIR)framework/domi/common/op \
$(TOPDIR)graphengine/ge \
$(TOPDIR)graphengine/ge/common \
$(TOPDIR)graphengine/ge/common/op \

#compile host libge_common
include $(CLEAR_VARS)


+ 0
- 1
src/ge/common/helper/model_cache_helper.cc View File

@@ -1497,7 +1497,6 @@ Status ModelCacheHelper::ParseMemResourceFromJson(const Json &json, map<rtMemTyp
}
mem_resource.clear();
for (const Json &mem_resource_json : json) {
MemResource var_addr_mgr;
try {
rtMemType_t mem_type = mem_resource_json[kMemType].get<rtMemType_t>();
uint64_t var_mem_size = mem_resource_json[kVarMemSize].get<int64_t>();


+ 1
- 0
src/ge/common/op/attr_value_util.cc View File

@@ -17,6 +17,7 @@
#include "framework/common/op/attr_value_util.h"
#include "framework/common/debug/log.h"
#include "framework/common/util.h"
#include "register/register_types.h"

namespace ge {
#define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \


+ 1
- 0
src/ge/common/op/ge_op_utils.cc View File

@@ -27,6 +27,7 @@
#include "framework/common/ge_inner_error_codes.h"
#include "framework/common/op/attr_value_util.h"
#include "framework/common/util.h"
#include "framework/common/types.h"
#include "graph/anchor.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/op_desc_utils.h"


+ 14
- 21
src/ge/common/profiling/profiling_manager.cc View File

@@ -353,20 +353,18 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProf
}
uint64_t module = GetProfilingModule();
int32_t device_num = static_cast<int32_t>(device_id_.size());
uint32_t *device_id_ptr = new (std::nothrow) uint32_t[device_num];
auto device_id_ptr = std::unique_ptr<uint32_t[]>(new (std::nothrow) uint32_t[device_num]);
if (device_id_ptr == nullptr) {
GELOGE(FAILED, "Stop profiling device id ptr is null.");
GELOGE(FAILED, "Stop profiling: device id ptr is null.");
return;
}
for (int32_t i = 0; i < device_num; i++) {
device_id_ptr[i] = static_cast<uint32_t>(device_id_[i]);
}
rtError_t rt_ret = rtProfilerStop(module, device_num, device_id_ptr);
rtError_t rt_ret = rtProfilerStop(module, device_num, device_id_ptr.get());
if (rt_ret != RT_ERROR_NONE) {
GELOGW("Call rtProfilerStop failed, ret:%d", rt_ret);
}
delete[] device_id_ptr;
device_id_ptr = nullptr;

for (size_t i = 0; i < prof_handle_vec_.size(); ++i) {
int result = ProfMgrStop(prof_handle_vec_[i]);
@@ -732,23 +730,21 @@ ProfilingManager::ProfStartProfiling(uint64_t module, const std::map<std::string
GELOGE(FAILED, "Prof start parse param failed.");
return FAILED;
}
auto *device_id = new (std::nothrow) uint32_t[device_num];
if (device_id == nullptr) {
GELOGE(FAILED, "Prof start parse param failed.");

auto device_id_ptr = std::unique_ptr<uint32_t[]>(new (std::nothrow) uint32_t[device_num]);
if (device_id_ptr == nullptr) {
GELOGE(FAILED, "Prof start: device id ptr is null.");
return FAILED;
}
for (int32_t i = 0; i < device_num; i++) {
device_id[i] = static_cast<uint32_t>(device_list[i]);
device_id_ptr[i] = static_cast<uint32_t>(device_list[i]);
}
GELOGI("Runtime config param: 0x%llx, device num: %d.", module, device_num);
rtError_t rt_ret = rtProfilerStart(module, device_num, device_id);
rtError_t rt_ret = rtProfilerStart(module, device_num, device_id_ptr.get());
if (rt_ret != RT_ERROR_NONE) {
delete[] device_id;
GELOGE(FAILED, "Runtime profiler config proc failed.");
return FAILED;
}
delete[] device_id;
device_id = nullptr;
if ((module & PROF_MODEL_EXECUTE_MASK) == PROF_MODEL_EXECUTE_MASK) {
for (int32_t i = 0; i < device_num; i++) {
if (std::find(device_id_.begin(), device_id_.end(), device_list[i]) == device_id_.end()) {
@@ -776,23 +772,20 @@ ProfilingManager::ProfStopProfiling(uint64_t module, const std::map<std::string,
GELOGE(FAILED, "Prof stop parse param failed.");
return FAILED;
}
auto *device_id = new (std::nothrow) uint32_t[device_num];
if (device_id == nullptr) {
GELOGE(FAILED, "Prof stop parse param failed.");
auto device_id_ptr = std::unique_ptr<uint32_t[]>(new (std::nothrow) uint32_t[device_num]);
if (device_id_ptr == nullptr) {
GELOGE(FAILED, "Prof stop: device id ptr is null.");
return FAILED;
}
for (int32_t i = 0; i < device_num; i++) {
device_id[i] = static_cast<uint32_t>(device_list[i]);
device_id_ptr[i] = static_cast<uint32_t>(device_list[i]);
}
GELOGI("Prof stop: runtime config param: 0x%llx, device num: %d", module, device_num);
rtError_t rt_ret = rtProfilerStop(module, device_num, device_id);
rtError_t rt_ret = rtProfilerStop(module, device_num, device_id_ptr.get());
if (rt_ret != RT_ERROR_NONE) {
delete[] device_id;
GELOGE(FAILED, "Prof stop: runtime profiler config proc failed.");
return FAILED;
}
delete[] device_id;
device_id = nullptr;
uint64_t execute_model_mask = module & PROF_MODEL_EXECUTE_MASK;
if (execute_model_mask == PROF_MODEL_EXECUTE_MASK) {
for (int32_t i = 0; i < device_num; i++) {


+ 1
- 0
src/ge/common/types.cc View File

@@ -384,6 +384,7 @@ REGISTER_OPTYPE_DEFINE(HCOMREDUCESCATTER, "HcomReduceScatter");
REGISTER_OPTYPE_DEFINE(HCOMSEND, "HcomSend");
REGISTER_OPTYPE_DEFINE(HCOMRECEIVE, "HcomReceive");
REGISTER_OPTYPE_DEFINE(HCOMREMOTEREAD, "HcomRemoteRead");
REGISTER_OPTYPE_DEFINE(HCOMREMOTEREFREAD, "HcomRemoteRefRead");
REGISTER_OPTYPE_DEFINE(HCOMREMOTEWRITE, "HcomRemoteWrite");

REGISTER_OPTYPE_DEFINE(VARASSIGN, "VarAssign");


+ 3
- 4
src/ge/common/util.cc View File

@@ -54,8 +54,7 @@ const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M

/// The maximum length of the file.
/// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1
const int kMaxFileSizeLimit = INT_MAX;
const uint32_t kMaxFileSizeLimit = UINT32_MAX; // 4G for now
const int kMaxBuffSize = 256;
const char *const kPathValidReason = "The path can only contain 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese character";
constexpr uint32_t kMaxConfigFileByte = 10 * 1024 * 1024;
@@ -186,7 +185,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(co
std::streamsize size = file.tellg();

GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((size <= 0), file.close(); return false, "file length <= 0, not valid.");
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size > kMaxFileSizeLimit, file.close();
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size > static_cast<int64_t>(kMaxFileSizeLimit), file.close();
return false, "file size %ld is out of limit: %d.", size, kMaxFileSizeLimit);

file.seekg(0, std::ios::beg); // [no need to check value]
@@ -304,7 +303,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const cha
return ret;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestap() {
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() {
struct timeval tv {};
int ret = gettimeofday(&tv, nullptr);
GE_LOGE_IF(ret != 0, "Func gettimeofday may failed: ret=%d", ret);


+ 3
- 3
src/ge/engine_manager/dnnengine_manager.cc View File

@@ -216,9 +216,9 @@ std::string DNNEngineManager::GetDNNEngineName(const ge::NodePtr &node_ptr) {
if (kernel_info_store != kernel_map.end()) {
std::string unsupported_reason;
// It will be replaced by engine' checksupport
uint64_t start_time = GetCurrentTimestap();
uint64_t start_time = GetCurrentTimestamp();
if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) {
checksupport_cost_[kernel_name] += GetCurrentTimestap() - start_time;
checksupport_cost_[kernel_name] += GetCurrentTimestamp() - start_time;
op_desc->SetOpEngineName(it.engine);
op_desc->SetOpKernelLibName(kernel_name);
// set attrs for taking information when load txt to graph object
@@ -228,7 +228,7 @@ std::string DNNEngineManager::GetDNNEngineName(const ge::NodePtr &node_ptr) {
it.engine.c_str(), op_desc->GetName().c_str());
return it.engine;
} else {
checksupport_cost_[kernel_name] += GetCurrentTimestap() - start_time;
checksupport_cost_[kernel_name] += GetCurrentTimestamp() - start_time;
bool is_custom_op = false;
if ((ge::AttrUtils::GetBool(op_desc, kCustomOpFlag, is_custom_op)) && is_custom_op) {
ErrorManager::GetInstance().ATCReportErrMessage("E13001", {"kernelname", "optype", "opname"},


+ 7
- 0
src/ge/engine_manager/engine_conf.json View File

@@ -42,6 +42,13 @@
"attach": true
},
{
"id": "DNN_VM_AICPU_ASCEND",
"name": "AICPU_ASCEND",
"independent": false,
"skip_assign_stream": false,
"attach": true
},
{
"id": "DNN_HCCL",
"name": "HCCL",
"independent": true,


+ 29
- 13
src/ge/executor/ge_executor.cc View File

@@ -38,6 +38,7 @@
#include "single_op/single_op_manager.h"
#include "graph/manager/graph_var_manager.h"
#include "graph/load/new_model_manager/davinci_model.h"
#include "opskernel_manager/ops_kernel_builder_manager.h"

using std::string;
using std::vector;
@@ -241,12 +242,16 @@ Status GeExecutor::Initialize() {
}

std::vector<rtMemType_t> mem_type(1, RT_MEMORY_HBM);
mem_type.push_back(RT_MEMORY_P2P_DDR);
auto ret = MemManager::Instance().Initialize(mem_type);
if (ret != SUCCESS) {
GELOGE(ret, "Memory Manager init failed.");
return ret;
}

GE_CHK_STATUS_RET(OpsKernelBuilderManager::Instance().Initialize({}, false),
"Failed to initialize OpsKernelBuilders");

// Start profiling
Options profiling_options;
profiling_options.device_id = 0;
@@ -265,6 +270,8 @@ Status GeExecutor::Finalize() {
return ge::SUCCESS;
}

(void)OpsKernelBuilderManager::Instance().Finalize();

// Stop profiling
if (ProfilingManager::Instance().ProfilingOn()) {
ProfilingManager::Instance().StopProfiling();
@@ -282,11 +289,14 @@ Status GeExecutor::SetDynamicBatchSize(uint32_t model_id, void *dynamic_input_ad
return PARAM_INVALID;
}

uint64_t size = sizeof(uint64_t);
uint64_t size = sizeof(uint32_t);
if (length < size) {
GELOGE(PARAM_INVALID, "Dynamic input size [%lu] is less than [%lu]!", length, size);
return PARAM_INVALID;
}
if (length >= sizeof(uint64_t)) {
size = sizeof(uint64_t);
}

// Verify whether the input dynamic batch matches the model gear
std::vector<std::vector<int64_t>> batch_info;
@@ -324,12 +334,15 @@ Status GeExecutor::SetDynamicImageSize(uint32_t model_id, void *dynamic_input_ad
return PARAM_INVALID;
}

uint64_t dynamic_input_size = kDynamicImageSizeInputSize * sizeof(uint64_t);
uint64_t dynamic_input_size = kDynamicImageSizeInputSize * sizeof(uint32_t);
if (length < dynamic_input_size) {
GELOGE(PARAM_INVALID, "Dynamic input size [%lu] is less than [%lu]!", length, dynamic_input_size);
return PARAM_INVALID;
}

uint64_t size = sizeof(uint32_t);
if (length >= kDynamicImageSizeInputSize * sizeof(uint64_t)) {
size = sizeof(uint64_t);
}
// Verify whether the input dynamic resolution matches the model gear
std::vector<std::vector<int64_t>> batch_info;
std::vector<uint64_t> batch_num{image_height, image_width};
@@ -350,18 +363,18 @@ Status GeExecutor::SetDynamicImageSize(uint32_t model_id, void *dynamic_input_ad
GELOGE(ret, "Set dynamic size failed");
return ret;
}

// Memcpy dynamic resolution height from host to device
rtError_t rt_ret =
rtMemcpy(dynamic_input_addr, sizeof(uint64_t), &image_height, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE);
rtError_t rt_ret = rtMemcpy(dynamic_input_addr, size, &image_height, size, RT_MEMCPY_HOST_TO_DEVICE);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "memcpy dynamic resolution input data failed! ret: 0x%X", rt_ret);
return RT_ERROR_TO_GE_STATUS(rt_ret);
}

uint64_t remain_size = length - sizeof(uint64_t);
uint64_t remain_size = length - size;
// Memcpy dynamic resolution width from host to device
if (rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(dynamic_input_addr) + sizeof(uint64_t)),
remain_size, &image_width, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE) != RT_ERROR_NONE) {
if (rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(dynamic_input_addr) + size), remain_size,
&image_width, size, RT_MEMCPY_HOST_TO_DEVICE) != RT_ERROR_NONE) {
GELOGE(FAILED, "memcpy dynamic resolution input data failed!");
return FAILED;
}
@@ -401,16 +414,19 @@ Status GeExecutor::SetDynamicDims(uint32_t model_id, void *dynamic_input_addr, u
}

size_t dynamic_dim_num = cur_dynamic_dims.size();
uint64_t dynamic_input_size = static_cast<uint64_t>(dynamic_dim_num * sizeof(uint64_t));
uint64_t dynamic_input_size = static_cast<uint64_t>(dynamic_dim_num * sizeof(uint32_t));
if (length < dynamic_input_size) {
GELOGE(FAILED, "Dynamic input size [%lu] is less than [%lu]!", length, dynamic_input_size);
return FAILED;
}
uint64_t size = sizeof(uint32_t);
if (length >= dynamic_dim_num * sizeof(uint64_t)) {
size = sizeof(uint64_t);
}
for (uint32_t i = 0; i < dynamic_dim_num; ++i) {
// Memcpy dynamic dim[i] from host to device
if (rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(dynamic_input_addr) + sizeof(uint64_t) * i),
length - sizeof(uint64_t) * i, &cur_dynamic_dims[i], sizeof(uint64_t),
RT_MEMCPY_HOST_TO_DEVICE) != RT_ERROR_NONE) {
if (rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(dynamic_input_addr) + size * i),
length - size * i, &cur_dynamic_dims[i], size, RT_MEMCPY_HOST_TO_DEVICE) != RT_ERROR_NONE) {
GELOGE(FAILED, "memcpy dynamic resolution input data failed!");
return FAILED;
}
@@ -1113,7 +1129,7 @@ Status GeExecutor::SetDump(const DumpConfig &dump_config) {
GELOGE(ret, "Set dump conf failed");
return ret;
}
GELOGI("Set dump config succ.");
GELOGI("Set dump config successfully");
return SUCCESS;
}
} // namespace ge

+ 11
- 4
src/ge/executor/module.mk View File

@@ -50,6 +50,7 @@ local_ge_executor_src_files := \
../graph/load/new_model_manager/task_info/end_graph_task_info.cc \
../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc \
../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc \
../opskernel_manager/ops_kernel_builder_manager.cc \
../single_op/single_op_manager.cc \
../single_op/single_op_model.cc \
../single_op/single_op.cc \
@@ -74,6 +75,7 @@ local_ge_executor_c_include := \
$(TOPDIR)inc/framework \
$(TOPDIR)inc \
$(LOCAL_PATH)/../ \
$(TOPDIR)graphengine/ge \
$(TOPDIR)libc_sec/include \
third_party/protobuf/include \
third_party/json/include \
@@ -89,7 +91,6 @@ local_ge_executor_shared_library := \
libregister \
libmsprof \
liberror_manager \
libascend_hal

local_ge_executor_ldflags := -lrt -ldl \

@@ -105,7 +106,12 @@ LOCAL_SRC_FILES := $(local_ge_executor_src_files)
LOCAL_C_INCLUDES := $(local_ge_executor_c_include)

LOCAL_SHARED_LIBRARIES := $(local_ge_executor_shared_library)
LOCAL_STATIC_LIBRARIES := libmsprofiler

LOCAL_SHARED_LIBRARIES += libascend_hal

LOCAL_STATIC_LIBRARIES := \
libmsprofiler \

ifeq ($(device_os),android)
LOCAL_LDFLAGS += -ldl
LOCAL_LDLIBS += -L$(PWD)/prebuilts/clang/linux-x86/aarch64/android-ndk-r21/sysroot/usr/lib/aarch64-linux-android/29 -llog
@@ -142,9 +148,10 @@ LOCAL_SHARED_LIBRARIES := \
libregister \
libmsprof \
liberror_manager \
stub/libascend_hal
stub/libascend_hal \

LOCAL_STATIC_LIBRARIES := libmsprofiler
LOCAL_STATIC_LIBRARIES := \
libmsprofiler \

LOCAL_LDFLAGS += $(local_ge_executor_ldflags)



+ 5
- 0
src/ge/ge_inference.mk View File

@@ -42,6 +42,7 @@ GRAPH_MANAGER_LOCAL_SRC_FILES := \
session/session_manager.cc \
engine_manager/dnnengine_manager.cc \
opskernel_manager/ops_kernel_manager.cc \
opskernel_manager/ops_kernel_builder_manager.cc \
graph/manager/graph_manager.cc \
graph/manager/graph_manager_utils.cc \
graph/manager/graph_context.cc \
@@ -57,9 +58,11 @@ GRAPH_MANAGER_LOCAL_SRC_FILES := \
graph/partition/engine_place.cc \
graph/partition/graph_partition.cc \
graph/partition/dynamic_shape_partition.cc \
graph/partition/stage_partition.cc \
generator/ge_generator.cc \
generator/generator_api.cc \
graph/manager/graph_var_manager.cc \
graph/manager/host_mem_manager.cc \
graph/manager/rdma_pool_allocator.cc \
graph/manager/graph_mem_allocator.cc \
graph/manager/graph_caching_allocator.cc \
@@ -178,6 +181,7 @@ OMG_HOST_SRC_FILES := \
graph/passes/multi_batch_pass.cc \
graph/passes/multi_batch_clone_pass.cc \
graph/passes/subexpression_migration_pass.cc \
graph/passes/subgraph_const_migration_pass.cc \
graph/passes/unused_args_clean_pass.cc \
graph/passes/next_iteration_pass.cc \
graph/passes/control_trigger_pass.cc \
@@ -343,6 +347,7 @@ DEVICE_LOCAL_C_INCLUDES := \
$(TOPDIR)inc/runtime \
$(TOPDIR)ops/built-in/op_proto/inc \
$(TOPDIR)framework/domi \
$(TOPDIR)graphengine/ge \
$(TOPDIR)toolchain/ide/ide-daemon/external \
third_party/json/include \
third_party/protobuf/include \


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save