Browse Source

Create Step-6-Report-Rewriting.ipynb

Sanyam Bhutani 3 months ago
parent
commit
f75a748f79
1 changed files with 237 additions and 0 deletions
  1. 237 0
      end-to-end-use-cases/researcher/Step-6-Report-Rewriting.ipynb

+ 237 - 0
end-to-end-use-cases/researcher/Step-6-Report-Rewriting.ipynb

@@ -0,0 +1,237 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import json\n",
+    "import os\n",
+    "import torch\n",
+    "from pathlib import Path\n",
+    "import re\n",
+    "from transformers import pipeline"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Use Llama 3.3 for high-quality rewriting\n",
+    "DEFAULT_MODEL = \"meta-llama/Llama-3.3-70B-Instruct\" \n",
+    "\n",
+    "# Set up directories\n",
+    "base_dir = Path(\"llama_data\")\n",
+    "reports_dir = base_dir / \"final_reports\"\n",
+    "rewritten_dir = base_dir / \"rewritten_reports\"\n",
+    "rewritten_dir.mkdir(exist_ok=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "ORIGINAL_USER_INPUT = "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "SYS_PROMPT ="
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "text_pipeline = pipeline(\n",
+    "    \"text-generation\",\n",
+    "    model=DEFAULT_MODEL,\n",
+    "    model_kwargs={\"torch_dtype\": torch.bfloat16},\n",
+    "    device_map=\"auto\",\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_report_files():\n",
+    "    \"\"\"Get all report files from the final_reports directory\"\"\"\n",
+    "    report_files = []\n",
+    "    \n",
+    "    if not reports_dir.exists():\n",
+    "        print(f\"Error: Reports directory '{reports_dir}' not found.\")\n",
+    "        return []\n",
+    "    \n",
+    "    for file in reports_dir.glob(\"report_*.txt\"):\n",
+    "        report_files.append(file)\n",
+    "    \n",
+    "    # Sort by report number to process in consistent order\n",
+    "    report_files.sort()\n",
+    "    \n",
+    "    print(f\"Found {len(report_files)} report files.\")\n",
+    "    return report_files"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def rewrite_report(report_path, original_user_input):\n",
+    "    \"\"\"Rewrite a single report using Llama 3.3\"\"\"\n",
+    "    \n",
+    "    # Extract report number and title from filename\n",
+    "    report_filename = report_path.name\n",
+    "    report_id = re.search(r'report_(\\d+)_', report_filename).group(1) if re.search(r'report_(\\d+)_', report_filename) else \"unknown\"\n",
+    "    \n",
+    "    print(f\"Processing report {report_id}: {report_filename}\")\n",
+    "    \n",
+    "    # Read the report content\n",
+    "    try:\n",
+    "        with open(report_path, \"r\", encoding=\"utf-8\") as f:\n",
+    "            report_content = f.read()\n",
+    "    except Exception as e:\n",
+    "        print(f\"Error reading report file: {str(e)}\")\n",
+    "        return None\n",
+    "    \n",
+    "    # Extract the report title from the content\n",
+    "    title_match = re.search(r'^# (.+?)\\n', report_content)\n",
+    "    report_title = title_match.group(1) if title_match else \"Unknown Report\"\n",
+    "    \n",
+    "    # Create the prompt for rewriting\n",
+    "    user_prompt = f\"\"\"\n",
+    "{original_user_input}\n",
+    "\n",
+    "Here's the report to rewrite:\n",
+    "\n",
+    "{report_content}\n",
+    "\"\"\"\n",
+    "    \n",
+    "    # Set up the conversation\n",
+    "    conversation = [\n",
+    "        {\"role\": \"system\", \"content\": SYS_PROMPT},\n",
+    "        {\"role\": \"user\", \"content\": user_prompt}\n",
+    "    ]\n",
+    "    \n",
+    "    # Generate the rewritten report\n",
+    "    output = text_pipeline(\n",
+    "        conversation,\n",
+    "        max_new_tokens=4000,\n",
+    "        temperature=0.7,\n",
+    "        do_sample=True,\n",
+    "    )\n",
+    "    \n",
+    "    # Extract the assistant's response\n",
+    "    assistant_response = output[0][\"generated_text\"][-1]\n",
+    "    rewritten_content = assistant_response[\"content\"]\n",
+    "    \n",
+    "    # Save the rewritten report\n",
+    "    rewritten_filename = f\"rewritten_{report_id}_{report_title.replace(' ', '_')[:30]}.txt\"\n",
+    "    rewritten_path = rewritten_dir / rewritten_filename\n",
+    "    \n",
+    "    with open(rewritten_path, \"w\", encoding=\"utf-8\") as f:\n",
+    "        f.write(rewritten_content)\n",
+    "    \n",
+    "    print(f\"  Saved rewritten report to: {rewritten_path}\")\n",
+    "    return rewritten_path"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def rewrite_all_reports():\n",
+    "    \"\"\"Rewrite all reports in the final_reports directory\"\"\"\n",
+    "    \n",
+    "    # Get all report files\n",
+    "    report_files = get_report_files()\n",
+    "    \n",
+    "    if not report_files:\n",
+    "        print(\"No reports to rewrite.\")\n",
+    "        return []\n",
+    "    \n",
+    "    rewritten_paths = []\n",
+    "    \n",
+    "    # Process each report file\n",
+    "    for report_path in report_files:\n",
+    "        # Rewrite the report\n",
+    "        rewritten_path = rewrite_report(report_path, ORIGINAL_USER_INPUT)\n",
+    "        \n",
+    "        if rewritten_path:\n",
+    "            rewritten_paths.append(rewritten_path)\n",
+    "        \n",
+    "        # IMPORTANT: Force garbage collection to ensure no history is kept between reports\n",
+    "        # This simulates removing the old report from input history\n",
+    "        import gc\n",
+    "        gc.collect()\n",
+    "        \n",
+    "        print(\"\\n--------\\n\")\n",
+    "    \n",
+    "    return rewritten_paths"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "print(\"Starting report rewriting process...\")\n",
+    "\n",
+    "# Validate user input\n",
+    "if ORIGINAL_USER_INPUT == \"[Your input here - describe what you want Llama to do with these reports]\":\n",
+    "    print(\"Please edit the ORIGINAL_USER_INPUT cell to specify your request before running this cell.\")\n",
+    "else:\n",
+    "    # Process all reports\n",
+    "    rewritten_paths = rewrite_all_reports()\n",
+    "    \n",
+    "    print(\"\\nReport rewriting complete!\")\n",
+    "    print(f\"Rewritten {len(rewritten_paths)} reports.\")\n",
+    "    \n",
+    "    # List all rewritten reports\n",
+    "    if rewritten_paths:\n",
+    "        print(\"\\nRewritten reports:\")\n",
+    "        for path in rewritten_paths:\n",
+    "            print(f\"- {path}\")"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.5"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}