diff --git a/backend/app/models/base.py b/backend/app/models/base.py index 0b73771..9842264 100644 --- a/backend/app/models/base.py +++ b/backend/app/models/base.py @@ -4,6 +4,7 @@ from enum import Enum, IntFlag # Python 3.11 >= StrEnum from enum import auto as auto_enum from uuid import UUID as RowId +from sqlalchemy import TypeDecorator, Integer, VARCHAR from sqlmodel import SQLModel from sqlalchemy.orm import declared_attr @@ -39,8 +40,85 @@ class DocumentedStrEnum(str, Enum): class DocumentedIntFlag(IntFlag): - # TODO: Build DB sport to proper store flags and make it possible to store all mutations - pass + @property + def names(self) -> str: + """ + Returns a comma-separated string of all active flag names. + """ + # Exclude 0-value flags + return ",".join([flag.name for flag in type(self) if flag in self and flag.value != 0]) + + def __str__(self) -> str: + # Default string conversion uses the names + return self.names + + # Optional: for Pydantic compatibility + def __get_pydantic_json__(self) -> str: + return self.names + + + +class DocumentedIntFlagType(TypeDecorator): + impl = Integer + cache_ok = True + + def __init__(self, enum_class: type[IntFlag], *args, **kwargs): + super().__init__(*args, **kwargs) + self.enum_class = enum_class + + def process_bind_param(self, value, dialect): + """ + Convert IntFlag to integer before storing + """ + if value is None: + return None + if isinstance(value, self.enum_class): + return int(value) + if isinstance(value, int): + return value + raise ValueError(f"Invalid value for {self.enum_class.__name__}: {value!r}") + + def process_result_value(self, value, dialect): + """ + Convert integer from DB back to IntFlag + """ + if value is None: + return None + return self.enum_class(value) + + +class DocumentedStrFlagType(TypeDecorator): + impl = VARCHAR + cache_ok = True + + def __init__(self, enum_class: type[IntFlag], *args, **kwargs): + super().__init__(*args, **kwargs) + self.enum_class = enum_class + + def process_bind_param(self, value, dialect): + """ + Convert IntFlag to comma-separated string of names for storing in DB. + """ + if value is None: + return None + if isinstance(value, self.enum_class): + return str(value) + raise ValueError(f"Invalid value for {self.enum_class.__name__}: {value!r}") + + def process_result_value(self, value, dialect): + """ + Convert comma-separated string of names from DB back to IntFlag. + """ + if value is None or value == "": + return self.enum_class(0) + names = value.split(",") + result = self.enum_class(0) + for name in names: + try: + result |= self.enum_class[name] + except KeyError: + raise ValueError(f"Invalid flag name '{name}' for {self.enum_class.__name__}") + return result # #############################################################################