dag1_calculate_hpc_worker.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import json,logging
  2. from airflow.models import DAG,Variable
  3. from airflow.operators.bash_operator import BashOperator
  4. from airflow.operators.python_operator import PythonOperator,BranchPythonOperator
  5. from airflow.contrib.operators.ssh_operator import SSHOperator
  6. from airflow.contrib.hooks.ssh_hook import SSHHook
  7. from airflow.utils.dates import days_ago
  8. from igf_airflow.celery.check_celery_queue import fetch_queue_list_from_redis_server
  9. from igf_airflow.celery.check_celery_queue import calculate_new_workers
  10. args = {
  11. 'owner':'airflow',
  12. 'start_date':days_ago(2),
  13. 'provide_context': True,
  14. }
  15. hpc_hook = SSHHook(ssh_conn_id='hpc_conn')
  16. dag = DAG(
  17. dag_id='dag1_calculate_hpc_worker',
  18. catchup=False,
  19. max_active_runs=1,
  20. schedule_interval="*/15 * * * *",
  21. default_args=args,
  22. tags=['igf-lims',]
  23. )
  24. def airflow_utils_for_redis(**kwargs):
  25. """
  26. A function for dag1, TO DO
  27. """
  28. try:
  29. if 'redis_conf_file' not in kwargs:
  30. raise ValueError('redis_conf_file info is not present in the kwargs')
  31. redis_conf_file = kwargs.get('redis_conf_file')
  32. json_data = dict()
  33. with open(redis_conf_file,'r') as jp:
  34. json_data = json.load(jp)
  35. if 'redis_db' not in json_data:
  36. raise ValueError('redis_db key not present in the conf file')
  37. url = json_data.get('redis_db')
  38. queue_list = fetch_queue_list_from_redis_server(url=url)
  39. return queue_list
  40. except Exception as e:
  41. logging.error('Failed to run, error:{0}'.format(e))
  42. raise
  43. def get_new_workers(**kwargs):
  44. try:
  45. if 'ti' not in kwargs:
  46. raise ValueError('ti not present in kwargs')
  47. ti = kwargs.get('ti')
  48. active_tasks = ti.xcom_pull(task_ids='fetch_active_jobs_from_hpc')
  49. active_tasks = active_tasks.decode()
  50. active_tasks = json.loads(active_tasks)
  51. queued_tasks = ti.xcom_pull(task_ids='fetch_queue_list_from_redis')
  52. worker_to_submit,unique_queue_list = \
  53. calculate_new_workers(
  54. queue_list=queued_tasks,
  55. active_jobs_dict=active_tasks,
  56. max_workers_per_queue=Variable.get('hpc_max_workers_per_queue'),
  57. max_total_workers=Variable.get('hpc_max_total_workers'))
  58. for key,value in worker_to_submit.items():
  59. ti.xcom_push(key=key,value=value)
  60. unique_queue_list = \
  61. [q for q in unique_queue_list if q.startswith('hpc')]
  62. return unique_queue_list
  63. except Exception as e:
  64. logging.error('Failed to get new workers, error: {0}'.format(e))
  65. raise
  66. with dag:
  67. ## TASK
  68. fetch_queue_list_from_redis = \
  69. PythonOperator(
  70. task_id='fetch_queue_list_from_redis',
  71. dag=dag,
  72. python_callable=airflow_utils_for_redis,
  73. op_kwargs={"redis_conf_file":Variable.get('redis_conn_file')},
  74. queue='igf-lims')
  75. ## TASK
  76. check_hpc_queue = \
  77. SSHOperator(
  78. task_id='check_hpc_queue',
  79. ssh_hook=hpc_hook,
  80. dag=dag,
  81. command='source /etc/bashrc;qstat',
  82. queue='igf-lims')
  83. ## TASK
  84. fetch_active_jobs_from_hpc = \
  85. SSHOperator(
  86. task_id='fetch_active_jobs_from_hpc',
  87. ssh_hook=hpc_hook,
  88. dag=dag,
  89. command="""
  90. source /etc/bashrc;\
  91. source /project/tgu/data2/airflow_test/secrets/hpc_env.sh;\
  92. python /project/tgu/data2/airflow_test/github/data-management-python/scripts/hpc/count_active_jobs_in_hpc.py """,
  93. do_xcom_push=True,
  94. queue='igf-lims')
  95. ## TASK
  96. calculate_new_worker_size_and_branch = \
  97. BranchPythonOperator(
  98. task_id='calculate_new_worker_size_and_branch',
  99. dag=dag,
  100. python_callable=get_new_workers,
  101. queue='igf-lims')
  102. ## TASK
  103. queue_tasks = list()
  104. hpc_queue_list = Variable.get('hpc_queue_list')
  105. for q,data in hpc_queue_list.items():
  106. pbs_resource = data.get('pbs_resource')
  107. airflow_queue = data.get('airflow_queue')
  108. t = SSHOperator(
  109. task_id=q,
  110. ssh_hook=hpc_hook,
  111. dag=dag,
  112. queue='igf-lims',
  113. command="""
  114. {% if ti.xcom_pull(key=params.job_name,task_ids="calculate_new_worker_size_and_branch" ) > 1 %}
  115. source /etc/bashrc; \
  116. qsub \
  117. -o /dev/null \
  118. -e /dev/null \
  119. -k n -m n \
  120. -N {{ params.job_name }} \
  121. -J 1-{{ ti.xcom_pull(key=params.job_name,task_ids="calculate_new_worker_size_and_branch" ) }} {{ params.pbs_resource }} -- \
  122. /project/tgu/data2/airflow_test/github/data-management-python/scripts/hpc/airflow_worker.sh {{ params.airflow_queue }} {{ params.job_name }}
  123. {% else %}
  124. source /etc/bashrc;\
  125. qsub \
  126. -o /dev/null \
  127. -e /dev/null \
  128. -k n -m n \
  129. -N {{ params.job_name }} {{ params.pbs_resource }} -- \
  130. /project/tgu/data2/airflow_test/github/data-management-python/scripts/hpc/airflow_worker.sh {{ params.airflow_queue }} {{ params.job_name }}
  131. {% endif %}
  132. """,
  133. params={'pbs_resource':pbs_resource,
  134. 'airflow_queue':airflow_queue,
  135. 'job_name':q})
  136. queue_tasks.\
  137. append(t)
  138. ## PIPELINE
  139. check_hpc_queue >> fetch_active_jobs_from_hpc
  140. calculate_new_worker_size_and_branch << \
  141. [fetch_queue_list_from_redis,
  142. fetch_active_jobs_from_hpc]
  143. calculate_new_worker_size_and_branch >> queue_tasks