| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290 | import osimport requestsimport jsonimport timeimport ioimport reimport gradio as grimport PyPDF2from together import Togetherdef download_pdf(url, save_path=None):    if url is None or 'arxiv.org' not in url:        return None    response = requests.get(url)    if save_path:        with open(save_path, 'wb') as f:            f.write(response.content)    return response.contentdef extract_arxiv_pdf_url(arxiv_url):    # Check if URL is already in PDF format    if 'arxiv.org/pdf/' in arxiv_url:        return arxiv_url        # Extract arxiv_id from different URL formats    arxiv_id = None    if 'arxiv.org/abs/' in arxiv_url:        arxiv_id = arxiv_url.split('arxiv.org/abs/')[1].split()[0]    elif 'arxiv.org/html/' in arxiv_url:        arxiv_id = arxiv_url.split('arxiv.org/html/')[1].split()[0]        if arxiv_id:        return f"https://arxiv.org/pdf/{arxiv_id}.pdf"        return None  # Return None if no valid arxiv_id founddef extract_text_from_pdf(pdf_content):    pdf_file = io.BytesIO(pdf_content)    reader = PyPDF2.PdfReader(pdf_file)    text = ""    for page in reader.pages:        text += page.extract_text() + "\n"    return textdef extract_references_with_llm(pdf_content):    # Extract text from PDF    text = extract_text_from_pdf(pdf_content)        # Truncate if too long    max_length = 50000    if len(text) > max_length:        text = text[:max_length] + "..."    client = Together(api_key="Your API key here")    citations = client.chat.completions.create(        model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",        messages=[            {                "role":"user",                "content":f"Extract all the arXiv citations from Reference section of the paper including their title, authors and origins. Paper: {text} "            }        ],        temperature=0.3,    )        # Prepare prompt for Llama 4    prompt = f"""                   Extract the arXiv ID from the list of citations provided, including preprint arXiv ID. If there is no arXiv ID presented with the list, skip that citations.                                Here are some examples on arXiv ID format:                1. arXiv preprint arXiv:1607.06450, where 1607.06450 is the arXiv ID.                2. CoRR, abs/1409.0473, where 1409.0473 is the arXiv ID.                Then, return a JSON array of objects with 'title' and 'ID' fields strictly in the following format, only return the paper title if it's arXiv ID is extracted:                Output format: [{{\"title\": \"Paper Title\", \"ID\": \"arXiv ID\"}}]                DO NOT return any other text.                List of citations:                {citations}                """        response = client.chat.completions.create(        model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",        messages=[            {                "role":"user",                "content":prompt            }        ],        temperature=0.3,    )    response_json = response.choices[0].message.content    # Convert the JSON string to a Python object    references = []    try:        references = json.loads(response_json)     # Now you can work with `references` as a Python list or dictionary    except json.JSONDecodeError as e:        print(f"Error decoding JSON: {e}")    return references# Check if ref_id is a valid arXiv IDdef is_valid_arxiv_id(ref_id):    # arXiv IDs are typically in the format of "1234.56789" or "1234567"    return bool(re.match(r'^\d{4}\.\d{4,5}$', ref_id) or re.match(r'^\d{7}$', ref_id))def download_arxiv_paper_and_citations(arxiv_url, download_dir, progress=None):    if not os.path.exists(download_dir):        os.makedirs(download_dir)        if progress:        progress("Downloading main paper PDF...")        # Download main paper PDF    pdf_url = extract_arxiv_pdf_url(arxiv_url)    main_pdf_path = os.path.join(download_dir, 'main_paper.pdf')    main_pdf_content = download_pdf(pdf_url, main_pdf_path)        if main_pdf_content is None:        if progress:            progress("Invalid Url. Valid example: https://arxiv.org/abs/1706.03762v7")        return None, 0    if progress:        progress("Main paper downloaded. Extracting references...")        # Extract references using LLM    references = extract_references_with_llm(main_pdf_content)    if progress:        progress(f"Found {len(references)} references. Downloading...")        time.sleep(1)        # Download reference PDFs    all_pdf_paths = [main_pdf_path]    for i, reference in enumerate(references):        ref_title = reference.get("title")        ref_id = reference.get("ID")        if ref_id and is_valid_arxiv_id(ref_id):            ref_url = f'https://arxiv.org/pdf/{ref_id}'            ref_pdf_path = os.path.join(download_dir, f'{ref_title}.pdf')            if progress:                progress(f"Downloading reference {i+1}/{len(references)}...{ref_title}")                time.sleep(0.2)            try:                download_pdf(ref_url, ref_pdf_path)                all_pdf_paths.append(ref_pdf_path)            except Exception as e:                if progress:                    progress(f"Error downloading {ref_url}: {str(e)}")                    time.sleep(0.2)        # Create a list of all PDF paths    paths_file = os.path.join(download_dir, 'pdf_paths.txt')    with open(paths_file, 'w', encoding='utf-8') as f:        f.write('\n'.join(all_pdf_paths))        if progress:        progress(f"All papers downloaded. Total references: {len(references)}")        time.sleep(1)    return paths_file, len(references)def ingest_paper_with_llama(paths_file, progress=None):    total_text = ""    total_word_count = 0    if progress:        progress("Ingesting paper content...")    with open(paths_file, 'r', encoding='utf-8') as f:        pdf_paths = f.read().splitlines()            for i, pdf_path in enumerate(pdf_paths):        if progress:            progress(f"Ingesting PDF {i+1}/{len(pdf_paths)}...")        with open(pdf_path, 'rb') as pdf_file:            pdf_content = pdf_file.read()            text = extract_text_from_pdf(pdf_content)            total_text += text + "\n\n"            total_word_count += len(text.split())    if progress:        progress("Paper ingestion complete!")    return total_text, total_word_count    def gradio_interface():    paper_content = {"text": ""}        def process(arxiv_url, progress=gr.Progress()):        download_dir = 'downloads'        progress(0, "Starting download...")        paper_path, num_references = download_arxiv_paper_and_citations(arxiv_url, download_dir,                                                       lambda msg: progress(0.3, msg))        if paper_path is None:            return "Invalid Url. Valid example: https://arxiv.org/abs/1706.03762v7"                    paper_content["text"], total_word_count = ingest_paper_with_llama(paper_path,                                                       lambda msg: progress(0.7, msg))        progress(1.0, "Ready for chat!")        return f"Total {total_word_count} words and {num_references} reference ingested. You can now chat about the paper and citations."    def respond(message, history):        user_message = message        if not user_message:            return history, ""                # Append user message immediately        history.append([user_message, ""])                client = Together(api_key="Your API key here")        # Prepare the system prompt and user message        system_prompt = f"""                    You are a research assistant that have access to the paper reference below.                    Answer questions based on your knowledge on these references.                    If you do not know the answer, say you don't know.                    paper reference: {paper_content["text"]}                    """                stream = client.chat.completions.create(            model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",            messages=[                {"role": "system", "content": system_prompt},                {"role": "user", "content": user_message}            ],            temperature=0.3,            stream=True  # Enable streaming        )                # Initialize an empty response        full_response = ""                # Stream the response chunks        for chunk in stream:            if len(chunk.choices) > 0 and chunk.choices[0].delta.content is not None:                content = chunk.choices[0].delta.content                full_response += content                # Update the last message in history with the current response                history[-1][1] = full_response                yield history,""                def clear_chat_history():        return [], ""        with gr.Blocks(css=".orange-button {background-color: #FF7C00 !important; color: white;}") as demo:        gr.Markdown("# Research Analyzer")        with gr.Column():            input_text = gr.Textbox(label="ArXiv URL")            status_text = gr.Textbox(label="Status", interactive=False)            submit_btn = gr.Button("Ingest", elem_classes="orange-button")            submit_btn.click(fn=process, inputs=input_text, outputs=status_text)                        gr.Markdown("## Chat with Llama")            chatbot = gr.Chatbot()        with gr.Row():            msg = gr.Textbox(label="Ask about the paper", scale=5)            submit_chat_btn = gr.Button("➤", elem_classes="orange-button", scale=1)                    submit_chat_btn.click(respond, [msg, chatbot], [chatbot, msg])        msg.submit(respond, [msg, chatbot], [chatbot, msg])                    def copy_last_response(history):            if history and len(history) > 0:                last_response = history[-1][1]                return gr.update(value=last_response)            return gr.update(value="No response to copy")                demo.launch()if __name__ == "__main__":    gradio_interface()
 |