123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290 |
- import os
- import requests
- import json
- import time
- import io
- import re
- import gradio as gr
- import PyPDF2
- from together import Together
- def download_pdf(url, save_path=None):
- if url is None or 'arxiv.org' not in url:
- return None
- response = requests.get(url)
- if save_path:
- with open(save_path, 'wb') as f:
- f.write(response.content)
- return response.content
- def extract_arxiv_pdf_url(arxiv_url):
- # Check if URL is already in PDF format
- if 'arxiv.org/pdf/' in arxiv_url:
- return arxiv_url
-
- # Extract arxiv_id from different URL formats
- arxiv_id = None
- if 'arxiv.org/abs/' in arxiv_url:
- arxiv_id = arxiv_url.split('arxiv.org/abs/')[1].split()[0]
- elif 'arxiv.org/html/' in arxiv_url:
- arxiv_id = arxiv_url.split('arxiv.org/html/')[1].split()[0]
-
- if arxiv_id:
- return f"https://arxiv.org/pdf/{arxiv_id}.pdf"
-
- return None # Return None if no valid arxiv_id found
- def extract_text_from_pdf(pdf_content):
- pdf_file = io.BytesIO(pdf_content)
- reader = PyPDF2.PdfReader(pdf_file)
- text = ""
- for page in reader.pages:
- text += page.extract_text() + "\n"
- return text
- def extract_references_with_llm(pdf_content):
- # Extract text from PDF
- text = extract_text_from_pdf(pdf_content)
-
- # Truncate if too long
- max_length = 50000
- if len(text) > max_length:
- text = text[:max_length] + "..."
- client = Together(api_key="Your API key here")
- citations = client.chat.completions.create(
- model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
- messages=[
- {
- "role":"user",
- "content":f"Extract all the arXiv citations from Reference section of the paper including their title, authors and origins. Paper: {text} "
- }
- ],
- temperature=0.3,
- )
-
- # Prepare prompt for Llama 4
- prompt = f"""
- Extract the arXiv ID from the list of citations provided, including preprint arXiv ID. If there is no arXiv ID presented with the list, skip that citations.
-
- Here are some examples on arXiv ID format:
- 1. arXiv preprint arXiv:1607.06450, where 1607.06450 is the arXiv ID.
- 2. CoRR, abs/1409.0473, where 1409.0473 is the arXiv ID.
- Then, return a JSON array of objects with 'title' and 'ID' fields strictly in the following format, only return the paper title if it's arXiv ID is extracted:
- Output format: [{{\"title\": \"Paper Title\", \"ID\": \"arXiv ID\"}}]
- DO NOT return any other text.
- List of citations:
- {citations}
- """
-
- response = client.chat.completions.create(
- model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
- messages=[
- {
- "role":"user",
- "content":prompt
- }
- ],
- temperature=0.3,
- )
- response_json = response.choices[0].message.content
- # Convert the JSON string to a Python object
- references = []
- try:
- references = json.loads(response_json)
- # Now you can work with `references` as a Python list or dictionary
- except json.JSONDecodeError as e:
- print(f"Error decoding JSON: {e}")
- return references
- # Check if ref_id is a valid arXiv ID
- def is_valid_arxiv_id(ref_id):
- # arXiv IDs are typically in the format of "1234.56789" or "1234567"
- return bool(re.match(r'^\d{4}\.\d{4,5}$', ref_id) or re.match(r'^\d{7}$', ref_id))
- def download_arxiv_paper_and_citations(arxiv_url, download_dir, progress=None):
- if not os.path.exists(download_dir):
- os.makedirs(download_dir)
-
- if progress:
- progress("Downloading main paper PDF...")
-
- # Download main paper PDF
- pdf_url = extract_arxiv_pdf_url(arxiv_url)
- main_pdf_path = os.path.join(download_dir, 'main_paper.pdf')
- main_pdf_content = download_pdf(pdf_url, main_pdf_path)
-
- if main_pdf_content is None:
- if progress:
- progress("Invalid Url. Valid example: https://arxiv.org/abs/1706.03762v7")
- return None, 0
- if progress:
- progress("Main paper downloaded. Extracting references...")
-
- # Extract references using LLM
- references = extract_references_with_llm(main_pdf_content)
- if progress:
- progress(f"Found {len(references)} references. Downloading...")
- time.sleep(1)
-
- # Download reference PDFs
- all_pdf_paths = [main_pdf_path]
- for i, reference in enumerate(references):
- ref_title = reference.get("title")
- ref_id = reference.get("ID")
- if ref_id and is_valid_arxiv_id(ref_id):
- ref_url = f'https://arxiv.org/pdf/{ref_id}'
- ref_pdf_path = os.path.join(download_dir, f'{ref_title}.pdf')
- if progress:
- progress(f"Downloading reference {i+1}/{len(references)}...{ref_title}")
- time.sleep(0.2)
- try:
- download_pdf(ref_url, ref_pdf_path)
- all_pdf_paths.append(ref_pdf_path)
- except Exception as e:
- if progress:
- progress(f"Error downloading {ref_url}: {str(e)}")
- time.sleep(0.2)
-
- # Create a list of all PDF paths
- paths_file = os.path.join(download_dir, 'pdf_paths.txt')
- with open(paths_file, 'w', encoding='utf-8') as f:
- f.write('\n'.join(all_pdf_paths))
-
- if progress:
- progress(f"All papers downloaded. Total references: {len(references)}")
- time.sleep(1)
- return paths_file, len(references)
- def ingest_paper_with_llama(paths_file, progress=None):
- total_text = ""
- total_word_count = 0
- if progress:
- progress("Ingesting paper content...")
- with open(paths_file, 'r', encoding='utf-8') as f:
- pdf_paths = f.read().splitlines()
-
- for i, pdf_path in enumerate(pdf_paths):
- if progress:
- progress(f"Ingesting PDF {i+1}/{len(pdf_paths)}...")
- with open(pdf_path, 'rb') as pdf_file:
- pdf_content = pdf_file.read()
- text = extract_text_from_pdf(pdf_content)
- total_text += text + "\n\n"
- total_word_count += len(text.split())
- if progress:
- progress("Paper ingestion complete!")
- return total_text, total_word_count
-
- def gradio_interface():
- paper_content = {"text": ""}
-
- def process(arxiv_url, progress=gr.Progress()):
- download_dir = 'downloads'
- progress(0, "Starting download...")
- paper_path, num_references = download_arxiv_paper_and_citations(arxiv_url, download_dir,
- lambda msg: progress(0.3, msg))
- if paper_path is None:
- return "Invalid Url. Valid example: https://arxiv.org/abs/1706.03762v7"
-
- paper_content["text"], total_word_count = ingest_paper_with_llama(paper_path,
- lambda msg: progress(0.7, msg))
- progress(1.0, "Ready for chat!")
- return f"Total {total_word_count} words and {num_references} reference ingested. You can now chat about the paper and citations."
- def respond(message, history):
- user_message = message
- if not user_message:
- return history, ""
-
- # Append user message immediately
- history.append([user_message, ""])
-
- client = Together(api_key="Your API key here")
- # Prepare the system prompt and user message
- system_prompt = f"""
- You are a research assistant that have access to the paper reference below.
- Answer questions based on your knowledge on these references.
- If you do not know the answer, say you don't know.
- paper reference: {paper_content["text"]}
- """
-
- stream = client.chat.completions.create(
- model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
- messages=[
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": user_message}
- ],
- temperature=0.3,
- stream=True # Enable streaming
- )
-
- # Initialize an empty response
- full_response = ""
-
- # Stream the response chunks
- for chunk in stream:
- if len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:
- content = chunk.choices[0].delta.content
- full_response += content
- # Update the last message in history with the current response
- history[-1][1] = full_response
- yield history,""
-
-
- def clear_chat_history():
- return [], ""
-
- with gr.Blocks(css=".orange-button {background-color: #FF7C00 !important; color: white;}") as demo:
- gr.Markdown("# Research Analyzer")
- with gr.Column():
- input_text = gr.Textbox(label="ArXiv URL")
- status_text = gr.Textbox(label="Status", interactive=False)
- submit_btn = gr.Button("Ingest", elem_classes="orange-button")
- submit_btn.click(fn=process, inputs=input_text, outputs=status_text)
-
- gr.Markdown("## Chat with Llama")
- chatbot = gr.Chatbot()
- with gr.Row():
- msg = gr.Textbox(label="Ask about the paper", scale=5)
- submit_chat_btn = gr.Button("➤", elem_classes="orange-button", scale=1)
-
- submit_chat_btn.click(respond, [msg, chatbot], [chatbot, msg])
- msg.submit(respond, [msg, chatbot], [chatbot, msg])
-
- def copy_last_response(history):
- if history and len(history) > 0:
- last_response = history[-1][1]
- return gr.update(value=last_response)
- return gr.update(value="No response to copy")
-
-
- demo.launch()
- if __name__ == "__main__":
- gradio_interface()
|