|
@@ -12,9 +12,11 @@ import json
|
|
|
from doc_processor import split_text_into_chunks
|
|
|
import logging
|
|
|
import json
|
|
|
-from langchain.embeddings import HuggingFaceEmbeddings
|
|
|
+from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
|
from langchain_experimental.text_splitter import SemanticChunker
|
|
|
from math import ceil
|
|
|
+import datasets
|
|
|
+from datasets import Dataset, load_dataset
|
|
|
import random
|
|
|
# Initialize logging
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
@@ -131,11 +133,11 @@ def get_chunks(
|
|
|
|
|
|
return chunks
|
|
|
# read all the files in the data folder, then split them into chunks
|
|
|
-# generate questions for each chunk and return a list of questions list
|
|
|
+# generate questions for each chunk and return zip of chunk and related questions list
|
|
|
async def generate_questions(chat_service, api_context: dict):
|
|
|
document_text = read_file_content(api_context)
|
|
|
- if len(document_text)== 0:
|
|
|
- logging.error(f"Error reading files, document_text is empty")
|
|
|
+ if len(document_text) == 0:
|
|
|
+ logging.info(f"Error reading files, document_text is {len(document_text)}")
|
|
|
model_name = "sentence-transformers/all-mpnet-base-v2"
|
|
|
embedding_model = HuggingFaceEmbeddings(model_name=model_name)
|
|
|
document_batches = get_chunks(document_text,api_context["chunk_size"],embedding_model)
|
|
@@ -148,12 +150,15 @@ async def generate_questions(chat_service, api_context: dict):
|
|
|
for batch_index, batch_content in enumerate(document_batches):
|
|
|
print(f"len of batch_content: {len(batch_content)}, batch_index: {batch_index}")
|
|
|
#Distribute extra questions across the first few batches
|
|
|
- print(f"Batch {batch_index + 1} - {api_context['questions_per_chunk']} questions ********")
|
|
|
- try:
|
|
|
- task = generate_question_request(chat_service, api_context, batch_content, api_context["questions_per_chunk"])
|
|
|
- generation_tasks.append(task)
|
|
|
- except Exception as e:
|
|
|
- print(f"Error during chat request execution: {e}")
|
|
|
+ if len(batch_content) < 10:
|
|
|
+ logging.info("Context is not enough, ignore this batch")
|
|
|
+ else:
|
|
|
+ print(f"Batch {batch_index + 1} - {api_context['questions_per_chunk']} questions ********")
|
|
|
+ try:
|
|
|
+ task = generate_question_request(chat_service, api_context, batch_content, api_context["questions_per_chunk"])
|
|
|
+ generation_tasks.append(task)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error during chat request execution: {e}")
|
|
|
|
|
|
question_generation_results = await asyncio.gather(*generation_tasks)
|
|
|
final_result = []
|
|
@@ -166,35 +171,44 @@ async def generate_questions(chat_service, api_context: dict):
|
|
|
# if queries is more than questions_per_chunk, then we need to truncate it and only keep last questions_per_chunk lines
|
|
|
queries = queries[-int(api_context['questions_per_chunk']):]
|
|
|
final_result.append(queries)
|
|
|
- return final_result
|
|
|
+ return list(zip(document_batches,final_result))
|
|
|
|
|
|
-def add_chunk_to_dataset(
|
|
|
- chunks: list[str],
|
|
|
- chunk: str,
|
|
|
- x: int = 5,
|
|
|
+async def generate_COT(chat_service, api_context: dict, document_content: str, question: str) -> dict:
|
|
|
+ prompt = api_context['COT_prompt_template'].format(question=question,context=str(document_content))
|
|
|
+ chat_request_payload = [{"role": "system", "content": "You are a helpful question answerer who can provide an answer given a question and relevant context."}]
|
|
|
+ chat_request_payload.append({"role": "user", "content": prompt})
|
|
|
+ response = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
|
|
|
+ return (document_content,question,response)
|
|
|
+async def add_chunk_to_dataset(
|
|
|
+ chunk_questions_zip: list,
|
|
|
+ context: dict,
|
|
|
+ chat_service,
|
|
|
+ ds,
|
|
|
num_distract: int = 3,
|
|
|
p: float = 0.8,
|
|
|
- model: str = None
|
|
|
) -> None:
|
|
|
"""
|
|
|
- Given a chunk, create {Q, A, D} triplets and add them to the dataset.
|
|
|
+ Given a chunk and related questions lists, create {Q, A, D} triplets and add them to the dataset.
|
|
|
"""
|
|
|
- global ds
|
|
|
- i = chunks.index(chunk)
|
|
|
- qs = generate_instructions(client, chunk, x, model) if doctype == "api" else generate_instructions_gen(client, chunk, x, model)
|
|
|
- for q in qs:
|
|
|
+ COT_tasks = []
|
|
|
+ chunks = [chunk for chunk, _ in chunk_questions_zip]
|
|
|
+ for i, chunk_questions in enumerate(chunk_questions_zip):
|
|
|
+ chunk, questions = chunk_questions
|
|
|
+ # generate COT answer for each question given the chunk context
|
|
|
+ for question in questions:
|
|
|
+ COT_tasks.append(generate_COT(chat_service, context, chunk, question))
|
|
|
+ COT_results = await asyncio.gather(*COT_tasks)
|
|
|
+ for chunk, q , cot in COT_results:
|
|
|
datapt = {
|
|
|
"id": None,
|
|
|
- "type": None,
|
|
|
- "question": None,
|
|
|
+ "type": "general",
|
|
|
+ "question": q,
|
|
|
"context": None,
|
|
|
"oracle_context": None,
|
|
|
- "cot_answer": None
|
|
|
+ "cot_answer": cot
|
|
|
}
|
|
|
-
|
|
|
+ i = chunks.index(chunk)
|
|
|
datapt["id"] = f"seed_task_{0 if not ds else ds.num_rows}"
|
|
|
- datapt["type"] = "api call" if doctype == "api" else "general"
|
|
|
- datapt["question"] = q
|
|
|
|
|
|
# add num_distract distractor docs
|
|
|
docs = [chunk]
|
|
@@ -218,9 +232,6 @@ def add_chunk_to_dataset(
|
|
|
datapt["context"] = d
|
|
|
datapt["oracle_context"] = chunk
|
|
|
|
|
|
- # add answer to q
|
|
|
- datapt["cot_answer"] = generate_label(client, q, chunk, doctype, model=model)
|
|
|
-
|
|
|
# construct model instruction
|
|
|
context = ""
|
|
|
for doc in docs:
|
|
@@ -241,7 +252,7 @@ def add_chunk_to_dataset(
|
|
|
ds = Dataset.from_dict(datapt)
|
|
|
else:
|
|
|
ds = ds.add_item(datapt)
|
|
|
-
|
|
|
+ return ds
|
|
|
# This function is used to evaluate the quality of generated QA pairs. Return the original QA pair if the model eval result is YES. Otherwise, return an empty dict.
|
|
|
async def LLM_judge_request(chat_service, api_context: dict, document_content: dict) -> dict:
|
|
|
prompt_for_system = api_context['judge_prompt_template'].format(language=api_context["language"])
|