Browse Source

added refusal and adjusted prompts

Kai Wu 11 months ago
parent
commit
af53ee051e

File diff suppressed because it is too large
+ 21 - 19
recipes/use_cases/end2end-recipes/raft/README.md


+ 2 - 6
recipes/use_cases/end2end-recipes/raft/raft.py

@@ -1,7 +1,4 @@
 import logging
 import logging
-from typing import Literal, Any
-import json
-import random
 import os
 import os
 import argparse
 import argparse
 from raft_utils import generate_questions, add_chunk_to_dataset
 from raft_utils import generate_questions, add_chunk_to_dataset
@@ -10,8 +7,6 @@ from config import load_config
 
 
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
 
-NUM_DISTRACT_DOCS = 5 # number of distracting documents to add to each chunk
-ORCALE_P = 0.8 # probability of related documents to be added to each chunk
 def main(api_config):
 def main(api_config):
     ds = None
     ds = None
     try:
     try:
@@ -26,7 +21,7 @@ def main(api_config):
             for question in questions:
             for question in questions:
                 logging.info(f"Question: {question}")
                 logging.info(f"Question: {question}")
         logging.info(f"Successfully generated {sum([len(q) for c,q in chunk_questions_zip])} question/answer pairs.")
         logging.info(f"Successfully generated {sum([len(q) for c,q in chunk_questions_zip])} question/answer pairs.")
-        ds = add_chunk_to_dataset(chunk_questions_zip,api_config,ds,NUM_DISTRACT_DOCS, ORCALE_P)
+        ds = add_chunk_to_dataset(chunk_questions_zip,api_config,ds)
         ds.save_to_disk(args.output)
         ds.save_to_disk(args.output)
         logging.info(f"Data successfully written to {api_config['output']}. Process completed.")
         logging.info(f"Data successfully written to {api_config['output']}. Process completed.")
         formatter = DatasetConverter()
         formatter = DatasetConverter()
@@ -92,6 +87,7 @@ if __name__ == "__main__":
         api_config["api_key"] = os.environ["API_KEY"]
         api_config["api_key"] = os.environ["API_KEY"]
     logging.info(f"Configuration loaded. Generating {args.questions_per_chunk} question per chunk using model '{args.model}'.")
     logging.info(f"Configuration loaded. Generating {args.questions_per_chunk} question per chunk using model '{args.model}'.")
     logging.info(f"Chunk size: {args.chunk_size}.")
     logging.info(f"Chunk size: {args.chunk_size}.")
+    logging.info(f"num_distract_docs: {api_config['num_distract_docs']}, orcale_p: {api_config['orcale_p']}")
     logging.info(f"Will use endpoint_url: {args.endpoint_url}.")
     logging.info(f"Will use endpoint_url: {args.endpoint_url}.")
     logging.info(f"Output will be written to {args.output}.")
     logging.info(f"Output will be written to {args.output}.")
     main(api_config)
     main(api_config)

+ 34 - 31
recipes/use_cases/end2end-recipes/raft/raft.yaml

@@ -1,40 +1,43 @@
 COT_prompt_template: >
 COT_prompt_template: >
-  <|begin_of_text|><|start_header_id|>system<|end_header_id|> Answer the following question using the information given in the context below. Here is things to pay attention to:
-    - First provide step-by-step reasoning on how to answer the question.
-    - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context.
-    - End your response with final answer in the form <ANSWER>: $answer, the answer should less than 60 words.
-    You MUST begin your final answer with the tag "<ANSWER>: <|eot_id|>
+  <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a helpful chatbot who can provide an answer to every questions from the user given a relevant context.<|eot_id|>
   <|start_header_id|>user<|end_header_id|>
   <|start_header_id|>user<|end_header_id|>
