You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

global.cpp 8.0 kB


  1. #include "lite/global.h"
  2. #include "common.h"
  3. #include "ipc_helper.h"
  4. #include "lite-c/global_c.h"
  5. namespace {
  6. class ErrorMsg {
  7. public:
  8. std::string& get_error_msg() { return error_msg; }
  9. ErrorCode get_error_code() { return error_code; }
  10. void set_error_msg(const std::string& msg, ErrorCode code) {
  11. error_msg = msg + ", Error Code: " + std::to_string(code);
  12. error_code = code;
  13. }
  14. void clear_error() {
  15. error_code = ErrorCode::OK;
  16. error_msg.clear();
  17. }
  18. private:
  19. std::string error_msg;
  20. ErrorCode error_code;
  21. };
  22. static LITE_MUTEX mtx_error;
  23. ErrorMsg& get_global_error() {
  24. static ErrorMsg error_msg;
  25. return error_msg;
  26. }
  27. } // namespace
  28. int LiteHandleException(const std::exception& e) {
  29. LITE_LOCK_GUARD(mtx_error);
  30. get_global_error().set_error_msg(e.what(), ErrorCode::LITE_INTERNAL_ERROR);
  31. return -1;
  32. }
  33. ErrorCode LITE_get_last_error_code() {
  34. LITE_LOCK_GUARD(mtx_error);
  35. return get_global_error().get_error_code();
  36. }
  37. void LITE_clear_last_error() {
  38. LITE_LOCK_GUARD(mtx_error);
  39. get_global_error().clear_error();
  40. }
  41. const char* LITE_get_last_error() {
  42. LITE_LOCK_GUARD(mtx_error);
  43. if (ipc_imp::is_server()) {
  44. return get_global_error().get_error_msg().c_str();
  45. } else {
  46. void* raw_shm_ptr = IPC_INSTACE().get_shm_ptr(nullptr);
  47. IPC_HELP_REMOTE_CALL(raw_shm_ptr, ipc::RemoteFuncId::LITE_GET_LAST_ERROR);
  48. char* ret_ptr = static_cast<char*>(raw_shm_ptr);
  49. return ret_ptr;
  50. }
  51. }
  52. int LITE_is_enable_ipc_debug_mode() {
  53. return ipc::IpcHelper::is_enable_fork_debug_mode();
  54. }
  55. void LITE_enable_lite_ipc_debug() {
  56. ipc_imp::enable_lite_ipc_debug();
  57. }
  58. int LITE_get_version(int* major, int* minor, int* patch) {
  59. LITE_ASSERT(major && minor && patch, "The ptr pass to LITE api is null");
  60. lite::get_version(*major, *minor, *patch);
  61. return 0;
  62. }
  63. int LITE_get_device_count(LiteDeviceType device_type, size_t* count) {
  64. LITE_CAPI_BEGIN();
  65. LITE_ASSERT(count, "The ptr pass to LITE api is null");
  66. *count = lite::get_device_count(device_type);
  67. LITE_CAPI_END();
  68. }
  69. int LITE_try_coalesce_all_free_memory() {
  70. LITE_CAPI_BEGIN();
  71. lite::try_coalesce_all_free_memory();
  72. LITE_CAPI_END();
  73. }
  74. int LITE_register_decryption_and_key(
  75. const char* decrypt_name, const LiteDecryptionFunc func,
  76. const uint8_t* key_data, size_t key_size) {
  77. LITE_CAPI_BEGIN();
  78. LITE_ASSERT(decrypt_name && key_data && func, "The ptr pass to LITE api is null");
  79. std::vector<uint8_t> key;
  80. for (size_t i = 0; i < key_size; i++) {
  81. key.push_back(key_data[i]);
  82. }
  83. auto decrypt_func = [func](const void* input_data, size_t input_size,
  84. const std::vector<uint8_t>& key) {
  85. auto size = func(input_data, input_size, key.data(), key.size(), nullptr);
  86. std::vector<uint8_t> output(size, 0);
  87. func(input_data, input_size, key.data(), key.size(), output.data());
  88. return output;
  89. };
  90. lite::register_decryption_and_key(decrypt_name, decrypt_func, key);
  91. LITE_CAPI_END();
  92. }
  93. int LITE_update_decryption_or_key(
  94. const char* decrypt_name, const LiteDecryptionFunc func,
  95. const uint8_t* key_data, size_t key_size) {
  96. LITE_CAPI_BEGIN();
  97. std::vector<uint8_t> key;
  98. for (size_t i = 0; i < key_size; i++) {
  99. key.push_back(key_data[i]);
  100. }
  101. lite::DecryptionFunc decrypt_func = nullptr;
  102. if (func) {
  103. decrypt_func = [func](const void* input_data, size_t input_size,
  104. const std::vector<uint8_t>& key) {
  105. auto size = func(input_data, input_size, key.data(), key.size(), nullptr);
  106. std::vector<uint8_t> output(size, 0);
  107. func(input_data, input_size, key.data(), key.size(), output.data());
  108. return output;
  109. };
  110. }
  111. lite::update_decryption_or_key(decrypt_name, decrypt_func, key);
  112. LITE_CAPI_END();
  113. }
  114. int LITE_register_parse_info_func(
  115. const char* info_type, const LiteParseInfoFunc parse_func) {
  116. LITE_CAPI_BEGIN();
  117. LITE_ASSERT(info_type && parse_func, "The ptr pass to LITE api is null");
  118. auto lite_func =
  119. [parse_func](
  120. const void* info_data, size_t info_size,
  121. const std::string model_name, lite::Config& config,
  122. lite::NetworkIO& network_io,
  123. std::unordered_map<std::string, lite::LiteAny>& separate_config_map,
  124. std::string& extra_info) {
  125. LITE_MARK_USED_VAR(extra_info);
  126. size_t nr_threads = 1;
  127. int device_id = 0, is_cpu_inplace_mode = false, use_tensorrt = false;
  128. LiteNetworkIO c_io;
  129. LiteConfig c_config;
  130. auto ret = parse_func(
  131. info_data, info_size, model_name.c_str(), &c_config, &c_io,
  132. &device_id, &nr_threads, &is_cpu_inplace_mode, &use_tensorrt);
  133. config = convert_to_lite_config(c_config);
  134. network_io = convert_to_lite_io(c_io);
  135. if (device_id != 0) {
  136. separate_config_map["device_id"] = device_id;
  137. }
  138. if (nr_threads != 1) {
  139. separate_config_map["nr_threads"] =
  140. static_cast<uint32_t>(nr_threads);
  141. }
  142. if (is_cpu_inplace_mode != false) {
  143. separate_config_map["is_inplace_mode"] = is_cpu_inplace_mode;
  144. }
  145. if (use_tensorrt != false) {
  146. separate_config_map["use_tensorrt"] = use_tensorrt;
  147. }
  148. return ret;
  149. };
  150. lite::register_parse_info_func(info_type, lite_func);
  151. LITE_CAPI_END();
  152. }
  153. int LITE_set_loader_lib_path(const char* loader_path) {
  154. LITE_CAPI_BEGIN();
  155. LITE_ASSERT(loader_path, "The ptr pass to LITE api is null");
  156. lite::set_loader_lib_path(loader_path);
  157. LITE_CAPI_END();
  158. }
  159. int LITE_set_persistent_cache(const char* cache_path, int always_sync) {
  160. LITE_CAPI_BEGIN();
  161. LITE_ASSERT(cache_path, "The ptr pass to LITE api is null");
  162. lite::set_persistent_cache(cache_path, always_sync);
  163. LITE_CAPI_END();
  164. }
  165. int LITE_set_tensor_rt_cache(const char* cache_path) {
  166. LITE_CAPI_BEGIN();
  167. LITE_ASSERT(cache_path, "The ptr pass to LITE api is null");
  168. lite::set_tensor_rt_cache(cache_path);
  169. LITE_CAPI_END();
  170. }
  171. int LITE_set_log_level(LiteLogLevel level) {
  172. LITE_CAPI_BEGIN();
  173. lite::set_log_level(level);
  174. LITE_CAPI_END();
  175. }
  176. int LITE_get_log_level(LiteLogLevel* level) {
  177. LITE_CAPI_BEGIN();
  178. LITE_ASSERT(level, "The ptr pass to LITE api is null");
  179. *level = lite::get_log_level();
  180. LITE_CAPI_END();
  181. }
  182. int LITE_dump_persistent_cache(const char* cache_path) {
  183. LITE_CAPI_BEGIN();
  184. LITE_ASSERT(cache_path, "The ptr pass to LITE api is null");
  185. lite::dump_persistent_cache(cache_path);
  186. LITE_CAPI_END();
  187. }
  188. int LITE_dump_tensor_rt_cache() {
  189. LITE_CAPI_BEGIN();
  190. lite::dump_tensor_rt_cache();
  191. LITE_CAPI_END();
  192. }
  193. int LITE_register_memory_pair(
  194. void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device,
  195. LiteBackend backend) {
  196. LITE_CAPI_BEGIN();
  197. lite::register_memory_pair(vir_ptr, phy_ptr, length, device, backend);
  198. LITE_CAPI_END();
  199. }
  200. int LITE_clear_memory_pair(
  201. void* vir_ptr, void* phy_ptr, LiteDeviceType device, LiteBackend backend) {
  202. LITE_CAPI_BEGIN();
  203. lite::clear_memory_pair(vir_ptr, phy_ptr, device, backend);
  204. LITE_CAPI_END();
  205. }
  206. int LITE_lookup_physic_ptr(
  207. void* vir_ptr, void** phy_ptr, LiteDeviceType device, LiteBackend backend) {
  208. LITE_CAPI_BEGIN();
  209. LITE_ASSERT(vir_ptr && phy_ptr, "The ptr pass to vir and phy is nullptr");
  210. *phy_ptr = lite::lookup_physic_ptr(vir_ptr, device, backend);
  211. LITE_CAPI_END();
  212. }
  213. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}