vis.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. '''
  2. Generate prompts for the LLM Needle Haystack.
  3. Source code from:
  4. https://github.com/gkamradt/LLMTest_NeedleInAHaystack/tree/main
  5. https://github.com/THUDM/LongAlign/tree/main/Needle_test
  6. '''
  7. import pandas as pd
  8. import seaborn as sns
  9. import matplotlib.pyplot as plt
  10. from matplotlib.colors import LinearSegmentedColormap
  11. import json
  12. import os
  13. import re
  14. import glob
  15. import argparse
  16. if __name__ == '__main__':
  17. # Using glob to find all json files in the directory
  18. parser = argparse.ArgumentParser()
  19. parser.add_argument("--input-path", type=str, default="")
  20. parser.add_argument("--exp-name", type=str, default="")
  21. args = parser.parse_args()
  22. json_files = glob.glob(f"{args.input_path}/*.json")
  23. if not os.path.exists('vis'):
  24. os.makedirs('vis')
  25. print(json_files)
  26. # Iterating through each file and extract the 3 columns we need
  27. for file in json_files:
  28. print(file)
  29. # List to hold the data
  30. data = []
  31. with open(file, 'r') as f:
  32. json_data = json.load(f)
  33. for k in json_data:
  34. pattern = r"_len_(\d+)_"
  35. match = re.search(pattern, k)
  36. context_length = int(match.group(1)) if match else None
  37. pattern = r"depth_(\d+)"
  38. match = re.search(pattern, k)
  39. document_depth = eval(match.group(1))/100 if match else None
  40. score = json_data[k]['score']
  41. # Appending to the list
  42. data.append({
  43. "Document Depth": document_depth,
  44. "Context Length": context_length,
  45. "Score": score
  46. })
  47. # Creating a DataFrame
  48. df = pd.DataFrame(data)
  49. pivot_table = pd.pivot_table(df, values='Score', index=['Document Depth', 'Context Length'], aggfunc='mean').reset_index() # This will aggregate
  50. pivot_table = pivot_table.pivot(index="Document Depth", columns="Context Length", values="Score") # This will turn into a proper pivot
  51. # Create a custom colormap. Go to https://coolors.co/ and pick cool colors
  52. cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#F0496E", "#EBB839", "#0CD79F"])
  53. # Create the heatmap with better aesthetics
  54. plt.figure(figsize=(17.5, 8)) # Can adjust these dimensions as needed
  55. sns.heatmap(
  56. pivot_table,
  57. # annot=True,
  58. fmt="g",
  59. cmap=cmap,
  60. cbar_kws={'label': 'Score'},
  61. vmin=1,
  62. vmax=10,
  63. )
  64. # More aesthetics
  65. plt.title(f'Pressure Testing\nFact Retrieval Across Context Lengths ("Needle In A HayStack")\n{args.exp_name}') # Adds a title
  66. plt.xlabel('Token Limit') # X-axis label
  67. plt.ylabel('Depth Percent') # Y-axis label
  68. plt.xticks(rotation=45) # Rotates the x-axis labels to prevent overlap
  69. plt.yticks(rotation=0) # Ensures the y-axis labels are horizontal
  70. plt.tight_layout() # Fits everything neatly into the figure area
  71. # Show the plot
  72. plt.savefig(f"vis/{file.split('/')[-1].replace('.json', '')}.png")