| 
					
				 | 
			
			
				@@ -1,12 +1,22 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import string 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import datasets 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-def doc_to_text(doc: dict) -> str: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def doc_to_text_pretrain(doc: dict) -> str: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # Strip out the last two characters, which is a space and the answer 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     # E.g., "Answer: B" -> "Answer:" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return doc["input_final_prompts"][0][:-2] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return text 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def doc_to_text_instruct(doc: dict) -> str: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # Strip out the last two characters, which is a space and the answer 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # E.g., "Answer: B" -> "Answer:" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return doc["input_final_prompts"][0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def process_docs_pretrain(dataset: datasets.Dataset) -> datasets.Dataset: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _process_doc(doc: dict) -> dict: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # input_correct_responses is in format of: "Answer: B" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         answer = doc["input_correct_responses"][0] 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -21,11 +31,43 @@ def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         return out_doc 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     dataset = dataset.select_columns( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ["input_question", "input_correct_responses", "input_final_prompts", "is_correct", "input_question_hash", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         "input_choice_list"]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "input_question", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "input_correct_responses", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "input_final_prompts", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "is_correct", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "input_question_hash", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "input_choice_list", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    dataset = dataset.rename_column("is_correct", "previously_is_correct") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    dataset = dataset.map(_process_doc) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    return dataset.map(_process_doc) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def process_docs_instruct(dataset: datasets.Dataset) -> datasets.Dataset: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def _process_doc(doc: dict) -> dict: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        out_doc = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "problem": doc["input_question"], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "gold": doc["input_correct_responses"][0], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        return out_doc 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    dataset = dataset.select_columns( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        [ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "input_question", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "input_correct_responses", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "input_final_prompts", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "is_correct", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "input_question_hash", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "input_choice_list", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            "output_prediction_text", 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     dataset = dataset.rename_column("is_correct", "previously_is_correct") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     dataset = dataset.map(_process_doc) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return dataset.map(_process_doc) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 def doc_to_target(doc: dict) -> str: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return doc["gold"] 
			 |