-  Question: {question}\nContext: {context}\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>
-
-# question_prompt_template: >
-#   <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a synthetic question-answer pair generator. Given a chunk of context about
-#   some topic(s), generate {num_questions} example questions a user could ask and would be answered
-#   using information from the chunk. For example, if the given context was a Wikipedia
-#   paragraph about the United States, an example question could be 'How many states are
-#   in the United States?
-#   The questions should be able to be answered in 100 words or less. Include only the
-#   questions in your response.<|eot_id|>
-#   <|start_header_id|>user<|end_header_id|>
-#   Context: {context}\n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
+  Question: {question}\nContext: {context}\n
+  Answer this question using the information given by multiple documents in the context above. Here is things to pay attention to:
+  - The context contains many documents, each document starts with <DOCUMENT> and ends </DOCUMENT>.
+  - First provide step-by-step reasoning on how to answer the question.
+  - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context.
+  - End your response with final answer in the form <ANSWER>: $answer, the answer should less than 60 words.
+  You MUST begin your final answer with the tag "<ANSWER> <|eot_id|><|start_header_id|>assistant<|end_header_id|>
 
 
 question_prompt_template: >
 question_prompt_template: >
-  <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a language model skilled in creating quiz questions.
-  You will be provided with a document,
-  read it and please generate factoid question and answer pairs that are most likely be asked by a user of Llama language models
-  which includes LLama, Llama2, Meta Llama3, Code Llama, Meta Llama Guard 1,	Meta Llama Guard 2
-  Your factoid questions should be answerable with a specific, concise piece of factual information from the context.
-  Your factoid questions should be formulated in the same style as questions users could ask in a search engine.
-  This means that your factoid questions MUST NOT mention something like "according to the passage" or "context".
-  please make sure you follow those rules:
-  1. Generate {num_questions} question answer pairs, you can generate less answer if there is nothing related to
-  model, training, fine-tuning and evaluation details of Llama language models,
-  2. The questions can be answered based *solely* on the given passage.
-  3. Avoid asking questions with similar meaning.
-  4. Never use any abbreviation.
-  5. The questions should be able to be answered in 60 words or less. Include only the questions in your response. <|eot_id|>
+  <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a synthetic question-answer pair generator. Given a chunk of context about
+  some topic(s), generate {num_questions} example questions a user could ask and would be answered
+  using information from the chunk. For example, if the given context was a Wikipedia
+  paragraph about the United States, an example question could be 'How many states are
+  in the United States?
+  Your questions should be formulated in the same style as questions that users could ask in a search engine.
+  This means that your questions MUST NOT mention something like "according to the passage" or "context".
+  The questions should be able to be answered in 60 words or less. Include only the questions in your response.<|eot_id|>
   <|start_header_id|>user<|end_header_id|>
   <|start_header_id|>user<|end_header_id|>
   Context: {context}\n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
   Context: {context}\n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
+
+# question_prompt_template: >
+#   <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a language model skilled in creating quiz questions.
+#   You will be provided with a document,
+#   read it and please generate factoid question and answer pairs that are most likely be asked by a user of Llama language models
+#   which includes LLama, Llama2, Meta Llama3, Code Llama, Meta Llama Guard 1,	Meta Llama Guard 2
+#   Your factoid questions should be answerable with a specific, concise piece of factual information from the context.
+#   Your factoid questions should be formulated in the same style as questions users could ask in a search engine.
+#   This means that your factoid questions MUST NOT mention something like "according to the passage" or "context".
+#   please make sure you follow those rules:
+#   1. Generate {num_questions} question answer pairs, you can generate less answer if there is nothing related to
+#   model, training, fine-tuning and evaluation details of Llama language models,
+#   2. The questions can be answered based *solely* on the given passage.
+#   3. Avoid asking questions with similar meaning.
+#   4. Never use any abbreviation.
+#   5. The questions should be able to be answered in 60 words or less. Include only the questions in your response. <|eot_id|>
+#   <|start_header_id|>user<|end_header_id|>
+#   Context: {context}\n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
 data_dir: "./data"
 data_dir: "./data"
 
 
 xml_path: ""
 xml_path: ""

