You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dump_model_mgb.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # -*- coding: utf-8 -*-
  2. from megskull.graph import NodeFilter, FpropEnv
  3. from megskull.opr.all import AssertEqual, DataProvider, BatchNormalization
  4. from megskull.utils.logconf import get_logger
  5. from meghair.utils import io
  6. import megbrain as mgb
  7. import argparse
  8. import struct
  9. import re
  10. import os
  11. import numpy as np
  12. import cv2
  13. logger = get_logger(__name__)
  14. def optimize_for_inference(args, outputs):
  15. args_map = {
  16. 'enable_io16xc32': 'f16_io_f32_comp',
  17. 'enable_ioc16': 'f16_io_comp',
  18. 'enable_hwcd4': 'use_nhwcd4',
  19. 'enable_nchw4': 'use_nchw4',
  20. 'enable_nchw88': 'use_nchw88',
  21. 'enable_nchw44': 'use_nchw44',
  22. 'enable_nchw44_dot': 'use_nchw44_dot',
  23. 'enable_nchw32': 'use_nchw32',
  24. 'enable_chwn4': 'use_chwn4',
  25. 'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity',
  26. 'enable_fuse_conv_bias_with_z': 'fuse_conv_bias_with_z',
  27. }
  28. kwargs = {}
  29. for k, v in args_map.items():
  30. if getattr(args, k):
  31. assert args.optimize_for_inference, (
  32. 'optimize_for_inference should be set when {} is given'.format(
  33. k))
  34. kwargs[v] = True
  35. if args.optimize_for_inference:
  36. return mgb.optimize_for_inference(outputs, **kwargs)
  37. return outputs
  38. def main():
  39. parser = argparse.ArgumentParser(
  40. description='Dump the Python Megbrain model to C++ model, by the way '
  41. 'optimizing for inference',
  42. formatter_class=argparse.ArgumentDefaultsHelpFormatter
  43. )
  44. parser.add_argument('input', help='input pkl model file ')
  45. parser.add_argument('-o', '--output', help='output file', required=True)
  46. parser.add_argument('--init-bn', action='store_true',
  47. help='initialize untrained batch-normalization, to '
  48. 'avoid NaN or Inf results')
  49. parser.add_argument('--silent', action='store_true',
  50. help='set verbose to False in AssertEqual opr')
  51. parser.add_argument('--optimize-for-inference', action='store_true',
  52. help='enbale optimization for inference')
  53. parser.add_argument('--discard-var-name', action='store_true',
  54. help='discard variable and param names in the '
  55. 'generated output')
  56. parser.add_argument('--output-strip-info', action='store_true',
  57. help='output code strip information')
  58. parser.add_argument('--enable-io16xc32', action='store_true',
  59. help='transform the mode to float16 io float32 compute')
  60. parser.add_argument('--enable-ioc16', action='store_true',
  61. help='transform the dtype of the model to float16 io '
  62. 'and compute')
  63. parser.add_argument('--enable-fuse-conv-bias-nonlinearity',
  64. action='store_true',
  65. help='fuse convolution bias and nonlinearity opr to a '
  66. 'conv_bias opr and compute')
  67. parser.add_argument('--enable-hwcd4', action='store_true',
  68. help='transform the model format from NCHW to NHWCD4 '
  69. 'for inference; you may need to disable CUDA and set '
  70. 'MGB_USE_MEGDNN_DBG=2')
  71. parser.add_argument('--enable-nchw4', action='store_true',
  72. help='transform the model format from NCHW to NCHW4 '
  73. 'for inference')
  74. parser.add_argument('--enable-nchw88', action='store_true',
  75. help='transform the model format from NCHW to NCHW88 '
  76. 'for inference')
  77. parser.add_argument('--enable-nchw44', action='store_true',
  78. help='transform the model format from NCHW to NCHW44 '
  79. 'for inference')
  80. parser.add_argument('--enable-nchw44-dot', action='store_true',
  81. help='transform the model format from NCHW to NCHW44_DOT '
  82. 'for optimizing armv8.2 dot in inference')
  83. parser.add_argument('--enable-chwn4', action='store_true',
  84. help='transform the model format to CHWN4 '
  85. 'for inference, mainly used for nvidia tensorcore')
  86. parser.add_argument('--enable-nchw32', action='store_true',
  87. help='transform the model format from NCHW4 to NCHW32 '
  88. 'for inference on nvidia TensoCore')
  89. parser.add_argument('--enable-fuse-conv-bias-with-z', action='store_true',
  90. help='fuse conv_bias with z input for inference on '
  91. 'nvidia GPU (this optimization pass will result in mismatch '
  92. 'of the precision of output of training and inference)')
  93. args = parser.parse_args()
  94. env = FpropEnv(verbose_fprop=False)
  95. outputs = io.load_network(args.input).outputs
  96. output_mgbvars = list(map(env.get_mgbvar, outputs))
  97. output_mgbvars = optimize_for_inference(args, output_mgbvars)
  98. if args.discard_var_name:
  99. sereg_kwargs = dict(keep_var_name=0, keep_param_name=False)
  100. else:
  101. sereg_kwargs = dict(keep_var_name=2, keep_param_name=True)
  102. stat = mgb.serialize_comp_graph_to_file(
  103. args.output, output_mgbvars, append=False,
  104. output_strip_info=args.output_strip_info,
  105. **sereg_kwargs)
  106. logger.info('graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB'.
  107. format(stat.tot_bytes / 1024,
  108. (stat.tot_bytes - stat.tensor_value_bytes) / 1024))
  109. if __name__ == '__main__':
  110. main()