|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- '''data tools'''
- import glob
- import os
- import subprocess
- import contextlib
- import shutil
- import tempfile
- import time
-
- from typing import Any, Mapping, Optional, Sequence
-
- from absl import logging
-
- _HHBLITS_DEFAULT_P = 20
- _HHBLITS_DEFAULT_Z = 500
-
-
- def _to_a3m(sequences: Sequence[str]) -> str:
- """Converts sequences to an a3m file."""
- names = ['sequence %d' % i for i in range(1, len(sequences) + 1)]
- a3m = []
- for sequence, name in zip(sequences, names):
- a3m.append(u'>' + name + u'\n')
- a3m.append(sequence + u'\n')
- return ''.join(a3m)
-
-
- class Kalign:
- """Python wrapper of the Kalign binary."""
-
- def __init__(self, *, binary_path: str):
- """Initializes the Python Kalign wrapper.
-
- Args:
- binary_path: The path to the Kalign binary.
- """
- self.binary_path = binary_path
-
- def align(self, sequences: Sequence[str]) -> str:
- """Aligns the sequences and returns the alignment in A3M string.
-
- Args:
- sequences: A list of query sequence strings. The sequences have to be at
- least 6 residues long (Kalign requires this). Note that the order in
- which you give the sequences might alter the output slightly as
- different alignment tree might get constructed.
-
- Returns:
- A string with the alignment in a3m format.
-
- Raises:
- RuntimeError: If Kalign fails.
- ValueError: If any of the sequences is less than 6 residues long.
- """
- logging.info('Aligning %d sequences', len(sequences))
-
- for s in sequences:
- if len(s) < 6:
- raise ValueError('Kalign requires all sequences to be at least 6 '
- 'residues long. Got %s (%d residues).' % (s, len(s)))
-
- with tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
- input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta')
- output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
-
- with open(input_fasta_path, 'w') as f:
- f.write(_to_a3m(sequences))
-
- cmd = [self.binary_path, '-i', input_fasta_path, '-o', output_a3m_path, '-format', 'fasta',]
-
- logging.info('Launching subprocess "%s"', ' '.join(cmd))
- process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
-
- with timing('Kalign query'):
- stdout, stderr = process.communicate()
- retcode = process.wait()
- logging.info('Kalign stdout:\n%s\n\nstderr:\n%s\n', stdout.decode('utf-8'), stderr.decode('utf-8'))
-
- if retcode:
- raise RuntimeError(
- 'Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' % (stdout.decode('utf-8'), stderr.decode('utf-8')))
-
- with open(output_a3m_path) as f:
- a3m = f.read()
-
- return a3m
-
-
- @contextlib.contextmanager
- def tmpdir_manager(base_dir: Optional[str] = None):
- """Context manager that deletes a temporary directory on exit."""
- tmpdir = tempfile.mkdtemp(dir=base_dir)
- try:
- yield tmpdir
- finally:
- shutil.rmtree(tmpdir, ignore_errors=True)
-
-
- @contextlib.contextmanager
- def timing(msg: str):
- logging.info('Started %s', msg)
- tic = time.time()
- yield
- toc = time.time()
- logging.info('Finished %s in %.3f seconds', msg, toc - tic)
-
-
- class HHBlits:
- """Python wrapper of the HHblits binary."""
-
- def __init__(self,
- *,
- binary_path: str,
- databases: Sequence[str],
- n_cpu: int = 4,
- n_iter: int = 3,
- e_value: float = 0.001,
- maxseq: int = 1_000_000,
- realign_max: int = 100_000,
- maxfilt: int = 100_000,
- min_prefilter_hits: int = 1000,
- all_seqs: bool = False,
- alt: Optional[int] = None,
- p: int = _HHBLITS_DEFAULT_P,
- z: int = _HHBLITS_DEFAULT_Z):
- """Initializes the Python HHblits wrapper.
-
- Args:
- binary_path: The path to the HHblits executable.
- databases: A sequence of HHblits database paths. This should be the
- common prefix for the database files (i.e. up to but not including
- _hhm.ffindex etc.)
- n_cpu: The number of CPUs to give HHblits.
- n_iter: The number of HHblits iterations.
- e_value: The E-value, see HHblits docs for more details.
- maxseq: The maximum number of rows in an input alignment. Note that this
- parameter is only supported in HHBlits version 3.1 and higher.
- realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
- maxfilt: Max number of hits allowed to pass the 2nd prefilter.
- HHblits default: 20000.
- min_prefilter_hits: Min number of hits to pass prefilter.
- HHblits default: 100.
- all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
- HHblits default: False.
- alt: Show up to this many alternative alignments.
- p: Minimum Prob for a hit to be included in the output hhr file.
- HHblits default: 20.
- z: Hard cap on number of hits reported in the hhr file.
- HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.
-
- Raises:
- RuntimeError: If HHblits binary not found within the path.
- """
- self.binary_path = binary_path
- self.databases = databases
-
- for database_path in self.databases:
- if not glob.glob(database_path + '_*'):
- logging.error('Could not find HHBlits database %s', database_path)
- raise ValueError(f'Could not find HHBlits database {database_path}')
-
- self.n_cpu = n_cpu
- self.n_iter = n_iter
- self.e_value = e_value
- self.maxseq = maxseq
- self.realign_max = realign_max
- self.maxfilt = maxfilt
- self.min_prefilter_hits = min_prefilter_hits
- self.all_seqs = all_seqs
- self.alt = alt
- self.p = p
- self.z = z
-
- def query(self, input_fasta_path: str) -> Mapping[str, Any]:
- """Queries the database using HHblits."""
- with tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
- a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
-
- db_cmd = []
- for db_path in self.databases:
- db_cmd.append('-d')
- db_cmd.append(db_path)
- cmd = [
- self.binary_path,
- '-i', input_fasta_path,
- '-cpu', str(self.n_cpu),
- '-oa3m', a3m_path,
- '-o', '/dev/null',
- '-n', str(self.n_iter),
- '-e', str(self.e_value),
- '-maxseq', str(self.maxseq),
- '-realign_max', str(self.realign_max),
- '-maxfilt', str(self.maxfilt),
- '-min_prefilter_hits', str(self.min_prefilter_hits)]
- if self.all_seqs:
- cmd += ['-all']
- if self.alt:
- cmd += ['-alt', str(self.alt)]
- if self.p != _HHBLITS_DEFAULT_P:
- cmd += ['-p', str(self.p)]
- if self.z != _HHBLITS_DEFAULT_Z:
- cmd += ['-Z', str(self.z)]
- cmd += db_cmd
-
- logging.info('Launching subprocess "%s"', ' '.join(cmd))
- process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
-
- with timing('HHblits query'):
- stdout, stderr = process.communicate()
- retcode = process.wait()
-
- if retcode:
- # Logs have a 15k character limit, so log HHblits error line by
- # line.
- logging.error('HHblits failed. HHblits stderr begin:')
- for error_line in stderr.decode('utf-8').splitlines():
- if error_line.strip():
- logging.error(error_line.strip())
- logging.error('HHblits stderr end')
- raise RuntimeError('HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % (
- stdout.decode('utf-8'), stderr[:500_000].decode('utf-8')))
-
- with open(a3m_path) as f:
- a3m = f.read()
-
- raw_output = dict(
- a3m=a3m,
- output=stdout,
- stderr=stderr,
- n_iter=self.n_iter,
- e_value=self.e_value)
- return raw_output
-
-
- class HHSearch:
- """Python wrapper of the HHsearch binary."""
-
- def __init__(self,
- *,
- binary_path: str,
- databases: Sequence[str],
- maxseq: int = 1_000_000):
- """Initializes the Python HHsearch wrapper.
-
- Args:
- binary_path: The path to the HHsearch executable.
- databases: A sequence of HHsearch database paths. This should be the
- common prefix for the database files (i.e. up to but not including
- _hhm.ffindex etc.)
- maxseq: The maximum number of rows in an input alignment. Note that this
- parameter is only supported in HHBlits version 3.1 and higher.
-
- Raises:
- RuntimeError: If HHsearch binary not found within the path.
- """
- self.binary_path = binary_path
- self.databases = databases
- self.maxseq = maxseq
-
- for database_path in self.databases:
- if not glob.glob(database_path + '_*'):
- logging.error(
- 'Could not find HHsearch database %s',
- database_path)
- raise ValueError(
- f'Could not find HHsearch database {database_path}')
-
- def query(self, a3m: str) -> str:
- """Queries the database using HHsearch using a given a3m."""
- with tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
- input_path = os.path.join(query_tmp_dir, 'query.a3m')
- hhr_path = os.path.join(query_tmp_dir, 'output.hhr')
- with open(input_path, 'w') as f:
- f.write(a3m)
-
- db_cmd = []
- for db_path in self.databases:
- db_cmd.append('-d')
- db_cmd.append(db_path)
- cmd = [self.binary_path,
- '-i', input_path,
- '-o', hhr_path,
- '-maxseq', str(self.maxseq),
- '-cpu', '8',] + db_cmd
-
- logging.info('Launching subprocess "%s"', ' '.join(cmd))
- process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
- with timing('HHsearch query'):
- stdout, stderr = process.communicate()
- retcode = process.wait()
- if retcode:
- # Stderr is truncated to prevent proto size errors in Beam.
- raise RuntimeError('HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
- stdout.decode('utf-8'), stderr[:100_000].decode('utf-8')))
- with open(hhr_path) as f:
- hhr = f.read()
- return hhr
-
-
- class Jackhmmer:
- """Python wrapper of the Jackhmmer binary."""
-
- def __init__(self,
- *,
- binary_path: str,
- database_path: str,
- n_cpu: int = 8,
- n_iter: int = 1,
- e_value: float = 0.0001,
- z_value: Optional[int] = None,
- get_tblout: bool = False,
- filter_f1: float = 0.0005,
- filter_f2: float = 0.00005,
- filter_f3: float = 0.0000005,
- incdom_e: Optional[float] = None,
- dom_e: Optional[float] = None):
- """Initializes the Python Jackhmmer wrapper.
-
- Args:
- binary_path: The path to the jackhmmer executable.
- database_path: The path to the jackhmmer database (FASTA format).
- n_cpu: The number of CPUs to give Jackhmmer.
- n_iter: The number of Jackhmmer iterations.
- e_value: The E-value, see Jackhmmer docs for more details.
- z_value: The Z-value, see Jackhmmer docs for more details.
- get_tblout: Whether to save tblout string.
- filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
- filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
- filter_f3: Forward pre-filter, set to >1.0 to turn off.
- incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
- round.
- dom_e: Domain e-value criteria for inclusion in tblout.
- """
- self.binary_path = binary_path
- self.database_path = database_path
-
- if not os.path.exists(self.database_path):
- logging.error(
- 'Could not find Jackhmmer database %s',
- database_path)
- raise ValueError(
- f'Could not find Jackhmmer database {database_path}')
-
- self.n_cpu = n_cpu
- self.n_iter = n_iter
- self.e_value = e_value
- self.z_value = z_value
- self.filter_f1 = filter_f1
- self.filter_f2 = filter_f2
- self.filter_f3 = filter_f3
- self.incdom_e = incdom_e
- self.dom_e = dom_e
- self.get_tblout = get_tblout
-
- def query(self, input_fasta_path: str) -> Mapping[str, Any]:
- """Queries the database using Jackhmmer."""
- with tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
- sto_path = os.path.join(query_tmp_dir, 'output.sto')
-
- # The F1/F2/F3 are the expected proportion to pass each of the filtering
- # stages (which get progressively more expensive), reducing these
- # speeds up the pipeline at the expensive of sensitivity. They are
- # currently set very low to make querying Mgnify run in a reasonable
- # amount of time.
- cmd_flags = [
- # Don't pollute stdout with Jackhmmer output.
- '-o', '/dev/null',
- '-A', sto_path,
- '--noali',
- '--F1', str(self.filter_f1),
- '--F2', str(self.filter_f2),
- '--F3', str(self.filter_f3),
- '--incE', str(self.e_value),
- # Report only sequences with E-values <= x in per-sequence
- # output.
- '-E', str(self.e_value),
- '--cpu', str(self.n_cpu),
- '-N', str(self.n_iter)
- ]
- if self.get_tblout:
- tblout_path = os.path.join(query_tmp_dir, 'tblout.txt')
- cmd_flags.extend(['--tblout', tblout_path])
-
- if self.z_value:
- cmd_flags.extend(['-Z', str(self.z_value)])
-
- if self.dom_e is not None:
- cmd_flags.extend(['--domE', str(self.dom_e)])
-
- if self.incdom_e is not None:
- cmd_flags.extend(['--incdomE', str(self.incdom_e)])
-
- cmd = [self.binary_path] + cmd_flags + [input_fasta_path, self.database_path]
-
- logging.info('Launching subprocess "%s"', ' '.join(cmd))
- process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
- with timing(f'Jackhmmer ({os.path.basename(self.database_path)}) query'):
- _, stderr = process.communicate()
- retcode = process.wait()
-
- if retcode:
- raise RuntimeError('Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8'))
-
- # Get e-values for each target name
- tbl = ''
- if self.get_tblout:
- with open(tblout_path) as f:
- tbl = f.read()
-
- with open(sto_path) as f:
- sto = f.read()
-
- raw_output = dict(sto=sto, tbl=tbl, stderr=stderr, n_iter=self.n_iter, e_value=self.e_value)
- return raw_output
|