generate_question_answers.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
  3. import argparse
  4. import asyncio
  5. import json
  6. from config import load_config
  7. from generator_utils import generate_question_batches, generate_data_eval
  8. from chat_utils import OctoAIChatService, VllmChatService
  9. from itertools import chain
  10. import logging
  11. import aiofiles # Ensure aiofiles is installed for async file operations
  12. # Configure logging to include the timestamp, log level, and message
  13. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  14. async def main(context):
  15. if context["endpoint"]:
  16. chat_service = VllmChatService()
  17. else:
  18. chat_service = OctoAIChatService()
  19. try:
  20. logging.info("Starting to generate question/answer pairs.")
  21. data = await generate_question_batches(chat_service, context)
  22. if not data:
  23. logging.warning("No data generated. Please check the input context or model configuration.")
  24. return
  25. flattened_list = list(chain.from_iterable(data))
  26. # with open("data.json") as fp:
  27. # flattened_list = json.load(fp)
  28. logging.info(f"Successfully generated {len(flattened_list)} question/answer pairs.")
  29. # Use asynchronous file operation for writing to the file
  30. # async with aiofiles.open("data.json", "w") as output_file:
  31. # await output_file.write(json.dumps(flattened_list, indent=4))
  32. # logging.info("Data successfully written to 'data.json'. Process completed.")
  33. curated_data = await generate_data_eval(chat_service, context,flattened_list)
  34. logging.info(f"Only {len(curated_data)} question/answer pairs pass the self-curation")
  35. async with aiofiles.open("curated_data.json", "w") as curated_data:
  36. await curated_data.write(json.dumps(flattened_list, indent=4))
  37. logging.info("Data successfully written to 'curated_data.json'. Process completed.")
  38. except Exception as e:
  39. logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
  40. def parse_arguments():
  41. # Define command line arguments for the script
  42. parser = argparse.ArgumentParser(
  43. description="Generate question/answer pairs from documentation."
  44. )
  45. parser.add_argument(
  46. "-t", "--total_questions",
  47. type=int,
  48. default=100,
  49. help="Specify the total number of question/answer pairs to generate."
  50. )
  51. parser.add_argument(
  52. "-m", "--model",
  53. choices=["meta-llama-3-70b-instruct","meta-llama-3-8b-instruct","llama-2-13b-chat", "llama-2-70b-chat"],
  54. default="meta-llama-3-70b-instruct",
  55. help="Select the model to use for generation."
  56. )
  57. parser.add_argument(
  58. "-c", "--config_path",
  59. default="./generation_config.yaml",
  60. help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
  61. )
  62. parser.add_argument(
  63. "-v", "--vllm_endpoint",
  64. default=None,
  65. type=int,
  66. help="If a port is specified, then use local vllm endpoint for generating question/answer pairs."
  67. )
  68. return parser.parse_args()
  69. if __name__ == "__main__":
  70. logging.info("Initializing the process and loading configuration...")
  71. args = parse_arguments()
  72. context = load_config(args.config_path)
  73. context["total_questions"] = args.total_questions
  74. context["model"] = args.model
  75. context["endpoint"] = args.vllm_endpoint
  76. logging.info(f"Configuration loaded. Generating {args.total_questions} question/answer pairs using model '{args.model}'.")
  77. if context["endpoint"]:
  78. logging.info(f"Use local vllm service at port: '{args.vllm_endpoint}'.")
  79. asyncio.run(main(context))