email_agent.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625
  1. from google.auth.transport.requests import Request
  2. from google_auth_oauthlib.flow import InstalledAppFlow
  3. from googleapiclient.discovery import build
  4. from email.mime.text import MIMEText
  5. from email.mime.multipart import MIMEMultipart
  6. from email.mime.base import MIMEBase
  7. from email import encoders
  8. from bs4 import BeautifulSoup
  9. import os
  10. import pytz
  11. import base64
  12. import pickle
  13. from datetime import datetime, timezone
  14. import json
  15. import ollama
  16. from pypdf import PdfReader
  17. from pathlib import Path
  18. SCOPES = ['https://www.googleapis.com/auth/gmail.readonly', 'https://www.googleapis.com/auth/gmail.compose']
  19. def authenticate_gmail(user_email):
  20. creds = None
  21. token_file = f'token_{user_email}.pickle' # Unique token file for each user
  22. # Load the user's token if it exists
  23. if os.path.exists(token_file):
  24. with open(token_file, 'rb') as token:
  25. creds = pickle.load(token)
  26. # If no valid credentials, prompt the user to log in
  27. if not creds or not creds.valid:
  28. if creds and creds.expired and creds.refresh_token:
  29. creds.refresh(Request())
  30. else:
  31. flow = InstalledAppFlow.from_client_secrets_file('credentials.json', SCOPES)
  32. creds = flow.run_console()
  33. # Save the new credentials to a user-specific token file
  34. with open(token_file, 'wb') as token:
  35. pickle.dump(creds, token)
  36. # Build the Gmail API service
  37. service = build('gmail', 'v1', credentials=creds)
  38. return service
  39. def num_of_emails(query=''):
  40. response = service.users().messages().list(
  41. userId='me',
  42. q=query,
  43. maxResults=1).execute()
  44. return response.get('resultSizeEstimate', 0)
  45. def list_emails(query='', max_results=100):
  46. emails = []
  47. next_page_token = None
  48. while True:
  49. response = service.users().messages().list(
  50. userId=user_id,
  51. maxResults=max_results,
  52. pageToken=next_page_token,
  53. q=query
  54. ).execute()
  55. if 'messages' in response:
  56. for msg in response['messages']:
  57. sender, subject, received_time = get_email_info(msg['id'])
  58. emails.append(
  59. {
  60. "message_id": msg['id'],
  61. "sender": sender,
  62. "subject": subject,
  63. "received_time": received_time
  64. }
  65. )
  66. next_page_token = response.get('nextPageToken')
  67. if not next_page_token:
  68. break
  69. return emails
  70. def get_email_detail(detail, which=''):
  71. if detail == 'body':
  72. return get_email_body(which)
  73. elif detail == 'attachment':
  74. return get_email_attachments(which)
  75. def get_email_body(message_id):
  76. try:
  77. message = service.users().messages().get(
  78. userId=user_id,
  79. id=message_id,
  80. format='full').execute()
  81. # Recursive function to extract the parts
  82. def extract_parts(payload):
  83. text_body = ""
  84. if 'parts' in payload:
  85. for part in payload['parts']:
  86. return extract_parts(part)
  87. else:
  88. mime_type = payload.get('mimeType')
  89. body = payload.get('body', {}).get('data')
  90. if mime_type == 'text/html':
  91. decoded_body = base64.urlsafe_b64decode(body).decode('utf-8')
  92. soup = BeautifulSoup(decoded_body, 'html.parser')
  93. text_body = soup.get_text().strip()
  94. elif mime_type == 'text/plain':
  95. decoded_body = base64.urlsafe_b64decode(body).decode('utf-8')
  96. text_body = decoded_body
  97. return text_body
  98. return extract_parts(message['payload'])
  99. except Exception as e:
  100. print(f"An error occurred: {e}")
  101. return None
  102. def parse_message(message):
  103. payload = message['payload']
  104. headers = payload.get("headers")
  105. subject = None
  106. sender = None
  107. for header in headers:
  108. if header['name'] == 'Subject':
  109. subject = header['value']
  110. elif header['name'] == 'From':
  111. sender = header['value']
  112. internal_date = message.get('internalDate')
  113. utc_time = datetime.fromtimestamp(int(internal_date) / 1000, tz=timezone.utc)
  114. # Convert UTC to the specified timezone
  115. local_timezone = pytz.timezone("America/Los_Angeles")
  116. local_time = utc_time.astimezone(local_timezone)
  117. # Format the local time as a string
  118. received_time = local_time.strftime('%Y-%m-%d %H:%M:%S %Z')
  119. # Check if the email is plain text or multipart
  120. if 'parts' in payload:
  121. # Multipart message - find the text/plain or text/html part
  122. for part in payload['parts']:
  123. if part['mimeType'] == 'text/plain' or part['mimeType'] == 'text/html': # You can also look for 'text/html'
  124. data = part['body']['data']
  125. body = base64.urlsafe_b64decode(data).decode('utf-8')
  126. return sender, subject, received_time, body
  127. elif part['mimeType'] in ['multipart/related', 'multipart/mixed', 'multipart/alternative']:
  128. return sender, subject, received_time, get_email_body(message.get('id'))
  129. else:
  130. # Single part message
  131. data = payload['body']['data']
  132. body = base64.urlsafe_b64decode(data).decode('utf-8')
  133. return sender, subject, received_time, body
  134. def get_email_info(msg_id):
  135. message = service.users().messages().get(
  136. userId=user_id,
  137. id=msg_id,
  138. format='full').execute()
  139. sender, subject, received_time, body = parse_message(message)
  140. return sender, subject, received_time
  141. def reply_email(message_id, reply_text):
  142. # Fetch the original message
  143. original_message = service.users().messages().get(
  144. userId=user_id,
  145. id=message_id,
  146. format='full').execute()
  147. # Get headers
  148. headers = original_message['payload']['headers']
  149. subject = None
  150. to = None
  151. for header in headers:
  152. if header['name'] == 'Subject':
  153. subject = header['value']
  154. if header['name'] == 'From':
  155. to = header['value']
  156. # Create the reply subject
  157. if not subject.startswith("Re: "):
  158. subject = "Re: " + subject
  159. # Compose the reply message
  160. reply_message = MIMEText(reply_text)
  161. reply_message['to'] = to
  162. reply_message['from'] = user_id
  163. reply_message['subject'] = subject
  164. reply_message['In-Reply-To'] = message_id
  165. # Encode and send the message
  166. raw_message = base64.urlsafe_b64encode(reply_message.as_bytes()).decode("utf-8")
  167. body = {'raw': raw_message,
  168. 'threadId': original_message['threadId']}
  169. sent_message = service.users().messages().send(
  170. userId=user_id,
  171. body=body).execute()
  172. print("Reply sent. Message ID:", sent_message['id'])
  173. def forward_email(message_id, forward_to, email_body=None):
  174. """
  175. Forwards an email, preserving the original MIME type, including multipart/related.
  176. """
  177. # Get the original message in 'full' format
  178. original_message = service.users().messages().get(
  179. userId=user_id,
  180. id=message_id,
  181. format='full').execute()
  182. # Extract the payload and headers
  183. payload = original_message.get('payload', {})
  184. headers = payload.get('headers', [])
  185. parts = payload.get('parts', [])
  186. # Get the Subject
  187. subject = next((header['value'] for header in headers if header['name'].lower() == 'subject'), 'No Subject')
  188. # Create a new MIME message for forwarding
  189. mime_message = MIMEMultipart(payload.get('mimeType', 'mixed').split('/')[-1])
  190. mime_message['To'] = forward_to
  191. mime_message['Subject'] = f"Fwd: {subject}"
  192. # Add the optional custom email body
  193. if email_body:
  194. mime_message.attach(MIMEText(email_body, 'plain'))
  195. # Function to fetch attachment data by attachmentId
  196. def fetch_attachment_data(attachment_id, message_id):
  197. attachment = service.users().messages().attachments().get(
  198. userId=user_id, messageId=message_id, id=attachment_id
  199. ).execute()
  200. return base64.urlsafe_b64decode(attachment['data'])
  201. # Rebuild MIME structure
  202. def rebuild_parts(parts):
  203. """
  204. Recursively rebuild MIME parts.
  205. """
  206. if not parts:
  207. return None
  208. for part in parts:
  209. part_mime_type = part.get('mimeType', 'text/plain')
  210. part_body = part.get('body', {})
  211. part_data = part_body.get('data', '')
  212. part_parts = part.get('parts', []) # Sub-parts for multipart types
  213. filename = part.get('filename')
  214. attachment_id = part_body.get('attachmentId')
  215. if part_mime_type.startswith('multipart/'):
  216. # Rebuild nested multipart
  217. sub_multipart = MIMEMultipart(part_mime_type.split('/')[-1])
  218. sub_parts = rebuild_parts(part_parts)
  219. if sub_parts:
  220. for sub_part in sub_parts:
  221. sub_multipart.attach(sub_part)
  222. yield sub_multipart
  223. elif filename and attachment_id:
  224. # Handle attachments
  225. decoded_data = fetch_attachment_data(attachment_id, message_id)
  226. attachment = MIMEBase(*part_mime_type.split('/'))
  227. attachment.set_payload(decoded_data)
  228. encoders.encode_base64(attachment)
  229. attachment.add_header('Content-Disposition', f'attachment; filename="{filename}"')
  230. yield attachment
  231. else:
  232. if part_data:
  233. # Decode and attach non-multipart parts
  234. decoded_data = base64.urlsafe_b64decode(part_data)
  235. if part_mime_type == 'text/plain':
  236. yield MIMEText(decoded_data.decode('utf-8'), 'plain')
  237. elif part_mime_type == 'text/html':
  238. yield MIMEText(decoded_data.decode('utf-8'), 'html')
  239. # Rebuild the main MIME structure
  240. rebuilt_parts = rebuild_parts(parts)
  241. if rebuilt_parts:
  242. for rebuilt_part in rebuilt_parts:
  243. mime_message.attach(rebuilt_part)
  244. # Encode the MIME message to base64
  245. raw = base64.urlsafe_b64encode(mime_message.as_bytes()).decode('utf-8')
  246. # Send the email
  247. forward_body = {'raw': raw}
  248. sent_message = service.users().messages().send(userId=user_id, body=forward_body).execute()
  249. print(f"Message forwarded successfully! Message ID: {sent_message['id']}")
  250. def send_email(action, to, subject, body="", email_id=""):
  251. if action == "compose":
  252. message = MIMEText(body)
  253. message['to'] = to
  254. message['from'] = user_id
  255. message['subject'] = subject
  256. # Encode and send the message
  257. raw_message = base64.urlsafe_b64encode(message.as_bytes()).decode("utf-8")
  258. body = {'raw': raw_message}
  259. sent_message = service.users().messages().send(
  260. userId=user_id,
  261. body=body).execute()
  262. return sent_message['id']
  263. elif action == "reply": # reply or forward; a message id is needed
  264. reply_email(email_id, body)
  265. elif action == "forward":
  266. forward_email(email_id, to, body)
  267. def create_draft(action, to, subject, body="", email_id=""):
  268. if action == "new":
  269. message = MIMEText(body)
  270. message['to'] = to
  271. message['from'] = user_id
  272. message['subject'] = subject
  273. encoded_message = base64.urlsafe_b64encode(message.as_bytes()).decode()
  274. draft_body = {'message': {'raw': encoded_message}}
  275. draft = service.users().drafts().create(
  276. userId=user_id,
  277. body=draft_body).execute()
  278. print(f"Draft created with ID: {draft['id']}")
  279. return draft['id']
  280. elif action == "reply":
  281. return create_reply_draft(email_id, body)
  282. elif action == "forward":
  283. return create_forward_draft(email_id, to, body)
  284. else:
  285. return
  286. def create_reply_draft(message_id, reply_text):
  287. # Fetch the original message
  288. original_message = service.users().messages().get(
  289. userId=user_id,
  290. id=message_id,
  291. format='full').execute()
  292. # Get headers
  293. headers = original_message['payload']['headers']
  294. subject = None
  295. to = None
  296. for header in headers:
  297. if header['name'] == 'Subject':
  298. subject = header['value']
  299. if header['name'] == 'From':
  300. to = header['value']
  301. # Create the reply subject
  302. if not subject.startswith("Re: "):
  303. subject = "Re: " + subject
  304. # Compose the reply message
  305. reply_message = MIMEText(reply_text)
  306. reply_message['to'] = to
  307. reply_message['from'] = user_id
  308. reply_message['subject'] = subject
  309. reply_message['In-Reply-To'] = message_id
  310. encoded_message = base64.urlsafe_b64encode(reply_message.as_bytes()).decode()
  311. draft_body = {'message': {'raw': encoded_message, 'threadId': original_message['threadId']}}
  312. draft = service.users().drafts().create(userId=user_id, body=draft_body).execute()
  313. return draft['id']
  314. def create_forward_draft(message_id, recipient_email, custom_message=None):
  315. # Get the original message
  316. original_message = service.users().messages().get(
  317. userId=user_id,
  318. id=message_id,
  319. format='raw').execute()
  320. # Decode the raw message
  321. raw_message = base64.urlsafe_b64decode(original_message['raw'].encode('utf-8'))
  322. # Prepare the forward header and optional custom message
  323. forward_header = f"----- Forwarded message -----\nFrom: {recipient_email}\n\n"
  324. if custom_message:
  325. forward_header += f"{custom_message}\n\n"
  326. # Combine the forward header with the original message
  327. new_message = forward_header + raw_message.decode('utf-8')
  328. # Encode the combined message into base64 format
  329. encoded_message = base64.urlsafe_b64encode(new_message.encode('utf-8')).decode('utf-8')
  330. draft_body = {'message': {'raw': encoded_message, 'threadId': original_message['threadId']}}
  331. draft = service.users().drafts().create(userId=user_id, body=draft_body).execute()
  332. print(f"Forward draft created with ID: {draft['id']}")
  333. return draft['id']
  334. def send_draft(id):
  335. sent_message = service.users().drafts().send(
  336. userId=user_id,
  337. body={'id': id}
  338. ).execute()
  339. return f"Draft sent with email ID: {sent_message['id']}"
  340. def get_pdf_summary(file_name):
  341. text = pdf2text(file_name)
  342. print("Calling Llama to generate a summary...")
  343. response = llama31(text, "Generate a summary of the input text in 5 sentences.")
  344. return response
  345. def get_email_attachments(message_id, mime_type='application/pdf'):
  346. attachments = []
  347. # Helper function to process email parts
  348. def process_parts(parts):
  349. for part in parts:
  350. if part['mimeType'] in ['multipart/related', 'multipart/mixed', 'multipart/alternative']:
  351. # Recursively process nested parts
  352. if 'parts' in part:
  353. process_parts(part['parts'])
  354. elif 'filename' in part and part['filename']:
  355. if part['mimeType'] == mime_type: # Check for the desired MIME type
  356. attachment_id = part['body'].get('attachmentId')
  357. if attachment_id:
  358. # Get the attachment data
  359. attachment = service.users().messages().attachments().get(
  360. userId=user_id,
  361. messageId=message_id,
  362. id=attachment_id
  363. ).execute()
  364. # Decode the attachment content
  365. file_data = base64.urlsafe_b64decode(attachment['data'].encode('UTF-8'))
  366. with open(part['filename'], "wb") as f:
  367. f.write(file_data)
  368. # Save the attachment information
  369. attachments.append(
  370. {'filename': part['filename'],
  371. 'data': file_data,
  372. 'size': attachment.get('size', 0)
  373. })
  374. # Retrieve the email message
  375. message = service.users().messages().get(
  376. userId=user_id,
  377. id=message_id,
  378. format='full').execute()
  379. payload = message['payload']
  380. # Start processing the parts
  381. if 'parts' in payload:
  382. process_parts(payload['parts'])
  383. rslt = ""
  384. for a in attachments:
  385. rslt += f"{a['filename']} - {a['size']} bytes\n"
  386. return rslt #attachments
  387. def pdf2text(file):
  388. text = ''
  389. try:
  390. with Path(file).open("rb") as f:
  391. reader = PdfReader(f)
  392. text = "\n\n".join([page.extract_text() for page in reader.pages])
  393. except Exception as e:
  394. raise f"Error reading the PDF file: {str(e)}"
  395. print(f"\nPDF text length: {len(text)}\n")
  396. return text
  397. user_email = None
  398. service = None
  399. user_id = 'me'
  400. def set_email_service(gmail):
  401. global user_email
  402. global service
  403. user_email = gmail
  404. service = authenticate_gmail(user_email)
  405. class Agent:
  406. def __init__(self, system_prompt=""):
  407. self.system_prompt = system_prompt
  408. self.messages = []
  409. # agent-specific short term memory, used to answer follow up questions AFTER a list of emails is found matching user's query
  410. self.emails = []
  411. self.draft_id = None
  412. if self.system_prompt:
  413. self.messages.append({"role": "system", "content": system_prompt})
  414. def __call__(self, user_prompt_or_tool_result, is_tool_call=False):
  415. # if it's tool call result, use "ipython" instead of "user" for the role
  416. self.messages.append({"role": ("ipython" if is_tool_call else "user"), "content": user_prompt_or_tool_result})
  417. result = self.llama()
  418. print(f"\nLlama returned: {result}.")
  419. if type(result) == dict: # result is a dict only if it's a tool call spec
  420. function_name = result["function_name"]
  421. func = globals()[function_name]
  422. parameters = result["parameters"]
  423. if function_name == "get_email_detail":
  424. # TODO: parse which - valid values are first, second,
  425. # third, fourth, last, from xxx
  426. if 'id' in parameters.keys():
  427. parameters['which'] = parameters['id']
  428. del parameters['id'] # per the function spec
  429. elif 'which' in parameters.keys():
  430. if 'from ' in parameters['which']:
  431. sender = parameters['which'].split('from ')[-1]
  432. for email in self.emails:
  433. if email['sender'].find(sender) != -1:
  434. parameters['which'] = email['message_id']
  435. break
  436. if 'subject ' in parameters['which']:
  437. subject = parameters['which'].split('subject ')[-1]
  438. # exact match beats substring
  439. for email in self.emails:
  440. if email['subject'].upper() == subject.upper():
  441. parameters['which'] = email['message_id']
  442. break
  443. elif email['subject'].upper().find(subject.upper()) != -1:
  444. parameters['which'] = email['message_id']
  445. elif 'id_' in parameters['which']:
  446. parameters['which'] = parameters['which'].split('id_')[-1]
  447. else:
  448. parameters['which'] = self.emails[-1]['message_id']
  449. elif function_name == "send_draft":
  450. parameters['id'] = self.draft_id
  451. print(f"\nCalling tool to access Gmail API: {function_name}, {parameters}...")
  452. result = func(**parameters)
  453. print(f"\nTool calling returned: {result}")
  454. # convert function calling result to concise summary, offering interactive follow ups,
  455. # for smooth and user friendly experience
  456. if function_name == 'list_emails':
  457. self.emails = result
  458. num = len(result)
  459. if num == 0:
  460. output = "I couldn't find any such emails. What else would you like to do?"
  461. elif num <= 5:
  462. output = f"I found {num} email{'s' if num > 1 else ''} matching your query:\n"
  463. for i, email in enumerate(result, start=1):
  464. output += f"{i}. From: {email['sender']}, Subject: {email['subject']}, Received on: {email['received_time']}\n"
  465. else:
  466. output = f"I found {num} emails matching your query. Here are the first 5 emails:\n"
  467. for i in range(1, 6):
  468. output += f"{i}. From: {result[i-1]['sender']}, Subject: {result[i-1]['subject']}, Received on: {result[i-1]['received_time']}\n"
  469. elif function_name == "get_email_detail":
  470. output = result
  471. elif function_name == "get_pdf_summary":
  472. output = result
  473. elif function_name == "send_email":
  474. output = "Email sent."
  475. elif function_name == "create_draft":
  476. output = "Draft created."
  477. self.draft_id = result
  478. elif function_name == "send_draft":
  479. output = result
  480. print(f"\n-------------------------\n\nAgent: {output}\n")
  481. else:
  482. output = result # direct text, not JSON, response by Llama
  483. # adding this may cause Llama to hallucinate when answering
  484. # follow up questions. e.g. "do i have emails with attachments
  485. # larger than 20mb" got right tool calling response, then
  486. # follow up "larger than 10mb" got hallucinated response.
  487. # self.messages.append({"role": "assistant", "content": output})
  488. # this mitigates the hallucination
  489. self.messages.append({"role": "assistant", "content": str(result)})
  490. return output
  491. def llama(self):
  492. response = ollama.chat(model='llama3.1',
  493. messages = self.messages,
  494. options = {
  495. "temperature": 0.0
  496. }
  497. )
  498. result = response['message']['content']
  499. try:
  500. res = json.loads(result.split("<|python_tag|>")[-1])
  501. function_name = res['name']
  502. parameters = res['parameters']
  503. return {"function_name": function_name,
  504. "parameters": parameters}
  505. except:
  506. return result
  507. def llama31(user_prompt: str, system_prompt = ""):
  508. response = ollama.chat(model='llama3.1',
  509. messages=[
  510. {"role": "system", "content": system_prompt},
  511. {"role": "user", "content": user_prompt},
  512. ],
  513. )
  514. return response['message']['content']