You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import io
  2. import numpy as np
  3. import megengine.core.tensor.megbrain_graph as G
  4. import megengine.utils.comp_graph_tools as cgtools
  5. from megengine import tensor
  6. from megengine.core.tensor.megbrain_graph import OutputNode
  7. from megengine.jit import trace
  8. from megengine.utils.network_node import VarNode
  9. def _default_compare_fn(x, y):
  10. if isinstance(x, np.ndarray):
  11. np.testing.assert_allclose(x, y, rtol=1e-6)
  12. elif isinstance(x, tensor):
  13. np.testing.assert_allclose(x.numpy(), y, rtol=1e-6)
  14. else:
  15. np.testing.assert_allclose(get_var_value(x), y, rtol=1e-6)
  16. def make_tensor(x, network=None, device=None):
  17. if network is not None:
  18. if isinstance(x, VarNode):
  19. return VarNode(x.var)
  20. return network.make_const(x, device=device)
  21. else:
  22. return tensor(x, device=device)
  23. def get_var_value(x):
  24. try:
  25. o = OutputNode(x.var)
  26. o.graph.compile(o.outputs).execute()
  27. return o.get_value().numpy()
  28. except RuntimeError:
  29. raise ValueError("value invalid!")
  30. def opr_test(
  31. cases,
  32. func,
  33. compare_fn=_default_compare_fn,
  34. ref_fn=None,
  35. test_trace=True,
  36. network=None,
  37. **kwargs
  38. ):
  39. """
  40. :param cases: the list which have dict element, the list length should be 2 for dynamic shape test.
  41. and the dict should have input,
  42. and should have output if ref_fn is None.
  43. should use list for multiple inputs and outputs for each case.
  44. :param func: the function to run opr.
  45. :param compare_fn: the function to compare the result and expected, use
  46. ``np.testing.assert_allclose`` if None.
  47. :param ref_fn: the function to generate expected data, should assign output if None.
  48. Examples:
  49. .. code-block::
  50. dtype = np.float32
  51. cases = [{"input": [10, 20]}, {"input": [20, 30]}]
  52. opr_test(cases,
  53. F.eye,
  54. ref_fn=lambda n, m: np.eye(n, m).astype(dtype),
  55. dtype=dtype)
  56. """
  57. def check_results(results, expected):
  58. if not isinstance(results, (tuple, list)):
  59. results = (results,)
  60. for r, e in zip(results, expected):
  61. if not isinstance(r, (tensor, VarNode)):
  62. r = tensor(r)
  63. compare_fn(r, e)
  64. def get_param(cases, idx):
  65. case = cases[idx]
  66. inp = case.get("input", None)
  67. outp = case.get("output", None)
  68. if inp is None:
  69. raise ValueError("the test case should have input")
  70. if not isinstance(inp, (tuple, list)):
  71. inp = (inp,)
  72. if ref_fn is not None and callable(ref_fn):
  73. outp = ref_fn(*inp)
  74. if outp is None:
  75. raise ValueError("the test case should have output or reference function")
  76. if not isinstance(outp, (tuple, list)):
  77. outp = (outp,)
  78. return inp, outp
  79. def run_index(index):
  80. inp, outp = get_param(cases, index)
  81. inp_tensor = [make_tensor(inpi, network) for inpi in inp]
  82. if test_trace and not network:
  83. copied_inp = inp_tensor.copy()
  84. for symbolic in [False, True]:
  85. traced_func = trace(symbolic=symbolic)(func)
  86. for _ in range(3):
  87. traced_results = traced_func(*copied_inp, **kwargs)
  88. check_results(traced_results, outp)
  89. dumped_func = trace(symbolic=True, capture_as_const=True)(func)
  90. dumped_results = dumped_func(*copied_inp, **kwargs)
  91. check_results(dumped_results, outp)
  92. file = io.BytesIO()
  93. dump_info = dumped_func.dump(file)
  94. file.seek(0)
  95. # arg_name has pattern arg_xxx, xxx is int value
  96. def take_number(arg_name):
  97. return int(arg_name.split("_")[-1])
  98. input_names = dump_info[4]
  99. inps_np = [i.numpy() for i in copied_inp]
  100. input_names.sort(key=take_number)
  101. inp_dict = dict(zip(input_names, inps_np))
  102. infer_cg = cgtools.GraphInference(file)
  103. # assume #outputs == 1
  104. loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0]
  105. check_results(loaded_results, outp)
  106. results = func(*inp_tensor, **kwargs)
  107. check_results(results, outp)
  108. if len(cases) == 0:
  109. raise ValueError("should give one case at least")
  110. if not callable(func):
  111. raise ValueError("the input func should be callable")
  112. for index in range(len(cases)):
  113. run_index(index)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台