raft.py 4.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import logging
  2. from typing import Literal, Any
  3. import json
  4. import random
  5. import os
  6. import argparse
  7. from raft_utils import generate_questions, add_chunk_to_dataset
  8. from format import DatasetConverter, datasetFormats, outputDatasetTypes
  9. from config import load_config
  10. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  11. NUM_DISTRACT_DOCS = 5 # number of distracting documents to add to each chunk
  12. ORCALE_P = 0.8 # probability of related documents to be added to each chunk
  13. def main(api_config):
  14. ds = None
  15. try:
  16. logging.info("Starting to generate question pair.")
  17. # Generate questions as list for each chunk
  18. chunk_questions_zip = generate_questions(api_config)
  19. if not chunk_questions_zip:
  20. logging.warning("No questions generated from text. Please check the api_config or model configuration.")
  21. return
  22. for chunk, questions in chunk_questions_zip:
  23. logging.info(f"Chunk: {chunk}, question length: {len(questions)}")
  24. for question in questions:
  25. logging.info(f"Question: {question}")
  26. logging.info(f"Successfully generated {sum([len(q) for c,q in chunk_questions_zip])} question/answer pairs.")
  27. ds = add_chunk_to_dataset(chunk_questions_zip,api_config,ds,NUM_DISTRACT_DOCS, ORCALE_P)
  28. ds.save_to_disk(args.output)
  29. logging.info(f"Data successfully written to {api_config['output']}. Process completed.")
  30. formatter = DatasetConverter()
  31. # Extract format specific params
  32. format_params = {}
  33. formatter.convert(ds=ds, format=args.output_format, output_path=args.output+"raft", output_type=args.output_type, params=format_params)
  34. except Exception as e:
  35. logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
  36. def parse_arguments():
  37. # Define command line arguments for the script
  38. parser = argparse.ArgumentParser(
  39. description="Generate RAFT question/answer/context pairs from documentation."
  40. )
  41. parser.add_argument(
  42. "-t", "--questions_per_chunk",
  43. type=int,
  44. default=3,
  45. help="Specify the number of question pairs to generate per chunk."
  46. )
  47. parser.add_argument(
  48. "-m", "--model",
  49. default="meta-llama/Meta-Llama-3-70B-Instruct",
  50. help="Select the model to use for generation."
  51. )
  52. parser.add_argument(
  53. "-c", "--config_path",
  54. default="./raft.yaml",
  55. help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
  56. )
  57. parser.add_argument(
  58. "-u", "--endpoint_url",
  59. default="http://localhost:8001/v1",
  60. type=str,
  61. help="LLM API url for generating question/answer pairs."
  62. )
  63. parser.add_argument(
  64. "-k", "--api_key",
  65. default="EMPTY",
  66. type=str,
  67. help="LLM API key for generating question/answer pairs."
  68. )
  69. parser.add_argument("--chunk_size", type=int, default=512, help="The size of each chunk in number of tokens")
  70. parser.add_argument("-o","--output", type=str, default="./", help="The path at which to save the dataset")
  71. parser.add_argument("--output-format", type=str, default="hf", help="Format to convert the dataset to. Defaults to hf.", choices=datasetFormats)
  72. parser.add_argument("--output-type", type=str, default="jsonl", help="Type to export the dataset to. Defaults to jsonl.", choices=outputDatasetTypes)
  73. return parser.parse_args()
  74. if __name__ == "__main__":
  75. logging.info("Initializing the process and loading configuration...")
  76. args = parse_arguments()
  77. api_config = load_config(args.config_path)
  78. api_config["questions_per_chunk"] = args.questions_per_chunk
  79. api_config["model"] = args.model
  80. api_config["chunk_size"] = args.chunk_size
  81. api_config["endpoint_url"] = args.endpoint_url
  82. api_config["output"] = args.output
  83. api_config["api_key"] = args.api_key
  84. # if OPENAI_API_KEY is defined in the system environment, use it as the API key
  85. if os.environ.get('API_KEY') is not None:
  86. api_config["api_key"] = os.environ["API_KEY"]
  87. logging.info(f"Configuration loaded. Generating {args.questions_per_chunk} question per chunk using model '{args.model}'.")
  88. logging.info(f"Chunk size: {args.chunk_size}.")
  89. logging.info(f"Will use endpoint_url: {args.endpoint_url}.")
  90. logging.info(f"Output will be written to {args.output}.")
  91. main(api_config)