You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_fill_parser.cc 2.1 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. // Copyright (c) <2018>, <Huawei Technologies Co., Ltd>
  17. #include "common/debug/log.h"
  18. #include "common/op/attr_value_util.h"
  19. #include "parser/common/op_def/fill_op.h"
  20. #include "common/util.h"
  21. #include "parser/tensorflow/tensorflow_parser_register.h"
  22. namespace ge {
  23. /*
  24. node {
  25. name: "model_with_buckets/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/zeros"
  26. op: "Fill"
  27. input: "model_with_buckets/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/concat"
  28. input: "model_with_buckets/bidirectional_rnn/fw/fw/BasicLSTMCellZeroState/zeros/Const"
  29. device: "/device:GPU:2"
  30. attr {
  31. key: "T"
  32. value {
  33. type: DT_FLOAT
  34. }
  35. }
  36. }
  37. */
  38. domi::Status ParseParams(const NodeDef *node, FillOperator *op) {
  39. GE_CHECK_NOTNULL(node);
  40. GE_CHECK_NOTNULL(op);
  41. op->Name(node->name());
  42. domi::tensorflow::DataType data_type;
  43. GE_RETURN_IF_ERROR(TensorFlowUtil::ParseDataType(node, TENSORFLOW_ATTR_T, data_type));
  44. ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(data_type);
  45. CHECK_FALSE_EXEC(
  46. type != ge::DataType::DT_UNDEFINED,
  47. GELOGE(PARAM_INVALID, "Data type %s of node %s is not supported.", DataType_Name(data_type).c_str(),
  48. node->name().c_str());
  49. return PARAM_INVALID);
  50. op->DataType(type);
  51. op->Alpha(ge::ALPHA_DEFAULT_VALUE);
  52. op->Beta(ge::BETA_DEFAULT_VALUE);
  53. return domi::SUCCESS;
  54. }
  55. DOMI_REGISTER_TENSORFLOW_PARSER(FILL, FillOperator).SetParseParamsFn(ParseParams);
  56. } // namespace ge

Ascend CANN Parser(简称parser)配合TF_Adapter、 ATC工具、IR构图等使用,开发者通过以上工具,借助parser能方便地将第三方框架的算法表示转换成Ascend IR,充分利用昇腾AI处理器卓越的运算能力