@@ -77,7 +77,7 @@ LITE_API bool update_decryption_or_key( | |||
* other config not inclue in config and networkIO, ParseInfoFunc can fill it | |||
* with the information in json, now support: | |||
* "device_id" : int, default 0 | |||
* "number_threads" : size_t, default 1 | |||
* "number_threads" : uint32_t, default 1 | |||
* "is_inplace_model" : bool, default false | |||
* "use_tensorrt" : bool, default false | |||
*/ | |||
@@ -149,28 +149,42 @@ private: | |||
*/ | |||
class LITE_API LiteAny { | |||
public: | |||
enum Type { | |||
STRING = 0, | |||
INT32 = 1, | |||
UINT32 = 2, | |||
UINT8 = 3, | |||
INT8 = 4, | |||
INT64 = 5, | |||
UINT64 = 6, | |||
BOOL = 7, | |||
VOID_PTR = 8, | |||
FLOAT = 9, | |||
NONE_SUPPORT = 10, | |||
}; | |||
LiteAny() = default; | |||
template <class T> | |||
LiteAny(T value) : m_holder(new AnyHolder<T>(value)) { | |||
m_is_string = std::is_same<std::string, T>(); | |||
m_type = get_type<T>(); | |||
} | |||
LiteAny(const LiteAny& any) { | |||
m_holder = any.m_holder->clone(); | |||
m_is_string = any.is_string(); | |||
m_type = any.m_type; | |||
} | |||
LiteAny& operator=(const LiteAny& any) { | |||
m_holder = any.m_holder->clone(); | |||
m_is_string = any.is_string(); | |||
m_type = any.m_type; | |||
return *this; | |||
} | |||
bool is_string() const { return m_is_string; } | |||
template <class T> | |||
Type get_type() const; | |||
class HolderBase { | |||
public: | |||
virtual ~HolderBase() = default; | |||
virtual std::shared_ptr<HolderBase> clone() = 0; | |||
virtual size_t type_length() const = 0; | |||
}; | |||
template <class T> | |||
@@ -180,7 +194,6 @@ public: | |||
virtual std::shared_ptr<HolderBase> clone() override { | |||
return std::make_shared<AnyHolder>(m_value); | |||
} | |||
virtual size_t type_length() const override { return sizeof(T); } | |||
public: | |||
T m_value; | |||
@@ -188,14 +201,21 @@ public: | |||
//! if type is miss matching, it will throw | |||
void type_missmatch(size_t expect, size_t get) const; | |||
//! only check the storage type and the visit type length, so it's not safe | |||
template <class T> | |||
T unsafe_cast() const { | |||
if (sizeof(T) != m_holder->type_length()) { | |||
type_missmatch(m_holder->type_length(), sizeof(T)); | |||
T safe_cast() const { | |||
if (get_type<T>() != m_type) { | |||
type_missmatch(m_type, get_type<T>()); | |||
} | |||
return static_cast<LiteAny::AnyHolder<T>*>(m_holder.get())->m_value; | |||
} | |||
template <class T> | |||
bool try_cast() const { | |||
if (get_type<T>() == m_type) { | |||
return true; | |||
} else { | |||
return false; | |||
} | |||
} | |||
//! only check the storage type and the visit type length, so it's not safe | |||
void* cast_void_ptr() const { | |||
return &static_cast<LiteAny::AnyHolder<char>*>(m_holder.get())->m_value; | |||
@@ -203,7 +223,7 @@ public: | |||
private: | |||
std::shared_ptr<HolderBase> m_holder; | |||
bool m_is_string = false; | |||
Type m_type = NONE_SUPPORT; | |||
}; | |||
/*********************** special tensor function ***************/ | |||
@@ -127,7 +127,8 @@ int LITE_register_parse_info_func( | |||
separate_config_map["device_id"] = device_id; | |||
} | |||
if (nr_threads != 1) { | |||
separate_config_map["nr_threads"] = nr_threads; | |||
separate_config_map["nr_threads"] = | |||
static_cast<uint32_t>(nr_threads); | |||
} | |||
if (is_cpu_inplace_mode != false) { | |||
separate_config_map["is_inplace_mode"] = is_cpu_inplace_mode; | |||
@@ -352,19 +352,19 @@ void NetworkImplDft::load_model( | |||
//! config some flag get from json config file | |||
if (separate_config_map.find("device_id") != separate_config_map.end()) { | |||
set_device_id(separate_config_map["device_id"].unsafe_cast<int>()); | |||
set_device_id(separate_config_map["device_id"].safe_cast<int>()); | |||
} | |||
if (separate_config_map.find("number_threads") != separate_config_map.end() && | |||
separate_config_map["number_threads"].unsafe_cast<size_t>() > 1) { | |||
separate_config_map["number_threads"].safe_cast<uint32_t>() > 1) { | |||
set_cpu_threads_number( | |||
separate_config_map["number_threads"].unsafe_cast<size_t>()); | |||
separate_config_map["number_threads"].safe_cast<uint32_t>()); | |||
} | |||
if (separate_config_map.find("enable_inplace_model") != separate_config_map.end() && | |||
separate_config_map["enable_inplace_model"].unsafe_cast<bool>()) { | |||
separate_config_map["enable_inplace_model"].safe_cast<bool>()) { | |||
set_cpu_inplace_mode(); | |||
} | |||
if (separate_config_map.find("use_tensorrt") != separate_config_map.end() && | |||
separate_config_map["use_tensorrt"].unsafe_cast<bool>()) { | |||
separate_config_map["use_tensorrt"].safe_cast<bool>()) { | |||
use_tensorrt(); | |||
} | |||
@@ -84,7 +84,7 @@ bool default_parse_info( | |||
} | |||
if (device_json.contains("number_threads")) { | |||
separate_config_map["number_threads"] = | |||
static_cast<size_t>(device_json["number_threads"]); | |||
static_cast<uint32_t>(device_json["number_threads"]); | |||
} | |||
if (device_json.contains("enable_inplace_model")) { | |||
separate_config_map["enable_inplace_model"] = | |||
@@ -277,10 +277,28 @@ void Tensor::update_from_implement() { | |||
void LiteAny::type_missmatch(size_t expect, size_t get) const { | |||
LITE_THROW(ssprintf( | |||
"The type store in LiteAny is not match the visit type, type of " | |||
"storage length is %zu, type of visit length is %zu.", | |||
"storage enum is %zu, type of visit enum is %zu.", | |||
expect, get)); | |||
} | |||
namespace lite { | |||
#define GET_TYPE(ctype, ENUM) \ | |||
template <> \ | |||
LiteAny::Type LiteAny::get_type<ctype>() const { \ | |||
return ENUM; \ | |||
} | |||
GET_TYPE(std::string, STRING) | |||
GET_TYPE(int32_t, INT32) | |||
GET_TYPE(uint32_t, UINT32) | |||
GET_TYPE(int8_t, INT8) | |||
GET_TYPE(uint8_t, UINT8) | |||
GET_TYPE(int64_t, INT64) | |||
GET_TYPE(uint64_t, UINT64) | |||
GET_TYPE(float, FLOAT) | |||
GET_TYPE(bool, BOOL) | |||
GET_TYPE(void*, VOID_PTR) | |||
} // namespace lite | |||
std::shared_ptr<Tensor> TensorUtils::concat( | |||
const std::vector<Tensor>& tensors, int dim, LiteDeviceType dst_device, | |||
int dst_device_id) { | |||