You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

openblassettings.py 2.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Wed Dec 19 15:31:01 2018
  5. A script to set the thread number of OpenBLAS (if used).
  6. Some modules (such as Numpy, Scipy, sklearn) using OpenBLAS perform parallel
  7. computation automatically, which causes conflict when other paralleling modules
  8. such as multiprossing.Pool, highly increase the computing time. By setting
  9. thread to 1, OpenBLAS is forced to use single thread/CPU, thus this conflict
  10. can be avoided.
  11. -e.g:
  12. with num_threads(8):
  13. np.dot(x, y)
  14. @author: ali_m
  15. @Reference: ali_m, https://stackoverflow.com/a/29582987, 2018.12
  16. """
  17. import contextlib
  18. import ctypes
  19. from ctypes.util import find_library
  20. import os
  21. # Prioritize hand-compiled OpenBLAS library over version in /usr/lib/
  22. # from Ubuntu repos
  23. try_paths = ['/opt/OpenBLAS/lib/libopenblas.so',
  24. '/lib/libopenblas.so',
  25. '/usr/lib/libopenblas.so.0',
  26. find_library('openblas')]
  27. openblas_lib = None
  28. for libpath in try_paths:
  29. try:
  30. openblas_lib = ctypes.cdll.LoadLibrary(libpath)
  31. break
  32. except OSError:
  33. continue
  34. if openblas_lib is None:
  35. raise EnvironmentError('Could not locate an OpenBLAS shared library', 2)
  36. def set_num_threads(n):
  37. """Set the current number of threads used by the OpenBLAS server."""
  38. openblas_lib.openblas_set_num_threads(int(n))
  39. # At the time of writing these symbols were very new:
  40. # https://github.com/xianyi/OpenBLAS/commit/65a847c
  41. try:
  42. openblas_lib.openblas_get_num_threads()
  43. def get_num_threads():
  44. """Get the current number of threads used by the OpenBLAS server."""
  45. return openblas_lib.openblas_get_num_threads()
  46. except AttributeError:
  47. def get_num_threads():
  48. """Dummy function (symbol not present in %s), returns -1."""
  49. return -1
  50. pass
  51. try:
  52. len(os.sched_getaffinity(0))
  53. def get_num_procs():
  54. """Get the total number of physical processors"""
  55. return len(os.sched_getaffinity(0))
  56. except AttributeError:
  57. def get_num_procs():
  58. """Dummy function (symbol not present), returns -1."""
  59. return -1
  60. pass
  61. @contextlib.contextmanager
  62. def num_threads(n):
  63. """Temporarily changes the number of OpenBLAS threads.
  64. Example usage:
  65. print("Before: {}".format(get_num_threads()))
  66. with num_threads(n):
  67. print("In thread context: {}".format(get_num_threads()))
  68. print("After: {}".format(get_num_threads()))
  69. """
  70. old_n = get_num_threads()
  71. set_num_threads(n)
  72. try:
  73. yield
  74. finally:
  75. set_num_threads(old_n)

A Python package for graph kernels, graph edit distances and graph pre-image problem.