Jelajahi Sumber

fixing payload to include model as param

Radu Boncea 2 tahun lalu
induk
melakukan
cfe065cecb
1 mengubah file dengan 16 tambahan dan 1 penghapusan
  1. 16 1
      hfapi/models/payload.py

+ 16 - 1
hfapi/models/payload.py

@@ -1,7 +1,22 @@
-from pydantic import BaseModel
+from typing import List
+from pydantic import BaseModel, validator
 
 
+QA_MODELS = [
+    "distilbert-base-uncased-distilled-squad",
+    "deepset/roberta-base-squad2"
+]
+
 class QAPredictionPayload(BaseModel):
 
     context: str = "42 is the answer to life, the universe and everything."
     question: str = "What is the answer to life?"
+    model: str = "distilbert-base-uncased-distilled-squad"
+    
+
+    @validator("model")
+    def model_validator(cls, v):
+        if v not in QA_MODELS:
+            raise ValueError("Model not supported")
+        return v
+