| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849 | import gcimport globimport osimport sysimport torchimport tqdmdef main() -> None:    """Compare two llama checkpoint directories"""    one_files = sorted(glob.glob(os.path.join(sys.argv[1], "consolidated.*.pth")))    two_files = sorted(glob.glob(os.path.join(sys.argv[2], "consolidated.*.pth")))    assert len(one_files) == len(        two_files    ), "One directory has {} files while another has {} files.".format(        len(one_files), len(two_files)    )    deltas = []    for i in tqdm.trange(len(one_files), desc="Comparing shards"):        one = torch.load(one_files[i])        two = torch.load(two_files[i])        assert len(one) == len(            two        ), "shard should have the same length: {} != {}".format(len(one), len(two))        for _, (v, w) in enumerate(zip(one.items(), two.items())):            assert v[0] == w[0], "{} != {}".format(v[0], w[0])            assert v[1].shape == w[1].shape, "tensor {} shape {} != {}".format(                v[0], v[1].shape, w[1].shape            )            delta = (v[1] - w[1]).abs().max().item()            deltas.append((i, v[0], delta))        del one        del two        gc.collect()    deltas = sorted(deltas, key=lambda x: x[-1], reverse=True)    print("Top 10 largest deltas:")    for i, k, v in deltas[:10]:        print(f"  shard {i} {k}: {v}")if __name__ == "__main__":    main()
 |