+ 1 - 1
recipes/use_cases/end2end-recipes/raft/raft_eval.py

@@ -8,7 +8,7 @@ import json
 from langchain_openai import ChatOpenAI
 from langchain_openai import ChatOpenAI
 from langchain_community.embeddings import HuggingFaceEmbeddings
 from langchain_community.embeddings import HuggingFaceEmbeddings
 from langchain_community.vectorstores import FAISS
 from langchain_community.vectorstores import FAISS
-from langchain.text_splitter import RecursiveCharacterTextSplitter,TokenTextSplitter
+from langchain.text_splitter import RecursiveCharacterTextSplitter
 from langchain_community.vectorstores.utils import DistanceStrategy
 from langchain_community.vectorstores.utils import DistanceStrategy
 from datetime import datetime
 from datetime import datetime
 from langchain_community.document_loaders import DirectoryLoader
 from langchain_community.document_loaders import DirectoryLoader

+ 9 - 8
recipes/use_cases/end2end-recipes/raft/raft_eval_config.yaml

@@ -2,7 +2,7 @@ eval_prompt_template: >
   <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a AI assistant that skilled in answering questions related to Llama language models,
   <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a AI assistant that skilled in answering questions related to Llama language models,
   which includes LLama, Llama2, Meta Llama3, Code Llama, Meta Llama Guard 1,	Meta Llama Guard 2,
   which includes LLama, Llama2, Meta Llama3, Code Llama, Meta Llama Guard 1,	Meta Llama Guard 2,
   Below is a question from a llama user, please the answer it with best of your knowledge,
   Below is a question from a llama user, please the answer it with best of your knowledge,
