raft.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import mdc
  2. from mdc import MDC
  3. import logging
  4. from typing import Literal, Any
  5. from openai import OpenAI
  6. import datasets
  7. from datasets import Dataset, load_dataset
  8. import json
  9. import random
  10. import os, shutil
  11. import argparse
  12. import asyncio
  13. from raft_utils import generate_questions, add_chunk_to_dataset
  14. from chat_utils import OctoAIChatService, VllmChatService
  15. from format import DatasetConverter, datasetFormats, outputDatasetTypes
  16. from config import load_config
  17. # def generate_label(client: OpenAI, question: str, context: Any, doctype: DocType = "pdf", model: str = None) -> str | None:
  18. # """
  19. # Generates the label / answer to `question` using `context` and GPT-4.
  20. # """
  21. # question = encode_question(question, context) if doctype == "api" else encode_question_gen(question, context)
  22. # response = client.chat.completions.create(
  23. # model=model,
  24. # messages=question,
  25. # n=1,
  26. # temperature=0
  27. # )
  28. # response = response.choices[0].message.content
  29. # return response
  30. # Configure logging to include the timestamp, log level, and message
  31. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  32. async def main(context):
  33. if context["endpoint"]:
  34. chat_service = VllmChatService()
  35. else:
  36. chat_service = OctoAIChatService()
  37. try:
  38. logging.info("Starting to generate question pair.")
  39. # Generate question/answer pairs as list
  40. chunks = await generate_questions(chat_service, context)
  41. if not chunks:
  42. logging.warning("No questions generated from text. Please check the input context or model configuration.")
  43. return
  44. logging.info(f"Successfully generated {sum([len(q) for q in chunks])} question/answer pairs.")
  45. print(chunks)
  46. for i, chunk in enumerate(chunks):
  47. perc = ceil(i / num_chunks * 100)
  48. with MDC(progress=f"{perc}%"):
  49. logger.info(f"Adding chunk {i}/{num_chunks}")
  50. add_chunk_to_dataset(client, chunks, chunk, args.doctype, args.questions, NUM_DISTRACT_DOCS, model=args.completion_model)
  51. logging.info(f"Data successfully written to {context['output']}. Process completed.")
  52. except Exception as e:
  53. logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
  54. def parse_arguments():
  55. # Define command line arguments for the script
  56. parser = argparse.ArgumentParser(
  57. description="Generate question/answer pairs from documentation."
  58. )
  59. parser.add_argument(
  60. "-t", "--questions_per_chunk",
  61. type=int,
  62. default=3,
  63. help="Specify the number of question pairs to generate per chunk."
  64. )
  65. parser.add_argument(
  66. "-m", "--model",
  67. choices=["meta-llama-3-70b-instruct","meta-llama-3-8b-instruct","llama-2-13b-chat", "llama-2-70b-chat"],
  68. default="meta-llama-3-70b-instruct",
  69. help="Select the model to use for generation."
  70. )
  71. parser.add_argument(
  72. "-c", "--config_path",
  73. default="./raft.yaml",
  74. help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
  75. )
  76. parser.add_argument(
  77. "-v", "--vllm_endpoint",
  78. default=None,
  79. type=int,
  80. help="If a port is specified, then use local vllm endpoint for generating question/answer pairs."
  81. )
  82. parser.add_argument("--chunk_size", type=int, default=512, help="The size of each chunk in number of tokens")
  83. parser.add_argument("-o","--output", type=str, default="./", help="The path at which to save the dataset")
  84. parser.add_argument("--output-format", type=str, default="hf", help="Format to convert the dataset to. Defaults to hf.", choices=datasetFormats)
  85. parser.add_argument("--output-type", type=str, default="jsonl", help="Type to export the dataset to. Defaults to jsonl.", choices=outputDatasetTypes)
  86. return parser.parse_args()
  87. if __name__ == "__main__":
  88. logging.info("Initializing the process and loading configuration...")
  89. args = parse_arguments()
  90. context = load_config(args.config_path)
  91. context["questions_per_chunk"] = args.questions_per_chunk
  92. context["model"] = args.model
  93. context["chunk_size"] = args.chunk_size
  94. context["endpoint"] = args.vllm_endpoint
  95. context["output"] = args.output
  96. logging.info(f"Configuration loaded. Generating {args.questions_per_chunk} question per chunk using model '{args.model}'.")
  97. if context["endpoint"]:
  98. logging.info(f"Use local vllm service at port: '{args.vllm_endpoint}'.")
  99. asyncio.run(main(context))