You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.cpp 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. /**
  2. * \file example/cpp_example/cpp_example/example.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "lite/global.h"
  12. #include "lite/network.h"
  13. #include "lite/tensor.h"
  14. #include "example.h"
  15. #include "npy.h"
  16. #include <string.h>
  17. #include <map>
  18. #include <memory>
  19. #include <vector>
  20. using namespace lite;
  21. using namespace example;
  22. Args Args::from_argv(int argc, char** argv) {
  23. Args ret;
  24. if (argc < 4) {
  25. printf("usage: lite_examples <example_name> <model file> <input "
  26. "file> <output file>.\n");
  27. printf("*********The output file is optional.*************\n");
  28. printf("The registered examples include:\n");
  29. size_t index = 0;
  30. for (auto it : *get_example_function_map()) {
  31. printf("%zu : %s\n", index, it.first.c_str());
  32. index++;
  33. }
  34. ret.args_parse_ret = -1;
  35. return ret;
  36. }
  37. ret.example_name = argv[1];
  38. ret.model_path = argv[2];
  39. ret.input_path = argv[3];
  40. if (argc > 4) {
  41. ret.output_path = argv[4];
  42. }
  43. if (argc > 5) {
  44. ret.loader_path = argv[5];
  45. }
  46. return ret;
  47. }
  48. ExampleFuncMap* lite::example::get_example_function_map() {
  49. static ExampleFuncMap static_map;
  50. return &static_map;
  51. }
  52. bool lite::example::register_example(
  53. std::string example_name, const ExampleFunc& fuction) {
  54. auto map = get_example_function_map();
  55. if (map->find(example_name) != map->end()) {
  56. printf("example_name: %s Error!!! This example is registed yet\n",
  57. example_name.c_str());
  58. return false;
  59. }
  60. (*map)[example_name] = fuction;
  61. return true;
  62. }
  63. std::shared_ptr<Tensor> lite::example::parse_npy(
  64. const std::string& path, LiteBackend backend) {
  65. std::string type_str;
  66. std::vector<npy::ndarray_len_t> stl_shape;
  67. std::vector<int8_t> raw;
  68. npy::LoadArrayFromNumpy(path, type_str, stl_shape, raw);
  69. auto lite_tensor = std::make_shared<Tensor>(backend, LiteDeviceType::LITE_CPU);
  70. Layout layout;
  71. layout.ndim = stl_shape.size();
  72. const std::map<std::string, LiteDataType> type_map = {
  73. {"f4", LiteDataType::LITE_FLOAT},
  74. {"i4", LiteDataType::LITE_INT},
  75. {"i1", LiteDataType::LITE_INT8},
  76. {"u1", LiteDataType::LITE_UINT8}};
  77. layout.shapes[0] = 1;
  78. for (size_t i = 0; i < layout.ndim; i++) {
  79. layout.shapes[i] = static_cast<size_t>(stl_shape[i]);
  80. }
  81. for (auto& item : type_map) {
  82. if (type_str.find(item.first) != std::string::npos) {
  83. layout.data_type = item.second;
  84. break;
  85. }
  86. }
  87. lite_tensor->set_layout(layout);
  88. size_t length = lite_tensor->get_tensor_total_size_in_byte();
  89. void* dest = lite_tensor->get_memory_ptr();
  90. memcpy(dest, raw.data(), length);
  91. //! rknn not support reshape now
  92. if (layout.ndim == 3) {
  93. lite_tensor->reshape(
  94. {1, static_cast<int>(layout.shapes[0]),
  95. static_cast<int>(layout.shapes[1]),
  96. static_cast<int>(layout.shapes[2])});
  97. }
  98. return lite_tensor;
  99. }
  100. void lite::example::set_cpu_affinity(const std::vector<int>& cpuset) {
  101. #if defined(__APPLE__) || defined(WIN32) || defined(_WIN32)
  102. #pragma message("set_cpu_affinity not enabled on apple and windows platform")
  103. #else
  104. cpu_set_t mask;
  105. CPU_ZERO(&mask);
  106. for (auto i : cpuset) {
  107. CPU_SET(i, &mask);
  108. }
  109. auto err = sched_setaffinity(0, sizeof(mask), &mask);
  110. if (err) {
  111. printf("failed to sched_setaffinity: %s (error ignored)", strerror(errno));
  112. }
  113. #endif
  114. }
  115. int main(int argc, char** argv) {
  116. set_log_level(LiteLogLevel::WARN);
  117. auto&& args = Args::from_argv(argc, argv);
  118. if (args.args_parse_ret)
  119. return -1;
  120. auto map = get_example_function_map();
  121. auto example = (*map)[args.example_name];
  122. if (example) {
  123. printf("Begin to run %s example.\n", args.example_name.c_str());
  124. if (example(args)) {
  125. return 0;
  126. } else {
  127. return -1;
  128. }
  129. } else {
  130. printf("The example of %s is not registed.", args.example_name.c_str());
  131. return -1;
  132. }
  133. }
  134. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}