dag1_calculate_hpc_worker.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import os, json, logging, requests
  2. from airflow.models import DAG,Variable
  3. from airflow.operators.bash_operator import BashOperator
  4. from airflow.operators.python_operator import PythonOperator, BranchPythonOperator
  5. from airflow.contrib.operators.ssh_operator import SSHOperator
  6. from airflow.contrib.hooks.ssh_hook import SSHHook
  7. from airflow.utils.dates import days_ago
  8. from requests.auth import HTTPBasicAuth
  9. from igf_airflow.celery.check_celery_queue import fetch_queue_list_from_redis_server
  10. from igf_airflow.celery.check_celery_queue import calculate_new_workers
  11. CELERY_FLOWER_BASE_URL = Variable.get('celery_flower_base_url')
  12. args = {
  13. 'owner':'airflow',
  14. 'start_date':days_ago(2),
  15. 'provide_context': True,
  16. }
  17. hpc_hook = SSHHook(ssh_conn_id='hpc_conn')
  18. dag = DAG(
  19. dag_id='dag1_calculate_hpc_worker',
  20. catchup=False,
  21. max_active_runs=1,
  22. schedule_interval="*/15 * * * *",
  23. default_args=args,
  24. tags=['igf-lims',]
  25. )
  26. def airflow_utils_for_redis(**kwargs):
  27. """
  28. A function for dag1, TO DO
  29. """
  30. try:
  31. if 'redis_conf_file' not in kwargs:
  32. raise ValueError('redis_conf_file info is not present in the kwargs')
  33. redis_conf_file = kwargs.get('redis_conf_file')
  34. json_data = dict()
  35. with open(redis_conf_file,'r') as jp:
  36. json_data = json.load(jp)
  37. if 'redis_db' not in json_data:
  38. raise ValueError('redis_db key not present in the conf file')
  39. url = json_data.get('redis_db')
  40. queue_list = fetch_queue_list_from_redis_server(url=url)
  41. return queue_list
  42. except Exception as e:
  43. logging.error('Failed to run, error:{0}'.format(e))
  44. raise
  45. def get_new_workers(**kwargs):
  46. try:
  47. if 'ti' not in kwargs:
  48. raise ValueError('ti not present in kwargs')
  49. ti = kwargs.get('ti')
  50. active_tasks = ti.xcom_pull(task_ids='fetch_active_jobs_from_hpc')
  51. active_tasks = active_tasks.decode()
  52. active_tasks = json.loads(active_tasks)
  53. queued_tasks = ti.xcom_pull(task_ids='fetch_queue_list_from_redis')
  54. worker_to_submit,unique_queue_list = \
  55. calculate_new_workers(
  56. queue_list=queued_tasks,
  57. active_jobs_dict=active_tasks,
  58. max_workers_per_queue=Variable.get('hpc_max_workers_per_queue'),
  59. max_total_workers=Variable.get('hpc_max_total_workers'))
  60. for key,value in worker_to_submit.items():
  61. ti.xcom_push(key=key,value=value)
  62. unique_queue_list = \
  63. [q for q in unique_queue_list if q.startswith('hpc')]
  64. return unique_queue_list
  65. except Exception as e:
  66. logging.error('Failed to get new workers, error: {0}'.format(e))
  67. raise
  68. def fetch_celery_worker_list(**context):
  69. """
  70. A function for fetching list of celery workers from flower server
  71. """
  72. try:
  73. ti = context.get('ti')
  74. celery_worker_key = context['params'].get('celery_worker_key')
  75. celery_basic_auth = os.environ.get('AIRFLOW__CELERY__FLOWER_BASIC_AUTH')
  76. if celery_basic_auth is None:
  77. raise ValueError('Missing env for flower basic auth')
  78. flower_user, flower_pass = celery_basic_auth.split(':')
  79. celery_url = '{0}/api/workers'.format(CELERY_FLOWER_BASE_URL)
  80. res = requests.get(celery_url, auth=HTTPBasicAuth(flower_user, flower_pass))
  81. if res.status_code != 200:
  82. raise ValueError('Failed to fetch celery workers')
  83. data = res.content.decode()
  84. data = json.loads(data)
  85. worker_list = list()
  86. for worker_id, val in data.items():
  87. worker_list.append({
  88. 'worker_id': worker_id,
  89. 'active_jobs': len(val.get('active')),
  90. 'queue_lists': [i.get('name') for i in val.get('active_queues')]})
  91. ti.xcom_push(key=celery_worker_key,value=worker_list)
  92. except Exception as e:
  93. logging.error('Failed to get celery workers, error: {0}'.format(e))
  94. raise
  95. with dag:
  96. ## TASK
  97. fetch_queue_list_from_redis = \
  98. PythonOperator(
  99. task_id='fetch_queue_list_from_redis',
  100. dag=dag,
  101. python_callable=airflow_utils_for_redis,
  102. op_kwargs={"redis_conf_file":Variable.get('redis_conn_file')},
  103. queue='igf-lims')
  104. ## TASK
  105. check_hpc_queue = \
  106. SSHOperator(
  107. task_id='check_hpc_queue',
  108. ssh_hook=hpc_hook,
  109. dag=dag,
  110. command='source /etc/bashrc;qstat',
  111. queue='igf-lims')
  112. ## TASK
  113. fetch_active_jobs_from_hpc = \
  114. SSHOperator(
  115. task_id='fetch_active_jobs_from_hpc',
  116. ssh_hook=hpc_hook,
  117. dag=dag,
  118. command="""
  119. source /etc/bashrc;\
  120. source /project/tgu/data2/airflow_test/secrets/hpc_env.sh;\
  121. python /project/tgu/data2/airflow_test/github/data-management-python/scripts/hpc/count_active_jobs_in_hpc.py """,
  122. do_xcom_push=True,
  123. queue='igf-lims')
  124. ## TASK
  125. fetch_celery_workers = \
  126. PythonOperator(
  127. task_id='fetch_celery_workers',
  128. dag=dag,
  129. python_callable=fetch_celery_worker_list,
  130. params={'celery_worker_key':'celery_workers'}
  131. )
  132. ## TASK
  133. calculate_new_worker_size_and_branch = \
  134. BranchPythonOperator(
  135. task_id='calculate_new_worker_size_and_branch',
  136. dag=dag,
  137. python_callable=get_new_workers,
  138. queue='igf-lims')
  139. ## TASK
  140. queue_tasks = list()
  141. hpc_queue_list = Variable.get('hpc_queue_list')
  142. for q,data in hpc_queue_list.items():
  143. pbs_resource = data.get('pbs_resource')
  144. airflow_queue = data.get('airflow_queue')
  145. t = SSHOperator(
  146. task_id=q,
  147. ssh_hook=hpc_hook,
  148. dag=dag,
  149. queue='igf-lims',
  150. command="""
  151. {% if ti.xcom_pull(key=params.job_name,task_ids="calculate_new_worker_size_and_branch" ) > 1 %}
  152. source /etc/bashrc; \
  153. qsub \
  154. -o /dev/null \
  155. -e /dev/null \
  156. -k n -m n \
  157. -N {{ params.job_name }} \
  158. -J 1-{{ ti.xcom_pull(key=params.job_name,task_ids="calculate_new_worker_size_and_branch" ) }} {{ params.pbs_resource }} -- \
  159. /project/tgu/data2/airflow_test/github/data-management-python/scripts/hpc/airflow_worker.sh {{ params.airflow_queue }} {{ params.job_name }}
  160. {% else %}
  161. source /etc/bashrc;\
  162. qsub \
  163. -o /dev/null \
  164. -e /dev/null \
  165. -k n -m n \
  166. -N {{ params.job_name }} {{ params.pbs_resource }} -- \
  167. /project/tgu/data2/airflow_test/github/data-management-python/scripts/hpc/airflow_worker.sh {{ params.airflow_queue }} {{ params.job_name }}
  168. {% endif %}
  169. """,
  170. params={'pbs_resource':pbs_resource,
  171. 'airflow_queue':airflow_queue,
  172. 'job_name':q})
  173. queue_tasks.\
  174. append(t)
  175. ## PIPELINE
  176. check_hpc_queue >> fetch_active_jobs_from_hpc
  177. calculate_new_worker_size_and_branch << \
  178. [fetch_queue_list_from_redis,
  179. fetch_active_jobs_from_hpc,
  180. fetch_celery_workers]
  181. calculate_new_worker_size_and_branch >> queue_tasks