plot.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import json
  2. import os
  3. import re
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. from collections import defaultdict
  7. def extract_info_from_filename(filename):
  8. pattern = r'(?P<backend>[^-]+)-(?P<qps>\d+\.\d+)qps-(?P<model>.+)-(?P<date>\d{8}-\d{6})\.json'
  9. match = re.match(pattern, filename)
  10. if match:
  11. return {
  12. 'qps': float(match.group('qps')),
  13. 'model': match.group('model')
  14. }
  15. return None
  16. def read_json_files(directory):
  17. data_tpot = defaultdict(list)
  18. data_ttft = defaultdict(list)
  19. for filename in os.listdir(directory):
  20. if filename.endswith('.json'):
  21. filepath = os.path.join(directory, filename)
  22. file_info = extract_info_from_filename(filename)
  23. if file_info:
  24. with open(filepath, 'r') as file:
  25. json_data = json.load(file)
  26. median_tpot = json_data.get('median_tpot_ms')
  27. std_tpot = json_data.get('std_tpot_ms')
  28. median_ttft = json_data.get('median_ttft_ms')
  29. std_ttft = json_data.get('std_ttft_ms')
  30. if all(v is not None for v in [median_tpot, std_tpot, median_ttft, std_ttft]):
  31. data_tpot[file_info['model']].append((file_info['qps'], median_tpot, std_tpot))
  32. data_ttft[file_info['model']].append((file_info['qps'], median_ttft, std_ttft))
  33. return {
  34. 'tpot': {model: sorted(points) for model, points in data_tpot.items()},
  35. 'ttft': {model: sorted(points) for model, points in data_ttft.items()}
  36. }
  37. def create_chart(data, metric, filename):
  38. plt.figure(figsize=(12, 6))
  39. colors = plt.cm.rainbow(np.linspace(0, 1, len(data)))
  40. for (model, points), color in zip(data.items(), colors):
  41. qps_values, median_values, std_values = zip(*points)
  42. plt.errorbar(qps_values, median_values, yerr=std_values, fmt='o-', capsize=5, capthick=2, label=model, color=color)
  43. plt.fill_between(qps_values,
  44. np.array(median_values) - np.array(std_values),
  45. np.array(median_values) + np.array(std_values),
  46. alpha=0.2, color=color)
  47. plt.xlabel('QPS (Queries Per Second)')
  48. plt.ylabel(f'Median {metric.upper()} (ms)')
  49. plt.title(f'Median {metric.upper()} vs QPS with Standard Deviation')
  50. plt.grid(True)
  51. plt.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')
  52. plt.tight_layout()
  53. plt.savefig(filename, dpi=300, bbox_inches='tight')
  54. plt.close()
  55. def main():
  56. directory = './'
  57. data = read_json_files(directory)
  58. if data['tpot'] and data['ttft']:
  59. create_chart(data['tpot'], 'tpot', 'tpot_vs_qps_chart.png')
  60. create_chart(data['ttft'], 'ttft', 'ttft_vs_qps_chart.png')
  61. print("Charts have been saved as 'tpot_vs_qps_chart.png' and 'ttft_vs_qps_chart.png'")
  62. else:
  63. print("No valid data found in the specified directory.")
  64. if __name__ == "__main__":
  65. main()