You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

compat.py 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import numpy as np
  2. from megengine.functional.tensor import zeros
  3. from ..core.ops.builtin import BatchNorm
  4. from .expr import CallMethod, Constant
  5. from .node import TensorNode
  6. from .serialization import (
  7. register_functional_loader,
  8. register_module_loader,
  9. register_opdef_loader,
  10. register_tensor_method_loader,
  11. )
  12. """
  13. # Expr loaders examples
  14. from ..core.ops.builtin import Elemwise
  15. @register_opdef_loader(Elemwise)
  16. def add_opdef_loader(expr):
  17. if expr.opdef_state["mode"] == "ADD":
  18. expr.opdef_state["mode"] == "MUL"
  19. node = expr.inputs[1]
  20. astype_expr = CallMethod(node, "astype")
  21. oup = TensorNode(
  22. astype_expr,
  23. shape=node.shape,
  24. dtype=expr.inputs[0].dtype,
  25. qparams=node.qparams,
  26. )
  27. astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
  28. astype_expr.return_val = (oup,)
  29. expr.inputs[1] = oup
  30. @register_functional_loader(("megengine.functional.nn", "conv2d"))
  31. def conv2df_loader(expr):
  32. # expr.func = ("megengine.functional.nn","conv2d")
  33. kwargs = expr.kwargs
  34. orig_weight = expr.named_args["weight"]
  35. astype_expr = CallMethod(orig_weight, "astype")
  36. oup = TensorNode(
  37. astype_expr,
  38. shape=orig_weight.shape,
  39. dtype=orig_weight.dtype,
  40. qparams=orig_weight.qparams,
  41. )
  42. astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype)
  43. astype_expr.return_val = (oup,)
  44. expr.set_arg("weight", oup)
  45. @register_module_loader(("megengine.module.conv", "Conv2d"))
  46. def conv2dm_loader(expr):
  47. module = expr.inputs[0].owner
  48. args = list(expr.args)
  49. orig_inp = args[1]
  50. astype_expr = CallMethod(orig_inp, "astype")
  51. oup = TensorNode(
  52. astype_expr,
  53. shape=orig_inp.shape,
  54. dtype=orig_inp.dtype,
  55. qparams=orig_inp.qparams,
  56. )
  57. astype_expr.set_args_kwargs(orig_inp, module.weight.dtype)
  58. astype_expr.return_val = (oup,)
  59. args[1] = oup
  60. expr.set_args_kwargs(*args)
  61. @register_tensor_method_loader("__add__")
  62. def add_loader(expr):
  63. args = list(expr.args)
  64. if not isinstance(args[1], TensorNode):
  65. args[1] = tensor(args[1])
  66. node = Constant(args[1], "const").outputs[0]
  67. astype_expr = CallMethod(node, "astype")
  68. oup = TensorNode(
  69. astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams,
  70. )
  71. astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
  72. astype_expr.return_val = (oup,)
  73. args[1] = oup
  74. expr.set_args_kwargs(*args)
  75. """
  76. @register_module_loader(
  77. ("megengine.module.batchnorm", "BatchNorm1d"),
  78. ("megengine.module.batchnorm", "BatchNorm2d"),
  79. ("megengine.module.batchnorm", "SyncBatchNorm"),
  80. )
  81. def bn2d_module_loader(expr):
  82. module = expr.inputs[0].owner
  83. if hasattr(module, "param_dim"):
  84. assert module.param_dim == "dim_1c11"
  85. delattr(module, "param_dim")
  86. @register_module_loader(
  87. ("megengine.module.conv_bn", "ConvBn2d"),
  88. ("megengine.module.conv_bn", "ConvBnRelu2d"),
  89. ("megengine.module.qat.conv_bn", "ConvBn2d"),
  90. ("megengine.module.qat.conv_bn", "ConvBnRelu2d"),
  91. )
  92. def convbn2d_module_loader(expr):
  93. module = expr.inputs[0].owner
  94. if hasattr(module.bn, "param_dim"):
  95. assert module.bn.param_dim == "dim_1c11"
  96. delattr(module.bn, "param_dim")
  97. if not hasattr(module.conv, "padding_mode"):
  98. module.conv.padding_mode = "zeros"
  99. @register_opdef_loader(BatchNorm)
  100. def bn_opdef_loader(expr):
  101. # mge 1.6
  102. if not hasattr(expr, "version") and len(expr.outputs) != 6:
  103. assert len(expr.outputs) == 5
  104. output = expr.outputs[-1]
  105. oup = TensorNode(expr, shape=(0,), dtype=None, qparams=output._qparams,)
  106. expr.outputs.insert(4, oup)
  107. @register_functional_loader(
  108. ("megengine.functional.tensor", "ones"), ("megengine.functional.tensor", "zeros")
  109. )
  110. def tensor_gen_func_loader(expr):
  111. if hasattr(expr, "version") and expr.version == "1.7.0":
  112. expr.set_args_kwargs(expr.args[0], dtype=expr.args[1], device=expr.args[2])
  113. if not hasattr(expr, "version"):
  114. # compatiable for version 1.6
  115. shape = expr.args[0] if len(expr.args) > 0 else expr.kwargs["shape"]
  116. if len(expr.args) > 1:
  117. dtype = expr.args[1]
  118. elif "dtype" in expr.kwargs:
  119. dtype = expr.kwargs["dtype"]
  120. else:
  121. dtype = "float32"
  122. if len(expr.args) > 2:
  123. device = expr.args[2]
  124. elif "device" in expr.kwargs:
  125. device = expr.kwargs["device"]
  126. else:
  127. device = None
  128. expr.set_args_kwargs(shape, dtype=dtype, device=device)
  129. @register_functional_loader(("megengine.functional.nn", "pad"))
  130. def pad_func_loader(expr):
  131. if "pad_witdth" in expr.kwargs:
  132. kwargs = expr.kwargs
  133. kwargs["pad_width"] = kwargs.pop("pad_witdth")
  134. expr.set_args_kwargs(*expr.args, **kwargs)
  135. @register_functional_loader(("megengine.functional.nn", "batch_norm"))
  136. def bn_func_loader(expr):
  137. kwargs = expr.kwargs
  138. if "compute_mode" in kwargs:
  139. assert kwargs["compute_mode"] == "default"
  140. kwargs.pop("compute_mode")
  141. if "param_dim" in kwargs:
  142. assert kwargs["param_dim"] == "dim_1c11"
  143. kwargs.pop("param_dim")
  144. expr.set_args_kwargs(*expr.args, **kwargs)
  145. @register_functional_loader(("megengine.functional.math", "matmul"))
  146. def matmul_func_loader(expr):
  147. args = expr.args
  148. if len(args) == 6:
  149. assert args[5] == "default"
  150. expr.set_args_kwargs(*args[0:5])
  151. @register_module_loader(
  152. ("megengine.module.conv", "Conv1d"),
  153. ("megengine.module.conv", "Conv2d"),
  154. ("megengine.module.conv", "ConvRelu2d"),
  155. ("megengine.module.qat.conv", "Conv2d"),
  156. ("megengine.module.qat.conv", "ConvRelu2d"),
  157. ("megengine.module.quantized.conv", "Conv2d"),
  158. ("megengine.module.quantized.conv", "ConvRelu2d"),
  159. )
  160. def conv2d_module_loader(expr):
  161. module = expr.inputs[0].owner
  162. if not hasattr(module, "padding_mode"):
  163. module.padding_mode = "zeros"
  164. @register_module_loader(
  165. ("megengine.module.quantized.conv_bn", "ConvBn2d"),
  166. ("megengine.module.quantized.conv_bn", "ConvBnRelu2d"),
  167. )
  168. def quantized_convbn2d_module_loader(expr):
  169. module = expr.inputs[0].owner
  170. if not hasattr(module, "padding_mode"):
  171. module.padding_mode = "zeros"