|
@@ -0,0 +1,88 @@
|
|
|
+import os
|
|
|
+from datetime import timedelta
|
|
|
+from dotenv import load_dotenv
|
|
|
+from sqlalchemy import create_engine
|
|
|
+from sqlalchemy.orm import sessionmaker, Session
|
|
|
+from fastapi import FastAPI, HTTPException, Depends
|
|
|
+from werkzeug.security import generate_password_hash, check_password_hash
|
|
|
+
|
|
|
+from models.Base import Base
|
|
|
+from models.User import User
|
|
|
+from models.Message import Message
|
|
|
+from models.Conversation import Conversation
|
|
|
+from utils.validators import UserCreate, UserLogin, Token
|
|
|
+from utils.auth import create_access_token
|
|
|
+
|
|
|
+ACCESS_TOKEN_EXPIRE_MINUTES = 30 # Token expiration time
|
|
|
+
|
|
|
+
|
|
|
+load_dotenv()
|
|
|
+
|
|
|
+app = FastAPI()
|
|
|
+engine = create_engine(os.environ["DB_URL"])
|
|
|
+SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
+
|
|
|
+Base.metadata.create_all(bind=engine)
|
|
|
+
|
|
|
+
|
|
|
+# Dependency to get the SQLAlchemy session
|
|
|
+def get_db():
|
|
|
+ db = SessionLocal()
|
|
|
+ try:
|
|
|
+ yield db
|
|
|
+ finally:
|
|
|
+ db.close()
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/register", response_model=dict)
|
|
|
+def register(user: UserCreate, db: Session = Depends(get_db)):
|
|
|
+ existing_user = db.query(User).filter(User.email == user.email).first()
|
|
|
+
|
|
|
+ if existing_user:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=400,
|
|
|
+ detail="The provided email is already taken.",
|
|
|
+ )
|
|
|
+
|
|
|
+ hashed_password = generate_password_hash(user.password)
|
|
|
+ new_user = User(
|
|
|
+ email=user.email,
|
|
|
+ password_hash=hashed_password,
|
|
|
+ )
|
|
|
+
|
|
|
+ db.add(new_user)
|
|
|
+ db.commit()
|
|
|
+ db.refresh(new_user)
|
|
|
+
|
|
|
+ return {"message": "User registered successfully."}
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/login", response_model=Token)
|
|
|
+def login(user: UserLogin, db: Session = Depends(get_db)):
|
|
|
+ # Check if the user exists
|
|
|
+ db_user = db.query(User).filter(User.email == user.email).first()
|
|
|
+ if not db_user or not check_password_hash(db_user.password_hash, user.password):
|
|
|
+ raise HTTPException(status_code=401, detail="Invalid email or password")
|
|
|
+
|
|
|
+ # Create a token
|
|
|
+ access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
+ access_token = create_access_token(
|
|
|
+ data={
|
|
|
+ "email": db_user.email,
|
|
|
+ "userId": db_user.id,
|
|
|
+ },
|
|
|
+ expires_delta=access_token_expires,
|
|
|
+ )
|
|
|
+
|
|
|
+ return {"access_token": access_token, "token_type": "bearer"}
|
|
|
+
|
|
|
+
|
|
|
+@app.get("/chats")
|
|
|
+def chats():
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ import uvicorn
|
|
|
+
|
|
|
+ uvicorn.run(app)
|