|
@@ -14,16 +14,7 @@ import json
|
|
# Initialize logging
|
|
# Initialize logging
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
-# Since OctoAI has different naming for llama models, get the huggingface offical model name using OctoAI names.
|
|
|
|
-def get_model_name(model):
|
|
|
|
- if model == "meta-llama-3-70b-instruct":
|
|
|
|
- return "meta-llama/Meta-Llama-3-70B-Instruct"
|
|
|
|
- elif model == "meta-llama-3-8b-instruct":
|
|
|
|
- return "meta-llama/Meta-Llama-3-8B-Instruct"
|
|
|
|
- elif model == "llama-2-7b-chat":
|
|
|
|
- return "meta-llama/Llama-2-7b-chat-hf"
|
|
|
|
- else:
|
|
|
|
- return "meta-llama/Llama-2-70b-chat-hf"
|
|
|
|
|
|
+
|
|
def read_text_file(file_path):
|
|
def read_text_file(file_path):
|
|
try:
|
|
try:
|
|
with open(file_path, 'r') as f:
|
|
with open(file_path, 'r') as f:
|
|
@@ -88,8 +79,13 @@ def read_file_content(context):
|
|
if len(text) == 0:
|
|
if len(text) == 0:
|
|
logging.error(f"Error reading files, text is empty")
|
|
logging.error(f"Error reading files, text is empty")
|
|
return ' '.join(file_strings)
|
|
return ' '.join(file_strings)
|
|
-
|
|
|
|
-
|
|
|
|
|
|
+# clean the text by removing all parts that did not contain any alphanumeric characters
|
|
|
|
+def clean(s):
|
|
|
|
+ result = []
|
|
|
|
+ for item in s.split('"'):
|
|
|
|
+ if any(c.isalnum() for c in item):
|
|
|
|
+ result.append(item)
|
|
|
|
+ return " ".join(result)
|
|
|
|
|
|
def parse_qa_to_json(response_string):
|
|
def parse_qa_to_json(response_string):
|
|
split_lines = response_string.split("\n")
|
|
split_lines = response_string.split("\n")
|
|
@@ -109,21 +105,21 @@ def parse_qa_to_json(response_string):
|
|
end = i
|
|
end = i
|
|
# found Question means we have reached the end of the question, so add it to qa_list
|
|
# found Question means we have reached the end of the question, so add it to qa_list
|
|
elif '"Question":' in line:
|
|
elif '"Question":' in line:
|
|
- question = " ".join(" ".join(split_lines[start:end]).split('"Question":')[1].split('"')[1:-1])
|
|
|
|
- answer = " ".join(" ".join(split_lines[end:i]).split('"Answer":')[1].split('"')[1:-1])
|
|
|
|
|
|
+ question = " ".join(split_lines[start:end]).split('"Question":')[1]
|
|
|
|
+ answer = " ".join(split_lines[end:i]).split('"Answer":')[1]
|
|
start,end = i,None
|
|
start,end = i,None
|
|
- qa_set.add((question, answer))
|
|
|
|
|
|
+ qa_set.add((clean(question), clean(answer)))
|
|
# adding last question back to qa_list
|
|
# adding last question back to qa_list
|
|
- if start and end:
|
|
|
|
- question = " ".join(" ".join(split_lines[start:end]).split('"Question":')[1].split('"')[1:-1])
|
|
|
|
- answer = " ".join(" ".join(split_lines[end:i]).split('"Answer":')[1].split('"')[1:-1])
|
|
|
|
- qa_set.add((question, answer))
|
|
|
|
|
|
+ if start and end:
|
|
|
|
+ question = " ".join(split_lines[start:end]).split('"Question":')[1]
|
|
|
|
+ answer = " ".join(split_lines[end:]).split('"Answer":')[1]
|
|
|
|
+ qa_set.add((clean(question), clean(answer)))
|
|
qa_list = [{"question": q, "answer":a} for q,a in qa_set]
|
|
qa_list = [{"question": q, "answer":a} for q,a in qa_set]
|
|
return json.dumps(qa_list, indent=4)
|
|
return json.dumps(qa_list, indent=4)
|
|
|
|
|
|
|
|
|
|
-async def prepare_and_send_request(chat_service, api_context: dict, document_content: str, total_questions: int) -> dict:
|
|
|
|
- prompt_for_system = api_context['question_prompt_template'].format(total_questions=total_questions, language=api_context["language"])
|
|
|
|
|
|
+async def prepare_and_send_request(chat_service, api_context: dict, document_content: str, num_questions: int) -> dict:
|
|
|
|
+ prompt_for_system = api_context['question_prompt_template'].format(num_questions=num_questions, language=api_context["language"])
|
|
chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': document_content}]
|
|
chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': document_content}]
|
|
result = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
|
|
result = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
|
|
if not result:
|
|
if not result:
|
|
@@ -142,7 +138,8 @@ async def generate_question_batches(chat_service, api_context: dict):
|
|
|
|
|
|
total_questions = api_context["total_questions"]
|
|
total_questions = api_context["total_questions"]
|
|
batches_count = len(document_batches)
|
|
batches_count = len(document_batches)
|
|
- base_questions_per_batch = total_questions // batches_count
|
|
|
|
|
|
+ # each batch should have at least 1 question
|
|
|
|
+ base_questions_per_batch = max(total_questions // batches_count,1)
|
|
extra_questions = total_questions % batches_count
|
|
extra_questions = total_questions % batches_count
|
|
|
|
|
|
print(f"Questions per batch: {base_questions_per_batch} (+1 for the first {extra_questions} batches), Total questions: {total_questions}, Batches: {batches_count}")
|
|
print(f"Questions per batch: {base_questions_per_batch} (+1 for the first {extra_questions} batches), Total questions: {total_questions}, Batches: {batches_count}")
|