You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_tools.py 17 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. '''data tools'''
  16. import glob
  17. import os
  18. import subprocess
  19. import contextlib
  20. import shutil
  21. import tempfile
  22. import time
  23. from typing import Any, Mapping, Optional, Sequence
  24. from absl import logging
  25. _HHBLITS_DEFAULT_P = 20
  26. _HHBLITS_DEFAULT_Z = 500
  27. def _to_a3m(sequences: Sequence[str]) -> str:
  28. """Converts sequences to an a3m file."""
  29. names = ['sequence %d' % i for i in range(1, len(sequences) + 1)]
  30. a3m = []
  31. for sequence, name in zip(sequences, names):
  32. a3m.append(u'>' + name + u'\n')
  33. a3m.append(sequence + u'\n')
  34. return ''.join(a3m)
  35. class Kalign:
  36. """Python wrapper of the Kalign binary."""
  37. def __init__(self, *, binary_path: str):
  38. """Initializes the Python Kalign wrapper.
  39. Args:
  40. binary_path: The path to the Kalign binary.
  41. """
  42. self.binary_path = binary_path
  43. def align(self, sequences: Sequence[str]) -> str:
  44. """Aligns the sequences and returns the alignment in A3M string.
  45. Args:
  46. sequences: A list of query sequence strings. The sequences have to be at
  47. least 6 residues long (Kalign requires this). Note that the order in
  48. which you give the sequences might alter the output slightly as
  49. different alignment tree might get constructed.
  50. Returns:
  51. A string with the alignment in a3m format.
  52. Raises:
  53. RuntimeError: If Kalign fails.
  54. ValueError: If any of the sequences is less than 6 residues long.
  55. """
  56. logging.info('Aligning %d sequences', len(sequences))
  57. for s in sequences:
  58. if len(s) < 6:
  59. raise ValueError('Kalign requires all sequences to be at least 6 '
  60. 'residues long. Got %s (%d residues).' % (s, len(s)))
  61. with tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
  62. input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta')
  63. output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
  64. with open(input_fasta_path, 'w') as f:
  65. f.write(_to_a3m(sequences))
  66. cmd = [self.binary_path, '-i', input_fasta_path, '-o', output_a3m_path, '-format', 'fasta',]
  67. logging.info('Launching subprocess "%s"', ' '.join(cmd))
  68. process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
  69. with timing('Kalign query'):
  70. stdout, stderr = process.communicate()
  71. retcode = process.wait()
  72. logging.info('Kalign stdout:\n%s\n\nstderr:\n%s\n', stdout.decode('utf-8'), stderr.decode('utf-8'))
  73. if retcode:
  74. raise RuntimeError(
  75. 'Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' % (stdout.decode('utf-8'), stderr.decode('utf-8')))
  76. with open(output_a3m_path) as f:
  77. a3m = f.read()
  78. return a3m
  79. @contextlib.contextmanager
  80. def tmpdir_manager(base_dir: Optional[str] = None):
  81. """Context manager that deletes a temporary directory on exit."""
  82. tmpdir = tempfile.mkdtemp(dir=base_dir)
  83. try:
  84. yield tmpdir
  85. finally:
  86. shutil.rmtree(tmpdir, ignore_errors=True)
  87. @contextlib.contextmanager
  88. def timing(msg: str):
  89. logging.info('Started %s', msg)
  90. tic = time.time()
  91. yield
  92. toc = time.time()
  93. logging.info('Finished %s in %.3f seconds', msg, toc - tic)
  94. class HHBlits:
  95. """Python wrapper of the HHblits binary."""
  96. def __init__(self,
  97. *,
  98. binary_path: str,
  99. databases: Sequence[str],
  100. n_cpu: int = 4,
  101. n_iter: int = 3,
  102. e_value: float = 0.001,
  103. maxseq: int = 1_000_000,
  104. realign_max: int = 100_000,
  105. maxfilt: int = 100_000,
  106. min_prefilter_hits: int = 1000,
  107. all_seqs: bool = False,
  108. alt: Optional[int] = None,
  109. p: int = _HHBLITS_DEFAULT_P,
  110. z: int = _HHBLITS_DEFAULT_Z):
  111. """Initializes the Python HHblits wrapper.
  112. Args:
  113. binary_path: The path to the HHblits executable.
  114. databases: A sequence of HHblits database paths. This should be the
  115. common prefix for the database files (i.e. up to but not including
  116. _hhm.ffindex etc.)
  117. n_cpu: The number of CPUs to give HHblits.
  118. n_iter: The number of HHblits iterations.
  119. e_value: The E-value, see HHblits docs for more details.
  120. maxseq: The maximum number of rows in an input alignment. Note that this
  121. parameter is only supported in HHBlits version 3.1 and higher.
  122. realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
  123. maxfilt: Max number of hits allowed to pass the 2nd prefilter.
  124. HHblits default: 20000.
  125. min_prefilter_hits: Min number of hits to pass prefilter.
  126. HHblits default: 100.
  127. all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
  128. HHblits default: False.
  129. alt: Show up to this many alternative alignments.
  130. p: Minimum Prob for a hit to be included in the output hhr file.
  131. HHblits default: 20.
  132. z: Hard cap on number of hits reported in the hhr file.
  133. HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.
  134. Raises:
  135. RuntimeError: If HHblits binary not found within the path.
  136. """
  137. self.binary_path = binary_path
  138. self.databases = databases
  139. for database_path in self.databases:
  140. if not glob.glob(database_path + '_*'):
  141. logging.error('Could not find HHBlits database %s', database_path)
  142. raise ValueError(f'Could not find HHBlits database {database_path}')
  143. self.n_cpu = n_cpu
  144. self.n_iter = n_iter
  145. self.e_value = e_value
  146. self.maxseq = maxseq
  147. self.realign_max = realign_max
  148. self.maxfilt = maxfilt
  149. self.min_prefilter_hits = min_prefilter_hits
  150. self.all_seqs = all_seqs
  151. self.alt = alt
  152. self.p = p
  153. self.z = z
  154. def query(self, input_fasta_path: str) -> Mapping[str, Any]:
  155. """Queries the database using HHblits."""
  156. with tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
  157. a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
  158. db_cmd = []
  159. for db_path in self.databases:
  160. db_cmd.append('-d')
  161. db_cmd.append(db_path)
  162. cmd = [
  163. self.binary_path,
  164. '-i', input_fasta_path,
  165. '-cpu', str(self.n_cpu),
  166. '-oa3m', a3m_path,
  167. '-o', '/dev/null',
  168. '-n', str(self.n_iter),
  169. '-e', str(self.e_value),
  170. '-maxseq', str(self.maxseq),
  171. '-realign_max', str(self.realign_max),
  172. '-maxfilt', str(self.maxfilt),
  173. '-min_prefilter_hits', str(self.min_prefilter_hits)]
  174. if self.all_seqs:
  175. cmd += ['-all']
  176. if self.alt:
  177. cmd += ['-alt', str(self.alt)]
  178. if self.p != _HHBLITS_DEFAULT_P:
  179. cmd += ['-p', str(self.p)]
  180. if self.z != _HHBLITS_DEFAULT_Z:
  181. cmd += ['-Z', str(self.z)]
  182. cmd += db_cmd
  183. logging.info('Launching subprocess "%s"', ' '.join(cmd))
  184. process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
  185. with timing('HHblits query'):
  186. stdout, stderr = process.communicate()
  187. retcode = process.wait()
  188. if retcode:
  189. # Logs have a 15k character limit, so log HHblits error line by
  190. # line.
  191. logging.error('HHblits failed. HHblits stderr begin:')
  192. for error_line in stderr.decode('utf-8').splitlines():
  193. if error_line.strip():
  194. logging.error(error_line.strip())
  195. logging.error('HHblits stderr end')
  196. raise RuntimeError('HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % (
  197. stdout.decode('utf-8'), stderr[:500_000].decode('utf-8')))
  198. with open(a3m_path) as f:
  199. a3m = f.read()
  200. raw_output = dict(
  201. a3m=a3m,
  202. output=stdout,
  203. stderr=stderr,
  204. n_iter=self.n_iter,
  205. e_value=self.e_value)
  206. return raw_output
  207. class HHSearch:
  208. """Python wrapper of the HHsearch binary."""
  209. def __init__(self,
  210. *,
  211. binary_path: str,
  212. databases: Sequence[str],
  213. maxseq: int = 1_000_000):
  214. """Initializes the Python HHsearch wrapper.
  215. Args:
  216. binary_path: The path to the HHsearch executable.
  217. databases: A sequence of HHsearch database paths. This should be the
  218. common prefix for the database files (i.e. up to but not including
  219. _hhm.ffindex etc.)
  220. maxseq: The maximum number of rows in an input alignment. Note that this
  221. parameter is only supported in HHBlits version 3.1 and higher.
  222. Raises:
  223. RuntimeError: If HHsearch binary not found within the path.
  224. """
  225. self.binary_path = binary_path
  226. self.databases = databases
  227. self.maxseq = maxseq
  228. for database_path in self.databases:
  229. if not glob.glob(database_path + '_*'):
  230. logging.error(
  231. 'Could not find HHsearch database %s',
  232. database_path)
  233. raise ValueError(
  234. f'Could not find HHsearch database {database_path}')
  235. def query(self, a3m: str) -> str:
  236. """Queries the database using HHsearch using a given a3m."""
  237. with tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
  238. input_path = os.path.join(query_tmp_dir, 'query.a3m')
  239. hhr_path = os.path.join(query_tmp_dir, 'output.hhr')
  240. with open(input_path, 'w') as f:
  241. f.write(a3m)
  242. db_cmd = []
  243. for db_path in self.databases:
  244. db_cmd.append('-d')
  245. db_cmd.append(db_path)
  246. cmd = [self.binary_path,
  247. '-i', input_path,
  248. '-o', hhr_path,
  249. '-maxseq', str(self.maxseq),
  250. '-cpu', '8',] + db_cmd
  251. logging.info('Launching subprocess "%s"', ' '.join(cmd))
  252. process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
  253. with timing('HHsearch query'):
  254. stdout, stderr = process.communicate()
  255. retcode = process.wait()
  256. if retcode:
  257. # Stderr is truncated to prevent proto size errors in Beam.
  258. raise RuntimeError('HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
  259. stdout.decode('utf-8'), stderr[:100_000].decode('utf-8')))
  260. with open(hhr_path) as f:
  261. hhr = f.read()
  262. return hhr
  263. class Jackhmmer:
  264. """Python wrapper of the Jackhmmer binary."""
  265. def __init__(self,
  266. *,
  267. binary_path: str,
  268. database_path: str,
  269. n_cpu: int = 8,
  270. n_iter: int = 1,
  271. e_value: float = 0.0001,
  272. z_value: Optional[int] = None,
  273. get_tblout: bool = False,
  274. filter_f1: float = 0.0005,
  275. filter_f2: float = 0.00005,
  276. filter_f3: float = 0.0000005,
  277. incdom_e: Optional[float] = None,
  278. dom_e: Optional[float] = None):
  279. """Initializes the Python Jackhmmer wrapper.
  280. Args:
  281. binary_path: The path to the jackhmmer executable.
  282. database_path: The path to the jackhmmer database (FASTA format).
  283. n_cpu: The number of CPUs to give Jackhmmer.
  284. n_iter: The number of Jackhmmer iterations.
  285. e_value: The E-value, see Jackhmmer docs for more details.
  286. z_value: The Z-value, see Jackhmmer docs for more details.
  287. get_tblout: Whether to save tblout string.
  288. filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
  289. filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
  290. filter_f3: Forward pre-filter, set to >1.0 to turn off.
  291. incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
  292. round.
  293. dom_e: Domain e-value criteria for inclusion in tblout.
  294. """
  295. self.binary_path = binary_path
  296. self.database_path = database_path
  297. if not os.path.exists(self.database_path):
  298. logging.error(
  299. 'Could not find Jackhmmer database %s',
  300. database_path)
  301. raise ValueError(
  302. f'Could not find Jackhmmer database {database_path}')
  303. self.n_cpu = n_cpu
  304. self.n_iter = n_iter
  305. self.e_value = e_value
  306. self.z_value = z_value
  307. self.filter_f1 = filter_f1
  308. self.filter_f2 = filter_f2
  309. self.filter_f3 = filter_f3
  310. self.incdom_e = incdom_e
  311. self.dom_e = dom_e
  312. self.get_tblout = get_tblout
  313. def query(self, input_fasta_path: str) -> Mapping[str, Any]:
  314. """Queries the database using Jackhmmer."""
  315. with tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
  316. sto_path = os.path.join(query_tmp_dir, 'output.sto')
  317. # The F1/F2/F3 are the expected proportion to pass each of the filtering
  318. # stages (which get progressively more expensive), reducing these
  319. # speeds up the pipeline at the expensive of sensitivity. They are
  320. # currently set very low to make querying Mgnify run in a reasonable
  321. # amount of time.
  322. cmd_flags = [
  323. # Don't pollute stdout with Jackhmmer output.
  324. '-o', '/dev/null',
  325. '-A', sto_path,
  326. '--noali',
  327. '--F1', str(self.filter_f1),
  328. '--F2', str(self.filter_f2),
  329. '--F3', str(self.filter_f3),
  330. '--incE', str(self.e_value),
  331. # Report only sequences with E-values <= x in per-sequence
  332. # output.
  333. '-E', str(self.e_value),
  334. '--cpu', str(self.n_cpu),
  335. '-N', str(self.n_iter)
  336. ]
  337. if self.get_tblout:
  338. tblout_path = os.path.join(query_tmp_dir, 'tblout.txt')
  339. cmd_flags.extend(['--tblout', tblout_path])
  340. if self.z_value:
  341. cmd_flags.extend(['-Z', str(self.z_value)])
  342. if self.dom_e is not None:
  343. cmd_flags.extend(['--domE', str(self.dom_e)])
  344. if self.incdom_e is not None:
  345. cmd_flags.extend(['--incdomE', str(self.incdom_e)])
  346. cmd = [self.binary_path] + cmd_flags + [input_fasta_path, self.database_path]
  347. logging.info('Launching subprocess "%s"', ' '.join(cmd))
  348. process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
  349. with timing(f'Jackhmmer ({os.path.basename(self.database_path)}) query'):
  350. _, stderr = process.communicate()
  351. retcode = process.wait()
  352. if retcode:
  353. raise RuntimeError('Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8'))
  354. # Get e-values for each target name
  355. tbl = ''
  356. if self.get_tblout:
  357. with open(tblout_path) as f:
  358. tbl = f.read()
  359. with open(sto_path) as f:
  360. sto = f.read()
  361. raw_output = dict(sto=sto, tbl=tbl, stderr=stderr, n_iter=self.n_iter, e_value=self.e_value)
  362. return raw_output