You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dump_model.py 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import argparse
  10. import megengine._internal as mgb
  11. import numpy as np
  12. import yaml
  13. # "1,3,224,224" -> (1,3,224,224)
  14. def str2tuple(x):
  15. x = x.split(",")
  16. x = [int(a) for a in x]
  17. x = tuple(x)
  18. return x
  19. def main():
  20. parser = argparse.ArgumentParser(
  21. description="load a .pb model and convert to corresponding "
  22. "load-and-run model"
  23. )
  24. parser.add_argument("input", help="mace model file")
  25. parser.add_argument("param", help="mace param file")
  26. parser.add_argument(
  27. "output", help="converted model that can be fed to dump_with_testcase_mge.py"
  28. )
  29. parser.add_argument("config", help="config file with yaml format")
  30. args = parser.parse_args()
  31. with open(args.config, "r") as f:
  32. configs = yaml.load(f)
  33. for model_name in configs["models"]:
  34. # ignore several sub models currently
  35. sub_model = configs["models"][model_name]["subgraphs"][0]
  36. # input/output shapes
  37. isizes = [str2tuple(x) for x in sub_model["input_shapes"]]
  38. # input/output names
  39. input_names = sub_model["input_tensors"]
  40. if "check_tensors" in sub_model:
  41. output_names = sub_model["check_tensors"]
  42. osizes = [str2tuple(x) for x in sub_model["check_shapes"]]
  43. else:
  44. output_names = sub_model["output_tensors"]
  45. osizes = [str2tuple(x) for x in sub_model["output_shapes"]]
  46. with open(args.input, "rb") as fin:
  47. raw_model = fin.read()
  48. with open(args.param, "rb") as fin:
  49. raw_param = fin.read()
  50. model_size = (len(raw_model)).to_bytes(4, byteorder="little")
  51. param_size = (len(raw_param)).to_bytes(4, byteorder="little")
  52. n_inputs = (len(input_names)).to_bytes(4, byteorder="little")
  53. n_outputs = (len(output_names)).to_bytes(4, byteorder="little")
  54. names_buffer = n_inputs + n_outputs
  55. for iname in input_names:
  56. names_buffer += (len(iname)).to_bytes(4, byteorder="little")
  57. names_buffer += str.encode(iname)
  58. for oname in output_names:
  59. names_buffer += (len(oname)).to_bytes(4, byteorder="little")
  60. names_buffer += str.encode(oname)
  61. shapes_buffer = n_outputs
  62. for oshape in osizes:
  63. shapes_buffer += (len(oshape)).to_bytes(4, byteorder="little")
  64. for oi in oshape:
  65. shapes_buffer += oi.to_bytes(4, byteorder="little")
  66. # raw content contains:
  67. # input/output names + output shapes + model buffer + param buffer
  68. wk_raw_content = (
  69. names_buffer
  70. + shapes_buffer
  71. + model_size
  72. + raw_model
  73. + param_size
  74. + raw_param
  75. )
  76. # cn not ensured
  77. cn = mgb.comp_node("xpux")
  78. cg = mgb.comp_graph()
  79. inp = [
  80. mgb.make_shared(
  81. comp_node=cn,
  82. comp_graph=cg,
  83. shape=isizes[i],
  84. name=input_names[i],
  85. dtype=np.float32,
  86. )
  87. for i in range(len(isizes))
  88. ]
  89. oup = mgb.opr.extern_c_opr_placeholder(
  90. inp, osizes, dump_name="mace", dump_data=wk_raw_content,
  91. )
  92. mgb.serialize_comp_graph_to_file(args.output, oup)
  93. if __name__ == "__main__":
  94. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台