123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 |
- """
- Source Code: https://github.com/mit-han-lab/streaming-llm/blob/main/streaming_llm/utils.py
- """
- import torch
- import argparse
- from transformers import (
- AutoTokenizer,
- AutoModelForCausalLM,
- )
- import os.path as osp
- import ssl
- import urllib.request
- import os
- import json
- def load(model_name_or_path):
- print(f"Loading model from {model_name_or_path} ...")
- # however, tensor parallel for running falcon will occur bugs
- tokenizer = AutoTokenizer.from_pretrained(
- model_name_or_path,
- trust_remote_code=True,
- )
- model = AutoModelForCausalLM.from_pretrained(
- model_name_or_path,
- device_map="auto",
- torch_dtype=torch.float16,
- trust_remote_code=True,
- )
- if tokenizer.pad_token_id is None:
- if tokenizer.eos_token_id is not None:
- tokenizer.pad_token_id = tokenizer.eos_token_id
- else:
- tokenizer.pad_token_id = 0
- model.eval()
- return model, tokenizer
- def download_url(url: str, folder="folder"):
- """
- Downloads the content of an url to a folder. Modified from \
- https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric
- Args:
- url (string): The url of target file.
- folder (string): The target folder.
- Returns:
- string: File path of downloaded files.
- """
- file = url.rpartition("/")[2]
- file = file if file[0] == "?" else file.split("?")[0]
- path = osp.join(folder, file)
- if osp.exists(path):
- print(f"File {file} exists, use existing file.")
- return path
- print(f"Downloading {url}")
- os.makedirs(folder, exist_ok=True)
- ctx = ssl._create_unverified_context()
- data = urllib.request.urlopen(url, context=ctx)
- with open(path, "wb") as f:
- f.write(data.read())
- return path
- def load_jsonl(
- file_path,
- ):
- list_data_dict = []
- with open(file_path, "r") as f:
- for line in f:
- list_data_dict.append(json.loads(line))
- return list_data_dict
- @torch.no_grad()
- def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len):
- outputs = model(
- input_ids=input_ids,
- past_key_values=past_key_values,
- use_cache=True,
- )
- past_key_values = outputs.past_key_values
- pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
- generated_ids = [pred_token_idx.item()]
- pos = 0
- for _ in range(max_gen_len - 1):
- outputs = model(
- input_ids=pred_token_idx,
- past_key_values=past_key_values,
- use_cache=True,
- )
- past_key_values = outputs.past_key_values
- pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
- generated_ids.append(pred_token_idx.item())
- generated_text = (
- tokenizer.decode(
- generated_ids,
- skip_special_tokens=True,
- clean_up_tokenization_spaces=True,
- spaces_between_special_tokens=False,
- )
- .strip()
- .split(" ")
- )
- now = len(generated_text) - 1
- if now > pos:
- print(" ".join(generated_text[pos:now]), end=" ", flush=True)
- pos = now
- if pred_token_idx == tokenizer.eos_token_id:
- break
- print(" ".join(generated_text[pos:]), flush=True)
- return past_key_values
|