|
@@ -0,0 +1,125 @@
|
|
|
|
+use dotenv::dotenv;
|
|
|
|
+use std::error::Error;
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+mod gptapi {
|
|
|
|
+ use async_openai::{
|
|
|
|
+ types::{
|
|
|
|
+ ChatCompletionRequestMessageArgs, CreateChatCompletionRequest,
|
|
|
|
+ CreateChatCompletionRequestArgs, CreateChatCompletionResponse, Role,
|
|
|
|
+ },
|
|
|
|
+ Client,
|
|
|
|
+ };
|
|
|
|
+ use csv::Reader;
|
|
|
|
+ use std::fs::OpenOptions;
|
|
|
|
+
|
|
|
|
+ pub struct ChatSession {
|
|
|
|
+ request: CreateChatCompletionRequest,
|
|
|
|
+ client: Client,
|
|
|
|
+ history_path: String,
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ impl ChatSession {
|
|
|
|
+ pub fn new(model: String, max_tokens: u16, history_path: String) -> Self {
|
|
|
|
+ let request = CreateChatCompletionRequestArgs::default()
|
|
|
|
+ .max_tokens(max_tokens)
|
|
|
|
+ .model(model)
|
|
|
|
+ .messages([])
|
|
|
|
+ .build();
|
|
|
|
+
|
|
|
|
+ if request.is_err() {
|
|
|
|
+ panic!("Error creating request");
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ let request = request.unwrap();
|
|
|
|
+ let client = Client::new();
|
|
|
|
+
|
|
|
|
+ let mut instance = Self { request, client, history_path };
|
|
|
|
+ instance.load_history_messages();
|
|
|
|
+ instance
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ pub fn add_message(
|
|
|
|
+ &mut self,
|
|
|
|
+ message: String,
|
|
|
|
+ role: Role,
|
|
|
|
+ ) -> Result<(), Box<dyn std::error::Error>> {
|
|
|
|
+ let message = ChatCompletionRequestMessageArgs::default()
|
|
|
|
+ .role(role)
|
|
|
|
+ .content(message)
|
|
|
|
+ .build()?;
|
|
|
|
+ self.request.messages.push(message);
|
|
|
|
+ Ok(())
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ pub async fn chat(&mut self, message: &str
|
|
|
|
+ ) -> Result<CreateChatCompletionResponse, Box<dyn std::error::Error>> {
|
|
|
|
+ self.add_message(message.to_string(), Role::User)?;
|
|
|
|
+ self.save_history_message(message, Role::User);
|
|
|
|
+ let response = self.client.chat().create(self.request.clone()).await?;
|
|
|
|
+ Ok(response)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ pub fn load_history_messages(&mut self) {
|
|
|
|
+ let mut rdr = Reader::from_path(self.history_path.clone()).unwrap();
|
|
|
|
+
|
|
|
|
+ for result in rdr.records() {
|
|
|
|
+ let record = result.unwrap();
|
|
|
|
+ let role = match &record[0] as &str {
|
|
|
|
+ "user" => Role::User,
|
|
|
|
+ "system" => Role::System,
|
|
|
|
+ "assistant" => Role::Assistant,
|
|
|
|
+ _ => Role::User,
|
|
|
|
+ };
|
|
|
|
+ self.add_message(record[1].to_string(), role).unwrap();
|
|
|
|
+
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ pub fn save_history_message(&mut self, message: &str, role: Role) {
|
|
|
|
+ let file = OpenOptions::new()
|
|
|
|
+ .write(true)
|
|
|
|
+ .append(true)
|
|
|
|
+ .open(self.history_path.clone())
|
|
|
|
+ .unwrap();
|
|
|
|
+
|
|
|
|
+ let mut wtr = csv::Writer::from_writer(file);
|
|
|
|
+ let role = match role {
|
|
|
|
+ Role::User => "user",
|
|
|
|
+ Role::System => "system",
|
|
|
|
+ Role::Assistant => "assistant",
|
|
|
|
+ };
|
|
|
|
+ wtr.write_record(&[role, message]).unwrap();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+#[tokio::main]
|
|
|
|
+async fn main() -> Result<(), Box<dyn Error>> {
|
|
|
|
+ use gptapi::ChatSession;
|
|
|
|
+
|
|
|
|
+ dotenv().ok();
|
|
|
|
+
|
|
|
|
+ let history_path = "history.csv".to_string();
|
|
|
|
+
|
|
|
|
+ let mut session = ChatSession::new(
|
|
|
|
+ "gpt-3.5-turbo".to_string(),
|
|
|
|
+ 512u16,
|
|
|
|
+ history_path);
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ //loop and take input from the user
|
|
|
|
+ loop {
|
|
|
|
+ let mut input = String::new();
|
|
|
|
+ println!("PROMPT: ");
|
|
|
|
+ std::io::stdin().read_line(&mut input)?;
|
|
|
|
+ let response = session.chat(input.as_str()).await?;
|
|
|
|
+ let message = response.choices[0].message.content.clone();
|
|
|
|
+
|
|
|
|
+ print!("<{}>: {}\n", "Bot", message);
|
|
|
|
+
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+}
|