""" Source Code: https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/utils.py """ import torch import argparse from transformers import ( AutoTokenizer, AutoModelForCausalLM, ) import os.path as osp import ssl import urllib.request import os import json def load(model_name_or_path): print(f"Loading model from {model_name_or_path} ...") # however, tensor parallel for running falcon will occur bugs tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, trust_remote_code=True, ) model = AutoModelForCausalLM.from_pretrained( model_name_or_path, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True, ) if tokenizer.pad_token_id is None: if tokenizer.eos_token_id is not None: tokenizer.pad_token_id = tokenizer.eos_token_id else: tokenizer.pad_token_id = 0 model.eval() return model, tokenizer def download_url(url: str, folder="folder"): """ Downloads the content of an url to a folder. Modified from \ https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric Args: url (string): The url of target file. folder (string): The target folder. Returns: string: File path of downloaded files. """ file = url.rpartition("/")[2] file = file if file[0] == "?" else file.split("?")[0] path = osp.join(folder, file) if osp.exists(path): print(f"File {file} exists, use existing file.") return path print(f"Downloading {url}") os.makedirs(folder, exist_ok=True) ctx = ssl._create_unverified_context() data = urllib.request.urlopen(url, context=ctx) with open(path, "wb") as f: f.write(data.read()) return path def load_jsonl( file_path, ): list_data_dict = [] with open(file_path, "r") as f: for line in f: list_data_dict.append(json.loads(line)) return list_data_dict @torch.no_grad() def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len): outputs = model( input_ids=input_ids, past_key_values=past_key_values, use_cache=True, ) past_key_values = outputs.past_key_values pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) generated_ids = [pred_token_idx.item()] pos = 0 for _ in range(max_gen_len - 1): outputs = model( input_ids=pred_token_idx, past_key_values=past_key_values, use_cache=True, ) past_key_values = outputs.past_key_values pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) generated_ids.append(pred_token_idx.item()) generated_text = ( tokenizer.decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True, spaces_between_special_tokens=False, ) .strip() .split(" ") ) now = len(generated_text) - 1 if now > pos: print(" ".join(generated_text[pos:now]), end=" ", flush=True) pos = now if pred_token_idx == tokenizer.eos_token_id: break print(" ".join(generated_text[pos:]), flush=True) return past_key_values