compare_llama_weights.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import gc
  2. import glob
  3. import os
  4. import sys
  5. import torch
  6. import tqdm
  7. def main() -> None:
  8. """Compare two llama checkpoint directories"""
  9. one_files = sorted(glob.glob(os.path.join(sys.argv[1], "consolidated.*.pth")))
  10. two_files = sorted(glob.glob(os.path.join(sys.argv[2], "consolidated.*.pth")))
  11. assert len(one_files) == len(
  12. two_files
  13. ), "One directory has {} files while another has {} files.".format(
  14. len(one_files), len(two_files)
  15. )
  16. deltas = []
  17. for i in tqdm.trange(len(one_files), desc="Comparing shards"):
  18. one = torch.load(one_files[i])
  19. two = torch.load(two_files[i])
  20. assert len(one) == len(
  21. two
  22. ), "shard should have the same length: {} != {}".format(len(one), len(two))
  23. for _, (v, w) in enumerate(zip(one.items(), two.items())):
  24. assert v[0] == w[0], "{} != {}".format(v[0], w[0])
  25. assert v[1].shape == w[1].shape, "tensor {} shape {} != {}".format(
  26. v[0], v[1].shape, w[1].shape
  27. )
  28. delta = (v[1] - w[1]).abs().max().item()
  29. deltas.append((i, v[0], delta))
  30. del one
  31. del two
  32. gc.collect()
  33. deltas = sorted(deltas, key=lambda x: x[-1], reverse=True)
  34. print("Top 10 largest deltas:")
  35. for i, k, v in deltas[:10]:
  36. print(f" shard {i} {k}: {v}")
  37. if __name__ == "__main__":
  38. main()