|
@@ -0,0 +1,237 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "import json\n",
|
|
|
+ "import os\n",
|
|
|
+ "import torch\n",
|
|
|
+ "from pathlib import Path\n",
|
|
|
+ "import re\n",
|
|
|
+ "from transformers import pipeline"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Use Llama 3.3 for high-quality rewriting\n",
|
|
|
+ "DEFAULT_MODEL = \"meta-llama/Llama-3.3-70B-Instruct\" \n",
|
|
|
+ "\n",
|
|
|
+ "# Set up directories\n",
|
|
|
+ "base_dir = Path(\"llama_data\")\n",
|
|
|
+ "reports_dir = base_dir / \"final_reports\"\n",
|
|
|
+ "rewritten_dir = base_dir / \"rewritten_reports\"\n",
|
|
|
+ "rewritten_dir.mkdir(exist_ok=True)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "ORIGINAL_USER_INPUT = "
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "SYS_PROMPT ="
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "text_pipeline = pipeline(\n",
|
|
|
+ " \"text-generation\",\n",
|
|
|
+ " model=DEFAULT_MODEL,\n",
|
|
|
+ " model_kwargs={\"torch_dtype\": torch.bfloat16},\n",
|
|
|
+ " device_map=\"auto\",\n",
|
|
|
+ ")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def get_report_files():\n",
|
|
|
+ " \"\"\"Get all report files from the final_reports directory\"\"\"\n",
|
|
|
+ " report_files = []\n",
|
|
|
+ " \n",
|
|
|
+ " if not reports_dir.exists():\n",
|
|
|
+ " print(f\"Error: Reports directory '{reports_dir}' not found.\")\n",
|
|
|
+ " return []\n",
|
|
|
+ " \n",
|
|
|
+ " for file in reports_dir.glob(\"report_*.txt\"):\n",
|
|
|
+ " report_files.append(file)\n",
|
|
|
+ " \n",
|
|
|
+ " # Sort by report number to process in consistent order\n",
|
|
|
+ " report_files.sort()\n",
|
|
|
+ " \n",
|
|
|
+ " print(f\"Found {len(report_files)} report files.\")\n",
|
|
|
+ " return report_files"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def rewrite_report(report_path, original_user_input):\n",
|
|
|
+ " \"\"\"Rewrite a single report using Llama 3.3\"\"\"\n",
|
|
|
+ " \n",
|
|
|
+ " # Extract report number and title from filename\n",
|
|
|
+ " report_filename = report_path.name\n",
|
|
|
+ " report_id = re.search(r'report_(\\d+)_', report_filename).group(1) if re.search(r'report_(\\d+)_', report_filename) else \"unknown\"\n",
|
|
|
+ " \n",
|
|
|
+ " print(f\"Processing report {report_id}: {report_filename}\")\n",
|
|
|
+ " \n",
|
|
|
+ " # Read the report content\n",
|
|
|
+ " try:\n",
|
|
|
+ " with open(report_path, \"r\", encoding=\"utf-8\") as f:\n",
|
|
|
+ " report_content = f.read()\n",
|
|
|
+ " except Exception as e:\n",
|
|
|
+ " print(f\"Error reading report file: {str(e)}\")\n",
|
|
|
+ " return None\n",
|
|
|
+ " \n",
|
|
|
+ " # Extract the report title from the content\n",
|
|
|
+ " title_match = re.search(r'^# (.+?)\\n', report_content)\n",
|
|
|
+ " report_title = title_match.group(1) if title_match else \"Unknown Report\"\n",
|
|
|
+ " \n",
|
|
|
+ " # Create the prompt for rewriting\n",
|
|
|
+ " user_prompt = f\"\"\"\n",
|
|
|
+ "{original_user_input}\n",
|
|
|
+ "\n",
|
|
|
+ "Here's the report to rewrite:\n",
|
|
|
+ "\n",
|
|
|
+ "{report_content}\n",
|
|
|
+ "\"\"\"\n",
|
|
|
+ " \n",
|
|
|
+ " # Set up the conversation\n",
|
|
|
+ " conversation = [\n",
|
|
|
+ " {\"role\": \"system\", \"content\": SYS_PROMPT},\n",
|
|
|
+ " {\"role\": \"user\", \"content\": user_prompt}\n",
|
|
|
+ " ]\n",
|
|
|
+ " \n",
|
|
|
+ " # Generate the rewritten report\n",
|
|
|
+ " output = text_pipeline(\n",
|
|
|
+ " conversation,\n",
|
|
|
+ " max_new_tokens=4000,\n",
|
|
|
+ " temperature=0.7,\n",
|
|
|
+ " do_sample=True,\n",
|
|
|
+ " )\n",
|
|
|
+ " \n",
|
|
|
+ " # Extract the assistant's response\n",
|
|
|
+ " assistant_response = output[0][\"generated_text\"][-1]\n",
|
|
|
+ " rewritten_content = assistant_response[\"content\"]\n",
|
|
|
+ " \n",
|
|
|
+ " # Save the rewritten report\n",
|
|
|
+ " rewritten_filename = f\"rewritten_{report_id}_{report_title.replace(' ', '_')[:30]}.txt\"\n",
|
|
|
+ " rewritten_path = rewritten_dir / rewritten_filename\n",
|
|
|
+ " \n",
|
|
|
+ " with open(rewritten_path, \"w\", encoding=\"utf-8\") as f:\n",
|
|
|
+ " f.write(rewritten_content)\n",
|
|
|
+ " \n",
|
|
|
+ " print(f\" Saved rewritten report to: {rewritten_path}\")\n",
|
|
|
+ " return rewritten_path"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def rewrite_all_reports():\n",
|
|
|
+ " \"\"\"Rewrite all reports in the final_reports directory\"\"\"\n",
|
|
|
+ " \n",
|
|
|
+ " # Get all report files\n",
|
|
|
+ " report_files = get_report_files()\n",
|
|
|
+ " \n",
|
|
|
+ " if not report_files:\n",
|
|
|
+ " print(\"No reports to rewrite.\")\n",
|
|
|
+ " return []\n",
|
|
|
+ " \n",
|
|
|
+ " rewritten_paths = []\n",
|
|
|
+ " \n",
|
|
|
+ " # Process each report file\n",
|
|
|
+ " for report_path in report_files:\n",
|
|
|
+ " # Rewrite the report\n",
|
|
|
+ " rewritten_path = rewrite_report(report_path, ORIGINAL_USER_INPUT)\n",
|
|
|
+ " \n",
|
|
|
+ " if rewritten_path:\n",
|
|
|
+ " rewritten_paths.append(rewritten_path)\n",
|
|
|
+ " \n",
|
|
|
+ " # IMPORTANT: Force garbage collection to ensure no history is kept between reports\n",
|
|
|
+ " # This simulates removing the old report from input history\n",
|
|
|
+ " import gc\n",
|
|
|
+ " gc.collect()\n",
|
|
|
+ " \n",
|
|
|
+ " print(\"\\n--------\\n\")\n",
|
|
|
+ " \n",
|
|
|
+ " return rewritten_paths"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "print(\"Starting report rewriting process...\")\n",
|
|
|
+ "\n",
|
|
|
+ "# Validate user input\n",
|
|
|
+ "if ORIGINAL_USER_INPUT == \"[Your input here - describe what you want Llama to do with these reports]\":\n",
|
|
|
+ " print(\"Please edit the ORIGINAL_USER_INPUT cell to specify your request before running this cell.\")\n",
|
|
|
+ "else:\n",
|
|
|
+ " # Process all reports\n",
|
|
|
+ " rewritten_paths = rewrite_all_reports()\n",
|
|
|
+ " \n",
|
|
|
+ " print(\"\\nReport rewriting complete!\")\n",
|
|
|
+ " print(f\"Rewritten {len(rewritten_paths)} reports.\")\n",
|
|
|
+ " \n",
|
|
|
+ " # List all rewritten reports\n",
|
|
|
+ " if rewritten_paths:\n",
|
|
|
+ " print(\"\\nRewritten reports:\")\n",
|
|
|
+ " for path in rewritten_paths:\n",
|
|
|
+ " print(f\"- {path}\")"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "Python 3",
|
|
|
+ "language": "python",
|
|
|
+ "name": "python3"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.8.5"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 4
|
|
|
+}
|