| 12345678910111213141516171819202122232425 | import datasetsimport redef preprocess(text):    text = text.strip()    # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.    text = text.replace(" [title]", ". ")    text = re.sub("\\[.*?\\]", "", text)    text = text.replace("  ", " ")    return textdef process_docs(dataset: datasets.Dataset) -> datasets.Dataset:    def _process_doc(doc):        ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()        out_doc = {            "query": preprocess(doc["activity_label"] + ": " + ctx),            "choices": [preprocess(ending) for ending in doc["endings"]],            "gold": int(doc["label"]),        }        return out_doc    return dataset.map(_process_doc)
 |