Browse Source

fix(lite): fix rknn error in lite

GitOrigin-RevId: b66aa1bf73
release-1.7
Megvii Engine Team 3 years ago
parent
commit
a2a46b56ef
6 changed files with 59 additions and 20 deletions
  1. +1
    -1
      lite/include/lite/global.h
  2. +31
    -11
      lite/include/lite/tensor.h
  3. +2
    -1
      lite/lite-c/src/global.cpp
  4. +5
    -5
      lite/src/mge/network_impl.cpp
  5. +1
    -1
      lite/src/parse_info/default_parse.h
  6. +19
    -1
      lite/src/tensor.cpp

+ 1
- 1
lite/include/lite/global.h View File

@@ -77,7 +77,7 @@ LITE_API bool update_decryption_or_key(
* other config not inclue in config and networkIO, ParseInfoFunc can fill it
* with the information in json, now support:
* "device_id" : int, default 0
* "number_threads" : size_t, default 1
* "number_threads" : uint32_t, default 1
* "is_inplace_model" : bool, default false
* "use_tensorrt" : bool, default false
*/


+ 31
- 11
lite/include/lite/tensor.h View File

@@ -149,28 +149,42 @@ private:
*/
class LITE_API LiteAny {
public:
enum Type {
STRING = 0,
INT32 = 1,
UINT32 = 2,
UINT8 = 3,
INT8 = 4,
INT64 = 5,
UINT64 = 6,
BOOL = 7,
VOID_PTR = 8,
FLOAT = 9,
NONE_SUPPORT = 10,
};
LiteAny() = default;
template <class T>
LiteAny(T value) : m_holder(new AnyHolder<T>(value)) {
m_is_string = std::is_same<std::string, T>();
m_type = get_type<T>();
}

LiteAny(const LiteAny& any) {
m_holder = any.m_holder->clone();
m_is_string = any.is_string();
m_type = any.m_type;
}
LiteAny& operator=(const LiteAny& any) {
m_holder = any.m_holder->clone();
m_is_string = any.is_string();
m_type = any.m_type;
return *this;
}
bool is_string() const { return m_is_string; }

template <class T>
Type get_type() const;

class HolderBase {
public:
virtual ~HolderBase() = default;
virtual std::shared_ptr<HolderBase> clone() = 0;
virtual size_t type_length() const = 0;
};

template <class T>
@@ -180,7 +194,6 @@ public:
virtual std::shared_ptr<HolderBase> clone() override {
return std::make_shared<AnyHolder>(m_value);
}
virtual size_t type_length() const override { return sizeof(T); }

public:
T m_value;
@@ -188,14 +201,21 @@ public:
//! if type is miss matching, it will throw
void type_missmatch(size_t expect, size_t get) const;

//! only check the storage type and the visit type length, so it's not safe
template <class T>
T unsafe_cast() const {
if (sizeof(T) != m_holder->type_length()) {
type_missmatch(m_holder->type_length(), sizeof(T));
T safe_cast() const {
if (get_type<T>() != m_type) {
type_missmatch(m_type, get_type<T>());
}
return static_cast<LiteAny::AnyHolder<T>*>(m_holder.get())->m_value;
}
template <class T>
bool try_cast() const {
if (get_type<T>() == m_type) {
return true;
} else {
return false;
}
}
//! only check the storage type and the visit type length, so it's not safe
void* cast_void_ptr() const {
return &static_cast<LiteAny::AnyHolder<char>*>(m_holder.get())->m_value;
@@ -203,7 +223,7 @@ public:

private:
std::shared_ptr<HolderBase> m_holder;
bool m_is_string = false;
Type m_type = NONE_SUPPORT;
};

/*********************** special tensor function ***************/


+ 2
- 1
lite/lite-c/src/global.cpp View File

@@ -127,7 +127,8 @@ int LITE_register_parse_info_func(
separate_config_map["device_id"] = device_id;
}
if (nr_threads != 1) {
separate_config_map["nr_threads"] = nr_threads;
separate_config_map["nr_threads"] =
static_cast<uint32_t>(nr_threads);
}
if (is_cpu_inplace_mode != false) {
separate_config_map["is_inplace_mode"] = is_cpu_inplace_mode;


+ 5
- 5
lite/src/mge/network_impl.cpp View File

@@ -352,19 +352,19 @@ void NetworkImplDft::load_model(

//! config some flag get from json config file
if (separate_config_map.find("device_id") != separate_config_map.end()) {
set_device_id(separate_config_map["device_id"].unsafe_cast<int>());
set_device_id(separate_config_map["device_id"].safe_cast<int>());
}
if (separate_config_map.find("number_threads") != separate_config_map.end() &&
separate_config_map["number_threads"].unsafe_cast<size_t>() > 1) {
separate_config_map["number_threads"].safe_cast<uint32_t>() > 1) {
set_cpu_threads_number(
separate_config_map["number_threads"].unsafe_cast<size_t>());
separate_config_map["number_threads"].safe_cast<uint32_t>());
}
if (separate_config_map.find("enable_inplace_model") != separate_config_map.end() &&
separate_config_map["enable_inplace_model"].unsafe_cast<bool>()) {
separate_config_map["enable_inplace_model"].safe_cast<bool>()) {
set_cpu_inplace_mode();
}
if (separate_config_map.find("use_tensorrt") != separate_config_map.end() &&
separate_config_map["use_tensorrt"].unsafe_cast<bool>()) {
separate_config_map["use_tensorrt"].safe_cast<bool>()) {
use_tensorrt();
}



+ 1
- 1
lite/src/parse_info/default_parse.h View File

@@ -84,7 +84,7 @@ bool default_parse_info(
}
if (device_json.contains("number_threads")) {
separate_config_map["number_threads"] =
static_cast<size_t>(device_json["number_threads"]);
static_cast<uint32_t>(device_json["number_threads"]);
}
if (device_json.contains("enable_inplace_model")) {
separate_config_map["enable_inplace_model"] =


+ 19
- 1
lite/src/tensor.cpp View File

@@ -277,10 +277,28 @@ void Tensor::update_from_implement() {
void LiteAny::type_missmatch(size_t expect, size_t get) const {
LITE_THROW(ssprintf(
"The type store in LiteAny is not match the visit type, type of "
"storage length is %zu, type of visit length is %zu.",
"storage enum is %zu, type of visit enum is %zu.",
expect, get));
}

namespace lite {
#define GET_TYPE(ctype, ENUM) \
template <> \
LiteAny::Type LiteAny::get_type<ctype>() const { \
return ENUM; \
}
GET_TYPE(std::string, STRING)
GET_TYPE(int32_t, INT32)
GET_TYPE(uint32_t, UINT32)
GET_TYPE(int8_t, INT8)
GET_TYPE(uint8_t, UINT8)
GET_TYPE(int64_t, INT64)
GET_TYPE(uint64_t, UINT64)
GET_TYPE(float, FLOAT)
GET_TYPE(bool, BOOL)
GET_TYPE(void*, VOID_PTR)
} // namespace lite

std::shared_ptr<Tensor> TensorUtils::concat(
const std::vector<Tensor>& tensors, int dim, LiteDeviceType dst_device,
int dst_device_id) {


Loading…
Cancel
Save