You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

__init__.py 5.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. from typing import List, Tuple
  10. import numpy as np
  11. import megengine._internal as mgb
  12. import megengine.functional as F
  13. from megengine import Graph, jit
  14. from megengine.module import Linear, Module
  15. from megengine.test import assertTensorClose
  16. from .env import modified_environ
  17. class MLP(Module):
  18. def __init__(self):
  19. super().__init__()
  20. self.dense0 = Linear(28, 50)
  21. self.dense1 = Linear(50, 20)
  22. def forward(self, x):
  23. x = self.dense0(x)
  24. x = F.relu(x)
  25. x = self.dense1(x)
  26. return x
  27. def has_gpu(num=1):
  28. try:
  29. mgb.comp_node("gpu{}".format(num - 1))
  30. except mgb.MegBrainError:
  31. return False
  32. return True
  33. def randomNp(*args):
  34. for arg in args:
  35. assert isinstance(arg, int)
  36. return np.random.random(args)
  37. def randomTorch(*args):
  38. import torch # pylint: disable=import-outside-toplevel
  39. for arg in args:
  40. assert isinstance(arg, int)
  41. return torch.tensor(randomNp(*args), dtype=torch.float32)
  42. def graph_mode(*modes):
  43. if not set(modes).issubset({"eager", "static"}):
  44. raise ValueError("graph mode must be in (eager, static)")
  45. def decorator(func):
  46. def wrapper(*args, **kwargs):
  47. if "eager" in set(modes):
  48. func(*args, **kwargs)
  49. if "static" in set(modes):
  50. with Graph() as cg:
  51. cg.set_option("eager_evaluation", False)
  52. func(*args, **kwargs)
  53. return wrapper
  54. return decorator
  55. def _default_compare_fn(x, y):
  56. assertTensorClose(x.numpy(), y)
  57. def opr_test(
  58. cases,
  59. func,
  60. mode=("eager", "static", "dynamic_shape"),
  61. compare_fn=_default_compare_fn,
  62. ref_fn=None,
  63. **kwargs
  64. ):
  65. """
  66. mode: the list of test mode which are eager, static and dynamic_shape
  67. will test all the cases if None.
  68. func: the function to run opr.
  69. compare_fn: the function to compare the result and expected, use assertTensorClose if None.
  70. ref_fn: the function to generate expected data, should assign output if None.
  71. cases: the list which have dict element, the list length should be 2 for dynamic shape test.
  72. and the dict should have input,
  73. and should have output if ref_fn is None.
  74. should use list for multiple inputs and outputs for each case.
  75. kwargs: The additional kwargs for opr func.
  76. simple examples:
  77. dtype = np.float32
  78. cases = [{"input": [10, 20]}, {"input": [20, 30]}]
  79. opr_test(cases,
  80. F.eye,
  81. ref_fn=lambda n, m: np.eye(n, m).astype(dtype),
  82. dtype=dtype)
  83. """
  84. def check_results(results, expected):
  85. if not isinstance(results, Tuple):
  86. results = (results,)
  87. for r, e in zip(results, expected):
  88. compare_fn(r, e)
  89. def get_trace_fn(func, enabled, symbolic):
  90. jit.trace.enabled = enabled
  91. return jit.trace(func, symbolic=symbolic)
  92. def get_param(cases, idx):
  93. case = cases[idx]
  94. inp = case.get("input", None)
  95. outp = case.get("output", None)
  96. if inp is None:
  97. raise ValueError("the test case should have input")
  98. if not isinstance(inp, List):
  99. inp = (inp,)
  100. else:
  101. inp = tuple(inp)
  102. if ref_fn is not None and callable(ref_fn):
  103. outp = ref_fn(*inp)
  104. if outp is None:
  105. raise ValueError("the test case should have output or reference function")
  106. if not isinstance(outp, List):
  107. outp = (outp,)
  108. else:
  109. outp = tuple(outp)
  110. return inp, outp
  111. if not set(mode).issubset({"eager", "static", "dynamic_shape"}):
  112. raise ValueError("opr test mode must be in (eager, static, dynamic_shape)")
  113. if len(cases) == 0:
  114. raise ValueError("should give one case at least")
  115. if "dynamic_shape" in set(mode):
  116. if len(cases) != 2:
  117. raise ValueError("should give 2 cases for dynamic shape test")
  118. if not callable(func):
  119. raise ValueError("the input func should be callable")
  120. inp, outp = get_param(cases, 0)
  121. def run(*args, **kwargs):
  122. return func(*args, **kwargs)
  123. if "eager" in set(mode):
  124. f = get_trace_fn(run, False, False)
  125. results = f(*inp, **kwargs)
  126. check_results(results, outp)
  127. if "static" in set(mode) or "dynamic_shape" in set(mode):
  128. f = get_trace_fn(run, True, True)
  129. results = f(*inp, **kwargs)
  130. check_results(results, outp)
  131. if "dynamic_shape" in set(mode):
  132. inp, outp = get_param(cases, 1)
  133. results = f(*inp, **kwargs)
  134. check_results(results, outp)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台