from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING, List
from sqlalchemy import TIMESTAMP, Boolean, Column, Float, ForeignKey, Integer, String, Table, exists, insert
from sqlalchemy.orm import Mapped, joinedload, mapped_column, relationship, selectinload
from harp.models.transactions import Transaction as TransactionModel
from .base import Base, Repository, with_session
from .flags import FLAGS_BY_TYPE, UserFlag
from .tags import TagValue
if TYPE_CHECKING:
from .messages import Message
transaction_tag_values_association_table = Table(
"trans_tag_values",
Base.metadata,
Column("transaction_id", ForeignKey("transactions.id", ondelete="CASCADE"), primary_key=True),
Column("value_id", ForeignKey("tag_values.id", ondelete="CASCADE"), primary_key=True),
)
[docs]
class Transaction(Base):
__tablename__ = "transactions"
id = mapped_column(String(27), primary_key=True, unique=True)
type = mapped_column(String(10), index=True)
endpoint = mapped_column(String(32), nullable=True, index=True)
started_at = mapped_column(TIMESTAMP(timezone=True), index=True)
finished_at = mapped_column(TIMESTAMP(timezone=True), nullable=True)
elapsed = mapped_column(Float(), nullable=True)
tpdex = mapped_column(Integer(), nullable=True)
x_method = mapped_column(String(16), nullable=True, index=True)
x_status_class = mapped_column(String(3), nullable=True, index=True)
x_cached = mapped_column(String(32), nullable=True)
x_no_cache = mapped_column(Boolean(), nullable=True, default=False)
messages: Mapped[List["Message"]] = relationship(
back_populates="transaction",
order_by="Message.id",
cascade="all, delete",
passive_deletes=True,
)
flags: Mapped[List["UserFlag"]] = relationship(
back_populates="transaction",
cascade="all, delete",
passive_deletes=True,
)
_tag_values: Mapped[List["TagValue"]] = relationship(
secondary=transaction_tag_values_association_table,
cascade="all, delete",
passive_deletes=True,
)
[docs]
def to_model(self, with_user_flags=False):
return TransactionModel(
id=self.id,
type=self.type,
endpoint=self.endpoint,
started_at=self.started_at.replace(tzinfo=UTC),
finished_at=self.finished_at.replace(tzinfo=UTC) if self.finished_at else self.finished_at,
elapsed=self.elapsed,
tpdex=self.tpdex,
extras=dict(
method=self.x_method,
status_class=self.x_status_class,
cached=bool(self.x_cached),
no_cache=bool(self.x_no_cache),
**(
{"flags": list(set(filter(None, (FLAGS_BY_TYPE.get(flag.type, None) for flag in self.flags))))}
if with_user_flags
else {}
),
),
messages=[message.to_model() for message in self.messages] if self.messages else [],
tags=self.tags,
)
@property
def tags(self):
return {tag_value.tag.name: tag_value.value for tag_value in self._tag_values}
[docs]
class TransactionsRepository(Repository[Transaction]):
Type = Transaction
[docs]
def __init__(self, session_factory, /, tags=None, tag_values=None):
super().__init__(session_factory)
self.tags = tags
self.tag_values = tag_values
[docs]
def select(self, /, *, with_messages=False, with_user_flags=False, with_tags=False):
query = super().select()
# should we join transaction messages?
if with_messages:
query = query.options(
joinedload(
self.Type.messages,
)
)
# should we select flags for given user id?
if with_user_flags:
query = query.options(
selectinload(
self.Type.flags.and_(
UserFlag.user_id == with_user_flags,
)
)
)
# should we select tags?
if with_tags:
query = query.options(
selectinload(
self.Type._tag_values,
).joinedload(TagValue.tag)
)
return query
[docs]
def delete_old(self, old_after: timedelta):
threshold = datetime.now(UTC) - old_after
no_flags = ~exists().where(UserFlag.transaction_id == self.Type.id)
return self.delete().where((self.Type.started_at < threshold) & no_flags)
[docs]
@with_session
async def create(self, values: dict | TransactionModel, /, *, session=None):
# convert model to dict
if isinstance(values, TransactionModel):
# todo in to_dict method ? but how to keep prototype of parent ?
values = dict(
id=values.id,
type=values.type,
endpoint=values.endpoint,
started_at=values.started_at,
x_method=values.extras.get("method"),
x_no_cache=bool(values.extras.get("no_cache")),
tags=values.tags,
)
tags = values.pop("tags", {})
transaction = await super().create(values, session=session)
if len(tags):
await self.set_tags(transaction, tags, session=session)
return transaction