You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mace_loader.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. /**
  2. * \file sdk/c-opr-loaders/mace/mace_loader.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include <numeric>
  12. #include <iostream>
  13. #include <sys/stat.h>
  14. #include "mace/public/mace.h"
  15. #include "extern_c_opr.h"
  16. #if defined(__APPLE__) || defined(__MACOSX)
  17. static const char* default_so_paths[] = {
  18. "/System/Library/Frameworks/OpenCL.framework/OpenCL", "libOpenCL.so"};
  19. #elif defined(__ANDROID__)
  20. static const char* default_so_paths[] = {
  21. #if defined(__aarch64__)
  22. "/system/lib64/libOpenCL.so",
  23. "/system/lib64/libOpenCL_system.so",
  24. "/system/lib64/egl/libGLES_mali.so",
  25. "/system/vendor/lib64/libOpenCL.so",
  26. "/system/vendor/lib64/egl/libGLES_mali.so",
  27. "/system/vendor/lib64/libPVROCL.so",
  28. "/vendor/lib64/libOpenCL.so",
  29. "/data/data/org.pocl.libs/files/lib64/libpocl.so",
  30. #else
  31. "/system/lib/libOpenCL.so",
  32. "/system/lib/libOpenCL_system.so",
  33. "/system/lib/egl/libGLES_mali.so",
  34. "/system/vendor/lib/libOpenCL.so",
  35. "/system/vendor/lib/egl/libGLES_mali.so",
  36. "/system/vendor/lib/libPVROCL.so",
  37. "/vendor/lib/libOpenCL.so",
  38. "/data/data/org.pocl.libs/files/lib/libpocl.so",
  39. #endif
  40. "libOpenCL.so"};
  41. #elif defined(_WIN32)
  42. static const char* default_so_paths[] = {"OpenCL.dll"};
  43. #elif defined(__linux__)
  44. static const char* default_so_paths[] = {
  45. #if defined(__x86_64__) || defined(__amd64__)
  46. "/usr/lib64/libOpenCL.so", "/usr/local/lib64/libOpenCL.so",
  47. "/usr/local/cuda/lib64/libOpenCL.so",
  48. "/opt/intel/opencl/libOpenCL.so",
  49. //! As in some system like apex, the driver exists here
  50. "/usr/lib/libOpenCL.so",
  51. #else
  52. "/usr/lib/libOpenCL.so",
  53. "/usr/lib32/libOpenCL.so",
  54. "/usr/local/lib/libOpenCL.so",
  55. "/usr/local/lib/libpocl.so",
  56. "/usr/local/cuda/lib/libOpenCL.so",
  57. #endif
  58. "libOpenCL.so"};
  59. #endif
  60. #define ASSERT(x, msg) \
  61. do { \
  62. if (!(x)) { \
  63. printf("error at %s:%d %s\n", __FILE__, __LINE__, __FUNCTION__); \
  64. printf(msg); \
  65. __builtin_trap(); \
  66. } \
  67. } while (0)
  68. inline bool file_exists (const char* name) {
  69. struct stat buffer;
  70. return (stat (name, &buffer) == 0);
  71. }
  72. class MGBOprDescImpl {
  73. struct UserData {
  74. std::shared_ptr<mace::MaceEngine> engine;
  75. size_t nr_inputs, nr_outputs;
  76. std::vector<std::vector<int64_t>> output_shapes;
  77. std::vector<std::string> input_names, output_names;
  78. };
  79. static UserData* user_data(const MGBOprDesc* self) {
  80. return static_cast<UserData*>(self->user_data);
  81. }
  82. static void release(MGBOprDesc* self) {
  83. // free all data buffers
  84. delete user_data(self);
  85. delete self;
  86. }
  87. static size_t hash(const MGBOprDesc* self) {
  88. return reinterpret_cast<size_t>(self);
  89. }
  90. static int is_same(const MGBOprDesc* self, const MGBOprDesc* rhs) {
  91. return self == rhs;
  92. }
  93. static void infer_shape(const MGBOprDesc* self, const MGBTensorShape* input,
  94. MGBTensorShape* output) {
  95. auto ud = user_data(self);
  96. // infer output shape from user data
  97. for (size_t i = 0; i < ud->nr_outputs; i++) {
  98. output[i].ndim = ud->output_shapes[i].size();
  99. for (size_t j = 0; j < output[i].ndim; j++) {
  100. output[i].shape[j] = ud->output_shapes[i][j];
  101. }
  102. }
  103. }
  104. static void infer_dtype(const MGBOprDesc*, const MGBDType* input, MGBDType* output) {
  105. ASSERT(input[0] == MGB_DTYPE_FLOAT32, "Input dtype is not float32");
  106. output[0] = MGB_DTYPE_FLOAT32;
  107. }
  108. static void execute(const MGBOprDesc* self, const MGBTensor* input,
  109. const MGBTensor* output) {
  110. auto ud = user_data(self);
  111. // create input and output tensor buffers
  112. std::map<std::string, mace::MaceTensor> mace_inputs;
  113. std::map<std::string, mace::MaceTensor> mace_outputs;
  114. auto mace_data_format = mace::DataFormat::NCHW;
  115. char *data_format = getenv("MGB_MACE_LOADER_FORMAT");
  116. if (data_format != nullptr && !strcmp(data_format, "NHWC")) {
  117. mace_data_format = mace::DataFormat::NHWC;
  118. }
  119. for (size_t i = 0; i < ud->nr_inputs; ++i) {
  120. // allocate input
  121. uint32_t ndim = input[i].layout.shape.ndim;
  122. auto input_shape = std::vector<int64_t>(input[i].layout.shape.shape,
  123. input[i].layout.shape.shape + ndim);
  124. int64_t input_size =
  125. std::accumulate(input_shape.begin(), input_shape.end(), 1,
  126. std::multiplies<uint64_t>());
  127. auto buffer_in = std::shared_ptr<float>(new float[input_size],
  128. std::default_delete<float[]>());
  129. memcpy(buffer_in.get(), input[i].data, input_size * sizeof(float));
  130. mace_inputs[ud->input_names[i]] =
  131. mace::MaceTensor(input_shape, buffer_in, mace_data_format);
  132. }
  133. for (size_t i = 0; i < ud->nr_outputs; ++i) {
  134. // allocate output
  135. uint32_t ndim = output[i].layout.shape.ndim;
  136. auto output_shape = std::vector<int64_t>(output[i].layout.shape.shape,
  137. output[i].layout.shape.shape + ndim);
  138. int64_t output_size =
  139. std::accumulate(output_shape.begin(), output_shape.end(), 1,
  140. std::multiplies<int64_t>());
  141. auto buffer_out = std::shared_ptr<float>(new float[output_size],
  142. std::default_delete<float[]>());
  143. mace_outputs[ud->output_names[i]] =
  144. mace::MaceTensor(output_shape, buffer_out, mace_data_format);
  145. }
  146. // run the model
  147. auto status = (ud->engine)->Run(mace_inputs, &mace_outputs);
  148. ASSERT(status == mace::MaceStatus::MACE_SUCCESS,
  149. "Error in running mace engine");
  150. // send computed output to MGB
  151. int idx = 0;
  152. for (auto it = mace_outputs.begin(); it != mace_outputs.end(); it++) {
  153. float* to = &((float *)output[idx++].data)[0];
  154. to = (it->second).data().get();
  155. }
  156. }
  157. public:
  158. static MGBOprDesc* make(size_t nr_input, const void *buf, size_t buf_len) {
  159. auto ud = std::make_unique<UserData>();
  160. std::shared_ptr<mace::MaceEngine> engine;
  161. mace::DeviceType device_type = mace::DeviceType::CPU;
  162. char *runtime_mode = getenv("MGB_MACE_RUNTIME");
  163. if (runtime_mode != nullptr && !strcmp(runtime_mode, "GPU")) {
  164. device_type = mace::DeviceType::GPU;
  165. }
  166. mace::MaceEngineConfig config(device_type);
  167. // set number of threads for cpu, default 1
  168. if (device_type == mace::DeviceType::CPU) {
  169. int nthread = 1;
  170. char *str_nthread = getenv("MGB_MACE_NR_THREADS");
  171. if (str_nthread != nullptr) {
  172. nthread = atoi(str_nthread);
  173. }
  174. config.SetCPUThreadPolicy(nthread, mace::CPUAffinityPolicy::AFFINITY_NONE);
  175. }
  176. // set gpu context, mainly opencl path
  177. if (device_type == mace::DeviceType::GPU) {
  178. std::shared_ptr<mace::GPUContext> gpu_context;
  179. char *cache_path = getenv("MGB_MACE_OPENCL_CACHE_PATH");
  180. ASSERT(cache_path, "there must be an opencl cache file path");
  181. char *param_path = getenv("MGB_MACE_TUNING_PARAM_PATH");
  182. std::string opencl_param_path("");
  183. if (param_path != nullptr) {
  184. opencl_param_path = std::string(param_path);
  185. }
  186. std::string storage_path(cache_path);
  187. gpu_context = mace::GPUContextBuilder()
  188. .SetStoragePath(storage_path)
  189. .SetOpenCLParameterPath(opencl_param_path)
  190. .Finalize();
  191. config.SetGPUContext(gpu_context);
  192. config.SetGPUHints(
  193. static_cast<mace::GPUPerfHint>(mace::GPUPerfHint::PERF_HIGH),
  194. static_cast<mace::GPUPriorityHint>(mace::GPUPriorityHint::PRIORITY_HIGH));
  195. }
  196. std::vector<std::string> input_names, output_names;
  197. // extract all information from buf
  198. void *buffer = const_cast<void *>(buf);
  199. ud->nr_inputs = *reinterpret_cast<uint32_t*>(buffer);
  200. ud->nr_outputs = *(reinterpret_cast<uint32_t*>(buffer) + 1);
  201. // interpret input names
  202. char *name_buf = reinterpret_cast<char*>(buffer) + 8;
  203. for (size_t i = 0; i < ud->nr_inputs; i++) {
  204. size_t ilen = *reinterpret_cast<uint32_t*>(name_buf);
  205. input_names.push_back(std::string(name_buf + 4, ilen));
  206. name_buf += (ilen + 4);
  207. }
  208. // interpret output names
  209. buffer = name_buf;
  210. name_buf = reinterpret_cast<char*>(buffer);
  211. for (size_t i = 0; i < ud->nr_outputs; i++) {
  212. size_t olen = *reinterpret_cast<uint32_t*>(name_buf);
  213. output_names.push_back(std::string(name_buf + 4, olen));
  214. name_buf += (olen + 4);
  215. }
  216. ud->input_names = input_names;
  217. ud->output_names = output_names;
  218. // interpret output shapes
  219. buffer = name_buf;
  220. uint32_t *shape_buf = reinterpret_cast<uint32_t*>(buffer) + 1;
  221. for (size_t i = 0; i < ud->nr_outputs; i++) {
  222. size_t olen = *reinterpret_cast<int*>(shape_buf);
  223. ud->output_shapes.push_back(
  224. std::vector<int64_t>(shape_buf + 1, shape_buf + olen + 1)
  225. );
  226. shape_buf += (olen + 1);
  227. }
  228. buffer = shape_buf;
  229. const size_t model_buf_len = *reinterpret_cast<int*>(buffer);
  230. unsigned char *model_buf = reinterpret_cast<unsigned char*>(buffer) + 4;
  231. const size_t param_buf_len = *reinterpret_cast<int*>(model_buf + model_buf_len);
  232. unsigned char *param_buf = model_buf + model_buf_len + 4;
  233. // create mace engine
  234. auto create_engine_status = mace::CreateMaceEngineFromProto(
  235. model_buf,
  236. model_buf_len,
  237. param_buf,
  238. param_buf_len,
  239. input_names,
  240. output_names,
  241. config,
  242. &engine
  243. );
  244. ASSERT(create_engine_status == mace::MaceStatus::MACE_SUCCESS,
  245. "Error in creating mace engine");
  246. ud->engine = engine;
  247. auto ret = std::make_unique<MGBOprDesc>();
  248. mgb_init_opr_desc(ret.get(), ud->nr_outputs, "mace");
  249. #define a(n) ret->n = &n;
  250. MGB_OPR_DESC_FOREACH_MEM_FN(a);
  251. a(infer_dtype);
  252. #undef a
  253. ret->user_data = ud.release();
  254. return ret.release();
  255. }
  256. };
  257. class MGBOprLoaderImpl {
  258. static MGBOprDesc* create_desc(size_t nr_input, const void *buf,
  259. size_t buf_len)
  260. {
  261. return MGBOprDescImpl::make(nr_input, buf, buf_len);
  262. }
  263. public:
  264. static MGBOprLoader make() {
  265. return {"mace", create_desc};
  266. }
  267. };
  268. extern "C" {
  269. // public interface
  270. __attribute__((visibility("default")))
  271. void MGB_C_OPR_INIT_FUNC(const MGBExternCOprApi* (*get_api)(int))
  272. {
  273. const MGBExternCOprApi* api = get_api(MGB_EXTERN_C_OPR_VERSION);
  274. ASSERT(api, "Create api failed");
  275. MGBOprLoader loader = MGBOprLoaderImpl::make();
  276. api->register_loader(&loader);
  277. }
  278. } // extern "C"