|
@@ -12,6 +12,7 @@ from transformers import LlamaForCausalLM # @manual
|
|
|
|
|
|
NUM_SHARDS = {
|
|
|
"7B": 1,
|
|
|
+ "8B": 1,
|
|
|
"13B": 2,
|
|
|
"34B": 4,
|
|
|
"30B": 4,
|
|
@@ -30,15 +31,12 @@ def write_model(model_path, model_size, output_base_path):
|
|
|
n_heads_per_shard = n_heads // num_shards
|
|
|
dim = params["dim"]
|
|
|
dims_per_head = dim // n_heads
|
|
|
- base = 10000.0
|
|
|
- inv_freq = (
|
|
|
- 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
|
|
- ).to(dtype)
|
|
|
+ llama_version = 3 if params.get("vocab_size") == 128256 else 2
|
|
|
|
|
|
if "n_kv_heads" in params:
|
|
|
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
|
|
|
- num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
|
|
|
- key_value_dim = dim // num_key_value_heads
|
|
|
+ num_local_key_value_heads = num_key_value_heads // num_shards
|
|
|
+ key_value_dim = dims_per_head * num_key_value_heads
|
|
|
else: # compatibility with other checkpoints
|
|
|
num_key_value_heads = n_heads
|
|
|
num_local_key_value_heads = n_heads_per_shard
|
|
@@ -72,7 +70,10 @@ def write_model(model_path, model_size, output_base_path):
|
|
|
for i, tensor in enumerate(tensors):
|
|
|
state_dict[i][name] = tensor.clone()
|
|
|
|
|
|
- insert_chunk("tok_embeddings.weight", loaded["model.embed_tokens.weight"], 1)
|
|
|
+ concat_dim = 0 if llama_version == 3 else 1
|
|
|
+ insert_chunk(
|
|
|
+ "tok_embeddings.weight", loaded["model.embed_tokens.weight"], concat_dim
|
|
|
+ )
|
|
|
insert("norm.weight", loaded["model.norm.weight"])
|
|
|
insert_chunk("output.weight", loaded["lm_head.weight"], 0)
|
|
|
|
|
@@ -136,7 +137,12 @@ def write_model(model_path, model_size, output_base_path):
|
|
|
f"layers.{layer_i}.ffn_norm.weight",
|
|
|
loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"],
|
|
|
)
|
|
|
- insert("rope.freqs", inv_freq)
|
|
|
+ if llama_version != 3:
|
|
|
+ base = 10000.0
|
|
|
+ inv_freq = (
|
|
|
+ 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
|
|
|
+ ).to(dtype)
|
|
|
+ insert("rope.freqs", inv_freq)
|
|
|
|
|
|
for i in tqdm(range(num_shards), desc="Saving checkpoint shards"):
|
|
|
torch.save(
|