You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

network_share_weights.cpp 2.7 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. #include "example.h"
  2. #if LITE_BUILD_WITH_MGE
  3. using namespace lite;
  4. using namespace example;
  5. namespace {
  6. bool network_share_same_weights(const Args& args) {
  7. std::string network_path = args.model_path;
  8. std::string input_path = args.input_path;
  9. //! create and load the network
  10. std::shared_ptr<Network> network = std::make_shared<Network>();
  11. network->load_model(network_path);
  12. //! load a new network from the created network and share the same weights,
  13. Config config_new;
  14. config_new.options.const_shape = true;
  15. NetworkIO network_io_new;
  16. std::shared_ptr<Network> weight_shared_network =
  17. std::make_shared<Network>(config_new, network_io_new);
  18. Runtime::shared_weight_with_network(weight_shared_network, network);
  19. //! set input data to input tensor
  20. std::shared_ptr<Tensor> input_tensor = network->get_input_tensor(0);
  21. void* dst_ptr = input_tensor->get_memory_ptr();
  22. std::shared_ptr<Tensor> input_tensor2 = weight_shared_network->get_input_tensor(0);
  23. void* dst_ptr2 = input_tensor2->get_memory_ptr();
  24. //! copy or forward data to network
  25. size_t length = input_tensor->get_tensor_total_size_in_byte();
  26. auto src_tensor = parse_npy(input_path);
  27. void* src = src_tensor->get_memory_ptr();
  28. memcpy(dst_ptr, src, length);
  29. memcpy(dst_ptr2, src, length);
  30. //! forward
  31. network->forward();
  32. network->wait();
  33. weight_shared_network->forward();
  34. weight_shared_network->wait();
  35. //! get the output data or read tensor set in network_in
  36. std::shared_ptr<Tensor> output_tensor = network->get_output_tensor(0);
  37. std::shared_ptr<Tensor> output_tensor2 =
  38. weight_shared_network->get_output_tensor(0);
  39. void* out_data = output_tensor->get_memory_ptr();
  40. void* out_data2 = output_tensor2->get_memory_ptr();
  41. size_t out_length = output_tensor->get_tensor_total_size_in_byte() /
  42. output_tensor->get_layout().get_elem_size();
  43. printf("length=%zu\n", length);
  44. float max = -1.0f;
  45. float sum = 0.0f;
  46. for (size_t i = 0; i < out_length; i++) {
  47. float data = static_cast<float*>(out_data)[i];
  48. float data2 = static_cast<float*>(out_data2)[i];
  49. if (data != data2) {
  50. printf("the result between the origin network and weight share "
  51. "netwrok is different.\n");
  52. }
  53. sum += data;
  54. if (max < data)
  55. max = data;
  56. }
  57. printf("max=%e, sum=%e\n", max, sum);
  58. return true;
  59. }
  60. } // namespace
  61. REGIST_EXAMPLE("network_share_same_weights", network_share_same_weights);
  62. #endif
  63. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}