|
@@ -1,12 +1,22 @@
|
|
|
import string
|
|
|
+
|
|
|
import datasets
|
|
|
|
|
|
-def doc_to_text(doc: dict) -> str:
|
|
|
+
|
|
|
+def doc_to_text_pretrain(doc: dict) -> str:
|
|
|
# Strip out the last two characters, which is a space and the answer
|
|
|
# E.g., "Answer: B" -> "Answer:"
|
|
|
return doc["input_final_prompts"][0][:-2]
|
|
|
+ return text
|
|
|
+
|
|
|
+
|
|
|
+def doc_to_text_instruct(doc: dict) -> str:
|
|
|
+ # Strip out the last two characters, which is a space and the answer
|
|
|
+ # E.g., "Answer: B" -> "Answer:"
|
|
|
+ return doc["input_final_prompts"][0]
|
|
|
+
|
|
|
|
|
|
-def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
|
|
|
+def process_docs_pretrain(dataset: datasets.Dataset) -> datasets.Dataset:
|
|
|
def _process_doc(doc: dict) -> dict:
|
|
|
# input_correct_responses is in format of: "Answer: B"
|
|
|
answer = doc["input_correct_responses"][0]
|
|
@@ -21,11 +31,43 @@ def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
|
|
|
return out_doc
|
|
|
|
|
|
dataset = dataset.select_columns(
|
|
|
- ["input_question", "input_correct_responses", "input_final_prompts", "is_correct", "input_question_hash",
|
|
|
- "input_choice_list"])
|
|
|
+ [
|
|
|
+ "input_question",
|
|
|
+ "input_correct_responses",
|
|
|
+ "input_final_prompts",
|
|
|
+ "is_correct",
|
|
|
+ "input_question_hash",
|
|
|
+ "input_choice_list",
|
|
|
+ ]
|
|
|
+ )
|
|
|
+ dataset = dataset.rename_column("is_correct", "previously_is_correct")
|
|
|
+ dataset = dataset.map(_process_doc)
|
|
|
+ return dataset.map(_process_doc)
|
|
|
+
|
|
|
+
|
|
|
+def process_docs_instruct(dataset: datasets.Dataset) -> datasets.Dataset:
|
|
|
+ def _process_doc(doc: dict) -> dict:
|
|
|
+ out_doc = {
|
|
|
+ "problem": doc["input_question"],
|
|
|
+ "gold": doc["input_correct_responses"][0],
|
|
|
+ }
|
|
|
+ return out_doc
|
|
|
+
|
|
|
+ dataset = dataset.select_columns(
|
|
|
+ [
|
|
|
+ "input_question",
|
|
|
+ "input_correct_responses",
|
|
|
+ "input_final_prompts",
|
|
|
+ "is_correct",
|
|
|
+ "input_question_hash",
|
|
|
+ "input_choice_list",
|
|
|
+ "output_prediction_text",
|
|
|
+ ]
|
|
|
+ )
|
|
|
dataset = dataset.rename_column("is_correct", "previously_is_correct")
|
|
|
dataset = dataset.map(_process_doc)
|
|
|
return dataset.map(_process_doc)
|
|
|
|
|
|
+
|
|
|
def doc_to_target(doc: dict) -> str:
|
|
|
return doc["gold"]
|