-  The returned answer should be no more than 100 words.Please return the answers in text directly without any special tokens.<|eot_id|>
+  The returned answer should be no more than 60 words. Please return the answers in text directly without any special tokens.<|eot_id|>
   <|start_header_id|>user<|end_header_id|>
   <|start_header_id|>user<|end_header_id|>
   Question:{question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
   Question:{question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
 judge_prompt_template: >
 judge_prompt_template: >
@@ -19,14 +19,15 @@ judge_prompt_template: >
     <|start_header_id|>user<|end_header_id|>
     <|start_header_id|>user<|end_header_id|>
     Question: {question} \n Teacher's Answer: {gold} \n Student's Answer: {prediction} <|eot_id|><|start_header_id|>assistant<|end_header_id|>
     Question: {question} \n Teacher's Answer: {gold} \n Student's Answer: {prediction} <|eot_id|><|start_header_id|>assistant<|end_header_id|>
 RAG_prompt_template: >
 RAG_prompt_template: >
-  <|begin_of_text|><|start_header_id|>system<|end_header_id|> Answer the following question using the information given in the context below. Here is things to pay attention to:
-    1.The context contains many documents, each document starts with <DOCUMENT> and ends </DOCUMENT>.
-    2.First provide step-by-step reasoning on how to answer the question.
-    3.In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context.
-    4.End your response with final answer in the form <ANSWER>: $answer, the answer should less than 60 words.
-    You MUST begin your final answer with the tag "<ANSWER>:<|eot_id|>
+  <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a helpful chatbot who can provide an answer to every questions from the user given a relevant context.<|eot_id|>
   <|start_header_id|>user<|end_header_id|>
   <|start_header_id|>user<|end_header_id|>
-  Question: {question}\nContext: {context}\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>
+  Question: {question}\nContext: {context}\n
+  Answer this question using the information given by multiple documents in the context above. Here is things to pay attention to:
+  - The context contains many documents, each document starts with <DOCUMENT> and ends </DOCUMENT>.
+  - First provide step-by-step reasoning on how to answer the question.
+  - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context.
+  - End your response with final answer in the form <ANSWER>: $answer, the answer should less than 60 words.
+  You MUST begin your final answer with the tag "<ANSWER> <|eot_id|><|start_header_id|>assistant<|end_header_id|>
 eval_file: "./eval_llama.json"
 eval_file: "./eval_llama.json"
 
 
 model_name: "raft-8b"
 model_name: "raft-8b"

+ 18 - 9
recipes/use_cases/end2end-recipes/raft/raft_utils.py

@@ -9,7 +9,7 @@ from datasets import Dataset
 import random
 import random
 from langchain_community.document_loaders import SitemapLoader,DirectoryLoader
 from langchain_community.document_loaders import SitemapLoader,DirectoryLoader
 from bs4 import BeautifulSoup
 from bs4 import BeautifulSoup
-
+import copy
 from langchain_openai import ChatOpenAI
 from langchain_openai import ChatOpenAI
 
 
 
 
@@ -171,12 +171,12 @@ def add_chunk_to_dataset(
     chunk_questions_zip: list,
     chunk_questions_zip: list,
     api_config: dict,
     api_config: dict,
     ds,
     ds,
-    num_distract: int = 3,
-    p: float = 0.8,
 ) -> None:
 ) -> None:
     """
     """
     Given a chunk and related questions lists, create {Q, A, D} triplets and add them to the dataset.
     Given a chunk and related questions lists, create {Q, A, D} triplets and add them to the dataset.
     """
     """
+    num_distract = api_config["num_distract_docs"]
+    p = api_config["oracle_p"]
     chunks = [chunk for chunk, _ in chunk_questions_zip]
     chunks = [chunk for chunk, _ in chunk_questions_zip]
     COT_results = generate_COT(chunk_questions_zip,api_config)
     COT_results = generate_COT(chunk_questions_zip,api_config)
     for chunk, q , cot in COT_results:
     for chunk, q , cot in COT_results:
@@ -198,12 +198,8 @@ def add_chunk_to_dataset(
         indices.remove(i)
         indices.remove(i)
         for j in random.sample(indices, num_distract):
         for j in random.sample(indices, num_distract):
             docs.append(chunks[j])
             docs.append(chunks[j])
-        # decides whether to add oracle document
-        oracle = random.uniform(0, 1) < p
-        if not oracle:
-            docs[0] = chunks[random.sample(indices, 1)[0]]
+        doc_copy = docs.copy()
         random.shuffle(docs)
         random.shuffle(docs)
-
         d = {
         d = {
             "title": [],
             "title": [],
             "sentences": []
             "sentences": []
@@ -221,7 +217,7 @@ def add_chunk_to_dataset(
         context += q
         context += q
         # This instruction will be used in the fine-tuning stage
         # This instruction will be used in the fine-tuning stage
         datapt["instruction"] = context
         datapt["instruction"] = context
-
+        datapt_copy = copy.deepcopy(datapt)
         # add to dataset
         # add to dataset
         if not ds:
         if not ds:
             # init ds
             # init ds
@@ -235,4 +231,17 @@ def add_chunk_to_dataset(
             ds = Dataset.from_dict(datapt)
             ds = Dataset.from_dict(datapt)
         else:
         else:
             ds = ds.add_item(datapt)
             ds = ds.add_item(datapt)
+        # decides whether to add refusal example where the related documents are not provided
+        oracle = random.uniform(0, 1) < p
+        if not oracle:
+            doc_copy[0] = chunks[random.sample(indices, 1)[0]]
+            random.shuffle(doc_copy)
+            context = ""
+            for doc in doc_copy:
+                context += "<DOCUMENT>" + str(doc) + "</DOCUMENT>\n"
+            context += q
+            # This instruction will be used in the fine-tuning stage
+            datapt_copy["instruction"] = context
+            datapt_copy["cot_answer"] = "Sorry, I don't know the answer to this question because related documents are not found. Please try again."
+            ds.add_item(datapt_copy)
     return ds
     return ds