You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pack_model_and_info.py 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. #
  4. # This file is part of MegEngine, a deep learning framework developed by
  5. # Megvii.
  6. #
  7. # copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
  8. import argparse
  9. import struct
  10. import os
  11. import subprocess
  12. import flatbuffers
  13. def generate_flatbuffer():
  14. status, path = subprocess.getstatusoutput('which flatc')
  15. if not status:
  16. cwd = os.path.dirname(os.path.dirname(__file__))
  17. fbs_file = os.path.abspath(os.path.join(cwd,
  18. "../../src/parse_model/pack_model.fbs"))
  19. cmd = path + ' -p -b '+fbs_file
  20. ret, _ = subprocess.getstatusoutput(str(cmd))
  21. if ret:
  22. raise Exception("flatc generate error!")
  23. else:
  24. raise Exception('no flatc in current environment, please build flatc '
  25. 'and put in the system PATH!')
  26. def main():
  27. parser = argparse.ArgumentParser(
  28. description='load a encrypted or not encrypted model and a '
  29. 'json format of the infomation of the model, pack them to a file '
  30. 'which can be loaded by lite.')
  31. parser.add_argument('--input-model', help='input a encrypted or not encrypted model')
  32. parser.add_argument('--input-info', help='input a encrypted or not encrypted '
  33. 'json format file.')
  34. parser.add_argument('--model-name', help='the model name, this must match '
  35. 'with the model name in model info', default = 'NONE')
  36. parser.add_argument('--model-cryption', help='the model encryption method '
  37. 'name, this is used to find the right decryption method. e.g. '
  38. '--model_cryption = "AES_default", default is NONE.', default =
  39. 'NONE')
  40. parser.add_argument('--info-cryption', help='the info encryption method '
  41. 'name, this is used to find the right decryption method. e.g. '
  42. '--model_cryption = "AES_default", default is NONE.', default =
  43. 'NONE')
  44. parser.add_argument('--info-parser', help='The information parse method name '
  45. 'default is "LITE_default". ', default = 'LITE_default')
  46. parser.add_argument('--append', '-a', help='append another model to a '
  47. 'packed model.')
  48. parser.add_argument('--output', '-o', help='output file of packed model.')
  49. args = parser.parse_args()
  50. generate_flatbuffer()
  51. assert not args.append, ('--append is not support yet')
  52. assert args.input_model, ('--input_model must be given')
  53. with open(args.input_model, 'rb') as fin:
  54. raw_model = fin.read()
  55. model_length = len(raw_model)
  56. if args.input_info:
  57. with open(args.input_info, 'rb') as fin:
  58. raw_info = fin.read()
  59. info_length = len(raw_info)
  60. else:
  61. raw_info = None
  62. info_length = 0
  63. # Generated by `flatc`.
  64. from model_parse import Model, ModelData, ModelHeader, ModelInfo, PackModel
  65. builder = flatbuffers.Builder(1024)
  66. model_name = builder.CreateString(args.model_name)
  67. model_cryption = builder.CreateString(args.model_cryption)
  68. info_cryption = builder.CreateString(args.info_cryption)
  69. info_parser = builder.CreateString(args.info_parser)
  70. info_data = builder.CreateByteVector(raw_info)
  71. arr_data = builder.CreateByteVector(raw_model)
  72. #model header
  73. ModelHeader.ModelHeaderStart(builder)
  74. ModelHeader.ModelHeaderAddName(builder, model_name)
  75. ModelHeader.ModelHeaderAddModelDecryptionMethod(builder, model_cryption)
  76. ModelHeader.ModelHeaderAddInfoDecryptionMethod(builder, info_cryption)
  77. ModelHeader.ModelHeaderAddInfoParseMethod(builder, info_parser)
  78. model_header = ModelHeader.ModelHeaderEnd(builder)
  79. #model info
  80. ModelInfo.ModelInfoStart(builder)
  81. ModelInfo.ModelInfoAddData(builder, info_data)
  82. model_info = ModelInfo.ModelInfoEnd(builder)
  83. #model data
  84. ModelData.ModelDataStart(builder)
  85. ModelData.ModelDataAddData(builder, arr_data)
  86. model_data = ModelData.ModelDataEnd(builder)
  87. Model.ModelStart(builder)
  88. Model.ModelAddHeader(builder, model_header)
  89. Model.ModelAddData(builder, model_data)
  90. Model.ModelAddInfo(builder, model_info)
  91. model = Model.ModelEnd(builder)
  92. PackModel.PackModelStartModelsVector(builder, 1)
  93. builder.PrependUOffsetTRelative(model)
  94. models = builder.EndVector(1)
  95. PackModel.PackModelStart(builder)
  96. PackModel.PackModelAddModels(builder, models)
  97. packed_model = PackModel.PackModelEnd(builder)
  98. builder.Finish(packed_model)
  99. buff = builder.Output()
  100. result = struct.pack(str(len("packed_model")) + 's', "packed_model".encode('ascii'))
  101. result += buff
  102. assert args.output, ('--output must be given')
  103. with open(args.output, 'wb') as fin:
  104. fin.write(result)
  105. print("Model packaged successfully!!!")
  106. print("model name is: {}.".format(args.model_name))
  107. print("model encryption method is: {}. ".format(args.model_cryption))
  108. print("model json infomation encryption method is: {}. ".format(args.info_cryption))
  109. print("model json infomation parse method is: {}. ".format(args.info_parser))
  110. print("packed model is write to {} ".format(args.output))
  111. if __name__ == '__main__':
  112. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台