123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- use dotenv::dotenv;
- use std::error::Error;
- mod gptapi {
- use async_openai::{
- types::{
- ChatCompletionRequestMessageArgs, CreateChatCompletionRequest,
- CreateChatCompletionRequestArgs, CreateChatCompletionResponse, Role,
- },
- Client,
- };
- use csv::Reader;
- use std::fs::OpenOptions;
- pub struct ChatSession {
- request: CreateChatCompletionRequest,
- client: Client,
- history_path: String,
- }
- impl ChatSession {
- pub fn new(model: String, max_tokens: u16, history_path: String) -> Self {
- let request = CreateChatCompletionRequestArgs::default()
- .max_tokens(max_tokens)
- .model(model)
- .messages([])
- .build();
- if request.is_err() {
- panic!("Error creating request");
- }
- let request = request.unwrap();
- let client = Client::new();
- let mut instance = Self { request, client, history_path };
- instance.load_history_messages();
- instance
- }
- pub fn add_message(
- &mut self,
- message: String,
- role: Role,
- ) -> Result<(), Box<dyn std::error::Error>> {
- let message = ChatCompletionRequestMessageArgs::default()
- .role(role)
- .content(message)
- .build()?;
- self.request.messages.push(message);
- Ok(())
- }
- pub async fn chat(&mut self, message: &str
- ) -> Result<CreateChatCompletionResponse, Box<dyn std::error::Error>> {
- self.add_message(message.to_string(), Role::User)?;
- self.save_history_message(message, Role::User);
- let response = self.client.chat().create(self.request.clone()).await?;
- Ok(response)
- }
- pub fn load_history_messages(&mut self) {
- let mut rdr = Reader::from_path(self.history_path.clone()).unwrap();
-
- for result in rdr.records() {
- let record = result.unwrap();
- let role = match &record[0] as &str {
- "user" => Role::User,
- "system" => Role::System,
- "assistant" => Role::Assistant,
- _ => Role::User,
- };
- self.add_message(record[1].to_string(), role).unwrap();
- }
- }
-
- pub fn save_history_message(&mut self, message: &str, role: Role) {
- let file = OpenOptions::new()
- .write(true)
- .append(true)
- .open(self.history_path.clone())
- .unwrap();
-
- let mut wtr = csv::Writer::from_writer(file);
- let role = match role {
- Role::User => "user",
- Role::System => "system",
- Role::Assistant => "assistant",
- };
- wtr.write_record(&[role, message]).unwrap();
- }
- }
- }
- #[tokio::main]
- async fn main() -> Result<(), Box<dyn Error>> {
- use gptapi::ChatSession;
- dotenv().ok();
- let history_path = "history.csv".to_string();
- let mut session = ChatSession::new(
- "gpt-3.5-turbo".to_string(),
- 512u16,
- history_path);
- //loop and take input from the user
- loop {
- let mut input = String::new();
- println!("PROMPT: ");
- std::io::stdin().read_line(&mut input)?;
- let response = session.chat(input.as_str()).await?;
- let message = response.choices[0].message.content.clone();
- print!("<{}>: {}\n", "Bot", message);
-
- }
- }
|