|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-
- from megskull.graph import NodeFilter, FpropEnv
- from megskull.opr.all import AssertEqual, DataProvider, BatchNormalization
- from megskull.utils.logconf import get_logger
- from meghair.utils import io
- import megbrain as mgb
-
- import argparse
- import struct
- import re
- import os
-
- import numpy as np
- import cv2
-
- logger = get_logger(__name__)
-
- def optimize_for_inference(args, outputs):
- args_map = {
- 'enable_io16xc32': 'f16_io_f32_comp',
- 'enable_ioc16': 'f16_io_comp',
- 'enable_hwcd4': 'use_nhwcd4',
- 'enable_nchw4': 'use_nchw4',
- 'enable_nchw88': 'use_nchw88',
- 'enable_nchw44': 'use_nchw44',
- 'enable_nchw44_dot': 'use_nchw44_dot',
- 'enable_nchw32': 'use_nchw32',
- 'enable_chwn4': 'use_chwn4',
- 'enable_fuse_conv_bias_nonlinearity': 'fuse_conv_bias_nonlinearity',
- 'enable_fuse_conv_bias_with_z': 'fuse_conv_bias_with_z',
- }
- kwargs = {}
- for k, v in args_map.items():
- if getattr(args, k):
- assert args.optimize_for_inference, (
- 'optimize_for_inference should be set when {} is given'.format(
- k))
- kwargs[v] = True
-
- if args.optimize_for_inference:
- return mgb.optimize_for_inference(outputs, **kwargs)
-
- return outputs
-
- def main():
- parser = argparse.ArgumentParser(
- description='Dump the Python Megbrain model to C++ model, by the way '
- 'optimizing for inference',
- formatter_class=argparse.ArgumentDefaultsHelpFormatter
- )
- parser.add_argument('input', help='input pkl model file ')
- parser.add_argument('-o', '--output', help='output file', required=True)
- parser.add_argument('--init-bn', action='store_true',
- help='initialize untrained batch-normalization, to '
- 'avoid NaN or Inf results')
- parser.add_argument('--silent', action='store_true',
- help='set verbose to False in AssertEqual opr')
- parser.add_argument('--optimize-for-inference', action='store_true',
- help='enbale optimization for inference')
- parser.add_argument('--discard-var-name', action='store_true',
- help='discard variable and param names in the '
- 'generated output')
- parser.add_argument('--output-strip-info', action='store_true',
- help='output code strip information')
- parser.add_argument('--enable-io16xc32', action='store_true',
- help='transform the mode to float16 io float32 compute')
- parser.add_argument('--enable-ioc16', action='store_true',
- help='transform the dtype of the model to float16 io '
- 'and compute')
- parser.add_argument('--enable-fuse-conv-bias-nonlinearity',
- action='store_true',
- help='fuse convolution bias and nonlinearity opr to a '
- 'conv_bias opr and compute')
- parser.add_argument('--enable-hwcd4', action='store_true',
- help='transform the model format from NCHW to NHWCD4 '
- 'for inference; you may need to disable CUDA and set '
- 'MGB_USE_MEGDNN_DBG=2')
- parser.add_argument('--enable-nchw4', action='store_true',
- help='transform the model format from NCHW to NCHW4 '
- 'for inference')
- parser.add_argument('--enable-nchw88', action='store_true',
- help='transform the model format from NCHW to NCHW88 '
- 'for inference')
- parser.add_argument('--enable-nchw44', action='store_true',
- help='transform the model format from NCHW to NCHW44 '
- 'for inference')
- parser.add_argument('--enable-nchw44-dot', action='store_true',
- help='transform the model format from NCHW to NCHW44_DOT '
- 'for optimizing armv8.2 dot in inference')
- parser.add_argument('--enable-chwn4', action='store_true',
- help='transform the model format to CHWN4 '
- 'for inference, mainly used for nvidia tensorcore')
- parser.add_argument('--enable-nchw32', action='store_true',
- help='transform the model format from NCHW4 to NCHW32 '
- 'for inference on nvidia TensoCore')
- parser.add_argument('--enable-fuse-conv-bias-with-z', action='store_true',
- help='fuse conv_bias with z input for inference on '
- 'nvidia GPU (this optimization pass will result in mismatch '
- 'of the precision of output of training and inference)')
- args = parser.parse_args()
-
- env = FpropEnv(verbose_fprop=False)
-
-
- outputs = io.load_network(args.input).outputs
-
- output_mgbvars = list(map(env.get_mgbvar, outputs))
-
- output_mgbvars = optimize_for_inference(args, output_mgbvars)
-
- if args.discard_var_name:
- sereg_kwargs = dict(keep_var_name=0, keep_param_name=False)
- else:
- sereg_kwargs = dict(keep_var_name=2, keep_param_name=True)
-
- stat = mgb.serialize_comp_graph_to_file(
- args.output, output_mgbvars, append=False,
- output_strip_info=args.output_strip_info,
- **sereg_kwargs)
- logger.info('graph dump sizes: tot_size={:.3f}KiB overhead={:.3f}KiB'.
- format(stat.tot_bytes / 1024,
- (stat.tot_bytes - stat.tensor_value_bytes) / 1024))
-
- if __name__ == '__main__':
- main()
|