ソースを参照

updating copyrights

Hamid Shojanazeri 1 年間 前
コミット
aefc045e42

+ 39 - 0
examples/llama_dataset.py

@@ -0,0 +1,39 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+# For dataset details visit: https://huggingface.co/datasets/samsum
+
+import copy
+import datasets
+from datasets import Dataset, load_dataset
+import itertools
+
+
+B_INST, E_INST = "[INST]", "[/INST]"
+
+def tokenize_dialog(q_a_pair, tokenizer):
+    prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(question).strip()} {E_INST}", add_special_tokens=False) for question in q_a_pair["question"]]
+    answer_tokens = [tokenizer.encode(f"{answer.strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in q_a_pair["answer"]]
+    dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
+    dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
+    #Add labels, convert prompt token to -100 in order to ignore in loss function
+    labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
+
+    combined_tokens = {
+        "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
+        "labels": list(itertools.chain(*(t for t in labels_tokens))),
+    }
+
+    return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
+
+
+def get_custom_dataset(dataset_config, tokenizer, split):
+    dataset = load_dataset('json', data_files=dataset_config.data_path)
+    dataset = dataset.map(lambda sample: {
+        "question": sample["question"],
+        "answer": sample["answer"],
+        },
+        batched=True,
+    )
+    dataset = dataset.map(lambda x: tokenize_dialog(x, tokenizer))
+    return dataset["train"]

+ 3 - 0
tutorials/chatbot/data_pipelines/config.py

@@ -1,3 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
 import yaml
 import os
 

+ 3 - 0
tutorials/chatbot/data_pipelines/generate_question_answers.py

@@ -1,3 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
 import argparse
 import asyncio
 import json

+ 3 - 0
tutorials/chatbot/data_pipelines/generator_utils.py

@@ -1,3 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
 import os
 import openai
 import asyncio

+ 3 - 0
tutorials/chatbot/data_pipelines/token_processor.py

@@ -1,3 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
 import tiktoken
 
 # Assuming result_average_token is a constant, use UPPER_CASE for its name to follow Python conventions