You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

user_allocator.cpp 2.7 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. /**
  2. * \file example/user_allocator.cpp
  3. *
  4. * This file is part of MegEngine, a deep learning framework developed by
  5. * Megvii.
  6. *
  7. * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
  8. */
  9. #include "../example.h"
  10. #if LITE_BUILD_WITH_MGE
  11. using namespace lite;
  12. using namespace example;
  13. namespace {
  14. class CheckAllocator : public lite::Allocator {
  15. public:
  16. //! allocate memory of size in the given device with the given align
  17. void* allocate(LiteDeviceType, int, size_t size, size_t align) override {
  18. #ifdef WIN32
  19. return _aligned_malloc(size, align);
  20. #elif defined(__ANDROID__) || defined(ANDROID)
  21. return memalign(align, size);
  22. #else
  23. void* ptr = nullptr;
  24. auto err = posix_memalign(&ptr, align, size);
  25. if (!err) {
  26. printf("failed to malloc %zu bytes with align %zu", size, align);
  27. }
  28. return ptr;
  29. #endif
  30. };
  31. //! free the memory pointed by ptr in the given device
  32. void free(LiteDeviceType, int, void* ptr) override {
  33. #ifdef WIN32
  34. _aligned_free(ptr);
  35. #else
  36. ::free(ptr);
  37. #endif
  38. };
  39. };
  40. } // namespace
  41. bool lite::example::config_user_allocator(const Args& args) {
  42. std::string network_path = args.model_path;
  43. std::string input_path = args.input_path;
  44. auto allocator = std::make_shared<CheckAllocator>();
  45. //! create and load the network
  46. std::shared_ptr<Network> network = std::make_shared<Network>();
  47. Runtime::set_memory_allocator(network, allocator);
  48. network->load_model(network_path);
  49. //! set input data to input tensor
  50. std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0);
  51. //! copy or forward data to network
  52. size_t length = input_tensor->get_tensor_total_size_in_byte();
  53. void* dst_ptr = input_tensor->get_memory_ptr();
  54. auto src_tensor = parse_npy(input_path);
  55. void* src = src_tensor->get_memory_ptr();
  56. memcpy(dst_ptr, src, length);
  57. //! forward
  58. network->forward();
  59. network->wait();
  60. //! get the output data or read tensor set in network_in
  61. std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0);
  62. void* out_data = output_tensor->get_memory_ptr();
  63. size_t out_length = output_tensor->get_tensor_total_size_in_byte() /
  64. output_tensor->get_layout().get_elem_size();
  65. printf("length=%zu\n", length);
  66. float max = -1.0f;
  67. float sum = 0.0f;
  68. for (size_t i = 0; i < out_length; i++) {
  69. float data = static_cast<float*>(out_data)[i];
  70. sum += data;
  71. if (max < data)
  72. max = data;
  73. }
  74. printf("max=%e, sum=%e\n", max, sum);
  75. return true;
  76. }
  77. #endif
  78. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台