You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lite_c_interface.cpp 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. #include "example.h"
  2. #include "helper.h"
  3. #if LITE_BUILD_WITH_MGE
  4. #include "lite-c/global_c.h"
  5. #include "lite-c/network_c.h"
  6. #include "lite-c/tensor_c.h"
  7. #include <thread>
  8. #define LITE_CAPI_CHECK(_expr) \
  9. do { \
  10. int _ret = (_expr); \
  11. if (_ret) { \
  12. LITE_EXAMPLE_THROW(LITE_get_last_error()); \
  13. } \
  14. } while (0)
  15. bool basic_c_interface(const lite::example::Args& args) {
  16. std::string network_path = args.model_path;
  17. std::string input_path = args.input_path;
  18. //! read input data to lite::tensor
  19. auto src_tensor = lite::example::parse_npy(input_path);
  20. void* src_ptr = src_tensor->get_memory_ptr();
  21. //! create and load the network
  22. LiteNetwork c_network;
  23. LITE_CAPI_CHECK(
  24. LITE_make_network(&c_network, *default_config(), *default_network_io()));
  25. LITE_CAPI_CHECK(LITE_load_model_from_path(c_network, network_path.c_str()));
  26. //! set input data to input tensor
  27. LiteTensor c_input_tensor;
  28. LITE_CAPI_CHECK(LITE_get_io_tensor(c_network, "data", LITE_IO, &c_input_tensor));
  29. void* dst_ptr;
  30. size_t length_in_byte;
  31. LITE_CAPI_CHECK(
  32. LITE_get_tensor_total_size_in_byte(c_input_tensor, &length_in_byte));
  33. LITE_CAPI_CHECK(LITE_get_tensor_memory(c_input_tensor, &dst_ptr));
  34. //! copy or forward data to network
  35. memcpy(dst_ptr, src_ptr, length_in_byte);
  36. //! forward
  37. LITE_CAPI_CHECK(LITE_forward(c_network));
  38. LITE_CAPI_CHECK(LITE_wait(c_network));
  39. //! get the output data or read tensor data
  40. const char* output_name;
  41. LiteTensor c_output_tensor;
  42. //! get the first output tensor name
  43. LITE_CAPI_CHECK(LITE_get_output_name(c_network, 0, &output_name));
  44. LITE_CAPI_CHECK(
  45. LITE_get_io_tensor(c_network, output_name, LITE_IO, &c_output_tensor));
  46. void* output_ptr;
  47. size_t length_output_in_byte;
  48. LITE_CAPI_CHECK(LITE_get_tensor_memory(c_output_tensor, &output_ptr));
  49. LITE_CAPI_CHECK(LITE_get_tensor_total_size_in_byte(
  50. c_output_tensor, &length_output_in_byte));
  51. size_t out_length = length_output_in_byte / sizeof(float);
  52. printf("length=%zu\n", out_length);
  53. float max = -1.0f;
  54. float sum = 0.0f;
  55. for (size_t i = 0; i < out_length; i++) {
  56. float data = static_cast<float*>(output_ptr)[i];
  57. sum += data;
  58. if (max < data)
  59. max = data;
  60. }
  61. printf("max=%e, sum=%e\n", max, sum);
  62. return true;
  63. }
  64. bool device_io_c_interface(const lite::example::Args& args) {
  65. std::string network_path = args.model_path;
  66. std::string input_path = args.input_path;
  67. //! read input data to lite::tensor
  68. auto src_tensor = lite::example::parse_npy(input_path);
  69. void* src_ptr = src_tensor->get_memory_ptr();
  70. size_t length_read_in = src_tensor->get_tensor_total_size_in_byte();
  71. //! create and load the network
  72. LiteNetwork c_network;
  73. LITE_CAPI_CHECK(
  74. LITE_make_network(&c_network, *default_config(), *default_network_io()));
  75. LITE_CAPI_CHECK(LITE_load_model_from_path(c_network, network_path.c_str()));
  76. //! set input data to input tensor
  77. LiteTensor c_input_tensor;
  78. size_t length_tensor_in;
  79. LITE_CAPI_CHECK(LITE_get_io_tensor(c_network, "data", LITE_IO, &c_input_tensor));
  80. LITE_CAPI_CHECK(
  81. LITE_get_tensor_total_size_in_byte(c_input_tensor, &length_tensor_in));
  82. if (length_read_in != length_tensor_in) {
  83. LITE_EXAMPLE_THROW(
  84. "The input data size is not match the network input tensro "
  85. "size,\n");
  86. }
  87. LITE_CAPI_CHECK(
  88. LITE_reset_tensor_memory(c_input_tensor, src_ptr, length_tensor_in));
  89. //! reset the output tensor memory with user allocated memory
  90. size_t out_length = 1000;
  91. LiteLayout output_layout{{1, 1000}, 2, LiteDataType::LITE_FLOAT};
  92. std::shared_ptr<float> ptr(new float[out_length], [](float* ptr) { delete[] ptr; });
  93. const char* output_name;
  94. LiteTensor c_output_tensor;
  95. LITE_CAPI_CHECK(LITE_get_output_name(c_network, 0, &output_name));
  96. LITE_CAPI_CHECK(
  97. LITE_get_io_tensor(c_network, output_name, LITE_IO, &c_output_tensor));
  98. LITE_CAPI_CHECK(LITE_reset_tensor(c_output_tensor, output_layout, ptr.get()));
  99. //! forward
  100. LITE_CAPI_CHECK(LITE_forward(c_network));
  101. LITE_CAPI_CHECK(LITE_wait(c_network));
  102. printf("length=%zu\n", out_length);
  103. float max = -1.0f;
  104. float sum = 0.0f;
  105. void* out_data = ptr.get();
  106. for (size_t i = 0; i < out_length; i++) {
  107. float data = static_cast<float*>(out_data)[i];
  108. sum += data;
  109. if (max < data)
  110. max = data;
  111. }
  112. printf("max=%e, sum=%e\n", max, sum);
  113. return true;
  114. }
  115. namespace {
  116. volatile bool finished = false;
  117. int async_callback(void) {
  118. #if !__DEPLOY_ON_XP_SP2__
  119. std::cout << "worker thread_id:" << std::this_thread::get_id() << std::endl;
  120. #endif
  121. finished = true;
  122. return 0;
  123. }
  124. } // namespace
  125. bool async_c_interface(const lite::example::Args& args) {
  126. std::string network_path = args.model_path;
  127. std::string input_path = args.input_path;
  128. //! read input data to lite::tensor
  129. auto src_tensor = lite::example::parse_npy(input_path);
  130. void* src_ptr = src_tensor->get_memory_ptr();
  131. LiteNetwork c_network;
  132. LiteConfig config = *default_config();
  133. config.options.var_sanity_check_first_run = false;
  134. LITE_CAPI_CHECK(LITE_make_network(&c_network, config, *default_network_io()));
  135. LITE_CAPI_CHECK(LITE_load_model_from_path(c_network, network_path.c_str()));
  136. //! set input data to input tensor
  137. LiteTensor c_input_tensor;
  138. size_t length_tensor_in;
  139. LITE_CAPI_CHECK(LITE_get_io_tensor(c_network, "data", LITE_IO, &c_input_tensor));
  140. LITE_CAPI_CHECK(
  141. LITE_get_tensor_total_size_in_byte(c_input_tensor, &length_tensor_in));
  142. LITE_CAPI_CHECK(
  143. LITE_reset_tensor_memory(c_input_tensor, src_ptr, length_tensor_in));
  144. #if !__DEPLOY_ON_XP_SP2__
  145. std::cout << "user thread_id:" << std::this_thread::get_id() << std::endl;
  146. #endif
  147. LITE_CAPI_CHECK(LITE_set_async_callback(c_network, async_callback));
  148. //! forward
  149. LITE_CAPI_CHECK(LITE_forward(c_network));
  150. size_t count = 0;
  151. while (finished == false) {
  152. count++;
  153. }
  154. printf("The count is %zu\n", count);
  155. finished = false;
  156. //! get the output data or read tensor data
  157. const char* output_name;
  158. LiteTensor c_output_tensor;
  159. //! get the first output tensor name
  160. LITE_CAPI_CHECK(LITE_get_output_name(c_network, 0, &output_name));
  161. LITE_CAPI_CHECK(
  162. LITE_get_io_tensor(c_network, output_name, LITE_IO, &c_output_tensor));
  163. void* output_ptr;
  164. size_t length_output_in_byte;
  165. LITE_CAPI_CHECK(LITE_get_tensor_memory(c_output_tensor, &output_ptr));
  166. LITE_CAPI_CHECK(LITE_get_tensor_total_size_in_byte(
  167. c_output_tensor, &length_output_in_byte));
  168. size_t out_length = length_output_in_byte / sizeof(float);
  169. printf("length=%zu\n", out_length);
  170. float max = -1.0f;
  171. float sum = 0.0f;
  172. for (size_t i = 0; i < out_length; i++) {
  173. float data = static_cast<float*>(output_ptr)[i];
  174. sum += data;
  175. if (max < data)
  176. max = data;
  177. }
  178. printf("max=%e, sum=%e\n", max, sum);
  179. return true;
  180. }
  181. REGIST_EXAMPLE("basic_c_interface", basic_c_interface);
  182. REGIST_EXAMPLE("device_io_c_interface", device_io_c_interface);
  183. REGIST_EXAMPLE("async_c_interface", async_c_interface);
  184. #endif
  185. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}