Browse Source

Add QA gen

Sanyam Bhutani 1 month ago
parent
commit
1fa09829b6

+ 3 - 0
end-to-end-use-cases/data-tool/src/parsers/utils/__init__.py

@@ -0,0 +1,3 @@
+from .qa_generator import QAGenerator
+
+__all__ = ['QAGenerator']

+ 269 - 0
end-to-end-use-cases/data-tool/src/parsers/utils/qa_generator.py

@@ -0,0 +1,269 @@
+import os
+import json
+import re
+import time
+import subprocess
+import sys
+from cerebras.cloud.sdk import Cerebras
+from typing import Dict, List, Any, Optional, Tuple
+    
+class QAGenerator:
+    def __init__(self, api_key=None, model="llama-3.3-70b"):
+        if api_key is None:
+            api_key = os.environ.get("CEREBRAS_API_KEY")
+            if not api_key:
+                raise ValueError("Set Key")
+        self.model = model
+        self.client = Cerebras(api_key=api_key)
+    
+    def generate_summary(self, document_text):
+        print(f"Document summarising...")
+        prompt = """You are summarizing a document for LLM training data preparation.
+        
+Write a comprehensive summary that captures:
+1. The document's title and main topic
+2. Key concepts and terminology
+3. Main arguments or findings
+4. Notable data points or examples
+5. Overall structure and purpose
+
+Your summary should be detailed enough that it could be used to understand what kinds of questions can be asked about this document. Focus on factual information rather than opinions."""
+        
+        response = self.client.chat.completions.create(
+            messages=[
+                {"role": "system", "content": prompt},
+                {"role": "user", "content": document_text}
+            ],
+            model=self.model,
+            temperature=0.1,
+        )
+        
+        summary = response.choices[0].message.content
+        print(f"Summary generated ({len(summary)} chars)")
+        return summary
+    
+    def generate_qa_pairs(self, document_text, summary, chunk_size=4000, num_pairs=25):
+        print(f"Generating QA pairs")
+        chunks = self._split_into_chunks(document_text, chunk_size)
+        print(f"Document split into {len(chunks)} chunks")
+        all_qa_pairs = []
+        pairs_per_chunk = max(1, round(num_pairs / len(chunks)))
+        
+        for i, chunk in enumerate(chunks):
+            print(f"Processing chunk {i+1}/{len(chunks)} (target: ~{pairs_per_chunk} pairs)")
+            prompt = f"""You are creating question-answer pairs for fine-tuning LLMs.
+
+Below is a chunk of text from a document about: {summary[:100]}...
+
+Your task is to create {pairs_per_chunk} high-quality question-answer pairs based ONLY on the information in this text chunk.
+
+For each pair:
+1. Create a specific, clear question about important information
+2. Provide a comprehensive but concise answer
+3. Focus on technical details and specific facts
+4. Ensure answers are directly supported by the text
+
+Return ONLY valid JSON formatted as:
+[
+  {{
+    "question": "Detailed question about the content?",
+    "answer": "Comprehensive answer based on the text."
+  }},
+  ...
+]
+
+Here is the text chunk:
+---
+{chunk}
+---"""
+            response = self.client.chat.completions.create(
+                messages=[
+                    {"role": "system", "content": prompt}
+                ],
+                model=self.model,
+                temperature=0.7,
+            )
+            try:
+                qa_text = response.choices[0].message.content
+                chunk_pairs = self._extract_qa_pairs(qa_text)
+                all_qa_pairs.extend(chunk_pairs)
+                print(f"  Generated {len(chunk_pairs)} pairs from chunk {i+1}")
+            except Exception as e:
+                print(f"  Error processing chunk {i+1}: {str(e)}")
+        
+        print(f"Generated {len(all_qa_pairs)} QA pairs")
+        return all_qa_pairs
+    
+    def rate_qa_pairs(self, qa_pairs, summary, threshold=7.0):
+        if not qa_pairs:
+            return [], {"total": 0, "filtered": 0, "retention_rate": 0, "avg_score": 0}
+        
+        print(f"Evaluating pairs {len(qa_pairs)}")
+        batch_size = 8
+        batches = [qa_pairs[i:i+batch_size] for i in range(0, len(qa_pairs), batch_size)]
+        
+        rated_pairs = []
+        total_score = 0
+        for i, batch in enumerate(batches):
+            print(f"Rating batch {i+1}/{len(batches)}...")
+            batch_json = json.dumps(batch, indent=2)
+            prompt = f"""Rate these question-answer pairs for LLM fine-tuning on a scale of 1-10.
+
+Here is a summary of the document these pairs are based on:
+---
+{summary[:500]}...
+---
+
+For each pair below, rate its quality on a scale of 1-10 based on:
+1. Relevance to the document topic
+2. Question specificity and clarity
+3. Answer accuracy and completeness
+4. Overall educational value
+
+The pairs to evaluate:
+{batch_json}
+
+IMPORTANT: Return ONLY a JSON array containing each original pair with an added "rating" field.
+Use this exact format:
+[
+  {{"question": "...", "answer": "...", "rating": 8}},
+  ...
+]"""
+            
+            try:
+                response = self.client.chat.completions.create(
+                    messages=[
+                        {"role": "system", "content": prompt}
+                    ],
+                    model=self.model,
+                    temperature=0.0
+                )
+                
+                rated_batch = self._extract_rated_pairs(response.choices[0].message.content)
+                for pair in rated_batch:
+                    if "rating" in pair:
+                        total_score += pair["rating"]
+                        if pair["rating"] >= threshold:
+                            rated_pairs.append(pair)
+            
+            except Exception as e:
+                print(f"Error rating batch {i+1}: {str(e)}")
+            time.sleep(0.5)
+        
+        metrics = {
+            "total": len(qa_pairs),
+            "filtered": len(rated_pairs),
+            "retention_rate": round(len(rated_pairs) / len(qa_pairs), 2) if qa_pairs else 0,
+            "avg_score": round(total_score / len(qa_pairs), 1) if qa_pairs else 0
+        }
+        
+        print(f"Keeping {len(rated_pairs)} out of {len(qa_pairs)} pairs")
+        return rated_pairs, metrics
+    
+    def _split_into_chunks(self, text, chunk_size):
+        paragraphs = text.split("\n\n")
+        chunks = []
+        current_chunk = ""
+        
+        for para in paragraphs:
+            if len(current_chunk) + len(para) > chunk_size and current_chunk:
+                chunks.append(current_chunk)
+                current_chunk = para
+            else:
+                if current_chunk:
+                    current_chunk += "\n\n" + para
+                else:
+                    current_chunk = para
+        if current_chunk:
+            chunks.append(current_chunk)
+        
+        return chunks
+    
+    def _extract_qa_pairs(self, text):
+        text = text.strip()
+        if text.startswith('[') and text.endswith(']'):
+            try:
+                return json.loads(text)
+            except:
+                pass       
+        # annoying```json start fix
+        if '```' in text:
+            match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', text)
+            if match:
+                try:
+                    return json.loads(match.group(1))
+                except:
+                    pass
+        return []
+    
+    def _extract_rated_pairs(self, text):
+        array_pattern = r'\[\s*\{\s*"question".*"rating".*\}\s*\]'
+        array_match = re.search(array_pattern, text)
+        if array_match:
+            try:
+                return json.loads(array_match.group(0))
+            except:
+                pass
+        
+        #annoying ``json response fix
+        if '```' in text:
+            code_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', text)
+            if code_match:
+                code_content = code_match.group(1).strip()
+                if code_content.startswith('[') and code_content.endswith(']'):
+                    try:
+                        return json.loads(code_content)
+                    except:
+                        pass
+
+        print("Using regex to extract rating pairs")
+        pair_pattern = r'{\s*"question":\s*"([^"]*)"\s*,\s*"answer":\s*"([^"]*)"\s*,\s*"rating":\s*(\d+)'
+        pairs = []
+        for match in re.finditer(pair_pattern, text):
+            try:
+                q = match.group(1).replace('\\"', '"')
+                a = match.group(2).replace('\\"', '"')
+                r = int(match.group(3))
+                pairs.append({"question": q, "answer": a, "rating": r})
+            except:
+                pass
+        if pairs:
+            return pairs
+        return []
+    
+    def process_document(self, document_text, num_pairs=25, quality_threshold=7.0):
+        summary = self.generate_summary(document_text)
+        qa_pairs = self.generate_qa_pairs(document_text, summary, num_pairs=num_pairs)
+        filtered_pairs, metrics = self.rate_qa_pairs(qa_pairs, summary, threshold=quality_threshold)
+        conversations = self._convert_to_conversations(filtered_pairs)
+        result = {
+            "summary": summary,
+            "qa_pairs": qa_pairs,
+            "filtered_pairs": filtered_pairs,  
+            "conversations": conversations, 
+            "metrics": metrics  
+        }
+        
+        return result
+    
+    def _convert_to_conversations(self, qa_pairs):
+        conversations = []
+        
+        for pair in qa_pairs:
+            conversation = [
+                {
+                    "role": "system",
+                    "content": "You are a helpful AI assistant that provides accurate, detailed responses."
+                },
+                {
+                    "role": "user",
+                    "content": pair["question"]
+                },
+                {
+                    "role": "assistant",
+                    "content": pair["answer"]
+                }
+            ]
+            conversations.append(conversation)
+        
+        return conversations