|
@@ -3,55 +3,61 @@
|
|
|
|
|
|
# from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
|
|
|
|
|
-import fire
|
|
|
import json
|
|
|
import os
|
|
|
+
|
|
|
import sys
|
|
|
|
|
|
+import fire
|
|
|
+
|
|
|
import torch
|
|
|
-from transformers import AutoTokenizer
|
|
|
+from accelerate.utils import is_xpu_available
|
|
|
|
|
|
from llama_recipes.inference.chat_utils import read_dialogs_from_file
|
|
|
from llama_recipes.inference.model_utils import load_model, load_peft_model
|
|
|
from llama_recipes.inference.safety_utils import get_safety_checker
|
|
|
-from accelerate.utils import is_xpu_available
|
|
|
+from transformers import AutoTokenizer
|
|
|
+
|
|
|
|
|
|
def main(
|
|
|
model_name,
|
|
|
- peft_model: str=None,
|
|
|
- quantization: str = None, # Options: 4bit, 8bit
|
|
|
- max_new_tokens =256, #The maximum numbers of tokens to generate
|
|
|
- min_new_tokens:int=0, #The minimum numbers of tokens to generate
|
|
|
- prompt_file: str=None,
|
|
|
- seed: int=42, #seed value for reproducibility
|
|
|
- safety_score_threshold: float=0.5,
|
|
|
- do_sample: bool=True, #Whether or not to use sampling ; use greedy decoding otherwise.
|
|
|
- use_cache: bool=True, #[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
|
|
|
- top_p: float=1.0, # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
|
|
|
- temperature: float=1.0, # [optional] The value used to modulate the next token probabilities.
|
|
|
- top_k: int=50, # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
|
|
- repetition_penalty: float=1.0, #The parameter for repetition penalty. 1.0 means no penalty.
|
|
|
- length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation.
|
|
|
- enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
|
|
|
- enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
|
|
|
- enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
|
|
|
- use_fast_kernels: bool = False, # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
|
|
|
+ peft_model: str = None,
|
|
|
+ quantization: str = None, # Options: 4bit, 8bit
|
|
|
+ max_new_tokens=256, # The maximum numbers of tokens to generate
|
|
|
+ min_new_tokens: int = 0, # The minimum numbers of tokens to generate
|
|
|
+ prompt_file: str = None,
|
|
|
+ seed: int = 42, # seed value for reproducibility
|
|
|
+ safety_score_threshold: float = 0.5,
|
|
|
+ do_sample: bool = True, # Whether or not to use sampling ; use greedy decoding otherwise.
|
|
|
+ use_cache: bool = True, # [optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
|
|
|
+ top_p: float = 1.0, # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
|
|
|
+ temperature: float = 1.0, # [optional] The value used to modulate the next token probabilities.
|
|
|
+ top_k: int = 50, # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
|
|
|
+ repetition_penalty: float = 1.0, # The parameter for repetition penalty. 1.0 means no penalty.
|
|
|
+ length_penalty: int = 1, # [optional] Exponential penalty to the length that is used with beam-based generation.
|
|
|
+ enable_azure_content_safety: bool = False, # Enable safety check with Azure content safety api
|
|
|
+ enable_sensitive_topics: bool = False, # Enable check for sensitive topics using AuditNLG APIs
|
|
|
+ enable_saleforce_content_safety: bool = True, # Enable safety check woth Saleforce safety flan t5
|
|
|
+ use_fast_kernels: bool = False, # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
|
|
|
enable_llamaguard_content_safety: bool = False,
|
|
|
- **kwargs
|
|
|
+ **kwargs,
|
|
|
):
|
|
|
+
|
|
|
if prompt_file is not None:
|
|
|
assert os.path.exists(
|
|
|
prompt_file
|
|
|
), f"Provided Prompt file does not exist {prompt_file}"
|
|
|
|
|
|
- dialogs= read_dialogs_from_file(prompt_file)
|
|
|
+ dialogs = read_dialogs_from_file(prompt_file)
|
|
|
|
|
|
elif not sys.stdin.isatty():
|
|
|
dialogs = "\n".join(sys.stdin.readlines())
|
|
|
try:
|
|
|
dialogs = json.loads(dialogs)
|
|
|
except:
|
|
|
- print("Could not parse json from stdin. Please provide a json file with the user prompts. Exiting.")
|
|
|
+ print(
|
|
|
+ "Could not parse json from stdin. Please provide a json file with the user prompts. Exiting."
|
|
|
+ )
|
|
|
sys.exit(1)
|
|
|
else:
|
|
|
print("No user prompt provided. Exiting.")
|
|
@@ -59,7 +65,7 @@ def main(
|
|
|
|
|
|
print(f"User dialogs:\n{dialogs}")
|
|
|
print("\n==================================\n")
|
|
|
-
|
|
|
+
|
|
|
# Set the seeds for reproducibility
|
|
|
if is_xpu_available():
|
|
|
torch.xpu.manual_seed(seed)
|
|
@@ -77,13 +83,16 @@ def main(
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for idx, chat in enumerate(chats):
|
|
|
- safety_checker = get_safety_checker(enable_azure_content_safety,
|
|
|
- enable_sensitive_topics,
|
|
|
- enable_saleforce_content_safety,
|
|
|
- enable_llamaguard_content_safety,
|
|
|
- )
|
|
|
+ safety_checker = get_safety_checker(
|
|
|
+ enable_azure_content_safety,
|
|
|
+ enable_sensitive_topics,
|
|
|
+ enable_saleforce_content_safety,
|
|
|
+ enable_llamaguard_content_safety,
|
|
|
+ )
|
|
|
# Safety check of the user prompt
|
|
|
- safety_results = [check(dialogs[idx][0]["content"]) for check in safety_checker]
|
|
|
+ safety_results = [
|
|
|
+ check(dialogs[idx][0]["content"]) for check in safety_checker
|
|
|
+ ]
|
|
|
are_safe = all([r[1] for r in safety_results])
|
|
|
if are_safe:
|
|
|
print(f"User prompt deemed safe.")
|
|
@@ -97,13 +106,15 @@ def main(
|
|
|
print(report)
|
|
|
print("Skipping the inferece as the prompt is not safe.")
|
|
|
sys.exit(1) # Exit the program with an error status
|
|
|
- tokens= torch.tensor(chat).long()
|
|
|
- tokens= tokens.unsqueeze(0)
|
|
|
+ tokens = torch.tensor(chat).long()
|
|
|
+ tokens = tokens.unsqueeze(0)
|
|
|
attention_mask = torch.ones_like(tokens)
|
|
|
if is_xpu_available():
|
|
|
- tokens= tokens.to("xpu:0")
|
|
|
+ tokens = tokens.to("xpu")
|
|
|
+ attention_mask = attention_mask.to("xpu")
|
|
|
else:
|
|
|
- tokens= tokens.to("cuda:0")
|
|
|
+ tokens = tokens.to("cuda")
|
|
|
+ attention_mask = attention_mask.to("cuda")
|
|
|
outputs = model.generate(
|
|
|
input_ids=tokens,
|
|
|
attention_mask=attention_mask,
|
|
@@ -115,7 +126,7 @@ def main(
|
|
|
top_k=top_k,
|
|
|
repetition_penalty=repetition_penalty,
|
|
|
length_penalty=length_penalty,
|
|
|
- **kwargs
|
|
|
+ **kwargs,
|
|
|
)
|
|
|
|
|
|
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
@@ -136,6 +147,5 @@ def main(
|
|
|
print(report)
|
|
|
|
|
|
|
|
|
-
|
|
|
if __name__ == "__main__":
|
|
|
fire.Fire(main)
|