Add base to store flags as strings and numbes
This commit is contained in:
@@ -4,6 +4,7 @@ from enum import Enum, IntFlag # Python 3.11 >= StrEnum
|
|||||||
from enum import auto as auto_enum
|
from enum import auto as auto_enum
|
||||||
from uuid import UUID as RowId
|
from uuid import UUID as RowId
|
||||||
|
|
||||||
|
from sqlalchemy import TypeDecorator, Integer, VARCHAR
|
||||||
from sqlmodel import SQLModel
|
from sqlmodel import SQLModel
|
||||||
from sqlalchemy.orm import declared_attr
|
from sqlalchemy.orm import declared_attr
|
||||||
|
|
||||||
@@ -39,8 +40,85 @@ class DocumentedStrEnum(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class DocumentedIntFlag(IntFlag):
|
class DocumentedIntFlag(IntFlag):
|
||||||
# TODO: Build DB sport to proper store flags and make it possible to store all mutations
|
@property
|
||||||
pass
|
def names(self) -> str:
|
||||||
|
"""
|
||||||
|
Returns a comma-separated string of all active flag names.
|
||||||
|
"""
|
||||||
|
# Exclude 0-value flags
|
||||||
|
return ",".join([flag.name for flag in type(self) if flag in self and flag.value != 0])
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
# Default string conversion uses the names
|
||||||
|
return self.names
|
||||||
|
|
||||||
|
# Optional: for Pydantic compatibility
|
||||||
|
def __get_pydantic_json__(self) -> str:
|
||||||
|
return self.names
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentedIntFlagType(TypeDecorator):
|
||||||
|
impl = Integer
|
||||||
|
cache_ok = True
|
||||||
|
|
||||||
|
def __init__(self, enum_class: type[IntFlag], *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.enum_class = enum_class
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
"""
|
||||||
|
Convert IntFlag to integer before storing
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, self.enum_class):
|
||||||
|
return int(value)
|
||||||
|
if isinstance(value, int):
|
||||||
|
return value
|
||||||
|
raise ValueError(f"Invalid value for {self.enum_class.__name__}: {value!r}")
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
"""
|
||||||
|
Convert integer from DB back to IntFlag
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return self.enum_class(value)
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentedStrFlagType(TypeDecorator):
|
||||||
|
impl = VARCHAR
|
||||||
|
cache_ok = True
|
||||||
|
|
||||||
|
def __init__(self, enum_class: type[IntFlag], *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.enum_class = enum_class
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
"""
|
||||||
|
Convert IntFlag to comma-separated string of names for storing in DB.
|
||||||
|
"""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, self.enum_class):
|
||||||
|
return str(value)
|
||||||
|
raise ValueError(f"Invalid value for {self.enum_class.__name__}: {value!r}")
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
"""
|
||||||
|
Convert comma-separated string of names from DB back to IntFlag.
|
||||||
|
"""
|
||||||
|
if value is None or value == "":
|
||||||
|
return self.enum_class(0)
|
||||||
|
names = value.split(",")
|
||||||
|
result = self.enum_class(0)
|
||||||
|
for name in names:
|
||||||
|
try:
|
||||||
|
result |= self.enum_class[name]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(f"Invalid flag name '{name}' for {self.enum_class.__name__}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
# #############################################################################
|
# #############################################################################
|
||||||
|
|||||||
Reference in New Issue
Block a user