فهرست منبع

llama working exmaple checkpoint

Kai Wu 1 سال پیش
والد
کامیت
8baf6b5d99

+ 28 - 8
recipes/finetuning/datasets/raft_dataset.py

@@ -8,20 +8,39 @@ from datasets import Dataset, load_dataset, DatasetDict
 import itertools
 
 B_INST, E_INST = "[INST]", "[/INST]"
+# check system prompt token seq or user prompt token seq is in the current token list
+def check_header(targets,seq):
+    for i in range(len(seq)-3):
+        if seq[i:i+3] in targets:
+            return True
+    return False
+def replace_target(target,seq):
+    for i in range(len(seq)-3):
+        if seq[i:i+3] == target:
+            seq[i],seq[i+1],seq[i+2] = -100,-100,-100
+    return seq
 def tokenize_dialog(dialog, tokenizer):
     # If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
     if tokenizer.vocab_size >= 128000:
         dialog_tokens = tokenizer.apply_chat_template(dialog)
-        dialog_tokens = dialog_tokens[:-4] # Remove generation prompt <|start_header_id|>assistant<|end_header_id|>\n\n
         eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009]
         labels = copy.copy(dialog_tokens)
         last_idx = 0
+        token_length = len(dialog_tokens)
+        last_idx = 0
+        # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
+        # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
+        prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
         for n, idx in enumerate(eot_indices):
-            if n % 2 == 1:
-                last_idx = idx
-            else:
+            current_seq = labels[last_idx:idx+1]
+            if check_header(prompt_header_seqs,current_seq):
+                # found prompt header, indicating that this seq should be masked
                 labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
-
+            else:
+                last_idx = idx
+        # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
+        assistant_header_seq = [128006, 78191, 128007]
+        labels = replace_target(assistant_header_seq,labels)
         dialog_tokens = [dialog_tokens]
         labels_tokens = [labels]
     else:
@@ -51,15 +70,16 @@ def raft_tokenize(q_a_pair, tokenizer):
     documents = q_a_pair["instruction"].split('\n')[:-1]
     # output is the label
     answer = q_a_pair["output"]
-    system_prompt = "You are a helpful question answerer who can provide an answer given a question and relevant context."
-    user_prompt = prompt = """
+    system_prompt = "You are a helpful chatbot who can provide an answer to every questions from the user given a relevant context."
+    user_prompt = """
         Question: {question}\nContext: {context}\n
