You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dump_model.py 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # -*- coding: utf-8 -*-
  2. import argparse
  3. import numpy as np
  4. import yaml
  5. from megengine import jit, tensor
  6. from megengine.module.external import ExternOprSubgraph
  7. # "1,3,224,224" -> (1,3,224,224)
  8. def str2tuple(x):
  9. x = x.split(",")
  10. x = [int(a) for a in x]
  11. x = tuple(x)
  12. return x
  13. def main():
  14. parser = argparse.ArgumentParser(
  15. description="load a .pb model and convert to corresponding "
  16. "load-and-run model"
  17. )
  18. parser.add_argument("--input", help="mace model file")
  19. parser.add_argument("--param", help="mace param file")
  20. parser.add_argument(
  21. "--output", help="converted mge model"
  22. )
  23. parser.add_argument("--config", help="config file with yaml format")
  24. args = parser.parse_args()
  25. with open(args.config, "r") as f:
  26. configs = yaml.load(f)
  27. for model_name in configs["models"]:
  28. # ignore several sub models currently
  29. sub_model = configs["models"][model_name]["subgraphs"][0]
  30. # input/output shapes
  31. isizes = [str2tuple(x) for x in sub_model["input_shapes"]]
  32. # input/output names
  33. input_names = sub_model["input_tensors"]
  34. if "check_tensors" in sub_model:
  35. output_names = sub_model["check_tensors"]
  36. osizes = [str2tuple(x) for x in sub_model["check_shapes"]]
  37. else:
  38. output_names = sub_model["output_tensors"]
  39. osizes = [str2tuple(x) for x in sub_model["output_shapes"]]
  40. with open(args.input, "rb") as fin:
  41. raw_model = fin.read()
  42. with open(args.param, "rb") as fin:
  43. raw_param = fin.read()
  44. model_size = (len(raw_model)).to_bytes(4, byteorder="little")
  45. param_size = (len(raw_param)).to_bytes(4, byteorder="little")
  46. n_inputs = (len(input_names)).to_bytes(4, byteorder="little")
  47. n_outputs = (len(output_names)).to_bytes(4, byteorder="little")
  48. names_buffer = n_inputs + n_outputs
  49. for iname in input_names:
  50. names_buffer += (len(iname)).to_bytes(4, byteorder="little")
  51. names_buffer += str.encode(iname)
  52. for oname in output_names:
  53. names_buffer += (len(oname)).to_bytes(4, byteorder="little")
  54. names_buffer += str.encode(oname)
  55. shapes_buffer = n_outputs
  56. for oshape in osizes:
  57. shapes_buffer += (len(oshape)).to_bytes(4, byteorder="little")
  58. for oi in oshape:
  59. shapes_buffer += oi.to_bytes(4, byteorder="little")
  60. # raw content contains:
  61. # input/output names + output shapes + model buffer + param buffer
  62. wk_raw_content = (
  63. names_buffer
  64. + shapes_buffer
  65. + model_size
  66. + raw_model
  67. + param_size
  68. + raw_param
  69. )
  70. net = ExternOprSubgraph(osizes, "mace", wk_raw_content)
  71. net.eval()
  72. @jit.trace(record_only=True)
  73. def inference(inputs):
  74. return net(inputs)
  75. inputs = [
  76. tensor(np.random.random(isizes[i]).astype(np.float32)) for i in range(len(isizes))
  77. ]
  78. inference(*inputs)
  79. inference.dump(args.output)
  80. if __name__ == "__main__":
  81. main()