raft.py 4.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import mdc
  2. from mdc import MDC
  3. import logging
  4. from typing import Literal, Any
  5. from openai import OpenAI
  6. import json
  7. import random
  8. import os, shutil
  9. import argparse
  10. import asyncio
  11. from raft_utils import generate_questions, add_chunk_to_dataset
  12. from chat_utils import OctoAIChatService, VllmChatService
  13. from format import DatasetConverter, datasetFormats, outputDatasetTypes
  14. from config import load_config
  15. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  16. NUM_DISTRACT_DOCS = 5 # number of distracting documents to add to each chunk
  17. ORCALE_P = 0.8 # probability of related documents to be added to each chunk
  18. async def main(context):
  19. ds = None
  20. if context["endpoint"]:
  21. chat_service = VllmChatService()
  22. else:
  23. chat_service = OctoAIChatService()
  24. try:
  25. logging.info("Starting to generate question pair.")
  26. # Generate questions as list for each chunk
  27. chunk_questions_zip = await generate_questions(chat_service, context)
  28. if not chunk_questions_zip:
  29. logging.warning("No questions generated from text. Please check the input context or model configuration.")
  30. return
  31. for chunk, questions in chunk_questions_zip:
  32. logging.info(f"Chunk: {chunk}, question length: {len(questions)}")
  33. for question in questions:
  34. logging.info(f"Question: {question}")
  35. logging.info(f"Successfully generated {sum([len(q) for c,q in chunk_questions_zip])} question/answer pairs.")
  36. ds = await add_chunk_to_dataset(chunk_questions_zip,context, chat_service,ds,NUM_DISTRACT_DOCS, ORCALE_P)
  37. ds.save_to_disk(args.output)
  38. logging.info(f"Data successfully written to {context['output']}. Process completed.")
  39. formatter = DatasetConverter()
  40. # Extract format specific params
  41. format_params = {}
  42. formatter.convert(ds=ds, format=args.output_format, output_path=args.output+"raft", output_type=args.output_type, params=format_params)
  43. except Exception as e:
  44. logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
  45. def parse_arguments():
  46. # Define command line arguments for the script
  47. parser = argparse.ArgumentParser(
  48. description="Generate question/answer pairs from documentation."
  49. )
  50. parser.add_argument(
  51. "-t", "--questions_per_chunk",
  52. type=int,
  53. default=3,
  54. help="Specify the number of question pairs to generate per chunk."
  55. )
  56. parser.add_argument(
  57. "-m", "--model",
  58. choices=["meta-llama-3-70b-instruct","meta-llama-3-8b-instruct","llama-2-13b-chat", "llama-2-70b-chat"],
  59. default="meta-llama-3-70b-instruct",
  60. help="Select the model to use for generation."
  61. )
  62. parser.add_argument(
  63. "-c", "--config_path",
  64. default="./raft.yaml",
  65. help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
  66. )
  67. parser.add_argument(
  68. "-v", "--vllm_endpoint",
  69. default=None,
  70. type=int,
  71. help="If a port is specified, then use local vllm endpoint for generating question/answer pairs."
  72. )
  73. parser.add_argument("--chunk_size", type=int, default=512, help="The size of each chunk in number of tokens")
  74. parser.add_argument("-o","--output", type=str, default="./", help="The path at which to save the dataset")
  75. parser.add_argument("--output-format", type=str, default="hf", help="Format to convert the dataset to. Defaults to hf.", choices=datasetFormats)
  76. parser.add_argument("--output-type", type=str, default="jsonl", help="Type to export the dataset to. Defaults to jsonl.", choices=outputDatasetTypes)
  77. return parser.parse_args()
  78. if __name__ == "__main__":
  79. logging.info("Initializing the process and loading configuration...")
  80. args = parse_arguments()
  81. context = load_config(args.config_path)
  82. context["questions_per_chunk"] = args.questions_per_chunk
  83. context["model"] = args.model
  84. context["chunk_size"] = args.chunk_size
  85. context["endpoint"] = args.vllm_endpoint
  86. context["output"] = args.output
  87. logging.info(f"Configuration loaded. Generating {args.questions_per_chunk} question per chunk using model '{args.model}'.")
  88. if context["endpoint"]:
  89. logging.info(f"Use local vllm service at port: '{args.vllm_endpoint}'.")
  90. asyncio.run(main(context))