/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef INC_EXTERNAL_GRAPH_OPERATOR_H_ #define INC_EXTERNAL_GRAPH_OPERATOR_H_ #include #include #include #include #include #include "./ge_error_codes.h" #include "./inference_context.h" #include "./tensor.h" #ifndef USER_GE_LOGI #define USER_GE_LOGI(...) #endif // USER_GE_LOGI #ifndef USER_GE_LOGW #define USER_GE_LOGW(...) #endif // USER_GE_LOGW #ifndef USER_GE_LOGE #define USER_GE_LOGE(...) #endif // USER_GE_LOGE #define DYNAMIC_OUTPUT_TD_NUM(name) ("__dynamic_output_" + name + "_cnt") #define DYNAMIC_INPUT_TD_NUM(name) ("__dynamic_input_" + name + "_cnt") namespace ge { class Operator; class OperatorImpl; class NodeUtils; class NamedAttrs; class Graph; class AttrValue; class Node; using SubgraphBuilder = std::function; using OperatorImplPtr = std::shared_ptr; using OperatorPtr = std::shared_ptr; class OpIO; using OutHandler = std::shared_ptr; using InHandler = std::shared_ptr; using std::function; using std::shared_ptr; using std::string; class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { public: friend class OperatorImpl; friend class GraphBuilderImpl; friend class NodeUtils; using OpInt = int64_t; using OpFloat = float; using OpString = string; using OpAscendString = AscendString; using OpBool = bool; using OpTensor = Tensor; using OpType = ge::DataType; using OpNamedAttrs = ge::NamedAttrs; using OpListInt = std::vector; using OpListFloat = std::vector; using OpListString = std::vector; using OpListAcendString = std::vector; using OpListBool = std::vector; using OpListTensor = std::vector; using OpBytes = std::vector; using OpListListInt = std::vector>; using OpListType = std::vector; using OpListNamedAttrs = std::vector; Operator() {} ATTRIBUTED_DEPRECATED(Operator(const char *)) explicit Operator(const string &type); explicit Operator(const char *type); ATTRIBUTED_DEPRECATED(Operator(const AscendString &, const AscendString &)) Operator(const string &name, const string &type); Operator(const AscendString &name, const AscendString &type); Operator(const char *name, const char *type); virtual ~Operator() = default; bool IsEmpty() const; ATTRIBUTED_DEPRECATED(graphStatus GetName(AscendString &) const) string GetName() const; graphStatus GetName(AscendString &name) const; ATTRIBUTED_DEPRECATED(graphStatus GetOpType(AscendString &) const) string GetOpType() const; graphStatus GetOpType(AscendString &type) const; // Only has one output index = 0 ATTRIBUTED_DEPRECATED(Operator &SetInput(const char *, const Operator &)) Operator &SetInput(const string &dst_name, const Operator &src_oprt); Operator &SetInput(const char *dst_name, const Operator &src_oprt); ATTRIBUTED_DEPRECATED(Operator &SetInput(const char *, const Operator &, const char *)) Operator &SetInput(const string &dst_name, const Operator &src_oprt, const string &name); Operator &SetInput(const char *dst_name, const Operator &src_oprt, const char *name); ATTRIBUTED_DEPRECATED(Operator &SetInput(const char *, const Operator &, uint32_t)) Operator &SetInput(const string &dst_name, const Operator &src_oprt, uint32_t index); Operator &SetInput(const char *dst_name, const Operator &src_oprt, uint32_t index); Operator &AddControlInput(const Operator &src_oprt); ATTRIBUTED_DEPRECATED(graphStatus GetInputConstData(const char *, Tensor &) const) graphStatus GetInputConstData(const string &dst_name, Tensor &data) const; graphStatus GetInputConstData(const char *dst_name, Tensor &data) const; ATTRIBUTED_DEPRECATED(TensorDesc GetInputDesc(const char *, uint32_t len) const) TensorDesc GetInputDesc(const string &name) const; TensorDesc GetInputDesc(const char *name, uint32_t len) const; TensorDesc GetInputDesc(uint32_t index) const; ATTRIBUTED_DEPRECATED(int GetDynamicOutputNum(const char *) const) int GetDynamicOutputNum(const string &name) const; int GetDynamicOutputNum(const char *name) const; ATTRIBUTED_DEPRECATED(int GetDynamicInputNum(const char *)) int GetDynamicInputNum(const string &name) const; int GetDynamicInputNum(const char *name) const; ATTRIBUTED_DEPRECATED(graphStatus TryGetInputDesc(const char *, TensorDesc &) const) graphStatus TryGetInputDesc(const string &name, TensorDesc &tensor_desc) const; graphStatus TryGetInputDesc(const char *name, TensorDesc &tensor_desc) const; ATTRIBUTED_DEPRECATED(graphStatus UpdateInputDesc(const char *, const TensorDesc &)) graphStatus UpdateInputDesc(const string &name, const TensorDesc &tensor_desc); graphStatus UpdateInputDesc(const char *name, const TensorDesc &tensor_desc); ATTRIBUTED_DEPRECATED(TensorDesc GetOutputDesc(const char *, uint32_t len) const) TensorDesc GetOutputDesc(const string &name) const; TensorDesc GetOutputDesc(const char *name, uint32_t len) const; TensorDesc GetOutputDesc(uint32_t index) const; ATTRIBUTED_DEPRECATED(graphStatus UpdateOutputDesc(const char *, const TensorDesc &tensor_desc)) graphStatus UpdateOutputDesc(const string &name, const TensorDesc &tensor_desc); graphStatus UpdateOutputDesc(const char *name, const TensorDesc &tensor_desc); ATTRIBUTED_DEPRECATED(TensorDesc GetDynamicInputDesc(const char *, uint32_t) const) TensorDesc GetDynamicInputDesc(const string &name, uint32_t index) const; TensorDesc GetDynamicInputDesc(const char *name, uint32_t index) const; ATTRIBUTED_DEPRECATED(graphStatus UpdateDynamicInputDesc(const char *, uint32_t, const TensorDesc &)) graphStatus UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); graphStatus UpdateDynamicInputDesc(const char *name, uint32_t index, const TensorDesc &tensor_desc); ATTRIBUTED_DEPRECATED(TensorDesc GetDynamicOutputDesc(const char *, uint32_t) const) TensorDesc GetDynamicOutputDesc(const string &name, uint32_t index) const; TensorDesc GetDynamicOutputDesc(const char *name, uint32_t index) const; ATTRIBUTED_DEPRECATED(graphStatus UpdateDynamicOutputDesc(const char *, uint32_t, const TensorDesc &)) graphStatus UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); graphStatus UpdateDynamicOutputDesc(const char *name, uint32_t index, const TensorDesc &tensor_desc); graphStatus InferShapeAndType(); void SetInferenceContext(const InferenceContextPtr &inference_context); InferenceContextPtr GetInferenceContext() const; graphStatus VerifyAllAttr(bool disable_common_verifier = false); size_t GetInputsSize() const; size_t GetOutputsSize() const; ATTRIBUTED_DEPRECATED(graphStatus GetAllAttrNamesAndTypes(std::map &) const) const std::map GetAllAttrNamesAndTypes() const; graphStatus GetAllAttrNamesAndTypes(std::map &attr_name_types) const; ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, int64_t)) Operator &SetAttr(const string &name, int64_t attr_value); ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, int32_t)) Operator &SetAttr(const string &name, int32_t attr_value); ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, uint32_t)) Operator &SetAttr(const string &name, uint32_t attr_value); ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, int64_t &) const) graphStatus GetAttr(const string &name, int64_t &attr_value) const; ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, int32_t &) const) graphStatus GetAttr(const string &name, int32_t &attr_value) const; ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, uint32_t &) const) graphStatus GetAttr(const string &name, uint32_t &attr_value) const; ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const std::vector &)) Operator &SetAttr(const string &name, const std::vector &attr_value); ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const std::vector &)) Operator &SetAttr(const string &name, const std::vector &attr_value); ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const std::vector &)) Operator &SetAttr(const string &name, const std::vector &attr_value); ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, std::initializer_list &&)) Operator &SetAttr(const string &name, std::initializer_list &&attr_value); ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *name, std::vector &) const) graphStatus GetAttr(const string &name, std::vector &attr_value) const; ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *name, std::vector &) const) graphStatus GetAttr(const string &name, std::vector &attr_value) const; ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const string &, std::vector &) const) graphStatus GetAttr(const string &name, std::vector &attr_value) const; ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, float attr_value)) Operator &SetAttr(const string &name, float attr_value); ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, float &) const) graphStatus GetAttr(const string &name, float &attr_value) const; ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const std::vector &)) Operator &SetAttr(const string &name, const std::vector &attr_value); ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, std::vector &) const) graphStatus GetAttr(const string &name, std::vector &attr_value) const; ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, AttrValue &&)) Operator &SetAttr(const string &name, AttrValue &&attr_value); ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, AttrValue &) const) graphStatus GetAttr(const string &name, AttrValue &attr_value) const; ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const AscendString &)) Operator &SetAttr(const string &name, const string &attr_value); ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const string &)) graphStatus GetAttr(const string &name, string &attr_value) const; ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const std::vector &)) Operator &SetAttr(const string &name, const std::vector &attr_value); ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const string &, std::vector &) const) graphStatus GetAttr(const string &name, std::vector &attr_value) const; ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, bool)) Operator &SetAttr(const string &name, bool attr_value); ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, bool &) const) graphStatus GetAttr(const string &name, bool &attr_value) const; ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const std::vector &)) Operator &SetAttr(const string &name, const std::vector &attr_value); ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, std::vector &) const) graphStatus GetAttr(const string &name, std::vector &attr_value) const; ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const Tensor &)) Operator &SetAttr(const string &name, const Tensor &attr_value); ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, Tensor &) const) graphStatus GetAttr(const string &name, Tensor &attr_value) const; ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const std::vector &)) Operator &SetAttr(const string &name, const std::vector &attr_value); ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, std::vector &) const) graphStatus GetAttr(const string &name, std::vector &attr_value) const; // Bytes type ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const OpBytes &)) Operator &SetAttr(const string &name, const OpBytes &attr_value); // Bytes type ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, OpBytes &) const) graphStatus GetAttr(const string &name, OpBytes &attr_value) const; ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const std::vector> &)) Operator &SetAttr(const string &name, const std::vector> &attr_value); ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, std::vector> &) const) graphStatus GetAttr(const string &name, std::vector> &attr_value) const; ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const std::vector &)) Operator &SetAttr(const string &name, const std::vector &attr_value); ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, std::vector &) const) graphStatus GetAttr(const string &name, std::vector &attr_value) const; ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const ge::DataType &)) Operator &SetAttr(const string &name, const ge::DataType &attr_value); ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, ge::DataType &) const) graphStatus GetAttr(const string &name, ge::DataType &attr_value) const; // func type ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const ge::NamedAttrs &)) Operator &SetAttr(const string &name, const ge::NamedAttrs &attr_value); ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, ge::NamedAttrs &) const) graphStatus GetAttr(const string &name, ge::NamedAttrs &attr_value) const; ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char *, const std::vector &)) Operator &SetAttr(const string &name, const std::vector &attr_value); ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char *, std::vector &) const) graphStatus GetAttr(const string &name, std::vector &attr_value) const; Operator &SetAttr(const char *name, int64_t attr_value); Operator &SetAttr(const char *name, int32_t attr_value); Operator &SetAttr(const char *name, uint32_t attr_value); graphStatus GetAttr(const char *name, int64_t &attr_value) const; graphStatus GetAttr(const char *name, int32_t &attr_value) const; graphStatus GetAttr(const char *name, uint32_t &attr_value) const; Operator &SetAttr(const char *name, const std::vector &attr_value); Operator &SetAttr(const char *name, const std::vector &attr_value); Operator &SetAttr(const char *name, const std::vector &attr_value); Operator &SetAttr(const char *name, std::initializer_list &&attr_value); graphStatus GetAttr(const char *name, std::vector &attr_value) const; graphStatus GetAttr(const char *name, std::vector &attr_value) const; graphStatus GetAttr(const char *name, std::vector &attr_value) const; Operator &SetAttr(const char *name, float attr_value); graphStatus GetAttr(const char *name, float &attr_value) const; Operator &SetAttr(const char *name, const std::vector &attr_value); graphStatus GetAttr(const char *name, std::vector &attr_value) const; Operator &SetAttr(const char *name, AttrValue &&attr_value); graphStatus GetAttr(const char *name, AttrValue &attr_value) const; Operator &SetAttr(const char *name, const char *attr_value); Operator &SetAttr(const char *name, const AscendString &attr_value); graphStatus GetAttr(const char *name, AscendString &attr_value) const; Operator &SetAttr(const char *name, const std::vector &attr_values); graphStatus GetAttr(const char *name, std::vector &attr_values) const; Operator &SetAttr(const char *name, bool attr_value); graphStatus GetAttr(const char *name, bool &attr_value) const; Operator &SetAttr(const char *name, const std::vector &attr_value); graphStatus GetAttr(const char *name, std::vector &attr_value) const; Operator &SetAttr(const char *name, const Tensor &attr_value); graphStatus GetAttr(const char *name, Tensor &attr_value) const; Operator &SetAttr(const char *name, const std::vector &attr_value); graphStatus GetAttr(const char *name, std::vector &attr_value) const; // Bytes type Operator &SetAttr(const char *name, const OpBytes &attr_value); // Bytes type graphStatus GetAttr(const char *name, OpBytes &attr_value) const; Operator &SetAttr(const char *name, const std::vector> &attr_value); graphStatus GetAttr(const char *name, std::vector> &attr_value) const; Operator &SetAttr(const char *name, const std::vector &attr_value); graphStatus GetAttr(const char *name, std::vector &attr_value) const; Operator &SetAttr(const char *name, const ge::DataType &attr_value); graphStatus GetAttr(const char *name, ge::DataType &attr_value) const; // func type Operator &SetAttr(const char *name, const ge::NamedAttrs &attr_value); graphStatus GetAttr(const char *name, ge::NamedAttrs &attr_value) const; Operator &SetAttr(const char *name, const std::vector &attr_value); graphStatus GetAttr(const char *name, std::vector &attr_value) const; void BreakConnect() const; size_t GetSubgraphNamesCount() const; ATTRIBUTED_DEPRECATED(graphStatus GetSubgraphNames(std::vector &) const) std::vector GetSubgraphNames() const; graphStatus GetSubgraphNames(std::vector &names) const; ATTRIBUTED_DEPRECATED(SubgraphBuilder GetSubgraphBuilder(const char *) const) SubgraphBuilder GetSubgraphBuilder(const string &name) const; SubgraphBuilder GetSubgraphBuilder(const char *name) const; ATTRIBUTED_DEPRECATED(Graph GetSubgraph(const char *) const) Graph GetSubgraph(const string &name) const; Graph GetSubgraph(const char *name) const; ATTRIBUTED_DEPRECATED(SubgraphBuilder GetDynamicSubgraphBuilder(const char *, uint32_t) const) SubgraphBuilder GetDynamicSubgraphBuilder(const string &name, uint32_t index) const; SubgraphBuilder GetDynamicSubgraphBuilder(const char *name, uint32_t index) const; ATTRIBUTED_DEPRECATED(Graph GetDynamicSubgraph(const char *, uint32_t) const) Graph GetDynamicSubgraph(const string &name, uint32_t index) const; Graph GetDynamicSubgraph(const char *name, uint32_t index) const; protected: void AttrRegister(const string &name, float attr_value); void AttrRegister(const string &name, const std::vector &attr_value); void AttrRegister(const string &name, int64_t attr_value); void AttrRegister(const string &name, const std::vector &attr_value); void AttrRegister(const string &name, const string &attr_value); void AttrRegister(const string &name, const std::vector &attr_value); void AttrRegister(const string &name, bool attr_value); void AttrRegister(const string &name, const std::vector &attr_value); void AttrRegister(const string &name, const Tensor &attr_value); void AttrRegister(const string &name, const std::vector &attr_value); void AttrRegister(const string &name, const OpBytes &attr_value); void AttrRegister(const string &name, const std::vector> &attr_value); void AttrRegister(const string &name, const std::vector &attr_value); void AttrRegister(const string &name, const ge::DataType &attr_value); void AttrRegister(const string &name, const ge::NamedAttrs &attr_value); void AttrRegister(const string &name, const std::vector &attr_value); explicit Operator(OperatorImplPtr &&op_impl); void InputRegister(const string &name); void OptionalInputRegister(const string &name); void InferFuncRegister(const std::function &func); void VerifierFuncRegister(const std::function &func); void InferFormatFuncRegister(const std::function &func); void OutputRegister(const string &name); void DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back = true); void DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index); void DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back = true); void RequiredAttrRegister(const string &name); graphStatus VerifyAll(); // Only has one output index = 0 Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt); Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, const string &name); void SubgraphRegister(const string &ir_name, bool dynamic); void SubgraphCountRegister(const string &ir_name, uint32_t count); void SetSubgraphBuilder(const string &ir_name, uint32_t index, const SubgraphBuilder &builder); Graph GetSubgraphImpl(const string &name) const; private: Operator &SetInput(const string &dst_name, const OutHandler &out_handler); OutHandler GetOutput(const string &name) const; OutHandler GetOutput(uint32_t index) const; OperatorImplPtr GetOperatorImplPtr() const; OperatorImplPtr operator_impl_{nullptr}; graphStatus GetInputConstDataOut(const string &dst_name, Tensor &data) const; std::shared_ptr GetNode() const; }; } // namespace ge #endif // INC_EXTERNAL_GRAPH_OPERATOR_H_