generate_question_answers.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
  3. import argparse
  4. import asyncio
  5. import json
  6. from config import load_config
  7. from generator_utils import generate_question_batches, generate_data_curation
  8. from chat_utils import OctoAIChatService, VllmChatService
  9. import logging
  10. import aiofiles # Ensure aiofiles is installed for async file operations
  11. # Configure logging to include the timestamp, log level, and message
  12. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  13. async def main(context):
  14. if context["endpoint"]:
  15. chat_service = VllmChatService()
  16. else:
  17. chat_service = OctoAIChatService()
  18. try:
  19. logging.info("Starting to generate question/answer pairs.")
  20. # Generate question/answer pairs as list
  21. data = await generate_question_batches(chat_service, context)
  22. if not data:
  23. logging.warning("No data generated. Please check the input context or model configuration.")
  24. return
  25. logging.info(f"Successfully generated {len(data)} question/answer pairs.")
  26. if context["use_curation"]:
  27. logging.info("Starting to do self-curation using LLM.")
  28. data = await generate_data_curation(chat_service, context,data)
  29. logging.info(f"Only {len(data)} question/answer pairs pass the self-curation")
  30. async with aiofiles.open(context['output_path'], "w") as output_file:
  31. await output_file.write(json.dumps(data, indent=4))
  32. logging.info(f"Data successfully written to {context['output_path']}. Process completed.")
  33. except Exception as e:
  34. logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
  35. def parse_arguments():
  36. # Define command line arguments for the script
  37. parser = argparse.ArgumentParser(
  38. description="Generate question/answer pairs from documentation."
  39. )
  40. parser.add_argument(
  41. "-t", "--total_questions",
  42. type=int,
  43. default=100,
  44. help="Specify the total number of question/answer pairs to generate."
  45. )
  46. parser.add_argument(
  47. "-m", "--model",
  48. choices=["meta-llama-3-70b-instruct","meta-llama-3-8b-instruct","llama-2-13b-chat", "llama-2-70b-chat"],
  49. default="meta-llama-3-70b-instruct",
  50. help="Select the model to use for generation."
  51. )
  52. parser.add_argument(
  53. "-c", "--config_path",
  54. default="./generation_config.yaml",
  55. help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
  56. )
  57. parser.add_argument(
  58. "-v", "--vllm_endpoint",
  59. default=None,
  60. type=int,
  61. help="If a port is specified, then use local vllm endpoint for generating question/answer pairs."
  62. )
  63. parser.add_argument(
  64. "-o", "--output_path",
  65. default="./data.json",
  66. help="set the output path for the generated QA pairs. Default is data.json"
  67. )
  68. return parser.parse_args()
  69. if __name__ == "__main__":
  70. logging.info("Initializing the process and loading configuration...")
  71. args = parse_arguments()
  72. context = load_config(args.config_path)
  73. context["total_questions"] = args.total_questions
  74. context["model"] = args.model
  75. context["endpoint"] = args.vllm_endpoint
  76. # If curation prompt is not empty, then use self-curation
  77. context["use_curation"] = len(context["curation_prompt_template"]) > 0
  78. context["output_path"] = args.output_path
  79. logging.info(f"Configuration loaded. Generating {args.total_questions} question/answer pairs using model '{args.model}'.")
  80. if context["endpoint"]:
  81. logging.info(f"Use local vllm service at port: '{args.vllm_endpoint}'.")
  82. asyncio.run(main(context))