Bladeren bron

enabling padding for inference scripts

Hamid Shojanazeri 1 jaar geleden
bovenliggende
commit
874bc8281e
3 gewijzigde bestanden met toevoegingen van 12 en 5 verwijderingen
  1. 5 2
      inference/chat_completion.py
  2. 6 3
      inference/chat_utils.py
  3. 1 0
      inference/inference.py

+ 5 - 2
inference/chat_completion.py

@@ -34,6 +34,7 @@ def main(
     enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
+    max_padding_length: int=0, # specifies the max padding length to pad the context/ prompt
     **kwargs
 ):
     if prompt_file is not None:
@@ -66,9 +67,11 @@ def main(
             "pad_token": "<PAD>",
         }
     )
+    # making sure embedding is updated accordingly with pad token being added as special token
+    # Ref: https://huggingface.co/docs/transformers/main/model_doc/llama2
+    model.resize_token_embeddings(model.config.vocab_size + 1)
     
-    chats = format_tokens(dialogs, tokenizer)
-
+    chats = format_tokens(dialogs, tokenizer, max_padding_length)
     with torch.no_grad():
         for idx, chat in enumerate(chats):
             safety_checker = get_safety_checker(enable_azure_content_safety,

+ 6 - 3
inference/chat_utils.py

@@ -21,7 +21,7 @@ You are a helpful, respectful and honest assistant. Always answer as helpfully a
 
 If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
 
-def format_tokens(dialogs, tokenizer):
+def format_tokens(dialogs, tokenizer, max_pad_length):
     prompt_tokens = []
     for dialog in dialogs:
         if dialog[0]["role"] != "system":
@@ -53,7 +53,7 @@ def format_tokens(dialogs, tokenizer):
         dialog_tokens: List[int] = sum(
             [
                 tokenizer.encode(
-                    f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
+                    f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} "
                 )
                 for prompt, answer in zip(dialog[::2], dialog[1::2])
             ],
@@ -65,7 +65,10 @@ def format_tokens(dialogs, tokenizer):
         dialog_tokens += tokenizer.encode(
             f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
         )
-        prompt_tokens.append(dialog_tokens)
+        tokens = {"input_ids":dialog_tokens}
+        prompt_tokens_padded = tokenizer.pad(tokens, max_length=max_pad_length, padding="max_length", pad_to_multiple_of=None)
+        prompt_tokens.append(prompt_tokens_padded["input_ids"])
+      
     return prompt_tokens
         
 

+ 1 - 0
inference/inference.py

@@ -61,6 +61,7 @@ def main(
         }
     )
     # making sure embedding is updated accordingly with pad token being added as special token
+    # Ref: https://huggingface.co/docs/transformers/main/model_doc/llama2
     model.resize_token_embeddings(model.config.vocab_size + 1)
     safety_checker = get_safety_checker(enable_azure_content_safety,
                                         enable_sensitive_topics,