|
@@ -6,7 +6,6 @@ import sys
|
|
|
import time
|
|
|
|
|
|
import fire
|
|
|
-import gradio as gr
|
|
|
|
|
|
import torch
|
|
|
|
|
@@ -146,6 +145,11 @@ def main(
|
|
|
user_prompt = "\n".join(sys.stdin.readlines())
|
|
|
inference(user_prompt, temperature, top_p, top_k, max_new_tokens)
|
|
|
else:
|
|
|
+ try:
|
|
|
+ import gradio as gr
|
|
|
+ except ImportError:
|
|
|
+ raise ImportError("This part of the recipe requires gradio. Please run `pip install gradio`")
|
|
|
+
|
|
|
gr.Interface(
|
|
|
fn=inference,
|
|
|
inputs=[
|