Browse Source

add QA logic

Sanyam Bhutani 1 month ago
parent
commit
714d3a5cee

+ 76 - 0
end-to-end-use-cases/data-tool/src/generate_qa.py

@@ -0,0 +1,76 @@
+# This script eats the parsed texts and gives QA pairs
+import os
+import json
+import argparse
+import sys
+import traceback
+from pathlib import Path
+
+# Annoying path bug fix
+sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from src.utils import QAGenerator
+
+def process_file(file_path, output_dir=None, model="llama-3.3-70b", 
+                api_key=None, text_file=None, num_pairs=25, threshold=7.0):
+    if output_dir is None:
+        output_dir = 'data/qa_pairs'
+    os.makedirs(output_dir, exist_ok=True)
+    if text_file and os.path.exists(text_file):
+        print(f"Using pre-processed text from {text_file}")
+        with open(text_file, 'r', encoding='utf-8') as f:
+            document_text = f.read()
+    else:
+        from src.main import process_file as parse_file
+        print(f"Parsing {file_path} to extract text...")
+        parsed_path = parse_file(file_path)
+        
+        print(f"Reading parsed content from {parsed_path}")
+        with open(parsed_path, 'r', encoding='utf-8') as f:
+            document_text = f.read()
+    generator = QAGenerator(api_key=api_key, model=model)
+    
+    result = generator.process_document(document_text,num_pairs=num_pairs,quality_threshold=threshold)
+    base_name = os.path.basename(file_path).split('.')[0]
+    output_path = os.path.join(output_dir, f"{base_name}_qa_pairs.json")
+    with open(output_path, 'w', encoding='utf-8') as f:
+        json.dump(result, f, indent=2)
+    print("Writing summary")
+    print(f"QA Pairs Generated: {result['metrics']['total']}")
+    print(f"Quality Pairs: {result['metrics']['filtered']} ({result['metrics']['retention_rate']*100:.1f}%)")
+    return output_path
+
+def main():
+    parser = argparse.ArgumentParser(
+        description='Generate QA pairs from documents using Llama',
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+    
+    parser.add_argument('input', help='File to parse and generate QA pairs from')
+    parser.add_argument('-o', '--output-dir', help='Where to save the output', default='data/qa_pairs')
+    parser.add_argument('-m', '--model', help='Model to use', default='llama-3.3-70b')
+    parser.add_argument('-k', '--api-key', help='API key (defaults to env var)')
+    parser.add_argument('-t', '--text-file', help='Path to already parsed text file (skips parsing step)')
+    parser.add_argument('-n', '--num-pairs', type=int, help='Target number of QA pairs to generate', default=25)
+    parser.add_argument('--threshold', type=float, help='Quality threshold for filtering pairs (1-10)', default=7.0)
+    
+    args = parser.parse_args()
+    
+    try:
+        output_path = process_file(
+            args.input,
+            args.output_dir,
+            args.model,
+            args.api_key,
+            args.text_file,
+            args.num_pairs,
+            args.threshold
+        )
+        print(f"QA pairs saved to {output_path}")
+        return 0
+    except Exception as e:
+        print(f"Error: {e}")
+        traceback.print_exc()
+        return 1
+
+if __name__ == '__main__':
+    sys.exit(main())

+ 54 - 2
end-to-end-use-cases/data-tool/src/main.py

@@ -1,15 +1,16 @@
 # Usage: python main.py mydocument.pdf
 import traceback
 import os
+import json
 import argparse
 import sys
 import time
 from pathlib import Path
 from typing import Optional
-
 # fix the annoying path bug
 sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 from src.parsers import PDFParser, HTMLParser, YouTubeParser, DOCXParser, TXTParser, PPTParser
+from src.utils import QAGenerator
 
 def determine_parser(file_path: str):
     if 'youtube.com' in file_path or 'youtu.be' in file_path:
@@ -95,15 +96,66 @@ def main():
     parser.add_argument('-o', '--output-dir', help='Where to save the output', default='data/output')
     parser.add_argument('-n', '--name', help='Custom output filename')
     parser.add_argument('-v', '--verbose', action='store_true', help='Show more details')
+    parser.add_argument('--generate-qa', action='store_true', help='Generate QA pairs after parsing')
+    # QA args
+    qa_group = parser.add_argument_group('QA Generation Options')
+    qa_group.add_argument('--qa-model', help='Model to use for QA generation', default='llama-3.1-70b-instruct')
+    qa_group.add_argument('--qa-pairs', type=int, help='Number of QA pairs to generate', default=25)
+    qa_group.add_argument('--qa-threshold', type=float, help='Quality threshold for QA pairs (1-10)', default=7.0)
+    qa_group.add_argument('--qa-output-dir', help='Directory to save QA pairs', default='data/qa_pairs')
     
     args = parser.parse_args()
     
     try:
         output_path = process_file(args.input, args.output_dir, args.name)
-        print(f"Saving text to....{output_path}")
+        print(f"Text extrctd to {output_path}")
+        if args.generate_qa:
+            try:
+                print("Generating QA Pairs")
+                
+                # Check for Cerebras API key
+                if not os.environ.get("CEREBRAS_API_KEY"):
+                    print("set the key")
+                    return 1
+                import subprocess
+                cmd = [
+                    sys.executable,
+                    os.path.join(os.path.dirname(os.path.abspath(__file__)), "generate_qa.py"),
+                    args.input,
+                    "--text-file", output_path,
+                    "--model", args.qa_model,
+                    "--num-pairs", str(args.qa_pairs),
+                    "--threshold", str(args.qa_threshold),
+                    "--output-dir", args.qa_output_dir
+                ]
+                
+                result = subprocess.run(cmd, capture_output=True, text=True)
+                
+                if result.returncode != 0:
+                    print(f"Eror:")
+                    print(result.stderr)
+                    return 1
+                else:
+                    print(result.stdout)
+                    
+            except ImportError:
+                print("Error with QA")
+                if args.verbose:
+                    import traceback
+                    traceback.print_exc()
+                return 1
+                    
+            except Exception as e:
+                print(f"Exception {e}")
+                if args.verbose:
+                    import traceback
+                    traceback.print_exc()
+                return 1
+                
         return 0
     except Exception as e:
         if args.verbose:
+            import traceback
             traceback.print_exc()
         else:
             print(f"Error: {e}")

end-to-end-use-cases/data-tool/src/parsers/utils/__init__.py → end-to-end-use-cases/data-tool/src/utils/__init__.py


end-to-end-use-cases/data-tool/src/parsers/utils/qa_generator.py → end-to-end-use-cases/data-tool/src/utils/qa_generator.py