You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

global.cpp 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. /**
  2. * \file lite-c/src/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "lite/global.h"
  12. #include "common.h"
  13. #include "lite-c/global_c.h"
  14. namespace {
  15. class ErrorMsg {
  16. public:
  17. std::string& get_error_msg() { return error_msg; }
  18. ErrorCode get_error_code() { return error_code; }
  19. void set_error_msg(const std::string& msg, ErrorCode code) {
  20. error_msg = msg + ", Error Code: " + std::to_string(code);
  21. error_code = code;
  22. }
  23. void clear_error() {
  24. error_code = ErrorCode::OK;
  25. error_msg.clear();
  26. }
  27. private:
  28. std::string error_msg;
  29. ErrorCode error_code;
  30. };
  31. static LITE_MUTEX mtx_error;
  32. ErrorMsg& get_global_error() {
  33. static ErrorMsg error_msg;
  34. return error_msg;
  35. }
  36. } // namespace
  37. int LiteHandleException(const std::exception& e) {
  38. LITE_LOCK_GUARD(mtx_error);
  39. get_global_error().set_error_msg(e.what(), ErrorCode::LITE_INTERNAL_ERROR);
  40. return -1;
  41. }
  42. ErrorCode LITE_get_last_error_code() {
  43. LITE_LOCK_GUARD(mtx_error);
  44. return get_global_error().get_error_code();
  45. }
  46. void LITE_clear_last_error() {
  47. LITE_LOCK_GUARD(mtx_error);
  48. get_global_error().clear_error();
  49. }
  50. const char* LITE_get_last_error() {
  51. LITE_LOCK_GUARD(mtx_error);
  52. return get_global_error().get_error_msg().c_str();
  53. }
  54. int LITE_get_version(int* major, int* minor, int* patch) {
  55. LITE_ASSERT(major && minor && patch, "The ptr pass to LITE api is null");
  56. lite::get_version(*major, *minor, *patch);
  57. return 0;
  58. }
  59. int LITE_get_device_count(LiteDeviceType device_type, size_t* count) {
  60. LITE_CAPI_BEGIN();
  61. LITE_ASSERT(count, "The ptr pass to LITE api is null");
  62. *count = lite::get_device_count(device_type);
  63. LITE_CAPI_END();
  64. }
  65. int LITE_try_coalesce_all_free_memory() {
  66. LITE_CAPI_BEGIN();
  67. lite::try_coalesce_all_free_memory();
  68. LITE_CAPI_END();
  69. }
  70. int LITE_register_decryption_and_key(
  71. const char* decrypt_name, const LiteDecryptionFunc func,
  72. const uint8_t* key_data, size_t key_size) {
  73. LITE_CAPI_BEGIN();
  74. LITE_ASSERT(decrypt_name && key_data && func, "The ptr pass to LITE api is null");
  75. std::vector<uint8_t> key;
  76. for (size_t i = 0; i < key_size; i++) {
  77. key.push_back(key_data[i]);
  78. }
  79. auto decrypt_func = [func](const void* input_data, size_t input_size,
  80. const std::vector<uint8_t>& key) {
  81. auto size = func(input_data, input_size, key.data(), key.size(), nullptr);
  82. std::vector<uint8_t> output(size, 0);
  83. func(input_data, input_size, key.data(), key.size(), output.data());
  84. return output;
  85. };
  86. lite::register_decryption_and_key(decrypt_name, decrypt_func, key);
  87. LITE_CAPI_END();
  88. }
  89. int LITE_update_decryption_or_key(
  90. const char* decrypt_name, const LiteDecryptionFunc func,
  91. const uint8_t* key_data, size_t key_size) {
  92. LITE_CAPI_BEGIN();
  93. std::vector<uint8_t> key;
  94. for (size_t i = 0; i < key_size; i++) {
  95. key.push_back(key_data[i]);
  96. }
  97. lite::DecryptionFunc decrypt_func = nullptr;
  98. if (func) {
  99. decrypt_func = [func](const void* input_data, size_t input_size,
  100. const std::vector<uint8_t>& key) {
  101. auto size = func(input_data, input_size, key.data(), key.size(), nullptr);
  102. std::vector<uint8_t> output(size, 0);
  103. func(input_data, input_size, key.data(), key.size(), output.data());
  104. return output;
  105. };
  106. }
  107. lite::update_decryption_or_key(decrypt_name, decrypt_func, key);
  108. LITE_CAPI_END();
  109. }
  110. int LITE_register_parse_info_func(
  111. const char* info_type, const LiteParseInfoFunc parse_func) {
  112. LITE_CAPI_BEGIN();
  113. LITE_ASSERT(info_type && parse_func, "The ptr pass to LITE api is null");
  114. auto lite_func =
  115. [parse_func](
  116. const void* info_data, size_t info_size,
  117. const std::string model_name, lite::Config& config,
  118. lite::NetworkIO& network_io,
  119. std::unordered_map<std::string, lite::LiteAny>& separate_config_map,
  120. std::string& extra_info) {
  121. LITE_MARK_USED_VAR(extra_info);
  122. size_t nr_threads = 1;
  123. int device_id = 0, is_cpu_inplace_mode = false, use_tensorrt = false;
  124. LiteNetworkIO c_io;
  125. LiteConfig c_config;
  126. auto ret = parse_func(
  127. info_data, info_size, model_name.c_str(), &c_config, &c_io,
  128. &device_id, &nr_threads, &is_cpu_inplace_mode, &use_tensorrt);
  129. config = convert_to_lite_config(c_config);
  130. network_io = convert_to_lite_io(c_io);
  131. if (device_id != 0) {
  132. separate_config_map["device_id"] = device_id;
  133. }
  134. if (nr_threads != 1) {
  135. separate_config_map["nr_threads"] =
  136. static_cast<uint32_t>(nr_threads);
  137. }
  138. if (is_cpu_inplace_mode != false) {
  139. separate_config_map["is_inplace_mode"] = is_cpu_inplace_mode;
  140. }
  141. if (use_tensorrt != false) {
  142. separate_config_map["use_tensorrt"] = use_tensorrt;
  143. }
  144. return ret;
  145. };
  146. lite::register_parse_info_func(info_type, lite_func);
  147. LITE_CAPI_END();
  148. }
  149. int LITE_set_loader_lib_path(const char* loader_path) {
  150. LITE_CAPI_BEGIN();
  151. LITE_ASSERT(loader_path, "The ptr pass to LITE api is null");
  152. lite::set_loader_lib_path(loader_path);
  153. LITE_CAPI_END();
  154. }
  155. int LITE_set_persistent_cache(const char* cache_path, int always_sync) {
  156. LITE_CAPI_BEGIN();
  157. LITE_ASSERT(cache_path, "The ptr pass to LITE api is null");
  158. lite::set_persistent_cache(cache_path, always_sync);
  159. LITE_CAPI_END();
  160. }
  161. int LITE_set_tensor_rt_cache(const char* cache_path) {
  162. LITE_CAPI_BEGIN();
  163. LITE_ASSERT(cache_path, "The ptr pass to LITE api is null");
  164. lite::set_tensor_rt_cache(cache_path);
  165. LITE_CAPI_END();
  166. }
  167. int LITE_set_log_level(LiteLogLevel level) {
  168. LITE_CAPI_BEGIN();
  169. lite::set_log_level(level);
  170. LITE_CAPI_END();
  171. }
  172. int LITE_get_log_level(LiteLogLevel* level) {
  173. LITE_CAPI_BEGIN();
  174. LITE_ASSERT(level, "The ptr pass to LITE api is null");
  175. *level = lite::get_log_level();
  176. LITE_CAPI_END();
  177. }
  178. int LITE_dump_persistent_cache(const char* cache_path) {
  179. LITE_CAPI_BEGIN();
  180. LITE_ASSERT(cache_path, "The ptr pass to LITE api is null");
  181. lite::dump_persistent_cache(cache_path);
  182. LITE_CAPI_END();
  183. }
  184. int LITE_dump_tensor_rt_cache() {
  185. LITE_CAPI_BEGIN();
  186. lite::dump_tensor_rt_cache();
  187. LITE_CAPI_END();
  188. }
  189. int LITE_register_memory_pair(
  190. void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device,
  191. LiteBackend backend) {
  192. LITE_CAPI_BEGIN();
  193. lite::register_memory_pair(vir_ptr, phy_ptr, length, device, backend);
  194. LITE_CAPI_END();
  195. }
  196. int LITE_clear_memory_pair(
  197. void* vir_ptr, void* phy_ptr, LiteDeviceType device, LiteBackend backend) {
  198. LITE_CAPI_BEGIN();
  199. lite::clear_memory_pair(vir_ptr, phy_ptr, device, backend);
  200. LITE_CAPI_END();
  201. }
  202. int LITE_lookup_physic_ptr(
  203. void* vir_ptr, void** phy_ptr, LiteDeviceType device, LiteBackend backend) {
  204. LITE_CAPI_BEGIN();
  205. LITE_ASSERT(vir_ptr && phy_ptr, "The ptr pass to vir and phy is nullptr");
  206. *phy_ptr = lite::lookup_physic_ptr(vir_ptr, device, backend);
  207. LITE_CAPI_END();
  208. }
  209. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}