Browse Source

fix for file copy

Avik Datta 4 years ago
parent
commit
7457476eea
2 changed files with 28 additions and 11 deletions
  1. 1 0
      airflow_var/var.json
  2. 27 11
      dags/dag8_copy_ongoing_seqrun.py

+ 1 - 0
airflow_var/var.json

@@ -6,6 +6,7 @@
   "hpc_ssh_key_file":"/home/igf/.ssh/id_rsa",
   "seqrun_server":"",
   "seqrun_base_path":"",
+  "seqrun_server_user":"",
   "database_config_file":"",
   "slack_conf":"",
   "asana_conf":"",

+ 27 - 11
dags/dag8_copy_ongoing_seqrun.py

@@ -54,11 +54,13 @@ def get_ongoing_seqrun_list(**context):
     ti = context.get('ti')
     seqrun_server = Variable.get('seqrun_server')
     seqrun_base_path = Variable.get('seqrun_base_path')
+    seqrun_server_user = Variable.get('seqrun_server_user')
     database_config_file = Variable.get('database_config_file')
     ongoing_seqruns = \
       fetch_ongoing_seqruns(
         seqrun_server=seqrun_server,
         seqrun_base_path=seqrun_base_path,
+        user_name=seqrun_server_user,
         database_config_file=database_config_file)
     ti.xcom_push(key='ongoing_seqruns',value=ongoing_seqruns)
     branch_list = ['generate_seqrun_file_list_{0}'.format(i[0]) 
@@ -83,6 +85,7 @@ def get_ongoing_seqrun_list(**context):
       dag_id=context['task'].dag_id,
       comment=e,
       reaction='fail')
+    raise
 
 
 def copy_seqrun_manifest_file(**context):
@@ -91,8 +94,10 @@ def copy_seqrun_manifest_file(**context):
   """
   try:
     remote_file_path = context['params'].get('file_path')
-    seqrun_server = context['params'].get('seqrun_server')
+    seqrun_server = Variable.get('seqrun_server')
+    seqrun_server_user = Variable.get('seqrun_server_user')
     xcom_pull_task_ids = context['params'].get('xcom_pull_task_ids')
+    
     ti = context.get('ti')
     remote_file_path = ti.xcom_pull(task_ids=xcom_pull_task_ids)
     if remote_file_path is not None and \
@@ -103,10 +108,12 @@ def copy_seqrun_manifest_file(**context):
       os.path.join(
         tmp_work_dir,
         os.path.basename(remote_file_path))
+    remote_address = \
+      '{0}@{1}'.format(seqrun_server_user,seqrun_server)
     copy_remote_file(
       remote_file_path,
       local_file_path,
-      source_address=seqrun_server)
+      source_address=remote_address)
     return local_file_path
   except Exception as e:
     logging.error(e)
@@ -117,6 +124,7 @@ def copy_seqrun_manifest_file(**context):
       dag_id=context['task'].dag_id,
       comment=e,
       reaction='fail')
+    raise
 
 
 def get_seqrun_chunks(**context):
@@ -152,8 +160,9 @@ def get_seqrun_chunks(**context):
     else:
       chunk_size = int(len(file_data) / worker_size)+1
     ti.xcom_push(key=seqrun_chunk_size_key,value=chunk_size)
-    worker_branchs = ['{0}_{1}'.format(child_task_prefix,i) 
-                        for i in range(worker_size)]
+    worker_branchs = \
+      ['{0}_{1}'.format(child_task_prefix,i) 
+         for i in range(worker_size)]
     return worker_branchs
   except Exception as e:
     logging.error(e)
@@ -164,6 +173,7 @@ def get_seqrun_chunks(**context):
       dag_id=context['task'].dag_id,
       comment=e,
       reaction='fail')
+    raise
 
 
 def copy_seqrun_chunk(**context):
@@ -180,11 +190,15 @@ def copy_seqrun_chunk(**context):
     local_seqrun_path = context['params'].get('local_seqrun_path')
     seqrun_id_pull_key = context['params'].get('seqrun_id_pull_key')
     seqrun_id_pull_task_ids = context['params'].get('seqrun_id_pull_task_ids')
-    seqrun_server = Variable.get('seqrun_server'),
+    seqrun_server = Variable.get('seqrun_server')
+    seqrun_server_user = Variable.get('seqrun_server_user')
     seqrun_base_path = Variable.get('seqrun_base_path')
-    seqrun_id = ti.xcom_pull(key=seqrun_id_pull_key,task_ids=seqrun_id_pull_task_ids)[run_index_number]
-    file_path = ti.xcom_pull(task_ids=file_path_task_ids)
-    chunk_size = ti.xcom_pull(key=seqrun_chunk_size_key,task_ids=seqrun_chunk_size_task_ids)
+    seqrun_id = \
+      ti.xcom_pull(key=seqrun_id_pull_key,task_ids=seqrun_id_pull_task_ids)[run_index_number]
+    file_path = \
+      ti.xcom_pull(task_ids=file_path_task_ids)
+    chunk_size = \
+      ti.xcom_pull(key=seqrun_chunk_size_key,task_ids=seqrun_chunk_size_task_ids)
     check_file_path(file_path)
     file_data = read_json_data(file_path)
     start_index = chunk_index_number*chunk_size
@@ -193,6 +207,8 @@ def copy_seqrun_chunk(**context):
       finish_index = len(file_data) - 1
     local_seqrun_path = \
       os.path.join(local_seqrun_path,seqrun_id)
+    remote_address = \
+      '{0}@{1}'.format(seqrun_server_user,seqrun_server)
     for entry in file_data[start_index:finish_index]:
       file_path = entry.get('file_path')
       file_size = entry.get('file_size')
@@ -211,7 +227,7 @@ def copy_seqrun_chunk(**context):
         copy_remote_file(
           remote_path,
           local_path,
-          source_address=seqrun_server,
+          source_address=remote_address,
           check_file=False)
   except Exception as e:
     logging.error(e)
@@ -222,6 +238,7 @@ def copy_seqrun_chunk(**context):
       dag_id=context['task'].dag_id,
       comment=e,
       reaction='fail')
+    raise
 
 
 with dag:
@@ -267,8 +284,7 @@ with dag:
         dag=dag,
         pool='orwell_scp_pool',
         queue='hpc_4G',
-        params={'xcom_pull_task_ids':'generate_seqrun_file_list_{0}'.format(i),
-                'seqrun_server':Variable.get('seqrun_server')},
+        params={'xcom_pull_task_ids':'generate_seqrun_file_list_{0}'.format(i)},
         python_callable=copy_seqrun_manifest_file)
     ## TASK
     t3 = \