1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- import os
- from datetime import timedelta
- from dotenv import load_dotenv
- from sqlalchemy import create_engine
- from sqlalchemy.orm import sessionmaker, Session
- from fastapi import FastAPI, HTTPException, Depends
- from werkzeug.security import generate_password_hash, check_password_hash
- from models.Base import Base
- from models.User import User
- from models.Message import Message
- from models.Conversation import Conversation
- from utils.validators import UserCreate, UserLogin, Token
- from utils.auth import create_access_token
- ACCESS_TOKEN_EXPIRE_MINUTES = 30 # Token expiration time
- load_dotenv()
- app = FastAPI()
- engine = create_engine(os.environ["DB_URL"])
- SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
- Base.metadata.create_all(bind=engine)
- # Dependency to get the SQLAlchemy session
- def get_db():
- db = SessionLocal()
- try:
- yield db
- finally:
- db.close()
- @app.post("/register", response_model=dict)
- def register(user: UserCreate, db: Session = Depends(get_db)):
- existing_user = db.query(User).filter(User.email == user.email).first()
- if existing_user:
- raise HTTPException(
- status_code=400,
- detail="The provided email is already taken.",
- )
- hashed_password = generate_password_hash(user.password)
- new_user = User(
- email=user.email,
- password_hash=hashed_password,
- )
- db.add(new_user)
- db.commit()
- db.refresh(new_user)
- return {"message": "User registered successfully."}
- @app.post("/login", response_model=Token)
- def login(user: UserLogin, db: Session = Depends(get_db)):
- # Check if the user exists
- db_user = db.query(User).filter(User.email == user.email).first()
- if not db_user or not check_password_hash(db_user.password_hash, user.password):
- raise HTTPException(status_code=401, detail="Invalid email or password")
- # Create a token
- access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
- access_token = create_access_token(
- data={
- "email": db_user.email,
- "userId": db_user.id,
- },
- expires_delta=access_token_expires,
- )
- return {"access_token": access_token, "token_type": "bearer"}
- @app.get("/chats")
- def chats():
- pass
- if __name__ == "__main__":
- import uvicorn
- uvicorn.run(app)
|