浏览代码

creation of raft dataset working

Kai Wu 1 年之前
父节点
当前提交
c856052115
共有 2 个文件被更改,包括 61 次插入59 次删除
  1. 19 28
      recipes/use_cases/end2end-recipes/raft/raft.py
  2. 42 31
      recipes/use_cases/end2end-recipes/raft/raft_utils.py

+ 19 - 28
recipes/use_cases/end2end-recipes/raft/raft.py

@@ -3,8 +3,6 @@ from mdc import MDC
 import logging
 from typing import Literal, Any
 from openai import OpenAI
-import datasets
-from datasets import Dataset, load_dataset
 import json
 import random
 import os, shutil
@@ -15,44 +13,37 @@ from chat_utils import OctoAIChatService, VllmChatService
 from format import DatasetConverter, datasetFormats, outputDatasetTypes
 from config import load_config
 
-# def generate_label(client: OpenAI, question: str, context: Any, doctype: DocType = "pdf", model: str = None) -> str | None:
-#     """
-#     Generates the label / answer to `question` using `context` and GPT-4.
-#     """
-#     question = encode_question(question, context) if doctype == "api" else encode_question_gen(question, context)
-#     response = client.chat.completions.create(
-#         model=model,
-#         messages=question,
-#         n=1,
-#         temperature=0
-#     )
-#     response = response.choices[0].message.content
-#     return response
-# Configure logging to include the timestamp, log level, and message
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
-
+NUM_DISTRACT_DOCS = 5 # number of distracting documents to add to each chunk
+ORCALE_P = 0.8 # probability of related documents to be added to each chunk
 async def main(context):
+    ds = None
     if context["endpoint"]:
         chat_service = VllmChatService()
     else:
         chat_service = OctoAIChatService()
     try:
         logging.info("Starting to generate question pair.")
-        # Generate question/answer pairs as list
-        chunks = await generate_questions(chat_service, context)
-        if not chunks:
+        # Generate questions as list for each chunk
+        chunk_questions_zip = await generate_questions(chat_service, context)
+        if not chunk_questions_zip:
             logging.warning("No questions generated from text. Please check the input context or model configuration.")
             return
-        logging.info(f"Successfully generated {sum([len(q) for q in chunks])} question/answer pairs.")
-        print(chunks)
-        for i, chunk in enumerate(chunks):
-            perc = ceil(i / num_chunks * 100)
-            with MDC(progress=f"{perc}%"):
-                logger.info(f"Adding chunk {i}/{num_chunks}")
-                add_chunk_to_dataset(client, chunks, chunk, args.doctype, args.questions, NUM_DISTRACT_DOCS, model=args.completion_model)
-
+        for chunk, questions in chunk_questions_zip:
+            logging.info(f"Chunk: {chunk}, question length: {len(questions)}")
+            for question in questions:
+                logging.info(f"Question: {question}")
+        logging.info(f"Successfully generated {sum([len(q) for c,q in chunk_questions_zip])} question/answer pairs.")
+        ds = await add_chunk_to_dataset(chunk_questions_zip,context, chat_service,ds,NUM_DISTRACT_DOCS, ORCALE_P)
+        print(ds[0])
+        ds.save_to_disk(args.output)
         logging.info(f"Data successfully written to {context['output']}. Process completed.")
+        formatter = DatasetConverter()
+
+        # Extract format specific params
+        format_params = {}
+        formatter.convert(ds=ds, format=args.output_format, output_path=args.output, output_type=args.output_type, params=format_params)
     except Exception as e:
         logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
 

+ 42 - 31
recipes/use_cases/end2end-recipes/raft/raft_utils.py

@@ -12,9 +12,11 @@ import json
 from doc_processor import split_text_into_chunks
 import logging
 import json
-from langchain.embeddings import HuggingFaceEmbeddings
+from langchain_community.embeddings import HuggingFaceEmbeddings
 from langchain_experimental.text_splitter import SemanticChunker
 from math import ceil
