diff --git a/backend/app/core/db.py b/backend/app/core/db.py index ba991fb..b4be66d 100644 --- a/backend/app/core/db.py +++ b/backend/app/core/db.py @@ -1,8 +1,7 @@ from sqlmodel import Session, create_engine, select -from app import crud from app.core.config import settings -from app.models import User, UserCreate +from app.models.user import User, UserCreate engine = create_engine(str(settings.SQLALCHEMY_DATABASE_URI)) @@ -28,6 +27,5 @@ def init_db(session: Session) -> None: user_in = UserCreate( email=settings.FIRST_SUPERUSER, password=settings.FIRST_SUPERUSER_PASSWORD, - is_superuser=True, ) - user = crud.create_user(session=session, user_create=user_in) + user = User.create_user(session=session, user_create=user_in) diff --git a/backend/app/crud.py b/backend/app/crud.py deleted file mode 100644 index 905bf48..0000000 --- a/backend/app/crud.py +++ /dev/null @@ -1,54 +0,0 @@ -import uuid -from typing import Any - -from sqlmodel import Session, select - -from app.core.security import get_password_hash, verify_password -from app.models import Item, ItemCreate, User, UserCreate, UserUpdate - - -def create_user(*, session: Session, user_create: UserCreate) -> User: - db_obj = User.model_validate( - user_create, update={"hashed_password": get_password_hash(user_create.password)} - ) - session.add(db_obj) - session.commit() - session.refresh(db_obj) - return db_obj - - -def update_user(*, session: Session, db_user: User, user_in: UserUpdate) -> Any: - user_data = user_in.model_dump(exclude_unset=True) - extra_data = {} - if "password" in user_data: - password = user_data["password"] - hashed_password = get_password_hash(password) - extra_data["hashed_password"] = hashed_password - db_user.sqlmodel_update(user_data, update=extra_data) - session.add(db_user) - session.commit() - session.refresh(db_user) - return db_user - - -def get_user_by_email(*, session: Session, email: str) -> User | None: - statement = select(User).where(User.email == email) - session_user = session.exec(statement).first() - return session_user - - -def authenticate(*, session: Session, email: str, password: str) -> User | None: - db_user = get_user_by_email(session=session, email=email) - if not db_user: - return None - if not verify_password(password, db_user.hashed_password): - return None - return db_user - - -def create_item(*, session: Session, item_in: ItemCreate, owner_id: uuid.UUID) -> Item: - db_item = Item.model_validate(item_in, update={"owner_id": owner_id}) - session.add(db_item) - session.commit() - session.refresh(db_item) - return db_item diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 2389b4a..e69de29 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,113 +0,0 @@ -import uuid - -from pydantic import EmailStr -from sqlmodel import Field, Relationship, SQLModel - - -# Shared properties -class UserBase(SQLModel): - email: EmailStr = Field(unique=True, index=True, max_length=255) - is_active: bool = True - is_superuser: bool = False - full_name: str | None = Field(default=None, max_length=255) - - -# Properties to receive via API on creation -class UserCreate(UserBase): - password: str = Field(min_length=8, max_length=40) - - -class UserRegister(SQLModel): - email: EmailStr = Field(max_length=255) - password: str = Field(min_length=8, max_length=40) - full_name: str | None = Field(default=None, max_length=255) - - -# Properties to receive via API on update, all are optional -class UserUpdate(UserBase): - email: EmailStr | None = Field(default=None, max_length=255) # type: ignore - password: str | None = Field(default=None, min_length=8, max_length=40) - - -class UserUpdateMe(SQLModel): - full_name: str | None = Field(default=None, max_length=255) - email: EmailStr | None = Field(default=None, max_length=255) - - -class UpdatePassword(SQLModel): - current_password: str = Field(min_length=8, max_length=40) - new_password: str = Field(min_length=8, max_length=40) - - -# Database model, database table inferred from class name -class User(UserBase, table=True): - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) - hashed_password: str - items: list["Item"] = Relationship(back_populates="owner", cascade_delete=True) - - -# Properties to return via API, id is always required -class UserPublic(UserBase): - id: uuid.UUID - - -class UsersPublic(SQLModel): - data: list[UserPublic] - count: int - - -# Shared properties -class ItemBase(SQLModel): - title: str = Field(min_length=1, max_length=255) - description: str | None = Field(default=None, max_length=255) - - -# Properties to receive on item creation -class ItemCreate(ItemBase): - pass - - -# Properties to receive on item update -class ItemUpdate(ItemBase): - title: str | None = Field(default=None, min_length=1, max_length=255) # type: ignore - - -# Database model, database table inferred from class name -class Item(ItemBase, table=True): - id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True) - owner_id: uuid.UUID = Field( - foreign_key="user.id", nullable=False, ondelete="CASCADE" - ) - owner: User | None = Relationship(back_populates="items") - - -# Properties to return via API, id is always required -class ItemPublic(ItemBase): - id: uuid.UUID - owner_id: uuid.UUID - - -class ItemsPublic(SQLModel): - data: list[ItemPublic] - count: int - - -# Generic message -class Message(SQLModel): - message: str - - -# JSON payload containing access token -class Token(SQLModel): - access_token: str - token_type: str = "bearer" - - -# Contents of JWT token -class TokenPayload(SQLModel): - sub: str | None = None - - -class NewPassword(SQLModel): - token: str - new_password: str = Field(min_length=8, max_length=40) diff --git a/backend/app/models/base.py b/backend/app/models/base.py new file mode 100644 index 0000000..a96a8ba --- /dev/null +++ b/backend/app/models/base.py @@ -0,0 +1,23 @@ +from sqlmodel import SQLModel + +from uuid import UUID as RowId + + +# region SQLModel base class ################################################### + + +class BaseSQLModel(SQLModel): + pass + +# endregion + +# region Generic message ####################################################### + + +class Message(SQLModel): + message: str + + +# ############################################################################# + +# endregion diff --git a/backend/app/models/mixin.py b/backend/app/models/mixin.py new file mode 100644 index 0000000..a7a303d --- /dev/null +++ b/backend/app/models/mixin.py @@ -0,0 +1,52 @@ +import uuid + +from pydantic import EmailStr, BaseModel +from sqlmodel import Field + +from .base import RowId as RowIdType + + +class FullName(BaseModel): + full_name: str | None = Field(default=None, nullable=True, max_length=255) + + +class IsActive(BaseModel): + is_active: bool | None = Field(default=False, nullable=False) + + +class IsVerified(BaseModel): + is_verified: bool | None = Field(default=False, nullable=False) + + +class UserName(BaseModel): + username: str | None = Field(default=None, nullable=True, unique=True, max_length=255) + + +class Email(BaseModel): + email: EmailStr | None = Field(default=None, nullable=True, unique=True, max_length=255) + + +class EmailUpdate(Email): + email: EmailStr | None = Field(default=None, max_length=255) + + +class ScoutingId(BaseModel): + scouting_id: str | None = Field(default=None, max_length=32) + + +class Password(BaseModel): + password: str = Field(min_length=8, max_length=100) + +class PasswordUpdate(Password): + password: str | None = Field(default=None, min_length=8, max_length=40) + + +class RowId(BaseModel): + id: RowIdType | None = Field( + primary_key=True, + nullable=False, + default_factory=uuid.uuid4, + ) + +class RowIdPublic(RowId): + id: RowIdType diff --git a/backend/app/models/user.py b/backend/app/models/user.py new file mode 100644 index 0000000..de57858 --- /dev/null +++ b/backend/app/models/user.py @@ -0,0 +1,121 @@ +import random +from typing import TYPE_CHECKING + +from pydantic import EmailStr +from sqlmodel import Session, Field, Relationship, select + +from app.core.config import settings +from app.core.security import get_password_hash, verify_password + +from .base import ( + BaseSQLModel, +) +from . import mixin + + +# region User ################################################################## + + +# Shared properties +class UserBase( + mixin.UserName, + mixin.Email, + mixin.FullName, + mixin.ScoutingId, + mixin.IsActive, + mixin.IsVerified, + BaseSQLModel +): + pass + + +# Properties to receive via API on creation +class UserCreate(mixin.Password, UserBase): + pass + + +class UserRegister(mixin.Password, BaseSQLModel): + email: EmailStr = Field(max_length=255) + + +# Properties to receive via API on update, all are optional +class UserUpdate(mixin.EmailUpdate, mixin.PasswordUpdate, UserBase): + pass + + +class UserUpdateMe(mixin.FullName, mixin.EmailUpdate, BaseSQLModel): + pass + + +class UpdatePassword(BaseSQLModel): + current_password: str = Field(min_length=8, max_length=40) + new_password: str = Field(min_length=8, max_length=40) + + +# Database model, database table inferred from class name +class User(mixin.RowId, UserBase, table=True): + # --- database only items -------------------------------------------------- + hashed_password: str + + # --- back_populates links ------------------------------------------------- + + # --- many-to-many links --------------------------------------------------- + + # --- CRUD actions --------------------------------------------------------- + @classmethod + def create(cls, *, session: Session, create_obj: UserCreate) -> "User": + data_obj = create_obj.model_dump(exclude_unset=True) + + extra_data = {"hashed_password": get_password_hash(create_obj.password)} + + db_obj = cls.model_validate(data_obj, update=extra_data) + session.add(db_obj) + session.commit() + session.refresh(db_obj) + return db_obj + + @classmethod + def update(cls, *, session: Session, db_obj: "User", in_obj: UserUpdate) -> "User": + data_obj = in_obj.model_dump(exclude_unset=True) + + extra_data = {} + if "password" in data_obj: + password = data_obj["password"] + hashed_password = get_password_hash(password) + extra_data["hashed_password"] = hashed_password + + db_obj.sqlmodel_update(data_obj, update=extra_data) + session.add(db_obj) + session.commit() + session.refresh(db_obj) + return db_obj + + @classmethod + def get_by_email(cls, *, session: Session, email: str) -> "User | None": + statement = select(cls).where(cls.email == email) + db_obj = session.exec(statement).first() + return db_obj + + @classmethod + def authenticate( + cls, *, session: Session, email: str, password: str + ) -> "User | None": + db_obj = cls.get_by_email(session=session, email=email) + if not db_obj: + return None + if not verify_password(password, db_obj.hashed_password): + return None + return db_obj + + +# Properties to return via API, id is always required +class UserPublic(mixin.RowIdPublic, UserBase): + pass + + +class UsersPublic(BaseSQLModel): + data: list[UserPublic] + count: int + + +# endregion