159 lines
4.0 KiB
Python
159 lines
4.0 KiB
Python
import re
|
|
|
|
from enum import Enum, IntFlag # Python 3.11 >= StrEnum
|
|
from enum import auto as auto_enum
|
|
from uuid import UUID as RowId
|
|
|
|
from sqlalchemy import TypeDecorator, Integer, VARCHAR
|
|
from sqlmodel import SQLModel
|
|
from sqlalchemy.orm import declared_attr
|
|
|
|
__all__ = [
|
|
"RowId",
|
|
"DocumentedStrEnum",
|
|
"DocumentedIntFlag",
|
|
"auto_enum",
|
|
"ApiTags",
|
|
"BaseSQLModel",
|
|
"Message",
|
|
]
|
|
|
|
# region SQLModel base class ###################################################
|
|
|
|
|
|
class BaseSQLModel(SQLModel):
|
|
# Generate __tablename__ automatically with snake_case
|
|
# noinspection PyMethodParameters
|
|
@declared_attr # type: ignore
|
|
def __tablename__(cls) -> str:
|
|
rx = re.compile(r"(?<=.)(((?<![A-Z])[A-Z])|([A-Z](?=[a-z])))")
|
|
return rx.sub("_\\1", cls.__name__).lower()
|
|
|
|
# endregion
|
|
|
|
|
|
# region enum # Fields #########################################################
|
|
|
|
|
|
class DocumentedStrEnum(str, Enum):
|
|
pass
|
|
|
|
|
|
class DocumentedIntFlag(IntFlag):
|
|
@property
|
|
def names(self) -> str:
|
|
"""
|
|
Returns a comma-separated string of all active flag names.
|
|
"""
|
|
# Exclude 0-value flags
|
|
return ",".join([flag.name for flag in type(self) if flag in self and flag.value != 0])
|
|
|
|
def __str__(self) -> str:
|
|
# Default string conversion uses the names
|
|
return self.names
|
|
|
|
# Optional: for Pydantic compatibility
|
|
def __get_pydantic_json__(self) -> str:
|
|
return self.names
|
|
|
|
|
|
|
|
class DocumentedIntFlagType(TypeDecorator):
|
|
impl = Integer
|
|
cache_ok = True
|
|
|
|
def __init__(self, enum_class: type[IntFlag], *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.enum_class = enum_class
|
|
|
|
def process_bind_param(self, value, dialect):
|
|
"""
|
|
Convert IntFlag to integer before storing
|
|
"""
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, self.enum_class):
|
|
return int(value)
|
|
if isinstance(value, int):
|
|
return value
|
|
raise ValueError(f"Invalid value for {self.enum_class.__name__}: {value!r}")
|
|
|
|
def process_result_value(self, value, dialect):
|
|
"""
|
|
Convert integer from DB back to IntFlag
|
|
"""
|
|
if value is None:
|
|
return None
|
|
return self.enum_class(value)
|
|
|
|
|
|
class DocumentedStrFlagType(TypeDecorator):
|
|
impl = VARCHAR
|
|
cache_ok = True
|
|
|
|
def __init__(self, enum_class: type[IntFlag], *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.enum_class = enum_class
|
|
|
|
def process_bind_param(self, value, dialect):
|
|
"""
|
|
Convert IntFlag to comma-separated string of names for storing in DB.
|
|
"""
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, self.enum_class):
|
|
return str(value)
|
|
raise ValueError(f"Invalid value for {self.enum_class.__name__}: {value!r}")
|
|
|
|
def process_result_value(self, value, dialect):
|
|
"""
|
|
Convert comma-separated string of names from DB back to IntFlag.
|
|
"""
|
|
if value is None or value == "":
|
|
return self.enum_class(0)
|
|
names = value.split(",")
|
|
result = self.enum_class(0)
|
|
for name in names:
|
|
try:
|
|
result |= self.enum_class[name]
|
|
except KeyError:
|
|
raise ValueError(f"Invalid flag name '{name}' for {self.enum_class.__name__}")
|
|
return result
|
|
|
|
|
|
# #############################################################################
|
|
|
|
|
|
class ApiTags(DocumentedStrEnum):
|
|
LOGIN = "Login"
|
|
USERS = "Users"
|
|
UTILS = "Utils"
|
|
PRIVATE = "Private"
|
|
|
|
APIKEY = "APIKey"
|
|
|
|
EVENTS = "Events"
|
|
TEAMS = "Teams"
|
|
ASSOCIATIONS = "Associations"
|
|
DIVISIONS = "Divisions"
|
|
MEMBERS = "Members"
|
|
|
|
HIKES = "Hikes"
|
|
ROUTES = "Routes"
|
|
PLACES = "Places"
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Generic message #######################################################
|
|
|
|
|
|
class Message(SQLModel):
|
|
message: str
|
|
|
|
|
|
# #############################################################################
|
|
|
|
# endregion
|