+import datasets
+from datasets import Dataset, load_dataset
 import random
 # Initialize logging
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -131,11 +133,11 @@ def get_chunks(
 
     return chunks
 # read all the files in the data folder, then split them into chunks
-# generate questions for each chunk and return a list of questions list
+# generate questions for each chunk and return zip of chunk and related questions list
 async def generate_questions(chat_service, api_context: dict):
     document_text = read_file_content(api_context)
-    if len(document_text)== 0:
-        logging.error(f"Error reading files, document_text is empty")
+    if len(document_text) == 0:
+        logging.info(f"Error reading files, document_text is {len(document_text)}")
     model_name = "sentence-transformers/all-mpnet-base-v2"
     embedding_model = HuggingFaceEmbeddings(model_name=model_name)
     document_batches = get_chunks(document_text,api_context["chunk_size"],embedding_model)
@@ -148,12 +150,15 @@ async def generate_questions(chat_service, api_context: dict):
     for batch_index, batch_content in enumerate(document_batches):
         print(f"len of batch_content: {len(batch_content)}, batch_index: {batch_index}")
         #Distribute extra questions across the first few batches
-        print(f"Batch {batch_index + 1} - {api_context['questions_per_chunk']} questions ********")
-        try:
-            task = generate_question_request(chat_service, api_context, batch_content, api_context["questions_per_chunk"])
-            generation_tasks.append(task)
-        except Exception as e:
-            print(f"Error during chat request execution: {e}")
+        if len(batch_content) < 10:
+            logging.info("Context is not enough, ignore this batch")
+        else:
+            print(f"Batch {batch_index + 1} - {api_context['questions_per_chunk']} questions ********")
+            try:
+                task = generate_question_request(chat_service, api_context, batch_content, api_context["questions_per_chunk"])
+                generation_tasks.append(task)
+            except Exception as e:
+                print(f"Error during chat request execution: {e}")
 
     question_generation_results = await asyncio.gather(*generation_tasks)
     final_result = []
@@ -166,35 +171,44 @@ async def generate_questions(chat_service, api_context: dict):
             # if queries is more than questions_per_chunk, then we need to truncate it and only keep last questions_per_chunk lines
             queries = queries[-int(api_context['questions_per_chunk']):]
         final_result.append(queries)
-    return final_result
+    return list(zip(document_batches,final_result))
 
-def add_chunk_to_dataset(
-    chunks: list[str],
-    chunk: str,
-    x: int = 5,
+async def generate_COT(chat_service, api_context: dict, document_content: str, question: str) -> dict:
+    prompt = api_context['COT_prompt_template'].format(question=question,context=str(document_content))
+    chat_request_payload = [{"role": "system", "content": "You are a helpful question answerer who can provide an answer given a question and relevant context."}]
+    chat_request_payload.append({"role": "user", "content": prompt})
+    response = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
+    return (document_content,question,response)
+async def add_chunk_to_dataset(
+    chunk_questions_zip: list,
+    context: dict,
+    chat_service,
+    ds,
     num_distract: int = 3,
     p: float = 0.8,
-    model: str = None
 ) -> None:
     """
-    Given a chunk, create {Q, A, D} triplets and add them to the dataset.
+    Given a chunk and related questions lists, create {Q, A, D} triplets and add them to the dataset.
     """
-    global ds
-    i = chunks.index(chunk)
-    qs = generate_instructions(client, chunk, x, model) if doctype == "api" else generate_instructions_gen(client, chunk, x, model)
-    for q in qs:
+    COT_tasks = []
+    chunks = [chunk for chunk, _ in chunk_questions_zip]
+    for i, chunk_questions in enumerate(chunk_questions_zip):
+        chunk, questions = chunk_questions
+        # generate COT answer for each question given the chunk context
+        for question in questions:
+            COT_tasks.append(generate_COT(chat_service, context, chunk, question))
+    COT_results = await asyncio.gather(*COT_tasks)
+    for chunk, q , cot in COT_results:
         datapt = {
             "id": None,
-            "type": None,
-            "question": None,
+            "type": "general",
+            "question": q,
             "context": None,
             "oracle_context": None,
-            "cot_answer": None
+            "cot_answer": cot
         }
-
+        i = chunks.index(chunk)
         datapt["id"] = f"seed_task_{0 if not ds else ds.num_rows}"
-        datapt["type"] = "api call" if doctype == "api" else "general"
-        datapt["question"] = q
 
         # add num_distract distractor docs
         docs = [chunk]
@@ -218,9 +232,6 @@ def add_chunk_to_dataset(
         datapt["context"] = d
         datapt["oracle_context"] = chunk
 
-        # add answer to q
-        datapt["cot_answer"] = generate_label(client, q, chunk, doctype, model=model)
-
         # construct model instruction
         context = ""
         for doc in docs:
@@ -241,7 +252,7 @@ def add_chunk_to_dataset(
             ds = Dataset.from_dict(datapt)
         else:
             ds = ds.add_item(datapt)
-
+    return ds
 # This function is used to evaluate the quality of generated QA pairs. Return the original QA pair if the model eval result is YES. Otherwise, return an empty dict.
 async def LLM_judge_request(chat_service, api_context: dict, document_content: dict) -> dict:
     prompt_for_system = api_context['judge_prompt_template'].format(language=api_context["language"])