diff --git a/lite/include/lite/global.h b/lite/include/lite/global.h index 2fde9d8b..e681ee7e 100644 --- a/lite/include/lite/global.h +++ b/lite/include/lite/global.h @@ -77,7 +77,7 @@ LITE_API bool update_decryption_or_key( * other config not inclue in config and networkIO, ParseInfoFunc can fill it * with the information in json, now support: * "device_id" : int, default 0 - * "number_threads" : size_t, default 1 + * "number_threads" : uint32_t, default 1 * "is_inplace_model" : bool, default false * "use_tensorrt" : bool, default false */ diff --git a/lite/include/lite/tensor.h b/lite/include/lite/tensor.h index 44495ce0..9ccc80c2 100644 --- a/lite/include/lite/tensor.h +++ b/lite/include/lite/tensor.h @@ -149,28 +149,42 @@ private: */ class LITE_API LiteAny { public: + enum Type { + STRING = 0, + INT32 = 1, + UINT32 = 2, + UINT8 = 3, + INT8 = 4, + INT64 = 5, + UINT64 = 6, + BOOL = 7, + VOID_PTR = 8, + FLOAT = 9, + NONE_SUPPORT = 10, + }; LiteAny() = default; template LiteAny(T value) : m_holder(new AnyHolder(value)) { - m_is_string = std::is_same(); + m_type = get_type(); } LiteAny(const LiteAny& any) { m_holder = any.m_holder->clone(); - m_is_string = any.is_string(); + m_type = any.m_type; } LiteAny& operator=(const LiteAny& any) { m_holder = any.m_holder->clone(); - m_is_string = any.is_string(); + m_type = any.m_type; return *this; } - bool is_string() const { return m_is_string; } + + template + Type get_type() const; class HolderBase { public: virtual ~HolderBase() = default; virtual std::shared_ptr clone() = 0; - virtual size_t type_length() const = 0; }; template @@ -180,7 +194,6 @@ public: virtual std::shared_ptr clone() override { return std::make_shared(m_value); } - virtual size_t type_length() const override { return sizeof(T); } public: T m_value; @@ -188,14 +201,21 @@ public: //! if type is miss matching, it will throw void type_missmatch(size_t expect, size_t get) const; - //! only check the storage type and the visit type length, so it's not safe template - T unsafe_cast() const { - if (sizeof(T) != m_holder->type_length()) { - type_missmatch(m_holder->type_length(), sizeof(T)); + T safe_cast() const { + if (get_type() != m_type) { + type_missmatch(m_type, get_type()); } return static_cast*>(m_holder.get())->m_value; } + template + bool try_cast() const { + if (get_type() == m_type) { + return true; + } else { + return false; + } + } //! only check the storage type and the visit type length, so it's not safe void* cast_void_ptr() const { return &static_cast*>(m_holder.get())->m_value; @@ -203,7 +223,7 @@ public: private: std::shared_ptr m_holder; - bool m_is_string = false; + Type m_type = NONE_SUPPORT; }; /*********************** special tensor function ***************/ diff --git a/lite/lite-c/src/global.cpp b/lite/lite-c/src/global.cpp index cc50e676..48703a44 100644 --- a/lite/lite-c/src/global.cpp +++ b/lite/lite-c/src/global.cpp @@ -127,7 +127,8 @@ int LITE_register_parse_info_func( separate_config_map["device_id"] = device_id; } if (nr_threads != 1) { - separate_config_map["nr_threads"] = nr_threads; + separate_config_map["nr_threads"] = + static_cast(nr_threads); } if (is_cpu_inplace_mode != false) { separate_config_map["is_inplace_mode"] = is_cpu_inplace_mode; diff --git a/lite/src/mge/network_impl.cpp b/lite/src/mge/network_impl.cpp index a8a98d86..a64c8bd0 100644 --- a/lite/src/mge/network_impl.cpp +++ b/lite/src/mge/network_impl.cpp @@ -352,19 +352,19 @@ void NetworkImplDft::load_model( //! config some flag get from json config file if (separate_config_map.find("device_id") != separate_config_map.end()) { - set_device_id(separate_config_map["device_id"].unsafe_cast()); + set_device_id(separate_config_map["device_id"].safe_cast()); } if (separate_config_map.find("number_threads") != separate_config_map.end() && - separate_config_map["number_threads"].unsafe_cast() > 1) { + separate_config_map["number_threads"].safe_cast() > 1) { set_cpu_threads_number( - separate_config_map["number_threads"].unsafe_cast()); + separate_config_map["number_threads"].safe_cast()); } if (separate_config_map.find("enable_inplace_model") != separate_config_map.end() && - separate_config_map["enable_inplace_model"].unsafe_cast()) { + separate_config_map["enable_inplace_model"].safe_cast()) { set_cpu_inplace_mode(); } if (separate_config_map.find("use_tensorrt") != separate_config_map.end() && - separate_config_map["use_tensorrt"].unsafe_cast()) { + separate_config_map["use_tensorrt"].safe_cast()) { use_tensorrt(); } diff --git a/lite/src/parse_info/default_parse.h b/lite/src/parse_info/default_parse.h index 09dcb89d..7873e10e 100644 --- a/lite/src/parse_info/default_parse.h +++ b/lite/src/parse_info/default_parse.h @@ -84,7 +84,7 @@ bool default_parse_info( } if (device_json.contains("number_threads")) { separate_config_map["number_threads"] = - static_cast(device_json["number_threads"]); + static_cast(device_json["number_threads"]); } if (device_json.contains("enable_inplace_model")) { separate_config_map["enable_inplace_model"] = diff --git a/lite/src/tensor.cpp b/lite/src/tensor.cpp index 3a81bfbc..9becbb98 100644 --- a/lite/src/tensor.cpp +++ b/lite/src/tensor.cpp @@ -277,10 +277,28 @@ void Tensor::update_from_implement() { void LiteAny::type_missmatch(size_t expect, size_t get) const { LITE_THROW(ssprintf( "The type store in LiteAny is not match the visit type, type of " - "storage length is %zu, type of visit length is %zu.", + "storage enum is %zu, type of visit enum is %zu.", expect, get)); } +namespace lite { +#define GET_TYPE(ctype, ENUM) \ + template <> \ + LiteAny::Type LiteAny::get_type() const { \ + return ENUM; \ + } +GET_TYPE(std::string, STRING) +GET_TYPE(int32_t, INT32) +GET_TYPE(uint32_t, UINT32) +GET_TYPE(int8_t, INT8) +GET_TYPE(uint8_t, UINT8) +GET_TYPE(int64_t, INT64) +GET_TYPE(uint64_t, UINT64) +GET_TYPE(float, FLOAT) +GET_TYPE(bool, BOOL) +GET_TYPE(void*, VOID_PTR) +} // namespace lite + std::shared_ptr TensorUtils::concat( const std::vector& tensors, int dim, LiteDeviceType dst_device, int dst_device_id) {