| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 | """    Source Code: https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/utils.py"""import torchimport argparsefrom transformers import (    AutoTokenizer,    AutoModelForCausalLM,)import os.path as ospimport sslimport urllib.requestimport osimport jsondef load(model_name_or_path):    print(f"Loading model from {model_name_or_path} ...")    # however, tensor parallel for running falcon will occur bugs    tokenizer = AutoTokenizer.from_pretrained(        model_name_or_path,        trust_remote_code=True,    )    model = AutoModelForCausalLM.from_pretrained(        model_name_or_path,        device_map="auto",        torch_dtype=torch.float16,        trust_remote_code=True,    )    if tokenizer.pad_token_id is None:        if tokenizer.eos_token_id is not None:            tokenizer.pad_token_id = tokenizer.eos_token_id        else:            tokenizer.pad_token_id = 0    model.eval()    return model, tokenizerdef download_url(url: str, folder="folder"):    """    Downloads the content of an url to a folder. Modified from \    https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric    Args:        url (string): The url of target file.        folder (string): The target folder.    Returns:        string: File path of downloaded files.    """    file = url.rpartition("/")[2]    file = file if file[0] == "?" else file.split("?")[0]    path = osp.join(folder, file)    if osp.exists(path):        print(f"File {file} exists, use existing file.")        return path    print(f"Downloading {url}")    os.makedirs(folder, exist_ok=True)    ctx = ssl._create_unverified_context()    data = urllib.request.urlopen(url, context=ctx)    with open(path, "wb") as f:        f.write(data.read())    return pathdef load_jsonl(    file_path,):    list_data_dict = []    with open(file_path, "r") as f:        for line in f:            list_data_dict.append(json.loads(line))    return list_data_dict@torch.no_grad()def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len):    outputs = model(        input_ids=input_ids,        past_key_values=past_key_values,        use_cache=True,    )    past_key_values = outputs.past_key_values    pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)    generated_ids = [pred_token_idx.item()]    pos = 0    for _ in range(max_gen_len - 1):        outputs = model(            input_ids=pred_token_idx,            past_key_values=past_key_values,            use_cache=True,        )        past_key_values = outputs.past_key_values        pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)        generated_ids.append(pred_token_idx.item())        generated_text = (            tokenizer.decode(                generated_ids,                skip_special_tokens=True,                clean_up_tokenization_spaces=True,                spaces_between_special_tokens=False,            )            .strip()            .split(" ")        )        now = len(generated_text) - 1        if now > pos:            print(" ".join(generated_text[pos:now]), end=" ", flush=True)            pos = now        if pred_token_idx == tokenizer.eos_token_id:            break    print(" ".join(generated_text[pos:]), flush=True)    return past_key_values
 |