dag8_copy_ongoing_seqrun.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. from datetime import timedelta
  2. import os,json,logging
  3. from airflow.models import DAG,Variable
  4. from airflow.utils.dates import days_ago
  5. from airflow.operators.bash_operator import BashOperator
  6. from airflow.contrib.operators.ssh_operator import SSHOperator
  7. from airflow.operators.python_operator import PythonOperator,BranchPythonOperator
  8. from airflow.contrib.hooks.ssh_hook import SSHHook
  9. from airflow.operators.dummy_operator import DummyOperator
  10. from igf_airflow.seqrun.ongoing_seqrun_processing import fetch_ongoing_seqruns
  11. from igf_airflow.logging.upload_log_msg import send_log_to_channels,log_success,log_failure,log_sleep
  12. from igf_data.utils.fileutils import get_temp_dir,copy_remote_file,check_file_path,read_json_data
  13. ## DEFAULT ARGS
  14. default_args = {
  15. 'owner': 'airflow',
  16. 'depends_on_past': False,
  17. 'start_date': days_ago(2),
  18. 'email_on_failure': False,
  19. 'email_on_retry': False,
  20. 'retries': 1,
  21. 'retry_delay': timedelta(minutes=5),
  22. 'provide_context': True,
  23. }
  24. ## SSH HOOKS
  25. orwell_ssh_hook = \
  26. SSHHook(
  27. key_file=Variable.get('hpc_ssh_key_file'),
  28. username=Variable.get('hpc_user'),
  29. remote_host=Variable.get('seqrun_server'))
  30. ## DAG
  31. dag = \
  32. DAG(
  33. dag_id='dag8_copy_ongoing_seqrun',
  34. catchup=False,
  35. schedule_interval="0 */2 * * *",
  36. max_active_runs=1,
  37. tags=['hpc'],
  38. default_args=default_args,
  39. orientation='LR')
  40. ## FUNCTIONS
  41. def get_ongoing_seqrun_list(**context):
  42. """
  43. A function for fetching ongoing sequencing run ids
  44. """
  45. try:
  46. ti = context.get('ti')
  47. seqrun_server = Variable.get('seqrun_server')
  48. seqrun_base_path = Variable.get('seqrun_base_path')
  49. database_config_file = Variable.get('database_config_file')
  50. ongoing_seqruns = \
  51. fetch_ongoing_seqruns(
  52. seqrun_server=seqrun_server,
  53. seqrun_base_path=seqrun_base_path,
  54. database_config_file=database_config_file)
  55. ti.xcom_push(key='ongoing_seqruns',value=ongoing_seqruns)
  56. branch_list = ['generate_seqrun_file_list_{0}'.format(i[0])
  57. for i in enumerate(ongoing_seqruns)]
  58. if len(branch_list) == 0:
  59. branch_list = ['no_ongoing_seqrun']
  60. else:
  61. send_log_to_channels(
  62. slack_conf=Variable.get('slack_conf'),
  63. ms_teams_conf=Variable.get('ms_teams_conf'),
  64. task_id=context['task'].task_id,
  65. dag_id=context['task'].dag_id,
  66. comment='Ongoing seqruns found: {0}'.format(ongoing_seqruns),
  67. reaction='pass')
  68. return branch_list
  69. except Exception as e:
  70. logging.error(e)
  71. send_log_to_channels(
  72. slack_conf=Variable.get('slack_conf'),
  73. ms_teams_conf=Variable.get('ms_teams_conf'),
  74. task_id=context['task'].task_id,
  75. dag_id=context['task'].dag_id,
  76. comment=e,
  77. reaction='fail')
  78. def copy_seqrun_manifest_file(**context):
  79. """
  80. A function for copying filesize manifest for ongoing sequencing runs to hpc
  81. """
  82. try:
  83. remote_file_path = context['params'].get('file_path')
  84. seqrun_server = context['params'].get('seqrun_server')
  85. xcom_pull_task_ids = context['params'].get('xcom_pull_task_ids')
  86. ti = context.get('ti')
  87. remote_file_path = ti.xcom_pull(task_ids=xcom_pull_task_ids)
  88. if remote_file_path is not None and \
  89. not isinstance(remote_file_path,str):
  90. remote_file_path = remote_file_path.decode()
  91. tmp_work_dir = get_temp_dir(use_ephemeral_space=True)
  92. local_file_path = \
  93. os.path.join(
  94. tmp_work_dir,
  95. os.path.basename(remote_file_path))
  96. copy_remote_file(
  97. remote_file_path,
  98. local_file_path,
  99. source_address=seqrun_server)
  100. return local_file_path
  101. except Exception as e:
  102. logging.error(e)
  103. send_log_to_channels(
  104. slack_conf=Variable.get('slack_conf'),
  105. ms_teams_conf=Variable.get('ms_teams_conf'),
  106. task_id=context['task'].task_id,
  107. dag_id=context['task'].dag_id,
  108. comment=e,
  109. reaction='fail')
  110. def get_seqrun_chunks(**context):
  111. """
  112. A function for setting file chunk size for seqrun files copy
  113. """
  114. try:
  115. ti = context.get('ti')
  116. worker_size = context['params'].get('worker_size')
  117. child_task_prefix = context['params'].get('child_task_prefix')
  118. seqrun_chunk_size_key = context['params'].get('seqrun_chunk_size_key')
  119. xcom_pull_task_ids = context['params'].get('xcom_pull_task_ids')
  120. file_path = ti.xcom_pull(task_ids=xcom_pull_task_ids)
  121. if file_path is not None and \
  122. not isinstance(file_path,str):
  123. file_path = file_path.decode()
  124. check_file_path(file_path)
  125. file_data = read_json_data(file_path)
  126. chunk_size = None
  127. if worker_size is None or \
  128. worker_size == 0:
  129. raise ValueError(
  130. 'Incorrect worker size: {0}'.\
  131. format(worker_size))
  132. if len(file_data) == 0:
  133. raise ValueError(
  134. 'No data present in seqrun list file {0}'.\
  135. format(file_path))
  136. if len(file_data) < int(5 * worker_size):
  137. worker_size = 1 # setting worker size to 1 for low input
  138. if len(file_data) % worker_size == 0:
  139. chunk_size = int(len(file_data) / worker_size)
  140. else:
  141. chunk_size = int(len(file_data) / worker_size)+1
  142. ti.xcom_push(key=seqrun_chunk_size_key,value=chunk_size)
  143. worker_branchs = ['{0}_{1}'.format(child_task_prefix,i)
  144. for i in range(worker_size)]
  145. return worker_branchs
  146. except Exception as e:
  147. logging.error(e)
  148. send_log_to_channels(
  149. slack_conf=Variable.get('slack_conf'),
  150. ms_teams_conf=Variable.get('ms_teams_conf'),
  151. task_id=context['task'].task_id,
  152. dag_id=context['task'].dag_id,
  153. comment=e,
  154. reaction='fail')
  155. def copy_seqrun_chunk(**context):
  156. """
  157. A function for copying seqrun chunks
  158. """
  159. try:
  160. ti = context.get('ti')
  161. file_path_task_ids = context['params'].get('file_path_task_ids')
  162. seqrun_chunk_size_key = context['params'].get('seqrun_chunk_size_key')
  163. seqrun_chunk_size_task_ids = context['params'].get('seqrun_chunk_size_task_ids')
  164. chunk_index_number = context['params'].get('chunk_index_number')
  165. run_index_number = context['params'].get('run_index_number')
  166. local_seqrun_path = context['params'].get('local_seqrun_path')
  167. seqrun_id_pull_key = context['params'].get('seqrun_id_pull_key')
  168. seqrun_id_pull_task_ids = context['params'].get('seqrun_id_pull_task_ids')
  169. seqrun_server = Variable.get('seqrun_server'),
  170. seqrun_base_path = Variable.get('seqrun_base_path')
  171. seqrun_id = ti.xcom_pull(key=seqrun_id_pull_key,task_ids=seqrun_id_pull_task_ids)[run_index_number]
  172. file_path = ti.xcom_pull(task_ids=file_path_task_ids)
  173. chunk_size = ti.xcom_pull(key=seqrun_chunk_size_key,task_ids=seqrun_chunk_size_task_ids)
  174. check_file_path(file_path)
  175. file_data = read_json_data(file_path)
  176. start_index = chunk_index_number*chunk_size
  177. finish_index = ((chunk_index_number+1)*chunk_size) - 1
  178. if finish_index > len(file_data) - 1:
  179. finish_index = len(file_data) - 1
  180. local_seqrun_path = \
  181. os.path.join(local_seqrun_path,seqrun_id)
  182. for entry in file_data[start_index:finish_index]:
  183. file_path = entry.get('file_path')
  184. file_size = entry.get('file_size')
  185. remote_path = \
  186. os.path.join(
  187. seqrun_base_path,
  188. file_path)
  189. local_path = \
  190. os.path.join(
  191. local_seqrun_path,
  192. file_path)
  193. if os.path.exists(local_path) and \
  194. os.path.getsize(local_path) == file_size:
  195. pass
  196. else:
  197. copy_remote_file(
  198. remote_path,
  199. local_path,
  200. source_address=seqrun_server,
  201. check_file=False)
  202. except Exception as e:
  203. logging.error(e)
  204. send_log_to_channels(
  205. slack_conf=Variable.get('slack_conf'),
  206. ms_teams_conf=Variable.get('ms_teams_conf'),
  207. task_id=context['task'].task_id,
  208. dag_id=context['task'].dag_id,
  209. comment=e,
  210. reaction='fail')
  211. with dag:
  212. ## TASK
  213. generate_seqrun_list = \
  214. BranchPythonOperator(
  215. task_id='generate_seqrun_list',
  216. dag=dag,
  217. queue='hpc_4G',
  218. python_callable=get_ongoing_seqrun_list)
  219. ## TASK
  220. no_ongoing_seqrun = \
  221. DummyOperator(
  222. task_id='no_ongoing_seqrun',
  223. dag=dag,
  224. queue='hpc_4G',
  225. on_success_callback=log_sleep)
  226. ## TASK
  227. tasks = list()
  228. for i in range(5):
  229. t1 = \
  230. SSHOperator(
  231. task_id='generate_seqrun_file_list_{0}'.format(i),
  232. dag=dag,
  233. pool='orwell_exe_pool',
  234. ssh_hook=orwell_ssh_hook,
  235. do_xcom_push=True,
  236. queue='hpc_4G',
  237. params={'source_task_id':'generate_seqrun_list',
  238. 'pull_key':'ongoing_seqruns',
  239. 'index_number':i},
  240. command="""
  241. source /home/igf/igf_code/airflow/env.sh; \
  242. python /home/igf/igf_code/airflow/data-management-python/scripts/seqrun_processing/create_file_list_for_ongoing_seqrun.py \
  243. --seqrun_base_dir /home/igf/seqrun/illumina \
  244. --output_path /home/igf/ongoing_run_tracking \
  245. --seqrun_id {{ ti.xcom_pull(key=params.pull_key,task_ids=params.source_task_id)[ params.index_number ] }}
  246. """)
  247. ## TASK
  248. t2 = \
  249. PythonOperator(
  250. task_id='copy_seqrun_file_list_{0}'.format(i),
  251. dag=dag,
  252. pool='orwell_scp_pool',
  253. queue='hpc_4G',
  254. params={'xcom_pull_task_ids':'generate_seqrun_file_list_{0}'.format(i),
  255. 'seqrun_server':Variable.get('seqrun_server')},
  256. python_callable=copy_seqrun_manifest_file)
  257. ## TASK
  258. t3 = \
  259. BranchPythonOperator(
  260. task_id='decide_copy_branch_{0}'.format(i),
  261. dag=dag,
  262. queue='hpc_4G',
  263. params={'xcom_pull_task_ids':'copy_seqrun_file_list_{0}'.format(i),
  264. 'worker_size':10,
  265. 'seqrun_chunk_size_key':'seqrun_chunk_size',
  266. 'child_task_prefix':'copy_file_run_{0}_chunk_'.format(i)},
  267. python_callable=get_seqrun_chunks)
  268. ## TASK
  269. t4 = list()
  270. for j in range(10):
  271. t4j = \
  272. PythonOperator(
  273. task_id='copy_file_run_{0}_chunk_{1}'.format(i,j),
  274. dag=dag,
  275. queue='hpc_4G',
  276. pool='orwell_scp_pool',
  277. params={'file_path_task_ids':'copy_seqrun_file_list_{0}'.format(i),
  278. 'seqrun_chunk_size_key':'seqrun_chunk_size',
  279. 'seqrun_chunk_size_task_ids':'decide_copy_branch_{0}'.format(i),
  280. 'run_index_number':i,
  281. 'chunk_index_number':j,
  282. 'seqrun_id_pull_key':'ongoing_seqruns',
  283. 'seqrun_id_pull_task_ids':'generate_seqrun_list',
  284. 'local_seqrun_path':Variable.get('hpc_seqrun_path')},
  285. python_callable=copy_seqrun_chunk)
  286. t4.append(t4j)
  287. #tasks.append([ t1 >> t2 >> t3 >> t4 ])
  288. generate_seqrun_list >> t1 >> t2 >> t3 >> t4
  289. ## PIPELINE
  290. generate_seqrun_list >> no_ongoing_seqrun
  291. #generate_seqrun_list >> tasks