You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

workflow.py 2.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Fri Nov 27 19:33:51 2020
  5. @author: ljia
  6. """
  7. import os
  8. import numpy as np
  9. import pickle
  10. from gklearn.dataset import Dataset
  11. from gklearn.model_learning import NestedCV
  12. from gklearn.kernels import GRAPH_KERNELS
  13. class Workflow(object):
  14. def __init__(self, **kwargs):
  15. self._job_prefix = kwargs.get('job_prefix', 'gktask')
  16. self._max_num_running_tasks = kwargs.get('max_num_running_tasks', np.inf)
  17. self._root_dir = kwargs.get('root_dir', 'outputs/')
  18. def run(self, tasks):
  19. ### Check inputs.
  20. if self._check_inputs(tasks):
  21. self._tasks = tasks
  22. else:
  23. raise ValueError('The input "tasks" is not correct.')
  24. ### Sort tasks.
  25. self.sort_tasks_by_complexity()
  26. ### The main process.
  27. complete = False
  28. while not complete:
  29. self.get_running_tasks()
  30. if self._num_running_tasks < self._max_num_running_tasks:
  31. ### Load results from table.
  32. self.load_results_from_table()
  33. for task in self._tasks:
  34. state = self.get_task_state(task)
  35. if state != 'complete' and state != 'runnning':
  36. self.run_task(task)
  37. if self._num_running_tasks >= self._max_num_running_tasks:
  38. break
  39. ### Save results.
  40. self.save_results()
  41. complete = self.check_completeness()
  42. # sleep()
  43. def _check_inputs(self, tasks):
  44. if not isinstance(tasks, list):
  45. return False
  46. else:
  47. for i in tasks:
  48. if not 'kernel' in i or not 'dataset' in i:
  49. return False
  50. return True
  51. def sort_tasks_by_complexity(self):
  52. return
  53. def get_running_tasks(self):
  54. command = 'squeue --user $USER --format "%.50j" --noheader'
  55. stream = os.popen(command)
  56. output = stream.readlines()
  57. running_tasks = [o for o in output if o.strip().startswith(self._job_prefix)]
  58. self._num_running_tasks = len(running_tasks)
  59. def load_results_from_table(self):
  60. pass
  61. def get_task_state(self, task):
  62. task_dir = os.path.join(self._root_dir, task['kernel'] + '.' + task['dataset'] + '/')
  63. fn_summary = os.path.join(task_dir, 'results_summary.pkl')
  64. if os.path.isfile(fn_summary):
  65. output = pickle.loads(fn_summary)
  66. state = output['state']
  67. return state
  68. else:
  69. return 'unstarted'
  70. def run_task(self, task):
  71. ds_name = task['dataset']
  72. k_name = task['kernel']
  73. # Get dataset.
  74. ds = Dataset(ds_name)
  75. graph_kernel = GRAPH_KERNELS[k_name]
  76. # Start CV.
  77. results = NestedCV(ds, graph_kernel)

A Python package for graph kernels, graph edit distances and graph pre-image problem.