You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dump_trt.py 1.1 kB

12345678910111213141516171819202122232425262728293031323334353637
  1. #!/usr/bin/env python3
  2. from megskull.network import RawNetworkBuilder
  3. import megskull.opr.all as O
  4. from megskull.opr.external import TensorRTRuntimeOpr
  5. from meghair.utils.io import dump
  6. import argparse
  7. def str2tuple(x):
  8. x = x.split(',')
  9. x = [int(a) for a in x]
  10. x = tuple(x)
  11. return x
  12. def make_network(model, isize):
  13. data = [O.DataProvider('input{}'.format(i), shape=isizes[i])
  14. for i in range(len(isizes))]
  15. f = open(model, 'rb')
  16. engine = f.read()
  17. opr = TensorRTRuntimeOpr(data, engine, 1)
  18. net = RawNetworkBuilder(inputs=[data], outputs=opr.outputs)
  19. return net
  20. if __name__ == "__main__":
  21. parser = argparse.ArgumentParser()
  22. parser.add_argument(dest = 'model')
  23. parser.add_argument(dest = 'output')
  24. parser.add_argument('--isize', help='input sizes. '
  25. 'e.g. for models with two (1,3,224,224) inputs, '
  26. 'the option --isize="1,3,224,224;1,3,224,224" should be used')
  27. args = parser.parse_args()
  28. isizes = [str2tuple(x) for x in args.isize.split(';')]
  29. net = make_network(args.model, isizes)
  30. dump(net, args.output)