You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_wrap.py 1.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import numpy as np
  10. from ._imperative_rt import CompNode
  11. from ._imperative_rt.core2 import set_py_device_type
  12. class Device:
  13. def __init__(self, device=None):
  14. if device is None:
  15. self._cn = CompNode()
  16. elif isinstance(device, Device):
  17. self._cn = device._cn
  18. elif isinstance(device, CompNode):
  19. self._cn = device
  20. else:
  21. self._cn = CompNode(device)
  22. self._logical_name = None
  23. @property
  24. def logical_name(self):
  25. if self._logical_name:
  26. return self._logical_name
  27. self._logical_name = self._cn.logical_name
  28. return self._logical_name
  29. def to_c(self):
  30. return self._cn
  31. def __repr__(self):
  32. return "{}({})".format(type(self).__qualname__, repr(self._cn))
  33. def __str__(self):
  34. return str(self._cn)
  35. def __hash__(self):
  36. return hash(str(self._cn))
  37. def __eq__(self, rhs):
  38. if not isinstance(rhs, Device):
  39. rhs = Device(rhs)
  40. return self._cn == rhs._cn
  41. def as_device(obj):
  42. if isinstance(obj, Device):
  43. return obj
  44. return Device(obj)
  45. set_py_device_type(Device)