You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

picture_classification.cpp 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. /**
  2. * \file example/cpp_example/cv/picture_classification.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include <thread>
  12. #include "../../example.h"
  13. #if LITE_BUILD_WITH_MGE
  14. #include <cstdio>
  15. #include "misc.h"
  16. #define STB_IMAGE_STATIC
  17. #define STB_IMAGE_IMPLEMENTATION
  18. #include "stb_image.h"
  19. #define STB_IMAGE_RESIZE_IMPLEMENTATION
  20. #include "stb_image_resize.h"
  21. #define STB_IMAGE_WRITE_IMPLEMENTATION
  22. #include "stb_image_write.h"
  23. using namespace lite;
  24. using namespace example;
  25. namespace {
  26. void preprocess_image(std::string pic_path, std::shared_ptr<Tensor> tensor) {
  27. int width, height, channel;
  28. uint8_t* image = stbi_load(pic_path.c_str(), &width, &height, &channel, 0);
  29. printf("Input image %s with height=%d, width=%d, channel=%d\n", pic_path.c_str(),
  30. width, height, channel);
  31. auto layout = tensor->get_layout();
  32. auto pixels = layout.shapes[2] * layout.shapes[3];
  33. for (size_t i = 0; i < layout.ndim; i++) {
  34. printf("model input shape[%zu]=%zu \n", i, layout.shapes[i]);
  35. }
  36. //! resize to tensor shape
  37. std::shared_ptr<std::vector<uint8_t>> resize_int8 =
  38. std::make_shared<std::vector<uint8_t>>(pixels * channel);
  39. stbir_resize_uint8(
  40. image, width, height, 0, resize_int8->data(), layout.shapes[2],
  41. layout.shapes[3], 0, channel);
  42. stbi_image_free(image);
  43. //! convert form rgba to bgr, relayout from hwc to chw, normalization copy to tensor
  44. float* in_data = static_cast<float*>(tensor->get_memory_ptr());
  45. for (size_t i = 0; i < pixels; i++) {
  46. in_data[i + 2 * pixels] = (resize_int8->at(i * channel + 0) - 123.675) / 58.395;
  47. in_data[i + 1 * pixels] = (resize_int8->at(i * channel + 1) - 116.280) / 57.120;
  48. in_data[i + 0 * pixels] = (resize_int8->at(i * channel + 2) - 103.530) / 57.375;
  49. }
  50. }
  51. void classfication_process(
  52. std::shared_ptr<Tensor> tensor, float& score, size_t& class_id) {
  53. auto layout = tensor->get_layout();
  54. for (size_t i = 0; i < layout.ndim; i++) {
  55. printf("model output shape[%zu]=%zu \n", i, layout.shapes[i]);
  56. }
  57. size_t nr_data = tensor->get_tensor_total_size_in_byte() / layout.get_elem_size();
  58. float* data = static_cast<float*>(tensor->get_memory_ptr());
  59. score = data[0];
  60. class_id = 0;
  61. float sum = data[0];
  62. for (size_t i = 1; i < nr_data; i++) {
  63. if (score < data[i]) {
  64. score = data[i];
  65. class_id = i;
  66. }
  67. sum += data[i];
  68. }
  69. printf("output tensor sum is %f\n", sum);
  70. }
  71. } // namespace
  72. bool lite::example::picture_classification(const Args& args) {
  73. std::string network_path = args.model_path;
  74. std::string input_path = args.input_path;
  75. //! create and load the network
  76. std::shared_ptr<Network> network = std::make_shared<Network>();
  77. network->load_model(network_path);
  78. //! set input data to input tensor
  79. std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0);
  80. //! copy or forward data to network
  81. preprocess_image(args.input_path, input_tensor);
  82. printf("Begin forward.\n");
  83. network->forward();
  84. network->wait();
  85. printf("End forward.\n");
  86. //! get the output data or read tensor set in network_in
  87. size_t class_id;
  88. float score;
  89. auto output_tensor = network->get_output_tensor(0);
  90. classfication_process(output_tensor, score, class_id);
  91. printf("Picture %s is class_id %zu, with score %f\n", args.input_path.c_str(),
  92. class_id, score);
  93. return 0;
  94. }
  95. #endif
  96. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}