You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

plugin.py 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import struct
  10. import numpy as np
  11. def load_tensor_binary(fobj):
  12. """
  13. Load a tensor dumped by the :class:`BinaryOprIODump` plugin; the actual
  14. tensor value dump is implemented by ``mgb::debug::dump_tensor``.
  15. Multiple values can be compared by ``tools/compare_binary_iodump.py``.
  16. :param fobj: file object, or a string that contains the file name.
  17. :return: tuple ``(tensor_value, tensor_name)``.
  18. """
  19. if isinstance(fobj, str):
  20. with open(fobj, "rb") as fin:
  21. return load_tensor_binary(fin)
  22. DTYPE_LIST = {
  23. 0: np.float32,
  24. 1: np.uint8,
  25. 2: np.int8,
  26. 3: np.int16,
  27. 4: np.int32,
  28. # 5: _mgb.intb1,
  29. # 6: _mgb.intb2,
  30. # 7: _mgb.intb4,
  31. 8: None,
  32. 9: np.float16,
  33. # quantized dtype start from 100000
  34. # see MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE in
  35. # dnn/include/megdnn/dtype.h
  36. 100000: np.uint8,
  37. 100001: np.int32,
  38. 100002: np.int8,
  39. }
  40. header_fmt = struct.Struct("III")
  41. name_len, dtype, max_ndim = header_fmt.unpack(fobj.read(header_fmt.size))
  42. assert (
  43. DTYPE_LIST[dtype] is not None
  44. ), "Cannot load this tensor: dtype Byte is unsupported."
  45. shape = list(struct.unpack("I" * max_ndim, fobj.read(max_ndim * 4)))
  46. while shape[-1] == 0:
  47. shape.pop(-1)
  48. name = fobj.read(name_len).decode("ascii")
  49. return np.fromfile(fobj, dtype=DTYPE_LIST[dtype]).reshape(shape), name

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台