123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- import os
- import re
- import string
- from transformers import AutoTokenizer
- import asyncio
- import magic
- from PyPDF2 import PdfReader
- import json
- from doc_processor import split_text_into_chunks
- import logging
- import json
- from langchain_community.embeddings import HuggingFaceEmbeddings
- from langchain_experimental.text_splitter import SemanticChunker
- from math import ceil
- import datasets
- from datasets import Dataset, load_dataset
- import random
- # Initialize logging
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
- def strip_str(s: str) -> str:
- """
- Helper function for helping format strings returned by GPT-4.
- """
- l, r = 0, len(s)-1
- beg_found = False
- for i in range(len(s)):
- if s[i].isalpha():
- if not beg_found:
- l = i
- beg_found = True
- else:
- r = i
- r += 2
- return s[l:min(r, len(s))]
- def read_text_file(file_path):
- try:
- with open(file_path, 'r') as f:
- text = f.read().strip() + ' '
- if len(text) == 0:
- print("File is empty ",file_path)
- return text
- except Exception as e:
- logging.error(f"Error reading text file {file_path}: {e}")
- return ''
- def read_pdf_file(file_path):
- try:
- with open(file_path, 'rb') as f:
- pdf_reader = PdfReader(f)
- num_pages = len(pdf_reader.pages)
- file_text = [pdf_reader.pages[page_num].extract_text().strip() + ' ' for page_num in range(num_pages)]
- text = ''.join(file_text)
- if len(text) == 0:
- print("File is empty ",file_path)
- return ''.join(file_text)
- except Exception as e:
- logging.error(f"Error reading PDF file {file_path}: {e}")
- return ''
- def read_json_file(file_path):
- try:
- with open(file_path, 'r') as f:
- data = json.load(f)
- # Assuming each item in the list has a 'question' and 'answer' key
- # Concatenating question and answer pairs with a space in between and accumulating them into a single string
- file_text = ' '.join([item['question'].strip() + ' ' + item['answer'].strip() + ' ' for item in data])
- if len(file_text) == 0:
- print("File is empty ",file_path)
- return file_text
- except Exception as e:
- logging.error(f"Error reading JSON file {file_path}: {e}")
- return ''
- def process_file(file_path):
- print("starting to process file: ", file_path)
- file_type = magic.from_file(file_path, mime=True)
- if file_type in ['text/plain', 'text/markdown', 'JSON']:
- return read_text_file(file_path)
- elif file_type == 'application/pdf':
- return read_pdf_file(file_path)
- else:
- logging.warning(f"Unsupported file type {file_type} for file {file_path}")
- return ''
- def read_file_content(context):
- file_strings = []
- for root, _, files in os.walk(context['data_dir']):
- for file in files:
- file_path = os.path.join(root, file)
- file_text = process_file(file_path)
- if file_text:
- file_strings.append(file_text)
- text = '\n'.join(file_strings)
- text = remove_non_printable(text)
- return remove_non_printable(text)
- def remove_non_printable(s):
- printable = set(string.printable)
- return ''.join(filter(lambda x: x in printable, s))
- async def generate_question_request(chat_service, api_context: dict, document_content: str, num_questions: int) -> dict:
- if num_questions == 0:
- logging.info(f"Error: num_questions is 0")
- return {}
- prompt_for_system = api_context['question_prompt_template'].format(num_questions=num_questions)
- chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': str(document_content)}]
- # parse the result string to a list of dict that has Question, Answer, Context
- return await chat_service.execute_chat_request_async(api_context, chat_request_payload)
- def get_chunks(
- text: str,
- chunk_size: int = 512,
- embedding_model: str = None
- ) -> list[str]:
- """
- Takes in a `file_path` and `doctype`, retrieves the document, breaks it down into chunks of size
- `chunk_size`, and returns the chunks.
- """
- chunks = []
- if len(text) == 0:
- raise TypeError("Can not get chunks from empty text")
- else:
- num_chunks = ceil(len(text) / chunk_size)
- logging.info(f"Splitting text into {num_chunks} chunks")
- text_splitter = SemanticChunker(embedding_model, number_of_chunks=num_chunks)
- chunks = text_splitter.create_documents([text])
- chunks = [chunk.page_content for chunk in chunks]
- return chunks
- # read all the files in the data folder, then split them into chunks
- # generate questions for each chunk and return zip of chunk and related questions list
- async def generate_questions(chat_service, api_context: dict):
- document_text = read_file_content(api_context)
- if len(document_text) == 0:
- logging.info(f"Error reading files, document_text is {len(document_text)}")
- model_name = "sentence-transformers/all-mpnet-base-v2"
- embedding_model = HuggingFaceEmbeddings(model_name=model_name)
- document_batches = get_chunks(document_text,api_context["chunk_size"],embedding_model)
- batches_count = len(document_batches)
- total_questions = api_context["questions_per_chunk"] * batches_count
- print(f"Questions per batch: {api_context['questions_per_chunk']}, Total questions: {total_questions}, Batches: {batches_count}")
- generation_tasks = []
- for batch_index, batch_content in enumerate(document_batches):
- print(f"len of batch_content: {len(batch_content)}, batch_index: {batch_index}")
- #Distribute extra questions across the first few batches
- if len(batch_content) < 10:
- logging.info("Context is not enough, ignore this batch")
- else:
- print(f"Batch {batch_index + 1} - {api_context['questions_per_chunk']} questions ********")
- try:
- task = generate_question_request(chat_service, api_context, batch_content, api_context["questions_per_chunk"])
- generation_tasks.append(task)
- except Exception as e:
- print(f"Error during chat request execution: {e}")
- question_generation_results = await asyncio.gather(*generation_tasks)
- final_result = []
- for result in question_generation_results:
- queries = result.split('\n')
- queries = [strip_str(q) for q in queries]
- queries = [q for q in queries if any(c.isalpha() for c in q)]
- if len(queries) > int(api_context['questions_per_chunk']):
- # As the model may have unrelated question at the begining of the result
- # if queries is more than questions_per_chunk, then we need to truncate it and only keep last questions_per_chunk lines
- queries = queries[-int(api_context['questions_per_chunk']):]
- final_result.append(queries)
- return list(zip(document_batches,final_result))
- async def generate_COT(chat_service, api_context: dict, document_content: str, question: str) -> dict:
- prompt = api_context['COT_prompt_template'].format(question=question,context=str(document_content))
- chat_request_payload = [{"role": "system", "content": "You are a helpful question answerer who can provide an answer given a question and relevant context."}]
- chat_request_payload.append({"role": "user", "content": prompt})
- response = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
- return (document_content,question,response)
- async def add_chunk_to_dataset(
- chunk_questions_zip: list,
- context: dict,
- chat_service,
- ds,
- num_distract: int = 3,
- p: float = 0.8,
- ) -> None:
- """
- Given a chunk and related questions lists, create {Q, A, D} triplets and add them to the dataset.
- """
- COT_tasks = []
- chunks = [chunk for chunk, _ in chunk_questions_zip]
- for i, chunk_questions in enumerate(chunk_questions_zip):
- chunk, questions = chunk_questions
- # generate COT answer for each question given the chunk context
- for question in questions:
- COT_tasks.append(generate_COT(chat_service, context, chunk, question))
- COT_results = await asyncio.gather(*COT_tasks)
- for chunk, q , cot in COT_results:
- datapt = {
- "id": None,
- "type": "general",
- "question": q,
- "context": None,
- "oracle_context": None,
- "cot_answer": cot
- }
- i = chunks.index(chunk)
- datapt["id"] = f"seed_task_{0 if not ds else ds.num_rows}"
- # add num_distract distractor docs
- docs = [chunk]
- indices = list(range(0, len(chunks)))
- indices.remove(i)
- for j in random.sample(indices, num_distract):
- docs.append(chunks[j])
- # decides whether to add oracle document
- oracle = random.uniform(0, 1) < p
- if not oracle:
- docs[0] = chunks[random.sample(indices, 1)[0]]
- random.shuffle(docs)
- d = {
- "title": [],
- "sentences": []
- }
- d["title"].append(["placeholder_title"]*(num_distract+1))
- d["sentences"].append(docs)
- datapt["context"] = d
- datapt["oracle_context"] = chunk
- # construct model instruction
- context = ""
- for doc in docs:
- context += "<DOCUMENT>" + str(doc) + "</DOCUMENT>\n"
- context += q
- datapt["instruction"] = context
- # add to dataset
- if not ds:
- # init ds
- datapt["id"] = [datapt["id"]]
- datapt["type"] = [datapt["type"]]
- datapt["question"] = [datapt["question"]]
- datapt["context"] = [datapt["context"]]
- datapt["oracle_context"] = [datapt["oracle_context"]]
- datapt["cot_answer"] = [datapt["cot_answer"]]
- datapt["instruction"] = [datapt["instruction"]]
- ds = Dataset.from_dict(datapt)
- else:
- ds = ds.add_item(datapt)
- return ds
- # This function is used to evaluate the quality of generated QA pairs. Return the original QA pair if the model eval result is YES. Otherwise, return an empty dict.
- async def LLM_judge_request(chat_service, api_context: dict, document_content: dict) -> dict:
- prompt_for_system = api_context['judge_prompt_template'].format(language=api_context["language"])
- chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': f"Question: {document_content['Question']} \n Teacher's Answer: {document_content['Ground_truth']}\n Student's Answer: {document_content['Generated_answer']} "}]
- result = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
- if not result:
- return {}
- # no parsing needed, just return the loads the result as a dict
- result = json.loads(result)
- if "Result" not in result:
- print("Error: eval response does not contain answer")
- print(document_content,result)
- return {}
- return result
- async def generate_LLM_eval(chat_service, api_context: dict, judge_list: list):
- eval_tasks = []
- for batch_index, batch_content in enumerate(judge_list):
- try:
- result = LLM_judge_request(chat_service, api_context, batch_content)
- eval_tasks.append(result)
- except Exception as e:
- print(f"Error during data eval request execution: {e}")
- judge_results = await asyncio.gather(*eval_tasks)
- return judge_results
|