1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- '''
- Generate prompts for the LLM Needle Haystack.
- Source code from:
- https://github.com/gkamradt/LLMTest_NeedleInAHaystack/tree/main
- https://github.com/THUDM/LongAlign/tree/main/Needle_test
- '''
- import pandas as pd
- import seaborn as sns
- import matplotlib.pyplot as plt
- from matplotlib.colors import LinearSegmentedColormap
- import json
- import os
- import re
- import glob
- import argparse
- if __name__ == '__main__':
- # Using glob to find all json files in the directory
- parser = argparse.ArgumentParser()
- parser.add_argument("--input-path", type=str, default="")
- parser.add_argument("--exp-name", type=str, default="")
- args = parser.parse_args()
- json_files = glob.glob(f"{args.input_path}/*.json")
- if not os.path.exists('vis'):
- os.makedirs('vis')
- print(json_files)
- # Iterating through each file and extract the 3 columns we need
- for file in json_files:
- print(file)
- # List to hold the data
- data = []
- with open(file, 'r') as f:
- json_data = json.load(f)
-
- for k in json_data:
- pattern = r"_len_(\d+)_"
- match = re.search(pattern, k)
- context_length = int(match.group(1)) if match else None
- pattern = r"depth_(\d+)"
- match = re.search(pattern, k)
- document_depth = eval(match.group(1))/100 if match else None
- score = json_data[k]['score']
- # Appending to the list
- data.append({
- "Document Depth": document_depth,
- "Context Length": context_length,
- "Score": score
- })
- # Creating a DataFrame
- df = pd.DataFrame(data)
- pivot_table = pd.pivot_table(df, values='Score', index=['Document Depth', 'Context Length'], aggfunc='mean').reset_index() # This will aggregate
- pivot_table = pivot_table.pivot(index="Document Depth", columns="Context Length", values="Score") # This will turn into a proper pivot
-
- # Create a custom colormap. Go to https://coolors.co/ and pick cool colors
- cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#F0496E", "#EBB839", "#0CD79F"])
- # Create the heatmap with better aesthetics
- plt.figure(figsize=(17.5, 8)) # Can adjust these dimensions as needed
- sns.heatmap(
- pivot_table,
- # annot=True,
- fmt="g",
- cmap=cmap,
- cbar_kws={'label': 'Score'},
- vmin=1,
- vmax=10,
- )
- # More aesthetics
- plt.title(f'Pressure Testing\nFact Retrieval Across Context Lengths ("Needle In A HayStack")\n{args.exp_name}') # Adds a title
- plt.xlabel('Token Limit') # X-axis label
- plt.ylabel('Depth Percent') # Y-axis label
- plt.xticks(rotation=45) # Rotates the x-axis labels to prevent overlap
- plt.yticks(rotation=0) # Ensures the y-axis labels are horizontal
- plt.tight_layout() # Fits everything neatly into the figure area
- # Show the plot
- plt.savefig(f"vis/{file.split('/')[-1].replace('.json', '')}.png")
|