|
@@ -1,15 +1,16 @@
|
|
# Usage: python main.py mydocument.pdf
|
|
# Usage: python main.py mydocument.pdf
|
|
import traceback
|
|
import traceback
|
|
import os
|
|
import os
|
|
|
|
+import json
|
|
import argparse
|
|
import argparse
|
|
import sys
|
|
import sys
|
|
import time
|
|
import time
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
from typing import Optional
|
|
-
|
|
|
|
# fix the annoying path bug
|
|
# fix the annoying path bug
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
from src.parsers import PDFParser, HTMLParser, YouTubeParser, DOCXParser, TXTParser, PPTParser
|
|
from src.parsers import PDFParser, HTMLParser, YouTubeParser, DOCXParser, TXTParser, PPTParser
|
|
|
|
+from src.utils import QAGenerator
|
|
|
|
|
|
def determine_parser(file_path: str):
|
|
def determine_parser(file_path: str):
|
|
if 'youtube.com' in file_path or 'youtu.be' in file_path:
|
|
if 'youtube.com' in file_path or 'youtu.be' in file_path:
|
|
@@ -95,15 +96,66 @@ def main():
|
|
parser.add_argument('-o', '--output-dir', help='Where to save the output', default='data/output')
|
|
parser.add_argument('-o', '--output-dir', help='Where to save the output', default='data/output')
|
|
parser.add_argument('-n', '--name', help='Custom output filename')
|
|
parser.add_argument('-n', '--name', help='Custom output filename')
|
|
parser.add_argument('-v', '--verbose', action='store_true', help='Show more details')
|
|
parser.add_argument('-v', '--verbose', action='store_true', help='Show more details')
|
|
|
|
+ parser.add_argument('--generate-qa', action='store_true', help='Generate QA pairs after parsing')
|
|
|
|
+ # QA args
|
|
|
|
+ qa_group = parser.add_argument_group('QA Generation Options')
|
|
|
|
+ qa_group.add_argument('--qa-model', help='Model to use for QA generation', default='llama-3.1-70b-instruct')
|
|
|
|
+ qa_group.add_argument('--qa-pairs', type=int, help='Number of QA pairs to generate', default=25)
|
|
|
|
+ qa_group.add_argument('--qa-threshold', type=float, help='Quality threshold for QA pairs (1-10)', default=7.0)
|
|
|
|
+ qa_group.add_argument('--qa-output-dir', help='Directory to save QA pairs', default='data/qa_pairs')
|
|
|
|
|
|
args = parser.parse_args()
|
|
args = parser.parse_args()
|
|
|
|
|
|
try:
|
|
try:
|
|
output_path = process_file(args.input, args.output_dir, args.name)
|
|
output_path = process_file(args.input, args.output_dir, args.name)
|
|
- print(f"Saving text to....{output_path}")
|
|
|
|
|
|
+ print(f"Text extrctd to {output_path}")
|
|
|
|
+ if args.generate_qa:
|
|
|
|
+ try:
|
|
|
|
+ print("Generating QA Pairs")
|
|
|
|
+
|
|
|
|
+ # Check for Cerebras API key
|
|
|
|
+ if not os.environ.get("CEREBRAS_API_KEY"):
|
|
|
|
+ print("set the key")
|
|
|
|
+ return 1
|
|
|
|
+ import subprocess
|
|
|
|
+ cmd = [
|
|
|
|
+ sys.executable,
|
|
|
|
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), "generate_qa.py"),
|
|
|
|
+ args.input,
|
|
|
|
+ "--text-file", output_path,
|
|
|
|
+ "--model", args.qa_model,
|
|
|
|
+ "--num-pairs", str(args.qa_pairs),
|
|
|
|
+ "--threshold", str(args.qa_threshold),
|
|
|
|
+ "--output-dir", args.qa_output_dir
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
+
|
|
|
|
+ if result.returncode != 0:
|
|
|
|
+ print(f"Eror:")
|
|
|
|
+ print(result.stderr)
|
|
|
|
+ return 1
|
|
|
|
+ else:
|
|
|
|
+ print(result.stdout)
|
|
|
|
+
|
|
|
|
+ except ImportError:
|
|
|
|
+ print("Error with QA")
|
|
|
|
+ if args.verbose:
|
|
|
|
+ import traceback
|
|
|
|
+ traceback.print_exc()
|
|
|
|
+ return 1
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"Exception {e}")
|
|
|
|
+ if args.verbose:
|
|
|
|
+ import traceback
|
|
|
|
+ traceback.print_exc()
|
|
|
|
+ return 1
|
|
|
|
+
|
|
return 0
|
|
return 0
|
|
except Exception as e:
|
|
except Exception as e:
|
|
if args.verbose:
|
|
if args.verbose:
|
|
|
|
+ import traceback
|
|
traceback.print_exc()
|
|
traceback.print_exc()
|
|
else:
|
|
else:
|
|
print(f"Error: {e}")
|
|
print(f"Error: {e}")
|