{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "import os\n", "import torch\n", "from pathlib import Path\n", "import re\n", "from transformers import pipeline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Use Llama 3.3 for high-quality rewriting\n", "DEFAULT_MODEL = \"meta-llama/Llama-3.3-70B-Instruct\" \n", "\n", "# Set up directories\n", "base_dir = Path(\"llama_data\")\n", "reports_dir = base_dir / \"final_reports\"\n", "rewritten_dir = base_dir / \"rewritten_reports\"\n", "rewritten_dir.mkdir(exist_ok=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ORIGINAL_USER_INPUT = " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "SYS_PROMPT =" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "text_pipeline = pipeline(\n", " \"text-generation\",\n", " model=DEFAULT_MODEL,\n", " model_kwargs={\"torch_dtype\": torch.bfloat16},\n", " device_map=\"auto\",\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_report_files():\n", " \"\"\"Get all report files from the final_reports directory\"\"\"\n", " report_files = []\n", " \n", " if not reports_dir.exists():\n", " print(f\"Error: Reports directory '{reports_dir}' not found.\")\n", " return []\n", " \n", " for file in reports_dir.glob(\"report_*.txt\"):\n", " report_files.append(file)\n", " \n", " # Sort by report number to process in consistent order\n", " report_files.sort()\n", " \n", " print(f\"Found {len(report_files)} report files.\")\n", " return report_files" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def rewrite_report(report_path, original_user_input):\n", " \"\"\"Rewrite a single report using Llama 3.3\"\"\"\n", " \n", " # Extract report number and title from filename\n", " report_filename = report_path.name\n", " report_id = re.search(r'report_(\\d+)_', report_filename).group(1) if re.search(r'report_(\\d+)_', report_filename) else \"unknown\"\n", " \n", " print(f\"Processing report {report_id}: {report_filename}\")\n", " \n", " # Read the report content\n", " try:\n", " with open(report_path, \"r\", encoding=\"utf-8\") as f:\n", " report_content = f.read()\n", " except Exception as e:\n", " print(f\"Error reading report file: {str(e)}\")\n", " return None\n", " \n", " # Extract the report title from the content\n", " title_match = re.search(r'^# (.+?)\\n', report_content)\n", " report_title = title_match.group(1) if title_match else \"Unknown Report\"\n", " \n", " # Create the prompt for rewriting\n", " user_prompt = f\"\"\"\n", "{original_user_input}\n", "\n", "Here's the report to rewrite:\n", "\n", "{report_content}\n", "\"\"\"\n", " \n", " # Set up the conversation\n", " conversation = [\n", " {\"role\": \"system\", \"content\": SYS_PROMPT},\n", " {\"role\": \"user\", \"content\": user_prompt}\n", " ]\n", " \n", " # Generate the rewritten report\n", " output = text_pipeline(\n", " conversation,\n", " max_new_tokens=4000,\n", " temperature=0.7,\n", " do_sample=True,\n", " )\n", " \n", " # Extract the assistant's response\n", " assistant_response = output[0][\"generated_text\"][-1]\n", " rewritten_content = assistant_response[\"content\"]\n", " \n", " # Save the rewritten report\n", " rewritten_filename = f\"rewritten_{report_id}_{report_title.replace(' ', '_')[:30]}.txt\"\n", " rewritten_path = rewritten_dir / rewritten_filename\n", " \n", " with open(rewritten_path, \"w\", encoding=\"utf-8\") as f:\n", " f.write(rewritten_content)\n", " \n", " print(f\" Saved rewritten report to: {rewritten_path}\")\n", " return rewritten_path" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def rewrite_all_reports():\n", " \"\"\"Rewrite all reports in the final_reports directory\"\"\"\n", " \n", " # Get all report files\n", " report_files = get_report_files()\n", " \n", " if not report_files:\n", " print(\"No reports to rewrite.\")\n", " return []\n", " \n", " rewritten_paths = []\n", " \n", " # Process each report file\n", " for report_path in report_files:\n", " # Rewrite the report\n", " rewritten_path = rewrite_report(report_path, ORIGINAL_USER_INPUT)\n", " \n", " if rewritten_path:\n", " rewritten_paths.append(rewritten_path)\n", " \n", " # IMPORTANT: Force garbage collection to ensure no history is kept between reports\n", " # This simulates removing the old report from input history\n", " import gc\n", " gc.collect()\n", " \n", " print(\"\\n--------\\n\")\n", " \n", " return rewritten_paths" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"Starting report rewriting process...\")\n", "\n", "# Validate user input\n", "if ORIGINAL_USER_INPUT == \"[Your input here - describe what you want Llama to do with these reports]\":\n", " print(\"Please edit the ORIGINAL_USER_INPUT cell to specify your request before running this cell.\")\n", "else:\n", " # Process all reports\n", " rewritten_paths = rewrite_all_reports()\n", " \n", " print(\"\\nReport rewriting complete!\")\n", " print(f\"Rewritten {len(rewritten_paths)} reports.\")\n", " \n", " # List all rewritten reports\n", " if rewritten_paths:\n", " print(\"\\nRewritten reports:\")\n", " for path in rewritten_paths:\n", " print(f\"- {path}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }