convert_hf_weights_to_llama.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import json
  4. import os
  5. from typing import List, Union
  6. import fire
  7. import torch
  8. from tqdm import tqdm
  9. from transformers import LlamaForCausalLM # @manual
  10. NUM_SHARDS = {
  11. "7B": 1,
  12. "8B": 1,
  13. "13B": 2,
  14. "34B": 4,
  15. "30B": 4,
  16. "65B": 8,
  17. "70B": 8,
  18. }
  19. def write_model(model_path, model_size, output_base_path):
  20. dtype = torch.bfloat16
  21. params = json.load(open(os.path.join(output_base_path, "params.json"), "r"))
  22. num_shards = NUM_SHARDS[model_size]
  23. n_layers = params["n_layers"]
  24. n_heads = params["n_heads"]
  25. n_heads_per_shard = n_heads // num_shards
  26. dim = params["dim"]
  27. dims_per_head = dim // n_heads
  28. llama_version = 3 if params.get("vocab_size") == 128256 else 2
  29. if "n_kv_heads" in params:
  30. num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
  31. num_local_key_value_heads = num_key_value_heads // num_shards
  32. key_value_dim = dims_per_head * num_key_value_heads
  33. else: # compatibility with other checkpoints
  34. num_key_value_heads = n_heads
  35. num_local_key_value_heads = n_heads_per_shard
  36. key_value_dim = dim
  37. model = LlamaForCausalLM.from_pretrained(
  38. model_path,
  39. torch_dtype=dtype,
  40. low_cpu_mem_usage=True,
  41. )
  42. loaded = model.state_dict()
  43. # permute for sliced rotary
  44. def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
  45. return (
  46. w.view(n_heads, 2, dim1 // n_heads // 2, dim2)
  47. .transpose(1, 2)
  48. .reshape(dim1, dim2)
  49. )
  50. state_dict = [{} for _ in range(num_shards)]
  51. def insert(name: str, tensor: Union[List, torch.Tensor]):
  52. for i in range(num_shards):
  53. state_dict[i][name] = (
  54. tensor[i].clone() if isinstance(tensor, list) else tensor
  55. )
  56. def insert_chunk(name: str, tensor: torch.Tensor, dim: int):
  57. tensors = tensor.chunk(num_shards, dim=dim)
  58. for i, tensor in enumerate(tensors):
  59. state_dict[i][name] = tensor.clone()
  60. concat_dim = 0 if llama_version == 3 else 1
  61. insert_chunk(
  62. "tok_embeddings.weight", loaded["model.embed_tokens.weight"], concat_dim
  63. )
  64. insert("norm.weight", loaded["model.norm.weight"])
  65. insert_chunk("output.weight", loaded["lm_head.weight"], 0)
  66. for layer_i in tqdm(range(n_layers), desc="Converting layers"):
  67. ts = (
  68. permute(loaded[f"model.layers.{layer_i}.self_attn.q_proj.weight"])
  69. .view(n_heads_per_shard * num_shards, dims_per_head, dim)
  70. .chunk(num_shards, dim=0)
  71. )
  72. insert(f"layers.{layer_i}.attention.wq.weight", [t.view(-1, dim) for t in ts])
  73. ts = (
  74. permute(
  75. loaded[f"model.layers.{layer_i}.self_attn.k_proj.weight"],
  76. num_key_value_heads,
  77. key_value_dim,
  78. dim,
  79. )
  80. .view(num_local_key_value_heads * num_shards, dims_per_head, dim)
  81. .chunk(num_shards, dim=0)
  82. )
  83. insert(f"layers.{layer_i}.attention.wk.weight", [t.view(-1, dim) for t in ts])
  84. ts = (
  85. loaded[f"model.layers.{layer_i}.self_attn.v_proj.weight"]
  86. .view(num_local_key_value_heads * num_shards, dims_per_head, dim)
  87. .chunk(num_shards, dim=0)
  88. )
  89. insert(f"layers.{layer_i}.attention.wv.weight", [t.view(-1, dim) for t in ts])
  90. insert_chunk(
  91. f"layers.{layer_i}.attention.wo.weight",
  92. loaded[f"model.layers.{layer_i}.self_attn.o_proj.weight"],
  93. 1,
  94. )
  95. insert_chunk(
  96. f"layers.{layer_i}.feed_forward.w1.weight",
  97. loaded[f"model.layers.{layer_i}.mlp.gate_proj.weight"],
  98. 0,
  99. )
  100. insert_chunk(
  101. f"layers.{layer_i}.feed_forward.w2.weight",
  102. loaded[f"model.layers.{layer_i}.mlp.down_proj.weight"],
  103. 1,
  104. )
  105. insert_chunk(
  106. f"layers.{layer_i}.feed_forward.w3.weight",
  107. loaded[f"model.layers.{layer_i}.mlp.up_proj.weight"],
  108. 0,
  109. )
  110. insert(
  111. f"layers.{layer_i}.attention_norm.weight",
  112. loaded[f"model.layers.{layer_i}.input_layernorm.weight"],
  113. )
  114. insert(
  115. f"layers.{layer_i}.ffn_norm.weight",
  116. loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"],
  117. )
  118. if llama_version != 3:
  119. base = 10000.0
  120. inv_freq = (
  121. 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
  122. ).to(dtype)
  123. insert("rope.freqs", inv_freq)
  124. for i in tqdm(range(num_shards), desc="Saving checkpoint shards"):
  125. torch.save(
  126. state_dict[i], os.path.join(output_base_path, f"consolidated.{i:02d}.pth")
  127. )
  128. def main(
  129. model_path: str,
  130. model_size: str,
  131. output_dir: str,
  132. ):
  133. """Convert llama weights from huggingface format to consolidated format.
  134. params:
  135. model_path: model name or path to the model directory.
  136. model_size: Llama model size, one of 7B, 13B, 34B, 30B, 65B, 70B.
  137. output_dir: directory to save Llama weights, should contains params.json.
  138. """
  139. assert model_size in NUM_SHARDS, f"Unknown model size {model_size}"
  140. params_path = os.path.join(output_dir, "params.json")
  141. assert os.path.isfile(params_path), f"{params_path} does not exist"
  142. write_model(model_path, model_size, output_dir)
  143. if __name__ == "__main__":
  144. fire.Fire(main)