Browse Source

initial commit

Radu Boncea 2 years ago
commit
00872654be
6 changed files with 1584 additions and 0 deletions
  1. 1 0
      .env
  2. 1 0
      .gitignore
  3. 1435 0
      Cargo.lock
  4. 14 0
      Cargo.toml
  5. 8 0
      history.csv
  6. 125 0
      src/main.rs

+ 1 - 0
.env

@@ -0,0 +1 @@
+OPENAI_API_KEY=sk-.....

+ 1 - 0
.gitignore

@@ -0,0 +1 @@
+/target

File diff suppressed because it is too large
+ 1435 - 0
Cargo.lock


+ 14 - 0
Cargo.toml

@@ -0,0 +1,14 @@
+[package]
+name = "openai-rusty"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+async-openai = "0.10.3"
+csv = "1.2.1"
+dotenv = "0.15.0"
+tokio = { version = "1.27.0", features = ["full"] }
+tokio-test = "0.4.2"
+

+ 8 - 0
history.csv

@@ -0,0 +1,8 @@
+user,"what is your name and role?
+"
+user,"What are the planets that orbits our sun?
+"
+user,"Which one is the biggest?
+"
+user,"what was my second question?
+"

+ 125 - 0
src/main.rs

@@ -0,0 +1,125 @@
+use dotenv::dotenv;
+use std::error::Error;
+
+
+mod gptapi {
+    use async_openai::{
+        types::{
+            ChatCompletionRequestMessageArgs, CreateChatCompletionRequest,
+            CreateChatCompletionRequestArgs, CreateChatCompletionResponse, Role,
+        },
+        Client,
+    };
+    use csv::Reader;
+    use std::fs::OpenOptions;
+
+    pub struct ChatSession {
+        request: CreateChatCompletionRequest,
+        client: Client,
+        history_path: String,
+    }
+
+    impl ChatSession {
+        pub fn new(model: String, max_tokens: u16, history_path: String) -> Self {
+            let request = CreateChatCompletionRequestArgs::default()
+                .max_tokens(max_tokens)
+                .model(model)
+                .messages([])
+                .build();
+
+            if request.is_err() {
+                panic!("Error creating request");
+            }
+
+            let request = request.unwrap();
+            let client = Client::new();
+
+            let mut instance = Self { request, client, history_path };
+            instance.load_history_messages();
+            instance
+        }
+
+        pub fn add_message(
+            &mut self,
+            message: String,
+            role: Role,
+            ) -> Result<(), Box<dyn std::error::Error>> {
+            let message = ChatCompletionRequestMessageArgs::default()
+                .role(role)
+                .content(message)
+                .build()?;
+            self.request.messages.push(message);
+            Ok(())
+        }
+
+        pub async fn chat(&mut self, message: &str
+            ) -> Result<CreateChatCompletionResponse, Box<dyn std::error::Error>> {
+            self.add_message(message.to_string(), Role::User)?;
+            self.save_history_message(message, Role::User);
+            let response = self.client.chat().create(self.request.clone()).await?;
+            Ok(response)
+        }
+
+        pub fn load_history_messages(&mut self) {
+            let mut rdr = Reader::from_path(self.history_path.clone()).unwrap();
+            
+            for result in rdr.records() {
+                let record = result.unwrap();
+                let role = match &record[0] as &str {
+                    "user" => Role::User,
+                    "system" => Role::System,
+                    "assistant" => Role::Assistant,
+                    _ => Role::User,
+                };
+                self.add_message(record[1].to_string(), role).unwrap();
+
+            }
+        }
+    
+        pub fn save_history_message(&mut self, message: &str, role: Role) {
+            let file = OpenOptions::new()
+                            .write(true)
+                            .append(true)
+                            .open(self.history_path.clone())
+                            .unwrap();
+    
+            let mut wtr = csv::Writer::from_writer(file);
+            let role = match role {
+                Role::User => "user",
+                Role::System => "system",
+                Role::Assistant => "assistant",
+            };
+            wtr.write_record(&[role, message]).unwrap();
+        }
+
+    }
+}
+
+
+#[tokio::main]
+async fn main() -> Result<(), Box<dyn Error>> {
+    use gptapi::ChatSession;
+
+    dotenv().ok();
+
+    let history_path =  "history.csv".to_string();
+
+    let mut session = ChatSession::new(
+        "gpt-3.5-turbo".to_string(), 
+        512u16, 
+        history_path);
+
+
+    //loop and take input from the user
+    loop {
+        let mut input = String::new();
+        println!("PROMPT: ");
+        std::io::stdin().read_line(&mut input)?;
+        let response = session.chat(input.as_str()).await?;
+        let message = response.choices[0].message.content.clone();
+
+        print!("<{}>: {}\n", "Bot", message);
+        
+    }
+
+}