research_analyzer.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. import os
  2. import requests
  3. import json
  4. import time
  5. import io
  6. import re
  7. import gradio as gr
  8. import PyPDF2
  9. from together import Together
  10. def download_pdf(url, save_path=None):
  11. if url is None or 'arxiv.org' not in url:
  12. return None
  13. response = requests.get(url)
  14. if save_path:
  15. with open(save_path, 'wb') as f:
  16. f.write(response.content)
  17. return response.content
  18. def extract_arxiv_pdf_url(arxiv_url):
  19. # Check if URL is already in PDF format
  20. if 'arxiv.org/pdf/' in arxiv_url:
  21. return arxiv_url
  22. # Extract arxiv_id from different URL formats
  23. arxiv_id = None
  24. if 'arxiv.org/abs/' in arxiv_url:
  25. arxiv_id = arxiv_url.split('arxiv.org/abs/')[1].split()[0]
  26. elif 'arxiv.org/html/' in arxiv_url:
  27. arxiv_id = arxiv_url.split('arxiv.org/html/')[1].split()[0]
  28. if arxiv_id:
  29. return f"https://arxiv.org/pdf/{arxiv_id}.pdf"
  30. return None # Return None if no valid arxiv_id found
  31. def extract_text_from_pdf(pdf_content):
  32. pdf_file = io.BytesIO(pdf_content)
  33. reader = PyPDF2.PdfReader(pdf_file)
  34. text = ""
  35. for page in reader.pages:
  36. text += page.extract_text() + "\n"
  37. return text
  38. def extract_references_with_llm(pdf_content):
  39. # Extract text from PDF
  40. text = extract_text_from_pdf(pdf_content)
  41. # Truncate if too long
  42. max_length = 50000
  43. if len(text) > max_length:
  44. text = text[:max_length] + "..."
  45. client = Together(api_key="Your API key here")
  46. citations = client.chat.completions.create(
  47. model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
  48. messages=[
  49. {
  50. "role":"user",
  51. "content":f"Extract all the arXiv citations from Reference section of the paper including their title, authors and origins. Paper: {text} "
  52. }
  53. ],
  54. temperature=0.3,
  55. )
  56. # Prepare prompt for Llama 4
  57. prompt = f"""
  58. Extract the arXiv ID from the list of citations provided, including preprint arXiv ID. If there is no arXiv ID presented with the list, skip that citations.
  59. Here are some examples on arXiv ID format:
  60. 1. arXiv preprint arXiv:1607.06450, where 1607.06450 is the arXiv ID.
  61. 2. CoRR, abs/1409.0473, where 1409.0473 is the arXiv ID.
  62. Then, return a JSON array of objects with 'title' and 'ID' fields strictly in the following format, only return the paper title if it's arXiv ID is extracted:
  63. Output format: [{{\"title\": \"Paper Title\", \"ID\": \"arXiv ID\"}}]
  64. DO NOT return any other text.
  65. List of citations:
  66. {citations}
  67. """
  68. response = client.chat.completions.create(
  69. model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
  70. messages=[
  71. {
  72. "role":"user",
  73. "content":prompt
  74. }
  75. ],
  76. temperature=0.3,
  77. )
  78. response_json = response.choices[0].message.content
  79. # Convert the JSON string to a Python object
  80. references = []
  81. try:
  82. references = json.loads(response_json)
  83. # Now you can work with `references` as a Python list or dictionary
  84. except json.JSONDecodeError as e:
  85. print(f"Error decoding JSON: {e}")
  86. return references
  87. # Check if ref_id is a valid arXiv ID
  88. def is_valid_arxiv_id(ref_id):
  89. # arXiv IDs are typically in the format of "1234.56789" or "1234567"
  90. return bool(re.match(r'^\d{4}\.\d{4,5}$', ref_id) or re.match(r'^\d{7}$', ref_id))
  91. def download_arxiv_paper_and_citations(arxiv_url, download_dir, progress=None):
  92. if not os.path.exists(download_dir):
  93. os.makedirs(download_dir)
  94. if progress:
  95. progress("Downloading main paper PDF...")
  96. # Download main paper PDF
  97. pdf_url = extract_arxiv_pdf_url(arxiv_url)
  98. main_pdf_path = os.path.join(download_dir, 'main_paper.pdf')
  99. main_pdf_content = download_pdf(pdf_url, main_pdf_path)
  100. if main_pdf_content is None:
  101. if progress:
  102. progress("Invalid Url. Valid example: https://arxiv.org/abs/1706.03762v7")
  103. return None, 0
  104. if progress:
  105. progress("Main paper downloaded. Extracting references...")
  106. # Extract references using LLM
  107. references = extract_references_with_llm(main_pdf_content)
  108. if progress:
  109. progress(f"Found {len(references)} references. Downloading...")
  110. time.sleep(1)
  111. # Download reference PDFs
  112. all_pdf_paths = [main_pdf_path]
  113. for i, reference in enumerate(references):
  114. ref_title = reference.get("title")
  115. ref_id = reference.get("ID")
  116. if ref_id and is_valid_arxiv_id(ref_id):
  117. ref_url = f'https://arxiv.org/pdf/{ref_id}'
  118. ref_pdf_path = os.path.join(download_dir, f'{ref_title}.pdf')
  119. if progress:
  120. progress(f"Downloading reference {i+1}/{len(references)}...{ref_title}")
  121. time.sleep(0.2)
  122. try:
  123. download_pdf(ref_url, ref_pdf_path)
  124. all_pdf_paths.append(ref_pdf_path)
  125. except Exception as e:
  126. if progress:
  127. progress(f"Error downloading {ref_url}: {str(e)}")
  128. time.sleep(0.2)
  129. # Create a list of all PDF paths
  130. paths_file = os.path.join(download_dir, 'pdf_paths.txt')
  131. with open(paths_file, 'w', encoding='utf-8') as f:
  132. f.write('\n'.join(all_pdf_paths))
  133. if progress:
  134. progress(f"All papers downloaded. Total references: {len(references)}")
  135. time.sleep(1)
  136. return paths_file, len(references)
  137. def ingest_paper_with_llama(paths_file, progress=None):
  138. total_text = ""
  139. total_word_count = 0
  140. if progress:
  141. progress("Ingesting paper content...")
  142. with open(paths_file, 'r', encoding='utf-8') as f:
  143. pdf_paths = f.read().splitlines()
  144. for i, pdf_path in enumerate(pdf_paths):
  145. if progress:
  146. progress(f"Ingesting PDF {i+1}/{len(pdf_paths)}...")
  147. with open(pdf_path, 'rb') as pdf_file:
  148. pdf_content = pdf_file.read()
  149. text = extract_text_from_pdf(pdf_content)
  150. total_text += text + "\n\n"
  151. total_word_count += len(text.split())
  152. if progress:
  153. progress("Paper ingestion complete!")
  154. return total_text, total_word_count
  155. def gradio_interface():
  156. paper_content = {"text": ""}
  157. def process(arxiv_url, progress=gr.Progress()):
  158. download_dir = 'downloads'
  159. progress(0, "Starting download...")
  160. paper_path, num_references = download_arxiv_paper_and_citations(arxiv_url, download_dir,
  161. lambda msg: progress(0.3, msg))
  162. if paper_path is None:
  163. return "Invalid Url. Valid example: https://arxiv.org/abs/1706.03762v7"
  164. paper_content["text"], total_word_count = ingest_paper_with_llama(paper_path,
  165. lambda msg: progress(0.7, msg))
  166. progress(1.0, "Ready for chat!")
  167. return f"Total {total_word_count} words and {num_references} reference ingested. You can now chat about the paper and citations."
  168. def respond(message, history):
  169. user_message = message
  170. if not user_message:
  171. return history, ""
  172. # Append user message immediately
  173. history.append([user_message, ""])
  174. client = Together(api_key="Your API key here")
  175. # Prepare the system prompt and user message
  176. system_prompt = f"""
  177. You are a research assistant that have access to the paper reference below.
  178. Answer questions based on your knowledge on these references.
  179. If you do not know the answer, say you don't know.
  180. paper reference: {paper_content["text"]}
  181. """
  182. stream = client.chat.completions.create(
  183. model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
  184. messages=[
  185. {"role": "system", "content": system_prompt},
  186. {"role": "user", "content": user_message}
  187. ],
  188. temperature=0.3,
  189. stream=True # Enable streaming
  190. )
  191. # Initialize an empty response
  192. full_response = ""
  193. # Stream the response chunks
  194. for chunk in stream:
  195. if len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
  196. content = chunk.choices[0].delta.content
  197. full_response += content
  198. # Update the last message in history with the current response
  199. history[-1][1] = full_response
  200. yield history,""
  201. def clear_chat_history():
  202. return [], ""
  203. with gr.Blocks(css=".orange-button {background-color: #FF7C00 !important; color: white;}") as demo:
  204. gr.Markdown("# Research Analyzer")
  205. with gr.Column():
  206. input_text = gr.Textbox(label="ArXiv URL")
  207. status_text = gr.Textbox(label="Status", interactive=False)
  208. submit_btn = gr.Button("Ingest", elem_classes="orange-button")
  209. submit_btn.click(fn=process, inputs=input_text, outputs=status_text)
  210. gr.Markdown("## Chat with Llama")
  211. chatbot = gr.Chatbot()
  212. with gr.Row():
  213. msg = gr.Textbox(label="Ask about the paper", scale=5)
  214. submit_chat_btn = gr.Button("➤", elem_classes="orange-button", scale=1)
  215. submit_chat_btn.click(respond, [msg, chatbot], [chatbot, msg])
  216. msg.submit(respond, [msg, chatbot], [chatbot, msg])
  217. def copy_last_response(history):
  218. if history and len(history) > 0:
  219. last_response = history[-1][1]
  220. return gr.update(value=last_response)
  221. return gr.update(value="No response to copy")
  222. demo.launch()
  223. if __name__ == "__main__":
  224. gradio_interface()