-        Answer this question using the information given in the context above. Here is things to pay attention to:
+        Answer this question using the information given multiple documents in the context above. Here is things to pay attention to:
         - First provide step-by-step reasoning on how to answer the question.
         - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context.
         - End your response with final answer in the form <ANSWER>: $answer, the answer should be succinct.
         You MUST begin your final answer with the tag "<ANSWER>:".
     """.format(question=question, context=str(documents))
+
     chat = [
     {"role": "system", "content": system_prompt},
     {"role": "user", "content": user_prompt},

تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 3 - 3
recipes/use_cases/end2end-recipes/raft/README.md


تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 46 - 0
recipes/use_cases/end2end-recipes/raft/data/website_data


+ 37 - 12
recipes/use_cases/end2end-recipes/raft/eval_config.yaml

@@ -1,22 +1,47 @@
 eval_prompt_template: >
-  You are a AI assistant that skilled in answering questions related to Llama language models,
+  <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a AI assistant that skilled in answering questions related to Llama language models,
   which includes LLama, Llama2, Meta Llama3, Code Llama, Meta Llama Guard 1,	Meta Llama Guard 2,
-  Below is a question from a llama user, think step by step, make the answer as concise as possible,
-  The returned answer should be no more than 100 words.Please return the answers in text directly without any special tokens.
-
+  Below is a question from a llama user, please the answer it with best of your knowledge,
+  The returned answer should be no more than 100 words.Please return the answers in text directly without any special tokens.<|eot_id|>
+  <|start_header_id|>user<|end_header_id|>
+  Question:{question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
+# judge_prompt_template: >
+#   <|begin_of_text|><|start_header_id|>system<|end_header_id|>You have been provided with a question, a teacher's answer and a student's answer above. Given that question, you need to score the how good the student answer is compare to
+#   the teacher's answer. If the student's answer is correct based on the teacher's answer, then return YES, else return NO.
+#   Review it carefully to make sure that the keywords and numerical vaules are exactly the same.
+#   Only respond with "YES" or "NO", do not respond with anything else.<|eot_id|>
+#   <|start_header_id|>user<|end_header_id|>
+#   Question: {question} \n Teacher's Answer: {gold} \n Student's Answer: {prediction} <|eot_id|><|start_header_id|>assistant<|end_header_id|>
 judge_prompt_template: >
-  You have been provided with a question, a teacher's answer and a student's answer above. Given that question, you need to score the how good the student answer is compare to
-  the teacher's answer. If the student's answer is correct based on the teacher's answer, then return YES, else return NO.
-  Review it carefully to make sure that the keywords and numerical vaules are exactly the same.
-  Only respond with "YES" or "NO", do not respond with anything else.
+    <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a teacher grading a quiz.
+
+    You will be given a QUESTION, the GROUND TRUTH (correct) ANSWER, and the STUDENT ANSWER.
+
+    Here is the grade criteria to follow:
+    (1) Grade the student answers based ONLY on their factual accuracy relative to the ground truth answer.
+    (2) Ensure that the student answer does not contain any conflicting statements.
+    (3) It is OK if the student answer contains more information than the ground truth answer, as long as it is factually accurate relative to the  ground truth answer.
+
+    Score:
+    YES means that the student's answer meets all of the criteria. This is the highest (best) score.
+    NO means that the student's answer does not meet all of the criteria. This is the lowest possible score you can give.
+
+    Explain your reasoning in a step-by-step manner to ensure your reasoning and conclusion are correct.
 
+    Avoid simply stating the correct answer at the outset.
+    End your response with final answer in the form <ANSWER>: $answer, answer must be YES or NO  <|eot_id|>
+    <|start_header_id|>user<|end_header_id|>
+    QUESTION: {{question}}
+    GROUND TRUTH ANSWER: {{gold}}
+    STUDENT ANSWER: {{prediction}}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
 RAG_prompt_template: >
-  Question: {question}\n Context: {context}\n
-  Answer this question using the information given in the context above. Here is things to pay attention to:
+  <|begin_of_text|><|start_header_id|>system<|end_header_id|> Answer the following question using the information given in the context below. Here is things to pay attention to:
     - First provide step-by-step reasoning on how to answer the question.
     - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context.
-    - End your response with final answer in the form <ANSWER>: $answer, the answer should be succinct.
-  You MUST begin your final answer with the tag "<ANSWER>:
+    - End your response with final answer in the form <ANSWER>: $answer, the answer should less than 60 words.
+    You MUST begin your final answer with the tag "<ANSWER>:<|eot_id|>
+  <|start_header_id|>user<|end_header_id|>
+  Question: {question}\nContext: {context}\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>
 eval_json: "./evalset.json"
 
 raft_model_name: "raft-8b"

+ 96 - 51
recipes/use_cases/end2end-recipes/raft/eval_raft.py

@@ -1,13 +1,12 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
-from chat_utils import OctoAIChatService, VllmChatService
 import logging
 import evaluate
 import argparse
 from config import load_config
 import json
 from itertools import chain
-from langchain_community.llms import VLLMOpenAI
+from langchain_openai import ChatOpenAI
 
 from langchain_community.embeddings import HuggingFaceEmbeddings
 from langchain_community.vectorstores import FAISS
@@ -26,17 +25,17 @@ def generate_answers_model_only(model_name,question_list,api_url="http://localho
         # Use langchain to load the documents from data directory
     # Load the RAFT model
 
-    llm = VLLMOpenAI(
+    llm = ChatOpenAI(
         openai_api_key=key,
         openai_api_base=api_url,
         model_name=model_name,
         temperature=0.0,
-        max_tokens=100
+        max_tokens=1000
         )
-    system_prompt = SystemMessage(content=context['eval_prompt_template'])
-    generated_answers = []
-    all_tasks = [[system_prompt, HumanMessage(content=question)] for question in question_list]
+
+    all_tasks = [api_config['eval_prompt_template'].format(question=question) for question in question_list]
     generated_answers = llm.batch(all_tasks)
+    generated_answers = [ item.content for item in generated_answers]
     if len(generated_answers) == 0:
         logging.error("No model answers generated. Please check the input context or model configuration in ",model_name)
         return []
@@ -48,7 +47,12 @@ def format_docs_raft(docs):
     return context
 def format_docs(docs):
     return "\n\n".join(doc.page_content for doc in docs)
-def generate_answers_with_RAG(model_name, data_dir,question_list,rag_template,api_url="http://localhost:8000/v1",key="EMPTY"):
+def generate_answers_with_RAG(model_name, question_list,api_config,api_url_overwrite=None):
+    data_dir = api_config['data_dir']
+    api_url = "http://localhost:"+str(api_config['vllm_endpoint'])+"/v1"
+    if api_url_overwrite:
+        api_url = api_url_overwrite
+    key = api_config['api_key']
     # Use langchain to load the documents from data directory
     loader = DirectoryLoader(data_dir)
     docs = loader.load()
@@ -62,12 +66,12 @@ def generate_answers_with_RAG(model_name, data_dir,question_list,rag_template,ap
         search_kwargs={"k": 5}
     )
     # Load the RAFT model
-    llm = VLLMOpenAI(
+    llm = ChatOpenAI(
         openai_api_key=key,
         openai_api_base=api_url,
         model_name=model_name,
         temperature=0.0,
-        max_tokens=100
+        max_tokens=1000
         )
     all_tasks = []
     for q in question_list:
@@ -79,9 +83,10 @@ def generate_answers_with_RAG(model_name, data_dir,question_list,rag_template,ap
         else:
             documents = format_docs_raft(retrieved_docs)
         # create a prompt
-        text = rag_template.format(context=documents,question=q)
+        text = api_config["RAG_prompt_template"].format(context=documents,question=q)
         all_tasks.append(text)
     generated_answers = llm.batch(all_tasks)
+    generated_answers = [ item.content for item in generated_answers]
     if len(generated_answers) == 0:
         logging.error("No RAG answers generated. Please check the input context or model configuration in ",model_name)
         return []
@@ -101,6 +106,7 @@ def clean_text_list(text_list):
         index = text.rfind("<ANSWER>")
         if index!= -1:
             text = text[index:]
+            text = text.replace("</ANSWER>:","")
         text = text.replace("begin_quote","")
         text = text.replace("end_quote","")
         text = text.replace("##","")
@@ -144,27 +150,24 @@ def compute_bert_score(generated : list, reference: list):
     precision = score["precision"]
     recall = score["recall"]
     return sum(precision)/len(precision), sum(recall)/len(recall), sum(f1)/len(f1)
-def compute_judge_score(questions: list, generated : list, reference: list, context,api_url="http://localhost:8001/v1",key="EMPTY"):
+def compute_judge_score(questions: list, generated : list, reference: list, api_config,api_url="http://localhost:8001/v1",key="EMPTY"):
     correct_num = 0
     model_name = "meta-llama/Meta-Llama-3-70B-Instruct"
-    llm = VLLMOpenAI(
+    llm = ChatOpenAI(
         openai_api_key=key,
         openai_api_base=api_url,
         model_name=model_name,
+        max_tokens=1000,
         temperature=0.0)
     all_tasks = []
-    for q,pred,gold in zip(questions, generated,reference):
-        messages = [
-            HumanMessage(content=f"Question: {q} \n Teacher's Answer: {gold} \n Student's Answer: {pred} "),
-            SystemMessage(content=context['judge_prompt_template'])
-        ]
-        all_tasks.append(messages)
-    response = llm.batch(all_tasks)
-    for response in response:
-        if  "YES" in response:
-            correct_num += 1
-    return correct_num/len(questions)
-def score_single(context,generated,reference,questions, run_exact_match=True,run_rouge=True, run_bert=True, run_llm_as_judge=False):
+    for question,prediction,gold in zip(questions, generated,reference):
+        message = api_config['judge_prompt_template'].format(question=question,prediction=prediction,gold=gold)
+        all_tasks.append(message)
+    judge_responses = llm.batch(all_tasks)
+    judge_responses = ["YES" in item.content.split("<ANSWER>")[-1] for item in judge_responses]
+    correct_num = sum(judge_responses)
+    return correct_num/len(questions),judge_responses
+def score_single(api_config,generated,reference,questions, run_exact_match=True,run_rouge=True, run_bert=True, run_llm_as_judge=True):
     # set metric to default -1, means no metric is computed
     metric = {
         "Rouge_score": -1,
@@ -184,22 +187,23 @@ def score_single(context,generated,reference,questions, run_exact_match=True,run
         metric["BERTScore_Precision"] = P
         metric["BERTScore_Recall"] = R
         metric["BERTScore_F1"] = F1
-    if context["judge_endpoint"] and run_llm_as_judge:
-        api_url = "http://localhost:"+str(context["judge_endpoint"])+"/v1"
-        LLM_judge_score = compute_judge_score(questions, generated, reference, context,api_url=api_url)
+    if api_config["judge_endpoint"] and run_llm_as_judge:
+        api_url = "http://localhost:"+str(api_config["judge_endpoint"])+"/v1"
+        LLM_judge_score,judge_responses = compute_judge_score(questions, generated, reference, api_config,api_url=api_url)
         metric["LLM_judge_score"] = LLM_judge_score
+        metric["LLM_judge_responses"] = judge_responses
         print(f"LLM_judge_score: {LLM_judge_score}")
     if run_exact_match:
         exact_match = exact_match_score(generated,reference)
         print(f"Exact_match_percentage: {exact_match:.4f}")
         metric["Exact_match"] = exact_match
     return metric
-def main(context):
+def main(api_config):
     # Since the eval set is small, we can run the eval without async functions
     try:
-        api_url = "http://localhost:"+str(context["vllm_endpoint"])+"/v1"
+        api_url = "http://localhost:"+str(api_config["vllm_endpoint"])+"/v1"
         logging.info("Starting to generate answer given the eval set.")
-        with open(context["eval_json"]) as fp:
+        with open(api_config["eval_json"]) as fp:
             eval_json = json.load(fp)
         questions,groud_truth = [],[]
         for index, item in enumerate(eval_json):
@@ -210,38 +214,73 @@ def main(context):
             "RAFT_RAG": [],
             "Baseline": [],
             "Baseline_RAG": [],
+            "70B_RAG": [],
+            "70B_Base": [],
+            
         }
         # Generate answers for baseline
-        base_model_name = context["base_model_name"]
-        #generated_answers["Baseline"] = generate_answers_model_only(base_model_name,questions,api_url)
-        generated_answers["Baseline_RAG"] = generate_answers_with_RAG(base_model_name, context["data_dir"],questions,context['RAG_prompt_template'],api_url)
+        base_model_name = api_config["base_model_name"]
+        generated_answers["Baseline"] = generate_answers_model_only(base_model_name,questions,api_url)
+        generated_answers["Baseline_RAG"] = generate_answers_with_RAG(base_model_name, questions,api_config)
         # Generate answers for RAFT
-        raft_model_name = context["raft_model_name"]
-        #generated_answers["RAFT"] = generate_answers_model_only(raft_model_name,questions,api_url)
-        generated_answers["RAFT_RAG"] = generate_answers_with_RAG(raft_model_name, context["data_dir"],questions,context['RAG_prompt_template'],api_url)
+        raft_model_name = api_config["raft_model_name"]
+        generated_answers["RAFT"] = generate_answers_model_only(raft_model_name,questions,api_url)
+        generated_answers["RAFT_RAG"] = generate_answers_with_RAG(raft_model_name, questions,api_config)
+
+        large_model_name = "meta-llama/Meta-Llama-3-70B-Instruct"
+        large_api_url = "http://localhost:"+str(api_config["judge_endpoint"])+"/v1"
+        generated_answers["70B_Base"] = generate_answers_model_only(large_model_name,questions,large_api_url)
+        generated_answers["70B_RAG"] = generate_answers_with_RAG(large_model_name, questions,api_config,large_api_url,)
         logging.info(f"Successfully generated {len(generated_answers['Baseline_RAG'])} answers for all models.")
         # for generate answer from each model, compute the score metric
+        all_metrics = []
         for model_name,model_answer in generated_answers.items():
             if len(model_answer) != len(groud_truth):
                 print(f"The length of {model_name} answer is not equal to the length of ground truth.")
                 continue
-            metric = score_single(context,model_answer,groud_truth,questions)
+            metric = score_single(api_config,model_answer,groud_truth,questions)
             print(f"The eval result for {model_name} is: {metric}")
-            with open(context["output_log"],"a") as fp:
+            with open(api_config["output_log"],"a") as fp:
                 fp.write(f"Eval_result for {model_name} \n")
                 fp.write(f"Rouge_score: {metric['Rouge_score']} \n")
                 fp.write(f"BERTScore Precision: {metric['BERTScore_Precision']:.4f}, Recall: {metric['BERTScore_Recall']:.4f}, F1: {metric['BERTScore_F1']:.4f} \n")
                 fp.write(f"Exact_match_percentage: {metric['Exact_match']} \n")
-                if context["judge_endpoint"]:
+                judge_responses = ["None"] * len(questions)
+                if api_config["judge_endpoint"]:
                     fp.write(f"LLM_judge_score: {metric['LLM_judge_score']} \n")
+                    judge_responses = metric["LLM_judge_responses"]
+                    all_metrics.append((model_name,metric['LLM_judge_score'],metric["LLM_judge_responses"]))
                 fp.write(f"QA details: \n")
-                for item in zip(questions,model_answer,groud_truth):
+                for item in zip(questions,model_answer,groud_truth,judge_responses):
                     fp.write(f"question: {item[0]} \n")
                     fp.write(f"generated_answers: {item[1]} \n")
                     fp.write(f"groud_truth: {item[2]} \n")
+                    fp.write(f"LLM_judge_response: {item[3]} \n")
                     fp.write("\n")
                 fp.write("\n------------------------------------\n")
-        logging.info(f"Eval successfully, the eval result is saved to {context['output_log']}.")
+        # Now we want to take a closer look at the questions that are not answered the same by all the models.
+        judge_zip = list(zip(*[item[-1] for item in all_metrics]))
+        with open(api_config["output_log"],"a") as fp:
+            for item in all_metrics:
+                fp.write(f"Model_Name: {item[0]}, LLM_SCORE: {item[1]} \n")
+            for idx,item in enumerate(judge_zip):
+                # if all the responses are "YES" or all the responses are "NO", then we skip this question
+                if sum([r=="YES" for r in item]) == len(item) or sum([r=="YES" for r in item]) == 0:
+                    continue 
+                else:
+                    fp.write(f"Comparing interested question: {questions[idx]} \n")
+                    fp.write(f"groud_truth: {groud_truth[idx]} \n")
+                    fp.write(f"{item[2]} Baseline_answers: {generated_answers['Baseline'][idx]} \n")
+                    fp.write(f"{item[3]} Baseline_RAG_answers: {generated_answers['Baseline_RAG'][idx]} \n")
+                    fp.write(f"{item[0]} RAFT_answers: {generated_answers['RAFT'][idx]} \n")
+                    fp.write(f"{item[1]} RAFT_RAG_answers: {generated_answers['RAFT_RAG'][idx]} \n")
+                    fp.write(f"{item[4]} 70B_Base_answers: {generated_answers['70B_Base'][idx]} \n")
+                    fp.write(f"{item[5]} 70B_RAG_answers: {generated_answers['70B_RAG'][idx]} \n")
+                    fp.write("-------\n")
+
+
+
+        logging.info(f"Eval successfully, the eval result is saved to {api_config['output_log']}.")
         # Saving the eval result to a log file
     except Exception as e:
         logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
@@ -283,20 +322,26 @@ def parse_arguments():
         default="eval_result.log",
         help="save the eval result to a log file. Default is eval_result.log"
     )
-
+    parser.add_argument(
+        "-k", "--api_key",
+        default="EMPTY",
+        type=str,
+        help="LLM API key for generating question/answer pairs."
+    )
     return parser.parse_args()
 
 if __name__ == "__main__":
     logging.info("Initializing the process and loading configuration...")
     args = parse_arguments()
-    context = load_config(args.config_path)
-    context["vllm_endpoint"] = args.vllm_endpoint
+    api_config = load_config(args.config_path)
+    api_config["vllm_endpoint"] = args.vllm_endpoint
     if args.data_dir:
-        context["data_dir"] = args.data_dir
+        api_config["data_dir"] = args.data_dir
     if args.raft_model_name:
-        context["raft_model_name"] = args.raft_model_name
-    context["judge_endpoint"] = args.judge_endpoint
-    context["output_log"] = args.output_log
-    if context["judge_endpoint"]:
+        api_config["raft_model_name"] = args.raft_model_name
+    api_config["judge_endpoint"] = args.judge_endpoint
+    api_config["output_log"] = args.output_log
+    api_config["api_key"] = args.api_key
+    if api_config["judge_endpoint"]:
         logging.info(f"Use local vllm service for judge at port: '{args.judge_endpoint}'.")
-    main(context)
+    main(api_config)

تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 128 - 176
recipes/use_cases/end2end-recipes/raft/evalset.json


+ 10 - 8
recipes/use_cases/end2end-recipes/raft/raft.yaml

@@ -2,19 +2,21 @@ COT_prompt_template: >
   <|begin_of_text|><|start_header_id|>system<|end_header_id|> Answer the following question using the information given in the context below. Here is things to pay attention to:
     - First provide step-by-step reasoning on how to answer the question.
     - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context.
-    - End your response with final answer in the form <ANSWER>: $answer, the answer should be succinct.
-    You MUST begin your final answer with the tag "<ANSWER>:<|eot_id|>
+    - End your response with final answer in the form <ANSWER>: $answer, the answer should less than 60 words.
+    You MUST begin your final answer with the tag "<ANSWER>: <|eot_id|>
   <|start_header_id|>user<|end_header_id|>
   Question: {question}\nContext: {context}\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>
 
 question_prompt_template: >
-  You are a synthetic question-answer pair generator. Given a chunk of context about
+  <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a synthetic question-answer pair generator. Given a chunk of context about
   some topic(s), generate {num_questions} example questions a user could ask and would be answered
-  \using information from the chunk. For example, if the given context was a Wikipedia
+  using information from the chunk. For example, if the given context was a Wikipedia
   paragraph about the United States, an example question could be 'How many states are
   in the United States?
   The questions should be able to be answered in 100 words or less. Include only the
-  questions in your response.
+  questions in your response.<|eot_id|>
+  <|start_header_id|>user<|end_header_id|>
+  Context: {context}\n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
 
 # question_prompt_template: >
 #   You are a language model skilled in creating quiz questions.
@@ -29,13 +31,13 @@ question_prompt_template: >
 #   4. Never use any abbreviation.
 #   5. Include only the questions in your response.
 
-data_dir: "./data"
+data_dir: "/home/kaiwu/work/pytorch/docs"
 
 xml_path: ""
 
-chunk_size: 512
+chunk_size: 1000
 
-questions_per_chunk: 3
+questions_per_chunk: 5
 
 num_distract_docs: 5 # number of distracting documents to add to each chunk
 

+ 6 - 7
recipes/use_cases/end2end-recipes/raft/raft_utils.py

@@ -15,8 +15,10 @@ from langchain_community.document_loaders import SitemapLoader,DirectoryLoader
 from bs4 import BeautifulSoup
 from langchain_openai import ChatOpenAI
 from langchain_core.messages import HumanMessage, SystemMessage
-from langchain_community.llms import VLLMOpenAI
+from langchain_community.llms import ChatOpenAI
 from langchain_core.prompts import ChatPromptTemplate
+from langchain_openai import ChatOpenAI
+
 
 
 # Initialize logging
@@ -126,17 +128,14 @@ def generate_questions(api_config):
     batches_count = len(document_batches)
     total_questions = api_config["questions_per_chunk"] * batches_count
     # use OpenAI API protocol to hanlde the chat request, including local VLLM openai compatible server
-    llm = VLLMOpenAI(
+    llm = ChatOpenAI(
         openai_api_key=key,
         openai_api_base=api_url,
         model_name=api_config["model"],
         temperature=0.0,
         max_tokens=250
         )
-    prompt = api_config['question_prompt_template'].format(num_questions=str(api_config['questions_per_chunk']))
-    system_prompt = SystemMessage(content=prompt)
-    generated_answers = []
-    all_tasks = [[system_prompt, HumanMessage(content=batch)] for batch in document_batches]
+    all_tasks = [api_config['question_prompt_template'].format(num_questions=str(api_config['questions_per_chunk']),context=document) for document in document_batches]
     generated_answers = llm.batch(all_tasks)
     if len(generated_answers) == 0:
         logging.error("No model answers generated. Please check the input context or model configuration in ",model_name)
@@ -163,7 +162,7 @@ def generate_COT(chunk_questions_zip,api_config) -> dict:
             all_tasks.append(prompt)
             chunk_questions.append((document_content,question))
     # use OpenAI API protocol to hanlde the chat request, including local VLLM openai compatible server
-    llm = VLLMOpenAI(
+    llm = ChatOpenAI(
         openai_api_key=api_config["api_key"],
         openai_api_base=api_config["endpoint_url"],
         model_name=api_config["model"],

+ 1 - 0
requirements.txt

@@ -34,3 +34,4 @@ coloredlogs==15.0.1
 sentence_transformers
 faiss-gpu
 unstructured[pdf]
+langchain_openai