You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

global.cpp 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. /**
  2. * \file lite-c/src/tensor.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "lite/global.h"
  12. #include "common.h"
  13. #include "lite-c/global_c.h"
  14. namespace {
  15. class ErrorMsg {
  16. public:
  17. std::string& get_error_msg() { return error_msg; }
  18. void set_error_msg(const std::string& msg) { error_msg = msg; }
  19. private:
  20. std::string error_msg;
  21. };
  22. static LITE_MUTEX mtx_error;
  23. ErrorMsg& get_global_error() {
  24. static ErrorMsg error_msg;
  25. return error_msg;
  26. }
  27. } // namespace
  28. int LiteHandleException(const std::exception& e) {
  29. LITE_LOCK_GUARD(mtx_error);
  30. get_global_error().set_error_msg(e.what());
  31. return -1;
  32. }
  33. const char* LITE_get_last_error() {
  34. LITE_LOCK_GUARD(mtx_error);
  35. return get_global_error().get_error_msg().c_str();
  36. }
  37. int LITE_get_version(int* major, int* minor, int* patch) {
  38. LITE_ASSERT(major && minor && patch, "The ptr pass to LITE api is null");
  39. lite::get_version(*major, *minor, *patch);
  40. return 0;
  41. }
  42. int LITE_get_device_count(LiteDeviceType device_type, size_t* count) {
  43. LITE_CAPI_BEGIN();
  44. LITE_ASSERT(count, "The ptr pass to LITE api is null");
  45. *count = lite::get_device_count(device_type);
  46. LITE_CAPI_END();
  47. }
  48. int LITE_try_coalesce_all_free_memory() {
  49. LITE_CAPI_BEGIN();
  50. lite::try_coalesce_all_free_memory();
  51. LITE_CAPI_END();
  52. }
  53. int LITE_register_decryption_and_key(
  54. const char* decrypt_name, const LiteDecryptionFunc func,
  55. const uint8_t* key_data, size_t key_size) {
  56. LITE_CAPI_BEGIN();
  57. LITE_ASSERT(decrypt_name && key_data && func, "The ptr pass to LITE api is null");
  58. std::vector<uint8_t> key;
  59. for (size_t i = 0; i < key_size; i++) {
  60. key.push_back(key_data[i]);
  61. }
  62. auto decrypt_func = [func](const void* input_data, size_t input_size,
  63. const std::vector<uint8_t>& key) {
  64. auto size = func(input_data, input_size, key.data(), key.size(), nullptr);
  65. std::vector<uint8_t> output(size, 0);
  66. func(input_data, input_size, key.data(), key.size(), output.data());
  67. return output;
  68. };
  69. lite::register_decryption_and_key(decrypt_name, decrypt_func, key);
  70. LITE_CAPI_END();
  71. }
  72. int LITE_update_decryption_or_key(
  73. const char* decrypt_name, const LiteDecryptionFunc func,
  74. const uint8_t* key_data, size_t key_size) {
  75. LITE_CAPI_BEGIN();
  76. std::vector<uint8_t> key;
  77. for (size_t i = 0; i < key_size; i++) {
  78. key.push_back(key_data[i]);
  79. }
  80. lite::DecryptionFunc decrypt_func = nullptr;
  81. if (func) {
  82. decrypt_func = [func](const void* input_data, size_t input_size,
  83. const std::vector<uint8_t>& key) {
  84. auto size = func(input_data, input_size, key.data(), key.size(), nullptr);
  85. std::vector<uint8_t> output(size, 0);
  86. func(input_data, input_size, key.data(), key.size(), output.data());
  87. return output;
  88. };
  89. }
  90. lite::update_decryption_or_key(decrypt_name, decrypt_func, key);
  91. LITE_CAPI_END();
  92. }
  93. int LITE_register_parse_info_func(
  94. const char* info_type, const LiteParseInfoFunc parse_func) {
  95. LITE_CAPI_BEGIN();
  96. LITE_ASSERT(info_type && parse_func, "The ptr pass to LITE api is null");
  97. auto lite_func =
  98. [parse_func](
  99. const void* info_data, size_t info_size,
  100. const std::string model_name, lite::Config& config,
  101. lite::NetworkIO& network_io,
  102. std::unordered_map<std::string, lite::LiteAny>& separate_config_map,
  103. std::string& extra_info) {
  104. LITE_MARK_USED_VAR(extra_info);
  105. size_t nr_threads = 1;
  106. int device_id = 0, is_cpu_inplace_mode = false, use_tensorrt = false;
  107. LiteNetworkIO c_io;
  108. LiteConfig c_config;
  109. auto ret = parse_func(
  110. info_data, info_size, model_name.c_str(), &c_config, &c_io,
  111. &device_id, &nr_threads, &is_cpu_inplace_mode, &use_tensorrt);
  112. config = convert_to_lite_config(c_config);
  113. network_io = convert_to_lite_io(c_io);
  114. if (device_id != 0) {
  115. separate_config_map["device_id"] = device_id;
  116. }
  117. if (nr_threads != 1) {
  118. separate_config_map["nr_threads"] =
  119. static_cast<uint32_t>(nr_threads);
  120. }
  121. if (is_cpu_inplace_mode != false) {
  122. separate_config_map["is_inplace_mode"] = is_cpu_inplace_mode;
  123. }
  124. if (use_tensorrt != false) {
  125. separate_config_map["use_tensorrt"] = use_tensorrt;
  126. }
  127. return ret;
  128. };
  129. lite::register_parse_info_func(info_type, lite_func);
  130. LITE_CAPI_END();
  131. }
  132. int LITE_set_loader_lib_path(const char* loader_path) {
  133. LITE_CAPI_BEGIN();
  134. LITE_ASSERT(loader_path, "The ptr pass to LITE api is null");
  135. lite::set_loader_lib_path(loader_path);
  136. LITE_CAPI_END();
  137. }
  138. int LITE_set_persistent_cache(const char* cache_path, int always_sync) {
  139. LITE_CAPI_BEGIN();
  140. LITE_ASSERT(cache_path, "The ptr pass to LITE api is null");
  141. lite::set_persistent_cache(cache_path, always_sync);
  142. LITE_CAPI_END();
  143. }
  144. int LITE_set_tensor_rt_cache(const char* cache_path) {
  145. LITE_CAPI_BEGIN();
  146. LITE_ASSERT(cache_path, "The ptr pass to LITE api is null");
  147. lite::set_tensor_rt_cache(cache_path);
  148. LITE_CAPI_END();
  149. }
  150. int LITE_set_log_level(LiteLogLevel level) {
  151. LITE_CAPI_BEGIN();
  152. lite::set_log_level(level);
  153. LITE_CAPI_END();
  154. }
  155. int LITE_get_log_level(LiteLogLevel* level) {
  156. LITE_CAPI_BEGIN();
  157. LITE_ASSERT(level, "The ptr pass to LITE api is null");
  158. *level = lite::get_log_level();
  159. LITE_CAPI_END();
  160. }
  161. int LITE_dump_persistent_cache(const char* cache_path) {
  162. LITE_CAPI_BEGIN();
  163. LITE_ASSERT(cache_path, "The ptr pass to LITE api is null");
  164. lite::dump_persistent_cache(cache_path);
  165. LITE_CAPI_END();
  166. }
  167. int LITE_dump_tensor_rt_cache() {
  168. LITE_CAPI_BEGIN();
  169. lite::dump_tensor_rt_cache();
  170. LITE_CAPI_END();
  171. }
  172. int LITE_register_memory_pair(
  173. void* vir_ptr, void* phy_ptr, size_t length, LiteDeviceType device,
  174. LiteBackend backend) {
  175. LITE_CAPI_BEGIN();
  176. lite::register_memory_pair(vir_ptr, phy_ptr, length, device, backend);
  177. LITE_CAPI_END();
  178. }
  179. int LITE_clear_memory_pair(
  180. void* phy_ptr, void* vir_ptr, LiteDeviceType device, LiteBackend backend) {
  181. LITE_CAPI_BEGIN();
  182. lite::clear_memory_pair(vir_ptr, phy_ptr, device, backend);
  183. LITE_CAPI_END();
  184. }
  185. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}