|
@@ -1,86 +0,0 @@ |
|
|
#!/usr/bin/env python3 |
|
|
|
|
|
# -*- coding: utf-8 -*- |
|
|
|
|
|
""" |
|
|
|
|
|
Created on Wed Dec 19 15:31:01 2018 |
|
|
|
|
|
A script to set the thread number of OpenBLAS (if used). |
|
|
|
|
|
Some modules (such as Numpy, Scipy, sklearn) using OpenBLAS perform parallel |
|
|
|
|
|
computation automatically, which causes conflict when other paralleling modules |
|
|
|
|
|
such as multiprossing.Pool, highly increase the computing time. By setting |
|
|
|
|
|
thread to 1, OpenBLAS is forced to use single thread/CPU, thus this conflict |
|
|
|
|
|
can be avoided. |
|
|
|
|
|
-e.g: |
|
|
|
|
|
with num_threads(8): |
|
|
|
|
|
np.dot(x, y) |
|
|
|
|
|
@author: ali_m |
|
|
|
|
|
@Reference: `ali_m's answer <https://stackoverflow.com/a/29582987>`__, 2018.12 |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
|
|
|
|
import ctypes |
|
|
|
|
|
from ctypes.util import find_library |
|
|
|
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
# Prioritize hand-compiled OpenBLAS library over version in /usr/lib/ |
|
|
|
|
|
# from Ubuntu repos |
|
|
|
|
|
try_paths = ['/opt/OpenBLAS/lib/libopenblas.so', |
|
|
|
|
|
'/lib/libopenblas.so', |
|
|
|
|
|
'/usr/lib/libopenblas.so.0', |
|
|
|
|
|
find_library('openblas')] |
|
|
|
|
|
openblas_lib = None |
|
|
|
|
|
for libpath in try_paths: |
|
|
|
|
|
try: |
|
|
|
|
|
openblas_lib = ctypes.cdll.LoadLibrary(libpath) |
|
|
|
|
|
break |
|
|
|
|
|
except OSError: |
|
|
|
|
|
continue |
|
|
|
|
|
if openblas_lib is None: |
|
|
|
|
|
raise EnvironmentError('Could not locate an OpenBLAS shared library', 2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_num_threads(n): |
|
|
|
|
|
"""Set the current number of threads used by the OpenBLAS server.""" |
|
|
|
|
|
openblas_lib.openblas_set_num_threads(int(n)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# At the time of writing these symbols were very new: |
|
|
|
|
|
# https://github.com/xianyi/OpenBLAS/commit/65a847c |
|
|
|
|
|
try: |
|
|
|
|
|
openblas_lib.openblas_get_num_threads() |
|
|
|
|
|
def get_num_threads(): |
|
|
|
|
|
"""Get the current number of threads used by the OpenBLAS server.""" |
|
|
|
|
|
return openblas_lib.openblas_get_num_threads() |
|
|
|
|
|
except AttributeError: |
|
|
|
|
|
def get_num_threads(): |
|
|
|
|
|
"""Dummy function (symbol not present in %s), returns -1.""" |
|
|
|
|
|
return -1 |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
len(os.sched_getaffinity(0)) |
|
|
|
|
|
def get_num_procs(): |
|
|
|
|
|
"""Get the total number of physical processors""" |
|
|
|
|
|
return len(os.sched_getaffinity(0)) |
|
|
|
|
|
except AttributeError: |
|
|
|
|
|
def get_num_procs(): |
|
|
|
|
|
"""Dummy function (symbol not present), returns -1.""" |
|
|
|
|
|
return -1 |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
|
|
|
def num_threads(n): |
|
|
|
|
|
"""Temporarily changes the number of OpenBLAS threads. |
|
|
|
|
|
|
|
|
|
|
|
Example usage: |
|
|
|
|
|
|
|
|
|
|
|
print("Before: {}".format(get_num_threads())) |
|
|
|
|
|
with num_threads(n): |
|
|
|
|
|
print("In thread context: {}".format(get_num_threads())) |
|
|
|
|
|
print("After: {}".format(get_num_threads())) |
|
|
|
|
|
""" |
|
|
|
|
|
old_n = get_num_threads() |
|
|
|
|
|
set_num_threads(n) |
|
|
|
|
|
try: |
|
|
|
|
|
yield |
|
|
|
|
|
finally: |
|
|
|
|
|
set_num_threads(old_n) |
|
|
|