You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.cpp 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. #include "lite/global.h"
  2. #include "lite/network.h"
  3. #include "lite/tensor.h"
  4. #include "example.h"
  5. #include "npy.h"
  6. #include <string.h>
  7. #include <map>
  8. #include <memory>
  9. #include <vector>
  10. using namespace lite;
  11. using namespace example;
  12. Args Args::from_argv(int argc, char** argv) {
  13. Args ret;
  14. if (argc < 4) {
  15. printf("usage: lite_examples <example_name> <model file> <input "
  16. "file> <output file>.\n");
  17. printf("*********The output file is optional.*************\n");
  18. printf("The registered examples include:\n");
  19. size_t index = 0;
  20. for (auto it : *get_example_function_map()) {
  21. printf("%zu : %s\n", index, it.first.c_str());
  22. index++;
  23. }
  24. ret.args_parse_ret = -1;
  25. return ret;
  26. }
  27. ret.example_name = argv[1];
  28. ret.model_path = argv[2];
  29. ret.input_path = argv[3];
  30. if (argc > 4) {
  31. ret.output_path = argv[4];
  32. }
  33. if (argc > 5) {
  34. ret.loader_path = argv[5];
  35. }
  36. return ret;
  37. }
  38. ExampleFuncMap* lite::example::get_example_function_map() {
  39. static ExampleFuncMap static_map;
  40. return &static_map;
  41. }
  42. bool lite::example::register_example(
  43. std::string example_name, const ExampleFunc& fuction) {
  44. auto map = get_example_function_map();
  45. if (map->find(example_name) != map->end()) {
  46. printf("example_name: %s Error!!! This example is registed yet\n",
  47. example_name.c_str());
  48. return false;
  49. }
  50. (*map)[example_name] = fuction;
  51. return true;
  52. }
  53. std::shared_ptr<Tensor> lite::example::parse_npy(
  54. const std::string& path, LiteBackend backend) {
  55. std::string type_str;
  56. std::vector<npy::ndarray_len_t> stl_shape;
  57. std::vector<int8_t> raw;
  58. npy::LoadArrayFromNumpy(path, type_str, stl_shape, raw);
  59. auto lite_tensor = std::make_shared<Tensor>(backend, LiteDeviceType::LITE_CPU);
  60. Layout layout;
  61. layout.ndim = stl_shape.size();
  62. const std::map<std::string, LiteDataType> type_map = {
  63. {"f4", LiteDataType::LITE_FLOAT},
  64. {"i4", LiteDataType::LITE_INT},
  65. {"i1", LiteDataType::LITE_INT8},
  66. {"u1", LiteDataType::LITE_UINT8}};
  67. layout.shapes[0] = 1;
  68. for (size_t i = 0; i < layout.ndim; i++) {
  69. layout.shapes[i] = static_cast<size_t>(stl_shape[i]);
  70. }
  71. for (auto& item : type_map) {
  72. if (type_str.find(item.first) != std::string::npos) {
  73. layout.data_type = item.second;
  74. break;
  75. }
  76. }
  77. lite_tensor->set_layout(layout);
  78. size_t length = lite_tensor->get_tensor_total_size_in_byte();
  79. void* dest = lite_tensor->get_memory_ptr();
  80. memcpy(dest, raw.data(), length);
  81. //! rknn not support reshape now
  82. if (layout.ndim == 3) {
  83. lite_tensor->reshape(
  84. {1, static_cast<int>(layout.shapes[0]),
  85. static_cast<int>(layout.shapes[1]),
  86. static_cast<int>(layout.shapes[2])});
  87. }
  88. return lite_tensor;
  89. }
  90. void lite::example::set_cpu_affinity(const std::vector<int>& cpuset) {
  91. #if defined(__APPLE__) || defined(WIN32) || defined(_WIN32)
  92. #pragma message("set_cpu_affinity not enabled on apple and windows platform")
  93. #else
  94. cpu_set_t mask;
  95. CPU_ZERO(&mask);
  96. for (auto i : cpuset) {
  97. CPU_SET(i, &mask);
  98. }
  99. auto err = sched_setaffinity(0, sizeof(mask), &mask);
  100. if (err) {
  101. printf("failed to sched_setaffinity: %s (error ignored)", strerror(errno));
  102. }
  103. #endif
  104. }
  105. int main(int argc, char** argv) {
  106. set_log_level(LiteLogLevel::WARN);
  107. auto&& args = Args::from_argv(argc, argv);
  108. if (args.args_parse_ret)
  109. return -1;
  110. auto map = get_example_function_map();
  111. auto example = (*map)[args.example_name];
  112. if (example) {
  113. printf("Begin to run %s example.\n", args.example_name.c_str());
  114. if (example(args)) {
  115. return 0;
  116. } else {
  117. return -1;
  118. }
  119. } else {
  120. printf("The example of %s is not registed.", args.example_name.c_str());
  121. return -1;
  122. }
  123. }
  124. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}