You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

compat.py 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import numpy as np
  2. from megengine.functional.tensor import zeros
  3. from ..core.ops.builtin import BatchNorm
  4. from .expr import CallMethod, Constant
  5. from .node import TensorNode
  6. from .serialization import (
  7. register_functional_loader,
  8. register_module_loader,
  9. register_opdef_loader,
  10. register_tensor_method_loader,
  11. )
  12. """
  13. # Expr loaders examples
  14. from ..core.ops.builtin import Elemwise
  15. @register_opdef_loader(Elemwise)
  16. def add_opdef_loader(expr):
  17. if expr.opdef_state["mode"] == "ADD":
  18. expr.opdef_state["mode"] == "MUL"
  19. node = expr.inputs[1]
  20. astype_expr = CallMethod(node, "astype")
  21. oup = TensorNode(
  22. astype_expr,
  23. shape=node.shape,
  24. dtype=expr.inputs[0].dtype,
  25. qparams=node.qparams,
  26. )
  27. astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
  28. astype_expr.return_val = (oup,)
  29. expr.inputs[1] = oup
  30. @register_functional_loader(("megengine.functional.nn", "conv2d"))
  31. def conv2df_loader(expr):
  32. # expr.func = ("megengine.functional.nn","conv2d")
  33. kwargs = expr.kwargs
  34. orig_weight = expr.named_args["weight"]
  35. astype_expr = CallMethod(orig_weight, "astype")
  36. oup = TensorNode(
  37. astype_expr,
  38. shape=orig_weight.shape,
  39. dtype=orig_weight.dtype,
  40. qparams=orig_weight.qparams,
  41. )
  42. astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype)
  43. astype_expr.return_val = (oup,)
  44. expr.set_arg("weight", oup)
  45. @register_module_loader(("megengine.module.conv", "Conv2d"))
  46. def conv2dm_loader(expr):
  47. module = expr.inputs[0].owner
  48. args = list(expr.args)
  49. orig_inp = args[1]
  50. astype_expr = CallMethod(orig_inp, "astype")
  51. oup = TensorNode(
  52. astype_expr,
  53. shape=orig_inp.shape,
  54. dtype=orig_inp.dtype,
  55. qparams=orig_inp.qparams,
  56. )
  57. astype_expr.set_args_kwargs(orig_inp, module.weight.dtype)
  58. astype_expr.return_val = (oup,)
  59. args[1] = oup
  60. expr.set_args_kwargs(*args)
  61. @register_tensor_method_loader("__add__")
  62. def add_loader(expr):
  63. args = list(expr.args)
  64. if not isinstance(args[1], TensorNode):
  65. args[1] = tensor(args[1])
  66. node = Constant(args[1], "const").outputs[0]
  67. astype_expr = CallMethod(node, "astype")
  68. oup = TensorNode(
  69. astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams,
  70. )
  71. astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
  72. astype_expr.return_val = (oup,)
  73. args[1] = oup
  74. expr.set_args_kwargs(*args)
  75. """
  76. @register_module_loader(
  77. ("megengine.module.batchnorm", "BatchNorm1d"),
  78. ("megengine.module.batchnorm", "BatchNorm2d"),
  79. ("megengine.module.batchnorm", "SyncBatchNorm"),
  80. )
  81. def bn2d_module_loader(expr):
  82. # mge 1.6
  83. if not hasattr(expr, "version"):
  84. module = expr.inputs[0].owner
  85. if not hasattr(module, "param_dim"):
  86. module.param_dim = "dim_1c11"
  87. @register_module_loader(
  88. ("megengine.module.conv_bn", "ConvBn2d"),
  89. ("megengine.module.conv_bn", "ConvBnRelu2d"),
  90. ("megengine.module.qat.conv_bn", "ConvBn2d"),
  91. ("megengine.module.qat.conv_bn", "ConvBnRelu2d"),
  92. )
  93. def convbn2d_module_loader(expr):
  94. # mge 1.6
  95. if not hasattr(expr, "version"):
  96. module = expr.inputs[0].owner
  97. if not hasattr(module.bn, "param_dim"):
  98. module.bn.param_dim = "dim_1c11"
  99. module = expr.inputs[0].owner
  100. if not hasattr(module.conv, "padding_mode"):
  101. module.conv.padding_mode = "zeros"
  102. @register_opdef_loader(BatchNorm)
  103. def bn_opdef_loader(expr):
  104. # mge 1.6
  105. if not hasattr(expr, "version") and len(expr.outputs) != 6:
  106. assert len(expr.outputs) == 5
  107. output = expr.outputs[-1]
  108. oup = TensorNode(expr, shape=(0,), dtype=None, qparams=output._qparams,)
  109. expr.outputs.insert(4, oup)
  110. @register_functional_loader(
  111. ("megengine.functional.tensor", "ones"), ("megengine.functional.tensor", "zeros")
  112. )
  113. def tensor_gen_func_loader(expr):
  114. if hasattr(expr, "version") and expr.version == "1.7.0":
  115. expr.set_args_kwargs(expr.args[0], dtype=expr.args[1], device=expr.args[2])
  116. if not hasattr(expr, "version"):
  117. # compatiable for version 1.6
  118. shape = expr.args[0] if len(expr.args) > 0 else expr.kwargs["shape"]
  119. if len(expr.args) > 1:
  120. dtype = expr.args[1]
  121. elif "dtype" in expr.kwargs:
  122. dtype = expr.kwargs["dtype"]
  123. else:
  124. dtype = "float32"
  125. if len(expr.args) > 2:
  126. device = expr.args[2]
  127. elif "device" in expr.kwargs:
  128. device = expr.kwargs["device"]
  129. else:
  130. device = None
  131. expr.set_args_kwargs(shape, dtype=dtype, device=device)
  132. @register_functional_loader(("megengine.functional.nn", "pad"))
  133. def pad_func_loader(expr):
  134. if "pad_witdth" in expr.kwargs:
  135. kwargs = expr.kwargs
  136. kwargs["pad_width"] = kwargs.pop("pad_witdth")
  137. expr.set_args_kwargs(*expr.args, **kwargs)
  138. @register_module_loader(
  139. ("megengine.module.conv", "Conv1d"),
  140. ("megengine.module.conv", "Conv2d"),
  141. ("megengine.module.conv", "ConvRelu2d"),
  142. ("megengine.module.qat.conv", "Conv2d"),
  143. ("megengine.module.qat.conv", "ConvRelu2d"),
  144. ("megengine.module.quantized.conv", "Conv2d"),
  145. ("megengine.module.quantized.conv", "ConvRelu2d"),
  146. )
  147. def conv2d_module_loader(expr):
  148. module = expr.inputs[0].owner
  149. if not hasattr(module, "padding_mode"):
  150. module.padding_mode = "zeros"
  151. @register_module_loader(
  152. ("megengine.module.quantized.conv_bn", "ConvBn2d"),
  153. ("megengine.module.quantized.conv_bn", "ConvBnRelu2d"),
  154. )
  155. def quantized_convbn2d_module_loader(expr):
  156. module = expr.inputs[0].owner
  157. if not hasattr(module, "padding_mode"):
  158. module.padding_mode = "zeros"