You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

user_allocator.cpp 2.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. #include "example.h"
  2. #if LITE_BUILD_WITH_MGE
  3. using namespace lite;
  4. using namespace example;
  5. namespace {
  6. class CheckAllocator : public lite::Allocator {
  7. public:
  8. //! allocate memory of size in the given device with the given align
  9. void* allocate(LiteDeviceType, int, size_t size, size_t align) override {
  10. #if defined(WIN32) || defined(_WIN32)
  11. return _aligned_malloc(size, align);
  12. #elif defined(__ANDROID__) || defined(ANDROID)
  13. return memalign(align, size);
  14. #else
  15. void* ptr = nullptr;
  16. auto err = posix_memalign(&ptr, align, size);
  17. if (!err) {
  18. printf("failed to malloc %zu bytes with align %zu", size, align);
  19. }
  20. return ptr;
  21. #endif
  22. };
  23. //! free the memory pointed by ptr in the given device
  24. void free(LiteDeviceType, int, void* ptr) override {
  25. #if defined(WIN32) || defined(_WIN32)
  26. _aligned_free(ptr);
  27. #else
  28. ::free(ptr);
  29. #endif
  30. };
  31. };
  32. bool config_user_allocator(const Args& args) {
  33. std::string network_path = args.model_path;
  34. std::string input_path = args.input_path;
  35. auto allocator = std::make_shared<CheckAllocator>();
  36. //! create and load the network
  37. std::shared_ptr<Network> network = std::make_shared<Network>();
  38. Runtime::set_memory_allocator(network, allocator);
  39. network->load_model(network_path);
  40. //! set input data to input tensor
  41. std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0);
  42. //! copy or forward data to network
  43. size_t length = input_tensor->get_tensor_total_size_in_byte();
  44. void* dst_ptr = input_tensor->get_memory_ptr();
  45. auto src_tensor = parse_npy(input_path);
  46. void* src = src_tensor->get_memory_ptr();
  47. memcpy(dst_ptr, src, length);
  48. //! forward
  49. network->forward();
  50. network->wait();
  51. //! get the output data or read tensor set in network_in
  52. std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0);
  53. void* out_data = output_tensor->get_memory_ptr();
  54. size_t out_length = output_tensor->get_tensor_total_size_in_byte() /
  55. output_tensor->get_layout().get_elem_size();
  56. printf("length=%zu\n", length);
  57. float max = -1.0f;
  58. float sum = 0.0f;
  59. for (size_t i = 0; i < out_length; i++) {
  60. float data = static_cast<float*>(out_data)[i];
  61. sum += data;
  62. if (max < data)
  63. max = data;
  64. }
  65. printf("max=%e, sum=%e\n", max, sum);
  66. return true;
  67. }
  68. } // namespace
  69. REGIST_EXAMPLE("config_user_allocator", config_user_allocator);
  70. #endif
  71. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}