@@ -77,7 +77,7 @@ LITE_API bool update_decryption_or_key( | |||||
* other config not inclue in config and networkIO, ParseInfoFunc can fill it | * other config not inclue in config and networkIO, ParseInfoFunc can fill it | ||||
* with the information in json, now support: | * with the information in json, now support: | ||||
* "device_id" : int, default 0 | * "device_id" : int, default 0 | ||||
* "number_threads" : size_t, default 1 | |||||
* "number_threads" : uint32_t, default 1 | |||||
* "is_inplace_model" : bool, default false | * "is_inplace_model" : bool, default false | ||||
* "use_tensorrt" : bool, default false | * "use_tensorrt" : bool, default false | ||||
*/ | */ | ||||
@@ -149,28 +149,42 @@ private: | |||||
*/ | */ | ||||
class LITE_API LiteAny { | class LITE_API LiteAny { | ||||
public: | public: | ||||
enum Type { | |||||
STRING = 0, | |||||
INT32 = 1, | |||||
UINT32 = 2, | |||||
UINT8 = 3, | |||||
INT8 = 4, | |||||
INT64 = 5, | |||||
UINT64 = 6, | |||||
BOOL = 7, | |||||
VOID_PTR = 8, | |||||
FLOAT = 9, | |||||
NONE_SUPPORT = 10, | |||||
}; | |||||
LiteAny() = default; | LiteAny() = default; | ||||
template <class T> | template <class T> | ||||
LiteAny(T value) : m_holder(new AnyHolder<T>(value)) { | LiteAny(T value) : m_holder(new AnyHolder<T>(value)) { | ||||
m_is_string = std::is_same<std::string, T>(); | |||||
m_type = get_type<T>(); | |||||
} | } | ||||
LiteAny(const LiteAny& any) { | LiteAny(const LiteAny& any) { | ||||
m_holder = any.m_holder->clone(); | m_holder = any.m_holder->clone(); | ||||
m_is_string = any.is_string(); | |||||
m_type = any.m_type; | |||||
} | } | ||||
LiteAny& operator=(const LiteAny& any) { | LiteAny& operator=(const LiteAny& any) { | ||||
m_holder = any.m_holder->clone(); | m_holder = any.m_holder->clone(); | ||||
m_is_string = any.is_string(); | |||||
m_type = any.m_type; | |||||
return *this; | return *this; | ||||
} | } | ||||
bool is_string() const { return m_is_string; } | |||||
template <class T> | |||||
Type get_type() const; | |||||
class HolderBase { | class HolderBase { | ||||
public: | public: | ||||
virtual ~HolderBase() = default; | virtual ~HolderBase() = default; | ||||
virtual std::shared_ptr<HolderBase> clone() = 0; | virtual std::shared_ptr<HolderBase> clone() = 0; | ||||
virtual size_t type_length() const = 0; | |||||
}; | }; | ||||
template <class T> | template <class T> | ||||
@@ -180,7 +194,6 @@ public: | |||||
virtual std::shared_ptr<HolderBase> clone() override { | virtual std::shared_ptr<HolderBase> clone() override { | ||||
return std::make_shared<AnyHolder>(m_value); | return std::make_shared<AnyHolder>(m_value); | ||||
} | } | ||||
virtual size_t type_length() const override { return sizeof(T); } | |||||
public: | public: | ||||
T m_value; | T m_value; | ||||
@@ -188,14 +201,21 @@ public: | |||||
//! if type is miss matching, it will throw | //! if type is miss matching, it will throw | ||||
void type_missmatch(size_t expect, size_t get) const; | void type_missmatch(size_t expect, size_t get) const; | ||||
//! only check the storage type and the visit type length, so it's not safe | |||||
template <class T> | template <class T> | ||||
T unsafe_cast() const { | |||||
if (sizeof(T) != m_holder->type_length()) { | |||||
type_missmatch(m_holder->type_length(), sizeof(T)); | |||||
T safe_cast() const { | |||||
if (get_type<T>() != m_type) { | |||||
type_missmatch(m_type, get_type<T>()); | |||||
} | } | ||||
return static_cast<LiteAny::AnyHolder<T>*>(m_holder.get())->m_value; | return static_cast<LiteAny::AnyHolder<T>*>(m_holder.get())->m_value; | ||||
} | } | ||||
template <class T> | |||||
bool try_cast() const { | |||||
if (get_type<T>() == m_type) { | |||||
return true; | |||||
} else { | |||||
return false; | |||||
} | |||||
} | |||||
//! only check the storage type and the visit type length, so it's not safe | //! only check the storage type and the visit type length, so it's not safe | ||||
void* cast_void_ptr() const { | void* cast_void_ptr() const { | ||||
return &static_cast<LiteAny::AnyHolder<char>*>(m_holder.get())->m_value; | return &static_cast<LiteAny::AnyHolder<char>*>(m_holder.get())->m_value; | ||||
@@ -203,7 +223,7 @@ public: | |||||
private: | private: | ||||
std::shared_ptr<HolderBase> m_holder; | std::shared_ptr<HolderBase> m_holder; | ||||
bool m_is_string = false; | |||||
Type m_type = NONE_SUPPORT; | |||||
}; | }; | ||||
/*********************** special tensor function ***************/ | /*********************** special tensor function ***************/ | ||||
@@ -127,7 +127,8 @@ int LITE_register_parse_info_func( | |||||
separate_config_map["device_id"] = device_id; | separate_config_map["device_id"] = device_id; | ||||
} | } | ||||
if (nr_threads != 1) { | if (nr_threads != 1) { | ||||
separate_config_map["nr_threads"] = nr_threads; | |||||
separate_config_map["nr_threads"] = | |||||
static_cast<uint32_t>(nr_threads); | |||||
} | } | ||||
if (is_cpu_inplace_mode != false) { | if (is_cpu_inplace_mode != false) { | ||||
separate_config_map["is_inplace_mode"] = is_cpu_inplace_mode; | separate_config_map["is_inplace_mode"] = is_cpu_inplace_mode; | ||||
@@ -352,19 +352,19 @@ void NetworkImplDft::load_model( | |||||
//! config some flag get from json config file | //! config some flag get from json config file | ||||
if (separate_config_map.find("device_id") != separate_config_map.end()) { | if (separate_config_map.find("device_id") != separate_config_map.end()) { | ||||
set_device_id(separate_config_map["device_id"].unsafe_cast<int>()); | |||||
set_device_id(separate_config_map["device_id"].safe_cast<int>()); | |||||
} | } | ||||
if (separate_config_map.find("number_threads") != separate_config_map.end() && | if (separate_config_map.find("number_threads") != separate_config_map.end() && | ||||
separate_config_map["number_threads"].unsafe_cast<size_t>() > 1) { | |||||
separate_config_map["number_threads"].safe_cast<uint32_t>() > 1) { | |||||
set_cpu_threads_number( | set_cpu_threads_number( | ||||
separate_config_map["number_threads"].unsafe_cast<size_t>()); | |||||
separate_config_map["number_threads"].safe_cast<uint32_t>()); | |||||
} | } | ||||
if (separate_config_map.find("enable_inplace_model") != separate_config_map.end() && | if (separate_config_map.find("enable_inplace_model") != separate_config_map.end() && | ||||
separate_config_map["enable_inplace_model"].unsafe_cast<bool>()) { | |||||
separate_config_map["enable_inplace_model"].safe_cast<bool>()) { | |||||
set_cpu_inplace_mode(); | set_cpu_inplace_mode(); | ||||
} | } | ||||
if (separate_config_map.find("use_tensorrt") != separate_config_map.end() && | if (separate_config_map.find("use_tensorrt") != separate_config_map.end() && | ||||
separate_config_map["use_tensorrt"].unsafe_cast<bool>()) { | |||||
separate_config_map["use_tensorrt"].safe_cast<bool>()) { | |||||
use_tensorrt(); | use_tensorrt(); | ||||
} | } | ||||
@@ -84,7 +84,7 @@ bool default_parse_info( | |||||
} | } | ||||
if (device_json.contains("number_threads")) { | if (device_json.contains("number_threads")) { | ||||
separate_config_map["number_threads"] = | separate_config_map["number_threads"] = | ||||
static_cast<size_t>(device_json["number_threads"]); | |||||
static_cast<uint32_t>(device_json["number_threads"]); | |||||
} | } | ||||
if (device_json.contains("enable_inplace_model")) { | if (device_json.contains("enable_inplace_model")) { | ||||
separate_config_map["enable_inplace_model"] = | separate_config_map["enable_inplace_model"] = | ||||
@@ -277,10 +277,28 @@ void Tensor::update_from_implement() { | |||||
void LiteAny::type_missmatch(size_t expect, size_t get) const { | void LiteAny::type_missmatch(size_t expect, size_t get) const { | ||||
LITE_THROW(ssprintf( | LITE_THROW(ssprintf( | ||||
"The type store in LiteAny is not match the visit type, type of " | "The type store in LiteAny is not match the visit type, type of " | ||||
"storage length is %zu, type of visit length is %zu.", | |||||
"storage enum is %zu, type of visit enum is %zu.", | |||||
expect, get)); | expect, get)); | ||||
} | } | ||||
namespace lite { | |||||
#define GET_TYPE(ctype, ENUM) \ | |||||
template <> \ | |||||
LiteAny::Type LiteAny::get_type<ctype>() const { \ | |||||
return ENUM; \ | |||||
} | |||||
GET_TYPE(std::string, STRING) | |||||
GET_TYPE(int32_t, INT32) | |||||
GET_TYPE(uint32_t, UINT32) | |||||
GET_TYPE(int8_t, INT8) | |||||
GET_TYPE(uint8_t, UINT8) | |||||
GET_TYPE(int64_t, INT64) | |||||
GET_TYPE(uint64_t, UINT64) | |||||
GET_TYPE(float, FLOAT) | |||||
GET_TYPE(bool, BOOL) | |||||
GET_TYPE(void*, VOID_PTR) | |||||
} // namespace lite | |||||
std::shared_ptr<Tensor> TensorUtils::concat( | std::shared_ptr<Tensor> TensorUtils::concat( | ||||
const std::vector<Tensor>& tensors, int dim, LiteDeviceType dst_device, | const std::vector<Tensor>& tensors, int dim, LiteDeviceType dst_device, | ||||
int dst_device_id) { | int dst_device_id) { | ||||