You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_load_dump.cpp 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. /**
  2. * \file src/serialization/impl/opr_load_dump.cpp
  3. *
  4. * This file is part of MegBrain, a deep learning framework developed by Megvii.
  5. *
  6. * \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  7. *
  8. */
  9. #include "megbrain/serialization/opr_load_dump.h"
  10. #include "megbrain/opr/param_defs.h"
  11. #include "megbrain/serialization/file.h"
  12. #include "megbrain/serialization/helper.h"
  13. using namespace mgb;
  14. using namespace serialization;
  15. MGB_TYPEINFO_OBJ_IMPL(OprLoadContext);
  16. OprLoader OprLoadContext::make_opr_loader(const std::string &id) {
  17. auto &&maker = config().opr_loader_maker;
  18. mgb_throw_if(!maker, SerializationError,
  19. "opr_loader_maker not set in LoadConfig; but opr loader with "
  20. "id %s is needed", id.c_str());
  21. return maker(id);
  22. }
  23. template <>
  24. void OprDumpContextRawPOD::write_param(const DType& param) {
  25. if (m_check_param_tag) {
  26. uint32_t tag = megdnn::param::FakeSerializedDType::TAG;
  27. write_raw(&tag, sizeof(tag));
  28. }
  29. serialization::serialize_dtype(param, [this](const void* data, size_t len) {
  30. write_raw(data, len);
  31. });
  32. }
  33. template <>
  34. DType OprLoadContextRawPOD::read_param() {
  35. if (m_check_param_tag) {
  36. uint32_t tag;
  37. read_raw(&tag, sizeof(tag));
  38. mgb_throw_if(tag != megdnn::param::FakeSerializedDType::TAG,
  39. MegBrainError, "ERROR tag");
  40. }
  41. return serialization::deserialize_dtype(
  42. [this](void* data, size_t len) { read_raw(data, len); });
  43. }
  44. std::string OprLoadContextRawPOD::load_buf_with_len() {
  45. std::string ret;
  46. uint32_t size;
  47. read_raw(&size, sizeof(size));
  48. ret.resize(size);
  49. read_raw(&ret[0], size);
  50. return ret;
  51. }
  52. SharedBuffer OprLoadContextRawPOD::load_shared_buf_with_len() {
  53. uint32_t size;
  54. read_raw(&size, sizeof(size));
  55. return load_shared_buf(size);
  56. }
  57. void GraphDumpConfig::default_tensor_value_dumper(
  58. OutputFile &fout, const cg::OperatorNodeBase &/*opr*/,
  59. const HostTensorND &tensor) {
  60. auto size = tensor.layout().span().high_byte;
  61. fout.write(tensor.raw_ptr(), size);
  62. }
  63. void GraphLoadConfig::default_tensor_value_loader(
  64. void *ptr, const TensorLayout &layout, InputFile &fin) {
  65. auto sz = layout.span().high_byte;
  66. if (ptr) {
  67. fin.read(ptr, sz);
  68. } else {
  69. fin.skip(sz);
  70. }
  71. }
  72. SharedBuffer OprLoadContextRawPOD::load_shared_buf(size_t size) {
  73. std::shared_ptr<uint8_t> shptr{new uint8_t[size],
  74. [](uint8_t* p) { delete[] p; }};
  75. read_raw(shptr.get(), size);
  76. return {std::move(shptr), size};
  77. }
  78. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台