You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

xor_deploy.cpp 1.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. #include <stdlib.h>
  2. #include <iostream>
  3. #include "megbrain/serialization/serializer.h"
  4. using namespace mgb;
  5. cg::ComputingGraph::OutputSpecItem make_callback_copy(SymbolVar dev,
  6. HostTensorND& host) {
  7. auto cb = [&host](DeviceTensorND& d) { host.copy_from(d); };
  8. return {dev, cb};
  9. }
  10. int main(int argc, char* argv[]) {
  11. std::cout << " Usage: ./xornet_deploy model_name x_value y_value"
  12. << std::endl;
  13. if (argc != 4) {
  14. std::cout << " Wrong argument" << std::endl;
  15. return 0;
  16. }
  17. std::unique_ptr<serialization::InputFile> inp_file =
  18. serialization::InputFile::make_fs(argv[1]);
  19. float x = atof(argv[2]);
  20. float y = atof(argv[3]);
  21. auto loader = serialization::GraphLoader::make(std::move(inp_file));
  22. serialization::GraphLoadConfig config;
  23. serialization::GraphLoader::LoadResult network =
  24. loader->load(config, false);
  25. auto data = network.tensor_map["data"];
  26. float* data_ptr = data->resize({1, 2}).ptr<float>();
  27. data_ptr[0] = x;
  28. data_ptr[1] = y;
  29. HostTensorND predict;
  30. std::unique_ptr<cg::AsyncExecutable> func =
  31. network.graph->compile({make_callback_copy(
  32. network.output_var_map.begin()->second, predict)});
  33. func->execute();
  34. func->wait();
  35. float* predict_ptr = predict.ptr<float>();
  36. std::cout << " Predicted: " << predict_ptr[0] << " " << predict_ptr[1]
  37. << std::endl;
  38. }