Browse Source

make the config file as an arg

Hamid Shojanazeri 1 year ago
parent
commit
b9fce932a1

+ 2 - 3
tutorials/chatbot/data_pipelines/config.py

@@ -4,10 +4,9 @@
 import yaml
 import os
 
-def load_config():
+def load_config(config_path: str = "./config.yaml"):
     # Read the YAML configuration file
-    file_path = "./config.yaml"
-    with open(file_path, "r") as file:
+    with open(config_path, "r") as file:
         config = yaml.safe_load(file)
     # Set the API key from the environment variable
     config["api_key"] = os.environ["OCTOAI_API_TOKEN"]

+ 9 - 4
tutorials/chatbot/data_pipelines/generate_question_answers.py

@@ -67,7 +67,7 @@ async def main(context):
     except Exception as e:
         logging.error(f"An unexpected error occurred during the process: {e}")
 
-def parse_arguments(context):
+def parse_arguments():
     # Define command line arguments for the script
     parser = argparse.ArgumentParser(
         description="Generate question/answer pairs from documentation."
@@ -75,7 +75,7 @@ def parse_arguments(context):
     parser.add_argument(
         "-t", "--total_questions",
         type=int,
-        default=context["total_questions"],
+        default=10,
         help="Specify the number of question/answer pairs to generate."
     )
     parser.add_argument(
@@ -84,13 +84,18 @@ def parse_arguments(context):
         default="llama-2-70b-chat-fp16",
         help="Select the model to use for generation."
     )
+    parser.add_argument(
+        "-c", "--config_path",
+        default="config.yaml",
+        help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
+    )
     return parser.parse_args()
 
 if __name__ == "__main__":
     logging.info("Initializing the process and loading configuration...")
-    context = load_config()
-    args = parse_arguments(context)
+    args = parse_arguments()
 
+    context = load_config(args.config_path)
     context["total_questions"] = args.total_questions
     context["model"] = args.model