main.rs 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. use dotenv::dotenv;
  2. use std::error::Error;
  3. mod gptapi {
  4. use async_openai::{
  5. types::{
  6. ChatCompletionRequestMessageArgs, CreateChatCompletionRequest,
  7. CreateChatCompletionRequestArgs, CreateChatCompletionResponse, Role,
  8. },
  9. Client,
  10. };
  11. use csv::Reader;
  12. use std::fs::OpenOptions;
  13. pub struct ChatSession {
  14. request: CreateChatCompletionRequest,
  15. client: Client,
  16. history_path: String,
  17. }
  18. impl ChatSession {
  19. pub fn new(model: String, max_tokens: u16, history_path: String) -> Self {
  20. let request = CreateChatCompletionRequestArgs::default()
  21. .max_tokens(max_tokens)
  22. .model(model)
  23. .messages([])
  24. .build();
  25. if request.is_err() {
  26. panic!("Error creating request");
  27. }
  28. let request = request.unwrap();
  29. let client = Client::new();
  30. let mut instance = Self { request, client, history_path };
  31. instance.load_history_messages();
  32. instance
  33. }
  34. pub fn add_message(
  35. &mut self,
  36. message: String,
  37. role: Role,
  38. ) -> Result<(), Box<dyn std::error::Error>> {
  39. let message = ChatCompletionRequestMessageArgs::default()
  40. .role(role)
  41. .content(message)
  42. .build()?;
  43. self.request.messages.push(message);
  44. Ok(())
  45. }
  46. pub async fn chat(&mut self, message: &str
  47. ) -> Result<CreateChatCompletionResponse, Box<dyn std::error::Error>> {
  48. self.add_message(message.to_string(), Role::User)?;
  49. self.save_history_message(message, Role::User);
  50. let response = self.client.chat().create(self.request.clone()).await?;
  51. Ok(response)
  52. }
  53. pub fn load_history_messages(&mut self) {
  54. let mut rdr = Reader::from_path(self.history_path.clone()).unwrap();
  55. for result in rdr.records() {
  56. let record = result.unwrap();
  57. let role = match &record[0] as &str {
  58. "user" => Role::User,
  59. "system" => Role::System,
  60. "assistant" => Role::Assistant,
  61. _ => Role::User,
  62. };
  63. self.add_message(record[1].to_string(), role).unwrap();
  64. }
  65. }
  66. pub fn save_history_message(&mut self, message: &str, role: Role) {
  67. let file = OpenOptions::new()
  68. .write(true)
  69. .append(true)
  70. .open(self.history_path.clone())
  71. .unwrap();
  72. let mut wtr = csv::Writer::from_writer(file);
  73. let role = match role {
  74. Role::User => "user",
  75. Role::System => "system",
  76. Role::Assistant => "assistant",
  77. };
  78. wtr.write_record(&[role, message]).unwrap();
  79. }
  80. }
  81. }
  82. #[tokio::main]
  83. async fn main() -> Result<(), Box<dyn Error>> {
  84. use gptapi::ChatSession;
  85. dotenv().ok();
  86. let history_path = "history.csv".to_string();
  87. let mut session = ChatSession::new(
  88. "gpt-3.5-turbo".to_string(),
  89. 512u16,
  90. history_path);
  91. //loop and take input from the user
  92. loop {
  93. let mut input = String::new();
  94. println!("PROMPT: ");
  95. std::io::stdin().read_line(&mut input)?;
  96. let response = session.chat(input.as_str()).await?;
  97. let message = response.choices[0].message.content.clone();
  98. print!("<{}>: {}\n", "Bot", message);
  99. }
  100. }