You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resnet50_mgb.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. #!/usr/bin/env python3
  2. import megbrain as mgb
  3. from megskull.graph import FpropEnv
  4. import megskull as mgsk
  5. from megskull.opr.compatible.caffepool import CaffePooling2D
  6. from megskull.opr.arith import ReLU
  7. from megskull.opr.all import (
  8. DataProvider, Conv2D, Pooling2D, FullyConnected,
  9. Softmax, Dropout, BatchNormalization, CrossEntropyLoss,
  10. ElementwiseAffine, WarpPerspective, WarpPerspectiveWeightProducer,
  11. WeightDecay, ParamProvider, ConvBiasActivation, ElemwiseMultiType)
  12. from megskull.network import RawNetworkBuilder
  13. from megskull.utils.debug import CallbackInjector
  14. import megskull.opr.helper.param_init as pinit
  15. from megskull.opr.helper.elemwise_trans import Identity
  16. from megskull.opr.netsrc import DataProvider
  17. from megskull.opr.cnn import Conv2D, Pooling2D, FullyConnected, Softmax, Conv2DImplHelper
  18. from megskull.opr.loss import CrossEntropyLoss
  19. from megskull.opr.regularizer import Dropout, BatchNormalization
  20. from megskull.opr.arith import Add, ReLU
  21. from megskull.opr.netsrc import ConstProvider
  22. from megskull.network import RawNetworkBuilder
  23. import numpy as np
  24. from megskull.network import RawNetworkBuilder, NetworkVisitor
  25. from megskull.graph import iter_dep_opr
  26. from megskull.utils.misc import get_2dshape
  27. import functools
  28. import re
  29. import fnmatch
  30. import argparse
  31. import sys
  32. def create_bn_relu_float(conv_name, f_in, ksize, stride, pad, num_outputs,
  33. has_relu, args):
  34. f = Conv2D(conv_name, f_in, kernel_shape=ksize, stride=stride,
  35. padding=pad, output_nr_channel=num_outputs,
  36. nonlinearity=mgsk.opr.helper.elemwise_trans.Identity())
  37. if has_relu:
  38. f = ReLU(f)
  39. return f
  40. def get_num_inputs(feature, format):
  41. if format == 'NCHW':
  42. return feature.partial_shape[1]
  43. else:
  44. assert format == 'NCHW4'
  45. return feature.partial_shape[1] * 4
  46. def create_bn_relu(prefix, f_in, ksize, stride, pad, num_outputs,
  47. has_relu, conv_name_fun, args):
  48. if conv_name_fun:
  49. conv_name = conv_name_fun(prefix)
  50. else:
  51. conv_name = prefix
  52. return create_bn_relu_float(conv_name, f_in, ksize, stride, pad,
  53. num_outputs, has_relu, args)
  54. def create_bottleneck(prefix, f_in, stride, num_outputs1, num_outputs2, args,
  55. has_proj=False):
  56. proj = f_in
  57. if has_proj:
  58. proj = create_bn_relu(prefix, f_in, ksize=1, stride=stride, pad=0,
  59. num_outputs=num_outputs2, has_relu=False,
  60. conv_name_fun=lambda p: "interstellar{}_branch1".format(
  61. p), args=args)
  62. f = create_bn_relu(prefix, f_in, ksize=1, stride=1, pad=0,
  63. num_outputs=num_outputs1, has_relu=True,
  64. conv_name_fun=lambda p: "interstellar{}_branch2a".format(
  65. p), args=args)
  66. f = create_bn_relu(prefix, f, ksize=3, stride=stride, pad=1,
  67. num_outputs=num_outputs1, has_relu=True,
  68. conv_name_fun=lambda p: "interstellar{}_branch2b".format(
  69. p), args=args)
  70. f = create_bn_relu(prefix, f, ksize=1, stride=1, pad=0,
  71. num_outputs=num_outputs2, has_relu=False,
  72. conv_name_fun=lambda p: "interstellar{}_branch2c".format(
  73. p), args=args)
  74. f = ReLU(f + proj)
  75. return f
  76. def get(args):
  77. img_size = 224
  78. num_inputs = 3
  79. data = DataProvider('data', shape=(args.batch_size, num_inputs,
  80. img_size, img_size))
  81. inp = data
  82. f = create_bn_relu("conv1", inp, ksize=7, stride=2, pad=3, num_outputs=64,
  83. has_relu=True, conv_name_fun=None,
  84. args=args)
  85. f = Pooling2D("pool1", f, window=3, stride=2, padding=1, mode="MAX",
  86. format=args.format)
  87. pre = [2, 3, 4, 5]
  88. stages = [3, 4, 6, 3]
  89. mid_outputs = [64, 128, 256, 512]
  90. enable_stride = [False, True, True, True]
  91. for p, s, o, es in zip(pre, stages, mid_outputs, enable_stride):
  92. for i in range(s):
  93. has_proj = False if i > 0 else True
  94. stride = 1 if not es or i > 0 else 2
  95. prefix = "{}{}".format(p, chr(ord("a") + i))
  96. f = create_bottleneck(prefix, f, stride, o, o * 4, args, has_proj)
  97. print("{}\t{}".format(prefix, f.partial_shape))
  98. f = Pooling2D("pool5", f, window=7, stride=7, padding=0, mode="AVERAGE",
  99. format=args.format)
  100. f = FullyConnected("fc1000", f, output_dim=1000,
  101. nonlinearity=mgsk.opr.helper.elemwise_trans.Identity())
  102. f = Softmax("cls_softmax", f)
  103. f.init_weights()
  104. net = RawNetworkBuilder(inputs=[data], outputs=[f])
  105. return net
  106. if __name__ == '__main__':
  107. parser = argparse.ArgumentParser(
  108. description='dump pkl model for resnet50',
  109. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  110. parser.add_argument('-b', '--batch-size',
  111. help='batch size of the model', default=1)
  112. parser.add_argument('-f', '--format', choices=['NCHW', 'NCHW4'],
  113. help='format of conv',
  114. default='NCHW')
  115. parser.add_argument('-o', '--output',
  116. help='output pkl path', required=True)
  117. args = parser.parse_args()
  118. if args.format != 'NCHW':
  119. print('Only suppprt NCHW for float model')
  120. parser.print_help()
  121. sys.exit(1)
  122. from meghair.utils import io
  123. io.dump(get(args), args.output)