You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.cpp 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. /**
  2. * \file example/example.cpp
  3. *
  4. * This file is part of MegEngine, a deep learning framework developed by
  5. * Megvii.
  6. *
  7. * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
  8. */
  9. #include "lite/global.h"
  10. #include "lite/network.h"
  11. #include "lite/tensor.h"
  12. #include "example.h"
  13. #include "npy.h"
  14. #include <string.h>
  15. #include <map>
  16. #include <memory>
  17. #include <vector>
  18. using namespace lite;
  19. using namespace example;
  20. Args Args::from_argv(int argc, char** argv) {
  21. Args ret;
  22. if (argc < 4) {
  23. printf("usage: lite_examples <example_name> <model file> <input "
  24. "file> <output file>.\n");
  25. printf("*********The output file is optional.*************\n");
  26. printf("The registered examples include:\n");
  27. size_t index = 0;
  28. for (auto it : *get_example_function_map()) {
  29. printf("%zu : %s\n", index, it.first.c_str());
  30. index++;
  31. }
  32. ret.args_parse_ret = -1;
  33. return ret;
  34. }
  35. ret.example_name = argv[1];
  36. ret.model_path = argv[2];
  37. ret.input_path = argv[3];
  38. if (argc > 4) {
  39. ret.output_path = argv[4];
  40. }
  41. if (argc > 5) {
  42. ret.loader_path = argv[5];
  43. }
  44. return ret;
  45. }
  46. ExampleFuncMap* lite::example::get_example_function_map() {
  47. static ExampleFuncMap static_map;
  48. return &static_map;
  49. }
  50. bool lite::example::register_example(std::string example_name,
  51. const ExampleFunc& fuction) {
  52. auto map = get_example_function_map();
  53. if (map->find(example_name) != map->end()) {
  54. printf("Error!!! This example is registed yet\n");
  55. return false;
  56. }
  57. (*map)[example_name] = fuction;
  58. return true;
  59. }
  60. std::shared_ptr<Tensor> lite::example::parse_npy(const std::string& path,
  61. LiteBackend backend) {
  62. std::string type_str;
  63. std::vector<npy::ndarray_len_t> stl_shape;
  64. std::vector<int8_t> raw;
  65. npy::LoadArrayFromNumpy(path, type_str, stl_shape, raw);
  66. auto lite_tensor =
  67. std::make_shared<Tensor>(backend, LiteDeviceType::LITE_CPU);
  68. Layout layout;
  69. layout.ndim = stl_shape.size();
  70. const std::map<std::string, LiteDataType> type_map = {
  71. {"f4", LiteDataType::LITE_FLOAT},
  72. {"i4", LiteDataType::LITE_INT},
  73. {"i1", LiteDataType::LITE_INT8},
  74. {"u1", LiteDataType::LITE_UINT8}};
  75. layout.shapes[0] = 1;
  76. for (size_t i = 0; i < layout.ndim; i++) {
  77. layout.shapes[i] = static_cast<size_t>(stl_shape[i]);
  78. }
  79. for (auto& item : type_map) {
  80. if (type_str.find(item.first) != std::string::npos) {
  81. layout.data_type = item.second;
  82. break;
  83. }
  84. }
  85. lite_tensor->set_layout(layout);
  86. size_t length = lite_tensor->get_tensor_total_size_in_byte();
  87. void* dest = lite_tensor->get_memory_ptr();
  88. memcpy(dest, raw.data(), length);
  89. //! rknn not support reshape now
  90. if (layout.ndim == 3) {
  91. lite_tensor->reshape({1, static_cast<int>(layout.shapes[0]),
  92. static_cast<int>(layout.shapes[1]),
  93. static_cast<int>(layout.shapes[2])});
  94. }
  95. return lite_tensor;
  96. }
  97. void lite::example::set_cpu_affinity(const std::vector<int>& cpuset) {
  98. #if defined(__APPLE__) || defined(WIN32)
  99. #pragma message("set_cpu_affinity not enabled on apple and windows platform")
  100. #else
  101. cpu_set_t mask;
  102. CPU_ZERO(&mask);
  103. for (auto i : cpuset) {
  104. CPU_SET(i, &mask);
  105. }
  106. auto err = sched_setaffinity(0, sizeof(mask), &mask);
  107. if (err) {
  108. printf("failed to sched_setaffinity: %s (error ignored)",
  109. strerror(errno));
  110. }
  111. #endif
  112. }
  113. int main(int argc, char** argv) {
  114. set_log_level(LiteLogLevel::WARN);
  115. auto&& args = Args::from_argv(argc, argv);
  116. if (args.args_parse_ret)
  117. return -1;
  118. auto map = get_example_function_map();
  119. auto example = (*map)[args.example_name];
  120. if (example) {
  121. printf("Begin to run %s example.\n", args.example_name.c_str());
  122. return example(args);
  123. } else {
  124. printf("The example of %s is not registed.", args.example_name.c_str());
  125. return -1;
  126. }
  127. }
  128. namespace lite {
  129. namespace example {
  130. #if LITE_BUILD_WITH_MGE
  131. #if LITE_WITH_CUDA
  132. REGIST_EXAMPLE("load_from_path_run_cuda", load_from_path_run_cuda);
  133. #endif
  134. REGIST_EXAMPLE("basic_load_from_path", basic_load_from_path);
  135. REGIST_EXAMPLE("basic_load_from_path_with_loader", basic_load_from_path_with_loader);
  136. REGIST_EXAMPLE("basic_load_from_memory", basic_load_from_memory);
  137. REGIST_EXAMPLE("cpu_affinity", cpu_affinity);
  138. REGIST_EXAMPLE("register_cryption_method", register_cryption_method);
  139. REGIST_EXAMPLE("update_cryption_key", update_cryption_key);
  140. REGIST_EXAMPLE("network_share_same_weights", network_share_same_weights);
  141. REGIST_EXAMPLE("reset_input", reset_input);
  142. REGIST_EXAMPLE("reset_input_output", reset_input_output);
  143. REGIST_EXAMPLE("config_user_allocator", config_user_allocator);
  144. REGIST_EXAMPLE("async_forward", async_forward);
  145. REGIST_EXAMPLE("basic_c_interface", basic_c_interface);
  146. REGIST_EXAMPLE("device_io_c_interface", device_io_c_interface);
  147. REGIST_EXAMPLE("async_c_interface", async_c_interface);
  148. #if LITE_WITH_CUDA
  149. REGIST_EXAMPLE("device_input", device_input);
  150. REGIST_EXAMPLE("device_input_output", device_input_output);
  151. REGIST_EXAMPLE("pinned_host_input", pinned_host_input);
  152. #endif
  153. #endif
  154. } // namespace example
  155. } // namespace lite
  156. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台