|
@@ -1,7 +1,22 @@
|
|
|
-from pydantic import BaseModel
|
|
|
+from typing import List
|
|
|
+from pydantic import BaseModel, validator
|
|
|
|
|
|
|
|
|
+QA_MODELS = [
|
|
|
+ "distilbert-base-uncased-distilled-squad",
|
|
|
+ "deepset/roberta-base-squad2"
|
|
|
+]
|
|
|
+
|
|
|
class QAPredictionPayload(BaseModel):
|
|
|
|
|
|
context: str = "42 is the answer to life, the universe and everything."
|
|
|
question: str = "What is the answer to life?"
|
|
|
+ model: str = "distilbert-base-uncased-distilled-squad"
|
|
|
+
|
|
|
+
|
|
|
+ @validator("model")
|
|
|
+ def model_validator(cls, v):
|
|
|
+ if v not in QA_MODELS:
|
|
|
+ raise ValueError("Model not supported")
|
|
|
+ return v
|
|
|
+
|