|
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- Created on Wed Dec 19 15:31:01 2018
- A script to set the thread number of OpenBLAS (if used).
- Some modules (such as Numpy, Scipy, sklearn) using OpenBLAS perform parallel
- computation automatically, which causes conflict when other paralleling modules
- such as multiprossing.Pool, highly increase the computing time. By setting
- thread to 1, OpenBLAS is forced to use single thread/CPU, thus this conflict
- can be avoided.
- -e.g:
- with num_threads(8):
- np.dot(x, y)
- @author: ali_m
- @Reference: ali_m, https://stackoverflow.com/a/29582987, 2018.12
- """
-
- import contextlib
- import ctypes
- from ctypes.util import find_library
- import os
-
- # Prioritize hand-compiled OpenBLAS library over version in /usr/lib/
- # from Ubuntu repos
- try_paths = ['/opt/OpenBLAS/lib/libopenblas.so',
- '/lib/libopenblas.so',
- '/usr/lib/libopenblas.so.0',
- find_library('openblas')]
- openblas_lib = None
- for libpath in try_paths:
- try:
- openblas_lib = ctypes.cdll.LoadLibrary(libpath)
- break
- except OSError:
- continue
- if openblas_lib is None:
- raise EnvironmentError('Could not locate an OpenBLAS shared library', 2)
-
-
- def set_num_threads(n):
- """Set the current number of threads used by the OpenBLAS server."""
- openblas_lib.openblas_set_num_threads(int(n))
-
-
- # At the time of writing these symbols were very new:
- # https://github.com/xianyi/OpenBLAS/commit/65a847c
- try:
- openblas_lib.openblas_get_num_threads()
- def get_num_threads():
- """Get the current number of threads used by the OpenBLAS server."""
- return openblas_lib.openblas_get_num_threads()
- except AttributeError:
- def get_num_threads():
- """Dummy function (symbol not present in %s), returns -1."""
- return -1
- pass
-
- try:
- len(os.sched_getaffinity(0))
- def get_num_procs():
- """Get the total number of physical processors"""
- return len(os.sched_getaffinity(0))
- except AttributeError:
- def get_num_procs():
- """Dummy function (symbol not present), returns -1."""
- return -1
- pass
-
-
- @contextlib.contextmanager
- def num_threads(n):
- """Temporarily changes the number of OpenBLAS threads.
-
- Example usage:
-
- print("Before: {}".format(get_num_threads()))
- with num_threads(n):
- print("In thread context: {}".format(get_num_threads()))
- print("After: {}".format(get_num_threads()))
- """
- old_n = get_num_threads()
- set_num_threads(n)
- try:
- yield
- finally:
- set_num_threads(old_n)
|