You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.cpp 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. /**
  2. * \file example/cpp_example/cpp_example/example.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "lite/global.h"
  12. #include "lite/network.h"
  13. #include "lite/tensor.h"
  14. #include "example.h"
  15. #include "npy.h"
  16. #include <string.h>
  17. #include <map>
  18. #include <memory>
  19. #include <vector>
  20. using namespace lite;
  21. using namespace example;
  22. Args Args::from_argv(int argc, char** argv) {
  23. Args ret;
  24. if (argc < 4) {
  25. printf("usage: lite_examples <example_name> <model file> <input "
  26. "file> <output file>.\n");
  27. printf("*********The output file is optional.*************\n");
  28. printf("The registered examples include:\n");
  29. size_t index = 0;
  30. for (auto it : *get_example_function_map()) {
  31. printf("%zu : %s\n", index, it.first.c_str());
  32. index++;
  33. }
  34. ret.args_parse_ret = -1;
  35. return ret;
  36. }
  37. ret.example_name = argv[1];
  38. ret.model_path = argv[2];
  39. ret.input_path = argv[3];
  40. if (argc > 4) {
  41. ret.output_path = argv[4];
  42. }
  43. if (argc > 5) {
  44. ret.loader_path = argv[5];
  45. }
  46. return ret;
  47. }
  48. ExampleFuncMap* lite::example::get_example_function_map() {
  49. static ExampleFuncMap static_map;
  50. return &static_map;
  51. }
  52. bool lite::example::register_example(
  53. std::string example_name, const ExampleFunc& fuction) {
  54. auto map = get_example_function_map();
  55. if (map->find(example_name) != map->end()) {
  56. printf("Error!!! This example is registed yet\n");
  57. return false;
  58. }
  59. (*map)[example_name] = fuction;
  60. return true;
  61. }
  62. std::shared_ptr<Tensor> lite::example::parse_npy(
  63. const std::string& path, LiteBackend backend) {
  64. std::string type_str;
  65. std::vector<npy::ndarray_len_t> stl_shape;
  66. std::vector<int8_t> raw;
  67. npy::LoadArrayFromNumpy(path, type_str, stl_shape, raw);
  68. auto lite_tensor = std::make_shared<Tensor>(backend, LiteDeviceType::LITE_CPU);
  69. Layout layout;
  70. layout.ndim = stl_shape.size();
  71. const std::map<std::string, LiteDataType> type_map = {
  72. {"f4", LiteDataType::LITE_FLOAT},
  73. {"i4", LiteDataType::LITE_INT},
  74. {"i1", LiteDataType::LITE_INT8},
  75. {"u1", LiteDataType::LITE_UINT8}};
  76. layout.shapes[0] = 1;
  77. for (size_t i = 0; i < layout.ndim; i++) {
  78. layout.shapes[i] = static_cast<size_t>(stl_shape[i]);
  79. }
  80. for (auto& item : type_map) {
  81. if (type_str.find(item.first) != std::string::npos) {
  82. layout.data_type = item.second;
  83. break;
  84. }
  85. }
  86. lite_tensor->set_layout(layout);
  87. size_t length = lite_tensor->get_tensor_total_size_in_byte();
  88. void* dest = lite_tensor->get_memory_ptr();
  89. memcpy(dest, raw.data(), length);
  90. //! rknn not support reshape now
  91. if (layout.ndim == 3) {
  92. lite_tensor->reshape(
  93. {1, static_cast<int>(layout.shapes[0]),
  94. static_cast<int>(layout.shapes[1]),
  95. static_cast<int>(layout.shapes[2])});
  96. }
  97. return lite_tensor;
  98. }
  99. void lite::example::set_cpu_affinity(const std::vector<int>& cpuset) {
  100. #if defined(__APPLE__) || defined(WIN32)
  101. #pragma message("set_cpu_affinity not enabled on apple and windows platform")
  102. #else
  103. cpu_set_t mask;
  104. CPU_ZERO(&mask);
  105. for (auto i : cpuset) {
  106. CPU_SET(i, &mask);
  107. }
  108. auto err = sched_setaffinity(0, sizeof(mask), &mask);
  109. if (err) {
  110. printf("failed to sched_setaffinity: %s (error ignored)", strerror(errno));
  111. }
  112. #endif
  113. }
  114. int main(int argc, char** argv) {
  115. set_log_level(LiteLogLevel::WARN);
  116. auto&& args = Args::from_argv(argc, argv);
  117. if (args.args_parse_ret)
  118. return -1;
  119. auto map = get_example_function_map();
  120. auto example = (*map)[args.example_name];
  121. if (example) {
  122. printf("Begin to run %s example.\n", args.example_name.c_str());
  123. if (example(args)) {
  124. return 0;
  125. } else {
  126. return -1;
  127. }
  128. } else {
  129. printf("The example of %s is not registed.", args.example_name.c_str());
  130. return -1;
  131. }
  132. }
  133. namespace lite {
  134. namespace example {
  135. #if LITE_BUILD_WITH_MGE
  136. #if LITE_WITH_CUDA
  137. REGIST_EXAMPLE("load_from_path_run_cuda", load_from_path_run_cuda);
  138. #endif
  139. REGIST_EXAMPLE("basic_load_from_path", basic_load_from_path);
  140. REGIST_EXAMPLE("basic_load_from_path_with_loader", basic_load_from_path_with_loader);
  141. REGIST_EXAMPLE("basic_load_from_memory", basic_load_from_memory);
  142. REGIST_EXAMPLE("cpu_affinity", cpu_affinity);
  143. REGIST_EXAMPLE("register_cryption_method", register_cryption_method);
  144. REGIST_EXAMPLE("update_cryption_key", update_cryption_key);
  145. REGIST_EXAMPLE("network_share_same_weights", network_share_same_weights);
  146. REGIST_EXAMPLE("reset_input", reset_input);
  147. REGIST_EXAMPLE("reset_input_output", reset_input_output);
  148. REGIST_EXAMPLE("config_user_allocator", config_user_allocator);
  149. REGIST_EXAMPLE("async_forward", async_forward);
  150. REGIST_EXAMPLE("set_input_callback", set_input_callback);
  151. REGIST_EXAMPLE("set_output_callback", set_output_callback);
  152. REGIST_EXAMPLE("basic_c_interface", basic_c_interface);
  153. REGIST_EXAMPLE("device_io_c_interface", device_io_c_interface);
  154. REGIST_EXAMPLE("async_c_interface", async_c_interface);
  155. REGIST_EXAMPLE("picture_classification", picture_classification);
  156. REGIST_EXAMPLE("detect_yolox", detect_yolox);
  157. #if LITE_WITH_CUDA
  158. REGIST_EXAMPLE("device_input", device_input);
  159. REGIST_EXAMPLE("device_input_output", device_input_output);
  160. REGIST_EXAMPLE("pinned_host_input", pinned_host_input);
  161. #endif
  162. #endif
  163. } // namespace example
  164. } // namespace lite
  165. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}