generate_question_answers.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
  3. import argparse
  4. import asyncio
  5. import json
  6. from config import load_config
  7. from generator_utils import generate_question_batches, generate_data_curation
  8. from chat_utils import OctoAIChatService, VllmChatService
  9. from itertools import chain
  10. import logging
  11. import aiofiles # Ensure aiofiles is installed for async file operations
  12. # Configure logging to include the timestamp, log level, and message
  13. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  14. async def main(context):
  15. if context["endpoint"]:
  16. chat_service = VllmChatService()
  17. else:
  18. chat_service = OctoAIChatService()
  19. try:
  20. logging.info("Starting to generate question/answer pairs.")
  21. data = await generate_question_batches(chat_service, context)
  22. if not data:
  23. logging.warning("No data generated. Please check the input context or model configuration.")
  24. return
  25. data = list(chain.from_iterable(data))
  26. logging.info(f"Successfully generated {len(data)} question/answer pairs.")
  27. if context["use_curation"]:
  28. logging.info("Starting to do self-curation using LLM.")
  29. data = await generate_data_curation(chat_service, context,data)
  30. logging.info(f"Only {len(data)} question/answer pairs pass the self-curation")
  31. async with aiofiles.open(context['output_path'], "w") as output_file:
  32. await output_file.write(json.dumps(data, indent=4))
  33. logging.info(f"Data successfully written to {context['output_path']}. Process completed.")
  34. except Exception as e:
  35. logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
  36. def parse_arguments():
  37. # Define command line arguments for the script
  38. parser = argparse.ArgumentParser(
  39. description="Generate question/answer pairs from documentation."
  40. )
  41. parser.add_argument(
  42. "-t", "--total_questions",
  43. type=int,
  44. default=100,
  45. help="Specify the total number of question/answer pairs to generate."
  46. )
  47. parser.add_argument(
  48. "-m", "--model",
  49. choices=["meta-llama-3-70b-instruct","meta-llama-3-8b-instruct","llama-2-13b-chat", "llama-2-70b-chat"],
  50. default="meta-llama-3-70b-instruct",
  51. help="Select the model to use for generation."
  52. )
  53. parser.add_argument(
  54. "-c", "--config_path",
  55. default="./generation_config.yaml",
  56. help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
  57. )
  58. parser.add_argument(
  59. "-v", "--vllm_endpoint",
  60. default=None,
  61. type=int,
  62. help="If a port is specified, then use local vllm endpoint for generating question/answer pairs."
  63. )
  64. parser.add_argument(
  65. "-o", "--output_path",
  66. default="./data.json",
  67. help="set the output path for the generated QA pairs. Default is data.json"
  68. )
  69. return parser.parse_args()
  70. if __name__ == "__main__":
  71. logging.info("Initializing the process and loading configuration...")
  72. args = parse_arguments()
  73. context = load_config(args.config_path)
  74. context["total_questions"] = args.total_questions
  75. context["model"] = args.model
  76. context["endpoint"] = args.vllm_endpoint
  77. # If curation prompt is not empty, then use self-curation
  78. context["use_curation"] = len(context["curation_prompt_template"]) > 0
  79. context["output_path"] = args.output_path
  80. logging.info(f"Configuration loaded. Generating {args.total_questions} question/answer pairs using model '{args.model}'.")
  81. if context["endpoint"]:
  82. logging.info(f"Use local vllm service at port: '{args.vllm_endpoint}'.")
  83. asyncio.run(main(context))