You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pack_model_and_info.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import argparse
  10. import struct
  11. import os
  12. import subprocess
  13. import flatbuffers
  14. def generate_flatbuffer():
  15. status, path = subprocess.getstatusoutput('which flatc')
  16. if not status:
  17. cwd = os.path.dirname(os.path.dirname(__file__))
  18. fbs_file = os.path.abspath(os.path.join(cwd,
  19. "../../src/parse_model/pack_model.fbs"))
  20. cmd = path + ' -p -b '+fbs_file
  21. ret, _ = subprocess.getstatusoutput(str(cmd))
  22. if ret:
  23. raise Exception("flatc generate error!")
  24. else:
  25. raise Exception('no flatc in current environment, please build flatc '
  26. 'and put in the system PATH!')
  27. def main():
  28. parser = argparse.ArgumentParser(
  29. description='load a encrypted or not encrypted model and a '
  30. 'json format of the infomation of the model, pack them to a file '
  31. 'which can be loaded by lite.')
  32. parser.add_argument('--input-model', help='input a encrypted or not encrypted model')
  33. parser.add_argument('--input-info', help='input a encrypted or not encrypted '
  34. 'json format file.')
  35. parser.add_argument('--model-name', help='the model name, this must match '
  36. 'with the model name in model info', default = 'NONE')
  37. parser.add_argument('--model-cryption', help='the model encryption method '
  38. 'name, this is used to find the right decryption method. e.g. '
  39. '--model_cryption = "AES_default", default is NONE.', default =
  40. 'NONE')
  41. parser.add_argument('--info-cryption', help='the info encryption method '
  42. 'name, this is used to find the right decryption method. e.g. '
  43. '--model_cryption = "AES_default", default is NONE.', default =
  44. 'NONE')
  45. parser.add_argument('--info-parser', help='The information parse method name '
  46. 'default is "LITE_default". ', default = 'LITE_default')
  47. parser.add_argument('--append', '-a', help='append another model to a '
  48. 'packed model.')
  49. parser.add_argument('--output', '-o', help='output file of packed model.')
  50. args = parser.parse_args()
  51. generate_flatbuffer()
  52. assert not args.append, ('--append is not support yet')
  53. assert args.input_model, ('--input_model must be given')
  54. with open(args.input_model, 'rb') as fin:
  55. raw_model = fin.read()
  56. model_length = len(raw_model)
  57. if args.input_info:
  58. with open(args.input_info, 'rb') as fin:
  59. raw_info = fin.read()
  60. info_length = len(raw_info)
  61. else:
  62. raw_info = None
  63. info_length = 0
  64. # Generated by `flatc`.
  65. from model_parse import Model, ModelData, ModelHeader, ModelInfo, PackModel
  66. builder = flatbuffers.Builder(1024)
  67. model_name = builder.CreateString(args.model_name)
  68. model_cryption = builder.CreateString(args.model_cryption)
  69. info_cryption = builder.CreateString(args.info_cryption)
  70. info_parser = builder.CreateString(args.info_parser)
  71. info_data = builder.CreateByteVector(raw_info)
  72. arr_data = builder.CreateByteVector(raw_model)
  73. #model header
  74. ModelHeader.ModelHeaderStart(builder)
  75. ModelHeader.ModelHeaderAddName(builder, model_name)
  76. ModelHeader.ModelHeaderAddModelDecryptionMethod(builder, model_cryption)
  77. ModelHeader.ModelHeaderAddInfoDecryptionMethod(builder, info_cryption)
  78. ModelHeader.ModelHeaderAddInfoParseMethod(builder, info_parser)
  79. model_header = ModelHeader.ModelHeaderEnd(builder)
  80. #model info
  81. ModelInfo.ModelInfoStart(builder)
  82. ModelInfo.ModelInfoAddData(builder, info_data)
  83. model_info = ModelInfo.ModelInfoEnd(builder)
  84. #model data
  85. ModelData.ModelDataStart(builder)
  86. ModelData.ModelDataAddData(builder, arr_data)
  87. model_data = ModelData.ModelDataEnd(builder)
  88. Model.ModelStart(builder)
  89. Model.ModelAddHeader(builder, model_header)
  90. Model.ModelAddData(builder, model_data)
  91. Model.ModelAddInfo(builder, model_info)
  92. model = Model.ModelEnd(builder)
  93. PackModel.PackModelStartModelsVector(builder, 1)
  94. builder.PrependUOffsetTRelative(model)
  95. models = builder.EndVector(1)
  96. PackModel.PackModelStart(builder)
  97. PackModel.PackModelAddModels(builder, models)
  98. packed_model = PackModel.PackModelEnd(builder)
  99. builder.Finish(packed_model)
  100. buff = builder.Output()
  101. result = struct.pack(str(len("packed_model")) + 's', "packed_model".encode('ascii'))
  102. result += buff
  103. assert args.output, ('--output must be given')
  104. with open(args.output, 'wb') as fin:
  105. fin.write(result)
  106. print("Model packaged successfully!!!")
  107. print("model name is: {}.".format(args.model_name))
  108. print("model encryption method is: {}. ".format(args.model_cryption))
  109. print("model json infomation encryption method is: {}. ".format(args.info_cryption))
  110. print("model json infomation parse method is: {}. ".format(args.info_parser))
  111. print("packed model is write to {} ".format(args.output))
  112. if __name__ == '__main__':
  113. main()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台