123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258 |
- import os, json, logging, requests
- from datetime import timedelta
- from requests.auth import HTTPBasicAuth
- from airflow.models import DAG, Variable
- from airflow.utils.dates import days_ago
- from airflow.contrib.hooks.ssh_hook import SSHHook
- from airflow.operators.python_operator import PythonOperator
- from airflow.operators.python_operator import BranchPythonOperator
- from airflow.contrib.operators.ssh_operator import SSHOperator
- from igf_airflow.celery.check_celery_queue import fetch_queue_list_from_redis_server
- from igf_airflow.celery.check_celery_queue import calculate_new_workers
- CELERY_FLOWER_BASE_URL = Variable.get('celery_flower_base_url')
- args = {
- 'owner':'airflow',
- 'start_date':days_ago(2),
- 'retries': 1,
- 'retry_delay': timedelta(minutes=2),
- 'provide_context': True,
- }
- hpc_hook = SSHHook(ssh_conn_id='hpc_conn')
- dag = DAG(
- dag_id='dag1_calculate_hpc_worker',
- catchup=False,
- max_active_runs=1,
- schedule_interval="*/5 * * * *",
- default_args=args,
- tags=['igf-lims',]
- )
- def airflow_utils_for_redis(**kwargs):
- """
- A function for dag1, TO DO
- """
- try:
- if 'redis_conf_file' not in kwargs:
- raise ValueError('redis_conf_file info is not present in the kwargs')
- redis_conf_file = kwargs.get('redis_conf_file')
- json_data = dict()
- with open(redis_conf_file,'r') as jp:
- json_data = json.load(jp)
- if 'redis_db' not in json_data:
- raise ValueError('redis_db key not present in the conf file')
- url = json_data.get('redis_db')
- queue_list = fetch_queue_list_from_redis_server(url=url)
- return queue_list
- except Exception as e:
- logging.error('Failed to run, error:{0}'.format(e))
- raise
- def get_new_workers(**kwargs):
- try:
- if 'ti' not in kwargs:
- raise ValueError('ti not present in kwargs')
- ti = kwargs.get('ti')
- active_tasks = ti.xcom_pull(task_ids='fetch_active_jobs_from_hpc')
- active_tasks = active_tasks.decode()
- active_tasks = json.loads(active_tasks)
- queued_tasks = ti.xcom_pull(task_ids='fetch_queue_list_from_redis')
- worker_to_submit,unique_queue_list = \
- calculate_new_workers(
- queue_list=queued_tasks,
- active_jobs_dict=active_tasks,
- max_workers_per_queue=Variable.get('hpc_max_workers_per_queue'),
- max_total_workers=Variable.get('hpc_max_total_workers'))
- for key,value in worker_to_submit.items():
- ti.xcom_push(key=key,value=value)
- unique_queue_list = \
- [q for q in unique_queue_list if q.startswith('hpc')]
- celery_worker_key = kwargs['params'].get('celery_worker_key')
- base_queue = kwargs['params'].get('base_queue')
- empty_celery_worker_key = kwargs['params'].get('empty_celery_worker_key')
- celery_workers = \
- ti.xcom_pull(task_ids='fetch_celery_workers',key=celery_worker_key)
- cleanup_list = list()
- for flower_entry in celery_workers:
- active_jobs_count = flower_entry.get('active_jobs')
- worker_id = flower_entry.get('worker_id')
- queue_list = flower_entry.get('queue_lists')
- if base_queue not in queue_list and \
- len(set(queue_list).intersection(set(unique_queue_list))) == 0 and \
- active_jobs_count == 0:
- cleanup_list.\
- append(worker_id)
- if len(cleanup_list) > 0:
- ti.xcom_push(key=empty_celery_worker_key,value=cleanup_list)
- unique_queue_list.\
- append('cleanup_celery_workers')
- return unique_queue_list
- except Exception as e:
- logging.error('Failed to get new workers, error: {0}'.format(e))
- raise
- def fetch_celery_worker_list(**context):
- """
- A function for fetching list of celery workers from flower server
- """
- try:
- ti = context.get('ti')
- celery_worker_key = context['params'].get('celery_worker_key')
- celery_basic_auth = os.environ.get('AIRFLOW__CELERY__FLOWER_BASIC_AUTH')
- if celery_basic_auth is None:
- raise ValueError('Missing env for flower basic auth')
- flower_user, flower_pass = celery_basic_auth.split(':')
- celery_url = '{0}/api/workers'.format(CELERY_FLOWER_BASE_URL)
- res = requests.get(celery_url, auth=HTTPBasicAuth(flower_user, flower_pass))
- if res.status_code != 200:
- raise ValueError('Failed to fetch celery workers')
- data = res.content.decode()
- data = json.loads(data)
- worker_list = list()
- for worker_id, val in data.items():
- worker_list.append({
- 'worker_id': worker_id,
- 'active_jobs': len(val.get('active')),
- 'queue_lists': [i.get('name') for i in val.get('active_queues')]})
- ti.xcom_push(key=celery_worker_key,value=worker_list)
- except Exception as e:
- logging.error('Failed to get celery workers, error: {0}'.format(e))
- raise
- def stop_celery_workers(**context):
- """
- A function for stopping celery workers
- """
- try:
- ti = context.get('ti')
- empty_celery_worker_key = context['params'].get('empty_celery_worker_key')
- celery_basic_auth = os.environ['AIRFLOW__CELERY__FLOWER_BASIC_AUTH']
- flower_user, flower_pass = celery_basic_auth.split(':')
- celery_workers = \
- ti.xcom_pull(
- task_ids='calculate_new_worker_size_and_branch',
- key=empty_celery_worker_key)
- for worker_id in celery_workers:
- flower_shutdown_url = \
- '{0}/api/worker/shutdown/{1}'.\
- format(CELERY_FLOWER_BASE_URL, worker_id)
- res = requests.post(
- flower_shutdown_url,
- auth=HTTPBasicAuth(flower_user, flower_pass))
- if res.status_code != 200:
- raise ValueError('Failed to delete worker {0}'.\
- format(worker_id))
- except Exception as e:
- logging.error('Failed to stop celery workers, error: {0}'.format(e))
- raise
- with dag:
- ## TASK
- fetch_queue_list_from_redis = \
- PythonOperator(
- task_id='fetch_queue_list_from_redis',
- dag=dag,
- python_callable=airflow_utils_for_redis,
- op_kwargs={"redis_conf_file":Variable.get('redis_conn_file')},
- queue='generic')
- ## TASK
- check_hpc_queue = \
- SSHOperator(
- task_id='check_hpc_queue',
- ssh_hook=hpc_hook,
- dag=dag,
- command='source /etc/bashrc;qstat',
- queue='generic')
- ## TASK
- fetch_active_jobs_from_hpc = \
- SSHOperator(
- task_id='fetch_active_jobs_from_hpc',
- ssh_hook=hpc_hook,
- dag=dag,
- command="""
- source /etc/bashrc;\
- source /project/tgu/data2/airflow_test/secrets/hpc_env.sh;\
- python /project/tgu/data2/airflow_test/github/data-management-python/scripts/hpc/count_active_jobs_in_hpc.py """,
- do_xcom_push=True,
- queue='generic')
- ## TASK
- fetch_celery_workers = \
- PythonOperator(
- task_id='fetch_celery_workers',
- dag=dag,
- queue='generic',
- python_callable=fetch_celery_worker_list,
- params={'celery_worker_key':'celery_workers'}
- )
- ## TASK
- calculate_new_worker_size_and_branch = \
- BranchPythonOperator(
- task_id='calculate_new_worker_size_and_branch',
- dag=dag,
- python_callable=get_new_workers,
- queue='generic',
- params={'celery_worker_key':'celery_workers',
- 'empty_celery_worker_key':'empty_celery_worker',
- 'base_queue':'generic'})
- ## TASK
- queue_tasks = list()
- hpc_queue_list = Variable.get('hpc_queue_list')
- for q,data in hpc_queue_list.items():
- pbs_resource = data.get('pbs_resource')
- airflow_queue = data.get('airflow_queue')
- t = SSHOperator(
- task_id=q,
- ssh_hook=hpc_hook,
- dag=dag,
- queue='generic',
- command="""
- {% if ti.xcom_pull(key=params.job_name,task_ids="calculate_new_worker_size_and_branch" ) > 1 %}
- source /etc/bashrc; \
- qsub \
- -o /dev/null \
- -e /dev/null \
- -k n -m n \
- -N {{ params.job_name }} \
- -J 1-{{ ti.xcom_pull(key=params.job_name,task_ids="calculate_new_worker_size_and_branch" ) }} {{ params.pbs_resource }} -- \
- /project/tgu/data2/airflow_test/github/data-management-python/scripts/hpc/airflow_worker.sh {{ params.airflow_queue }} {{ params.job_name }}
- {% else %}
- source /etc/bashrc;\
- qsub \
- -o /dev/null \
- -e /dev/null \
- -k n -m n \
- -N {{ params.job_name }} {{ params.pbs_resource }} -- \
- /project/tgu/data2/airflow_test/github/data-management-python/scripts/hpc/airflow_worker.sh {{ params.airflow_queue }} {{ params.job_name }}
- {% endif %}
- """,
- params={'pbs_resource':pbs_resource,
- 'airflow_queue':airflow_queue,
- 'job_name':q})
- queue_tasks.\
- append(t)
- ## TASK
- cleanup_celery_workers = \
- PythonOperator(
- task_id='cleanup_celery_workers',
- dag=dag,
- queue='generic',
- params={'empty_celery_worker_key':'empty_celery_worker'},
- python_callable=stop_celery_workers)
- ## PIPELINE
- check_hpc_queue >> fetch_active_jobs_from_hpc
- calculate_new_worker_size_and_branch << \
- [fetch_queue_list_from_redis,
- fetch_active_jobs_from_hpc,
- fetch_celery_workers]
- calculate_new_worker_size_and_branch >> queue_tasks
- calculate_new_worker_size_and_branch >> cleanup_celery_workers
|