convert_hf_weights_to_llama.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. import json
  2. import os
  3. from typing import List, Union
  4. import fire
  5. import torch
  6. from tqdm import tqdm
  7. from transformers import LlamaForCausalLM # @manual
  8. NUM_SHARDS = {
  9. "7B": 1,
  10. "13B": 2,
  11. "34B": 4,
  12. "30B": 4,
  13. "65B": 8,
  14. "70B": 8,
  15. }
  16. def write_model(model_path, model_size, output_base_path):
  17. dtype = torch.bfloat16
  18. params = json.load(open(os.path.join(output_base_path, "params.json"), "r"))
  19. num_shards = NUM_SHARDS[model_size]
  20. n_layers = params["n_layers"]
  21. n_heads = params["n_heads"]
  22. n_heads_per_shard = n_heads // num_shards
  23. dim = params["dim"]
  24. dims_per_head = dim // n_heads
  25. base = 10000.0
  26. inv_freq = (
  27. 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
  28. ).to(dtype)
  29. if "n_kv_heads" in params:
  30. num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
  31. num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
  32. key_value_dim = dim // num_key_value_heads
  33. else: # compatibility with other checkpoints
  34. num_key_value_heads = n_heads
  35. num_local_key_value_heads = n_heads_per_shard
  36. key_value_dim = dim
  37. model = LlamaForCausalLM.from_pretrained(
  38. model_path,
  39. torch_dtype=dtype,
  40. low_cpu_mem_usage=True,
  41. )
  42. loaded = model.state_dict()
  43. # permute for sliced rotary
  44. def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
  45. return (
  46. w.view(n_heads, 2, dim1 // n_heads // 2, dim2)
  47. .transpose(1, 2)
  48. .reshape(dim1, dim2)
  49. )
  50. state_dict = [{} for _ in range(num_shards)]
  51. def insert(name: str, tensor: Union[List, torch.Tensor]):
  52. for i in range(num_shards):
  53. state_dict[i][name] = (
  54. tensor[i].clone() if isinstance(tensor, list) else tensor
  55. )
  56. def insert_chunk(name: str, tensor: torch.Tensor, dim: int):
  57. tensors = tensor.chunk(num_shards, dim=dim)
  58. for i, tensor in enumerate(tensors):
  59. state_dict[i][name] = tensor.clone()
  60. insert_chunk("tok_embeddings.weight", loaded["model.embed_tokens.weight"], 1)
  61. insert("norm.weight", loaded["model.norm.weight"])
  62. insert_chunk("output.weight", loaded["lm_head.weight"], 0)
  63. for layer_i in tqdm(range(n_layers), desc="Converting layers"):
  64. ts = (
  65. permute(loaded[f"model.layers.{layer_i}.self_attn.q_proj.weight"])
  66. .view(n_heads_per_shard * num_shards, dims_per_head, dim)
  67. .chunk(num_shards, dim=0)
  68. )
  69. insert(f"layers.{layer_i}.attention.wq.weight", [t.view(-1, dim) for t in ts])
  70. ts = (
  71. permute(
  72. loaded[f"model.layers.{layer_i}.self_attn.k_proj.weight"],
  73. num_key_value_heads,
  74. key_value_dim,
  75. dim,
  76. )
  77. .view(num_local_key_value_heads * num_shards, dims_per_head, dim)
  78. .chunk(num_shards, dim=0)
  79. )
  80. insert(f"layers.{layer_i}.attention.wk.weight", [t.view(-1, dim) for t in ts])
  81. ts = (
  82. loaded[f"model.layers.{layer_i}.self_attn.v_proj.weight"]
  83. .view(num_local_key_value_heads * num_shards, dims_per_head, dim)
  84. .chunk(num_shards, dim=0)
  85. )
  86. insert(f"layers.{layer_i}.attention.wv.weight", [t.view(-1, dim) for t in ts])
  87. insert_chunk(
  88. f"layers.{layer_i}.attention.wo.weight",
  89. loaded[f"model.layers.{layer_i}.self_attn.o_proj.weight"],
  90. 1,
  91. )
  92. insert_chunk(
  93. f"layers.{layer_i}.feed_forward.w1.weight",
  94. loaded[f"model.layers.{layer_i}.mlp.gate_proj.weight"],
  95. 0,
  96. )
  97. insert_chunk(
  98. f"layers.{layer_i}.feed_forward.w2.weight",
  99. loaded[f"model.layers.{layer_i}.mlp.down_proj.weight"],
  100. 1,
  101. )
  102. insert_chunk(
  103. f"layers.{layer_i}.feed_forward.w3.weight",
  104. loaded[f"model.layers.{layer_i}.mlp.up_proj.weight"],
  105. 0,
  106. )
  107. insert(
  108. f"layers.{layer_i}.attention_norm.weight",
  109. loaded[f"model.layers.{layer_i}.input_layernorm.weight"],
  110. )
  111. insert(
  112. f"layers.{layer_i}.ffn_norm.weight",
  113. loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"],
  114. )
  115. insert("rope.freqs", inv_freq)
  116. for i in tqdm(range(num_shards), desc="Saving checkpoint shards"):
  117. torch.save(
  118. state_dict[i], os.path.join(output_base_path, f"consolidated.{i:02d}.pth")
  119. )
  120. def main(
  121. model_path: str,
  122. model_size: str,
  123. output_dir: str,
  124. ):
  125. """Convert llama weights from huggingface format to consolidated format.
  126. params:
  127. model_path: model name or path to the model directory.
  128. model_size: Llama model size, one of 7B, 13B, 34B, 30B, 65B, 70B.
  129. output_dir: directory to save Llama weights, should contains params.json.
  130. """
  131. assert model_size in NUM_SHARDS, f"Unknown model size {model_size}"
  132. params_path = os.path.join(output_dir, "params.json")
  133. assert os.path.isfile(params_path), f"{params_path} does not exist"
  134. write_model(model_path, model_size, output_dir)
  135. if __name__ == "__main__":
  136. fire.Fire(main)