You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_variable_v2_parser.cc 9.1 kB

4 years ago

  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "framework/common/debug/ge_log.h"
  17. #include "framework/common/debug/log.h"
  18. #include "framework/common/op/ge_op_utils.h"
  19. #include "graph/compute_graph.h"
  20. #include "graph/ge_attr_value.h"
  21. #include "graph/ge_tensor.h"
  22. #include "graph/op_desc.h"
  23. #include "graph/operator.h"
  24. #include "graph/utils/attr_utils.h"
  25. #include "graph/utils/tensor_utils.h"
  26. #include "parser/common/op_def/variable_op.h"
  27. #include "parser/common/op_parser_factory.h"
  28. #include "parser/tensorflow/tensorflow_op_parser.h"
  29. #include "parser/tensorflow/tensorflow_parser_register.h"
  30. using domi::tensorflow::AttrValue;
  31. using domi::tensorflow::NodeDef;
  32. using domi::tensorflow::TensorShapeProto;
  33. namespace ge {
  34. const std::string SERIALIZE_FORMAT = "serialize_format";
  35. /* Original definition of variablev2 operator
  36. node_def {
  37. name: "Variable_7/Momentum"
  38. op: "VariableV2"
  39. device: "/job:localhost/replica:0/task:0/device:CPU:0"
  40. attr {
  41. key: "_class"
  42. value {
  43. list {
  44. s: "loc:@Variable_7"
  45. }
  46. }
  47. }
  48. attr {
  49. key: "_var_format"
  50. value {
  51. s: "4D"
  52. }
  53. }
  54. attr {
  55. key: "container"
  56. value {
  57. s: ""
  58. }
  59. }
  60. attr {
  61. key: "dtype"
  62. value {
  63. type: DT_FLOAT
  64. }
  65. }
  66. attr {
  67. key: "shape"
  68. value {
  69. shape {
  70. dim {
  71. size: 10
  72. }
  73. }
  74. }
  75. }
  76. attr {
  77. key: "shared_name"
  78. value {
  79. s: ""
  80. }
  81. }
  82. }
  83. */
  84. static Status ParseSrcType(const domi::tensorflow::NodeDef *node, VariableOperator *op) {
  85. // The upper caller guarantees input params is not empty.
  86. domi::tensorflow::AttrValue attr;
  87. CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, VAR_ATTR_DTYPE, attr),
  88. GELOGE(FAILED, "Attr %s does not exist in NodeDef %s.",
  89. VAR_ATTR_DTYPE.c_str(), node->name().c_str());
  90. return PARAM_INVALID);
  91. GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, TENSORFLOW_ATTR_TYPE_TYPE),
  92. "check Attr type failed");
  93. domi::tensorflow::DataType tf_type = attr.type();
  94. ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tf_type);
  95. CHECK_FALSE_EXEC(type != ge::DataType::DT_UNDEFINED, GELOGE(FAILED, "Data type %s of node %s is not supported.",
  96. DataType_Name(tf_type).c_str(), node->name().c_str());
  97. return PARAM_INVALID);
  98. op->SrcType(type);
  99. return SUCCESS;
  100. }
  101. Status ParseContainer(const domi::tensorflow::NodeDef *node, VariableOperator *op) {
  102. // The upper caller guarantees input params is not empty.
  103. domi::tensorflow::AttrValue attr;
  104. CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, VAR_ATTR_CONTAINER, attr),
  105. GELOGE(FAILED, "Attr %s does not exist in NodeDef %s.",
  106. VAR_ATTR_CONTAINER.c_str(), node->name().c_str());
  107. return PARAM_INVALID);
  108. GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, TENSORFLOW_ATTR_TYPE_STRING),
  109. "check Attr s failed");
  110. std::string container = attr.s();
  111. op->Container(container);
  112. return SUCCESS;
  113. }
  114. Status ParseSharedName(const domi::tensorflow::NodeDef *node, VariableOperator *op) {
  115. // The upper caller guarantees input params is not empty.
  116. domi::tensorflow::AttrValue attr;
  117. CHECK_FALSE_EXEC(
  118. TensorFlowUtil::FindAttrValue(node, VAR_ATTR_SHARED_NAME, attr),
  119. GELOGE(FAILED, "Attr %s does not exist in NodeDef %s.", VAR_ATTR_SHARED_NAME.c_str(), node->name().c_str());
  120. return PARAM_INVALID);
  121. GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, TENSORFLOW_ATTR_TYPE_STRING),
  122. "check Attr s failed");
  123. std::string shared_name = attr.s();
  124. op->SharedName(shared_name);
  125. return SUCCESS;
  126. }
  127. static Status ParseVarName(const domi::tensorflow::NodeDef *node, VariableOperator *op) {
  128. // The upper caller guarantees input params is not empty.
  129. domi::tensorflow::AttrValue attr;
  130. CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, ge::VAR_ATTR_NAME, attr),
  131. GELOGE(FAILED, "Attr %s does not exist in NodeDef %s.", ge::VAR_ATTR_NAME.c_str(),
  132. node->name().c_str()); return PARAM_INVALID);
  133. GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, TENSORFLOW_ATTR_TYPE_STRING),
  134. "check Attr s failed");
  135. std::string var_name = attr.s();
  136. op->SharedName(var_name);
  137. return SUCCESS;
  138. }
  139. static Status InitOutTensor(const vector<int64_t> &shape, int64_t data_type, ge::GeTensorDesc &out_tensor_desc,
  140. ge::Format format) {
  141. out_tensor_desc.SetFormat(format);
  142. out_tensor_desc.SetDataType((ge::DataType)data_type);
  143. ge::TensorUtils::SetReuseInput(out_tensor_desc, false);
  144. ge::TensorUtils::SetRealDimCnt(out_tensor_desc, shape.size());
  145. out_tensor_desc.SetShape(ge::GeShape(shape));
  146. int64_t size = out_tensor_desc.GetShape().GetShapeSize();
  147. size *= sizeof(float);
  148. ge::TensorUtils::SetSize(out_tensor_desc, size);
  149. return SUCCESS;
  150. }
  151. static Status ParseVarShape(const domi::tensorflow::NodeDef *node, VariableOperator *op) {
  152. // The upper caller guarantees input params is not empty.
  153. string node_src_name = node->name();
  154. domi::tensorflow::AttrValue attr_value;
  155. if (!TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, attr_value)) {
  156. GELOGE(FAILED, "In NodeDef %s Attr %s is not exist.", node_src_name.c_str(),
  157. ge::ATTR_NAME_OUTPUT_TENSOR_DESC.c_str());
  158. return FAILED;
  159. }
  160. ge::GeTensorDesc infer_shape_domi_desc;
  161. domi::tensorflow::AttrValue_ListValue attr_list = attr_value.list();
  162. int32_t tf_datatype = 0;
  163. GE_CHK_BOOL_RET_STATUS(TensorFlowUtil::ParseFromAttrValueList(infer_shape_domi_desc, attr_list, 0, tf_datatype),
  164. PARAM_INVALID, "parse domi_desc failed.");
  165. ge::Format src_format = ge::FORMAT_ND;
  166. CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, VAR_ATTR_SHAPE, attr_value),
  167. GELOGE(FAILED, "Attr %s does not exist in NodeDef %s.", VAR_ATTR_SHAPE.c_str(),
  168. node->name().c_str()); return PARAM_INVALID);
  169. GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr_value, TENSORFLOW_ATTR_TYPE_SHAPE),
  170. "check Attr s failed");
  171. const TensorShapeProto &data_shape = attr_value.shape();
  172. vector<int64_t> var_dims_v;
  173. for (int32_t i = 0; i < data_shape.dim_size(); i++) {
  174. var_dims_v.push_back(data_shape.dim(i).size());
  175. }
  176. op->VarShape(var_dims_v);
  177. ge::GeTensorDesc out_tensor_desc;
  178. GE_RETURN_WITH_LOG_IF_ERROR(InitOutTensor(var_dims_v, op->GetVarSrcType(), out_tensor_desc, src_format),
  179. "Init Output Tensor failed");
  180. op->OutputTensorDesc(out_tensor_desc);
  181. return SUCCESS;
  182. }
  183. static void ParsePlacement(const domi::tensorflow::NodeDef *node, VariableOperator *op) {
  184. // The upper caller guarantees input params is not empty.
  185. string node_src_name = node->name();
  186. domi::tensorflow::AttrValue attr_value;
  187. GELOGI("Start to parse placement, %s", node_src_name.c_str());
  188. if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_VARIABLE_PLACEMENT, attr_value)) {
  189. std::string placement = attr_value.s();
  190. op->Placement(placement);
  191. }
  192. }
  193. Status ParseParams(const Message *op_src, VariableOperator *op) {
  194. GE_CHECK_NOTNULL(op_src);
  195. const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src);
  196. GE_CHECK_NOTNULL(node);
  197. GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str());
  198. string node_op = node->op();
  199. if (node_op == TEMPORARYVARIABLE) {
  200. GE_RETURN_IF_ERROR(ParseVarName(node, op));
  201. } else {
  202. GE_RETURN_IF_ERROR(ParseContainer(node, op));
  203. GE_RETURN_IF_ERROR(ParseSharedName(node, op));
  204. }
  205. GE_RETURN_IF_ERROR(ParseSrcType(node, op));
  206. GE_RETURN_IF_ERROR(ParseVarShape(node, op));
  207. ParsePlacement(node, op);
  208. GELOGD("VariabeV2 OP parser params success.op name : %s.", node->name().c_str());
  209. return SUCCESS;
  210. }
  211. DOMI_REGISTER_TENSORFLOW_PARSER(VARIABLE, VariableOperator).SetParseParamsFn(ParseParams);
  212. DOMI_REGISTER_TENSORFLOW_PARSER(VARHANDLEOP, VariableOperator).SetParseParamsFn(ParseParams);
  213. DOMI_REGISTER_TENSORFLOW_PARSER(TEMPORARYVARIABLE, VariableOperator).SetParseParamsFn(ParseParams);
  214. } // namespace ge

Ascend CANN Parser(简称parser)配合TF_Adapter、 ATC工具、IR构图等使用,开发者通过以上工具,借助parser能方便地将第三方框架的算法表示转换成Ascend IR,充分利用昇腾AI处理器卓越的运算能力