You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pack_model_and_info.py 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # -*- coding: utf-8 -*-
  2. import argparse
  3. import struct
  4. import os
  5. import subprocess
  6. import flatbuffers
  7. def generate_flatbuffer():
  8. status, path = subprocess.getstatusoutput('which flatc')
  9. if not status:
  10. cwd = os.path.dirname(os.path.dirname(__file__))
  11. fbs_file = os.path.abspath(os.path.join(cwd,
  12. "../../src/parse_model/pack_model.fbs"))
  13. cmd = path + ' -p -b '+fbs_file
  14. ret, _ = subprocess.getstatusoutput(str(cmd))
  15. if ret:
  16. raise Exception("flatc generate error!")
  17. else:
  18. raise Exception('no flatc in current environment, please build flatc '
  19. 'and put in the system PATH!')
  20. def main():
  21. parser = argparse.ArgumentParser(
  22. description='load a encrypted or not encrypted model and a '
  23. 'json format of the infomation of the model, pack them to a file '
  24. 'which can be loaded by lite.')
  25. parser.add_argument('--input-model', help='input a encrypted or not encrypted model')
  26. parser.add_argument('--input-info', help='input a encrypted or not encrypted '
  27. 'json format file.')
  28. parser.add_argument('--model-name', help='the model name, this must match '
  29. 'with the model name in model info', default = 'NONE')
  30. parser.add_argument('--model-cryption', help='the model encryption method '
  31. 'name, this is used to find the right decryption method. e.g. '
  32. '--model_cryption = "AES_default", default is NONE.', default =
  33. 'NONE')
  34. parser.add_argument('--info-cryption', help='the info encryption method '
  35. 'name, this is used to find the right decryption method. e.g. '
  36. '--model_cryption = "AES_default", default is NONE.', default =
  37. 'NONE')
  38. parser.add_argument('--info-parser', help='The information parse method name '
  39. 'default is "LITE_default". ', default = 'LITE_default')
  40. parser.add_argument('--append', '-a', help='append another model to a '
  41. 'packed model.')
  42. parser.add_argument('--output', '-o', help='output file of packed model.')
  43. args = parser.parse_args()
  44. generate_flatbuffer()
  45. assert not args.append, ('--append is not support yet')
  46. assert args.input_model, ('--input_model must be given')
  47. with open(args.input_model, 'rb') as fin:
  48. raw_model = fin.read()
  49. model_length = len(raw_model)
  50. if args.input_info:
  51. with open(args.input_info, 'rb') as fin:
  52. raw_info = fin.read()
  53. info_length = len(raw_info)
  54. else:
  55. raw_info = None
  56. info_length = 0
  57. # Generated by `flatc`.
  58. from model_parse import Model, ModelData, ModelHeader, ModelInfo, PackModel
  59. builder = flatbuffers.Builder(1024)
  60. model_name = builder.CreateString(args.model_name)
  61. model_cryption = builder.CreateString(args.model_cryption)
  62. info_cryption = builder.CreateString(args.info_cryption)
  63. info_parser = builder.CreateString(args.info_parser)
  64. info_data = builder.CreateByteVector(raw_info)
  65. arr_data = builder.CreateByteVector(raw_model)
  66. #model header
  67. ModelHeader.ModelHeaderStart(builder)
  68. ModelHeader.ModelHeaderAddName(builder, model_name)
  69. ModelHeader.ModelHeaderAddModelDecryptionMethod(builder, model_cryption)
  70. ModelHeader.ModelHeaderAddInfoDecryptionMethod(builder, info_cryption)
  71. ModelHeader.ModelHeaderAddInfoParseMethod(builder, info_parser)
  72. model_header = ModelHeader.ModelHeaderEnd(builder)
  73. #model info
  74. ModelInfo.ModelInfoStart(builder)
  75. ModelInfo.ModelInfoAddData(builder, info_data)
  76. model_info = ModelInfo.ModelInfoEnd(builder)
  77. #model data
  78. ModelData.ModelDataStart(builder)
  79. ModelData.ModelDataAddData(builder, arr_data)
  80. model_data = ModelData.ModelDataEnd(builder)
  81. Model.ModelStart(builder)
  82. Model.ModelAddHeader(builder, model_header)
  83. Model.ModelAddData(builder, model_data)
  84. Model.ModelAddInfo(builder, model_info)
  85. model = Model.ModelEnd(builder)
  86. PackModel.PackModelStartModelsVector(builder, 1)
  87. builder.PrependUOffsetTRelative(model)
  88. models = builder.EndVector()
  89. PackModel.PackModelStart(builder)
  90. PackModel.PackModelAddModels(builder, models)
  91. packed_model = PackModel.PackModelEnd(builder)
  92. builder.Finish(packed_model)
  93. buff = builder.Output()
  94. result = struct.pack(str(len("packed_model")) + 's', "packed_model".encode('ascii'))
  95. result += buff
  96. assert args.output, ('--output must be given')
  97. with open(args.output, 'wb') as fin:
  98. fin.write(result)
  99. print("Model packaged successfully!!!")
  100. print("model name is: {}.".format(args.model_name))
  101. print("model encryption method is: {}. ".format(args.model_cryption))
  102. print("model json infomation encryption method is: {}. ".format(args.info_cryption))
  103. print("model json infomation parse method is: {}. ".format(args.info_parser))
  104. print("packed model is write to {} ".format(args.output))
  105. if __name__ == '__main__':
  106. main()