384 lines
16 KiB
Python
384 lines
16 KiB
Python
# [DEF:TaskPersistenceModule:Module]
|
|
# @SEMANTICS: persistence, sqlite, sqlalchemy, task, storage
|
|
# @PURPOSE: Handles the persistence of tasks using SQLAlchemy and the tasks.db database.
|
|
# @LAYER: Core
|
|
# @RELATION: Used by TaskManager to save and load tasks.
|
|
# @INVARIANT: Database schema must match the TaskRecord model structure.
|
|
|
|
# [SECTION: IMPORTS]
|
|
from datetime import datetime
|
|
from typing import List, Optional
|
|
import json
|
|
|
|
from sqlalchemy.orm import Session
|
|
from ...models.task import TaskRecord, TaskLogRecord
|
|
from ..database import TasksSessionLocal
|
|
from .models import Task, TaskStatus, LogEntry, TaskLog, LogFilter, LogStats
|
|
from ..logger import logger, belief_scope
|
|
# [/SECTION]
|
|
|
|
# [DEF:TaskPersistenceService:Class]
|
|
# @SEMANTICS: persistence, service, database, sqlalchemy
|
|
# @PURPOSE: Provides methods to save and load tasks from the tasks.db database using SQLAlchemy.
|
|
class TaskPersistenceService:
|
|
# [DEF:__init__:Function]
|
|
# @PURPOSE: Initializes the persistence service.
|
|
# @PRE: None.
|
|
# @POST: Service is ready.
|
|
def __init__(self):
|
|
with belief_scope("TaskPersistenceService.__init__"):
|
|
# We use TasksSessionLocal from database.py
|
|
pass
|
|
# [/DEF:__init__:Function]
|
|
|
|
# [DEF:persist_task:Function]
|
|
# @PURPOSE: Persists or updates a single task in the database.
|
|
# @PRE: isinstance(task, Task)
|
|
# @POST: Task record created or updated in database.
|
|
# @PARAM: task (Task) - The task object to persist.
|
|
# @SIDE_EFFECT: Writes to task_records table in tasks.db
|
|
def persist_task(self, task: Task) -> None:
|
|
with belief_scope("TaskPersistenceService.persist_task", f"task_id={task.id}"):
|
|
session: Session = TasksSessionLocal()
|
|
try:
|
|
record = session.query(TaskRecord).filter(TaskRecord.id == task.id).first()
|
|
if not record:
|
|
record = TaskRecord(id=task.id)
|
|
session.add(record)
|
|
|
|
record.type = task.plugin_id
|
|
record.status = task.status.value
|
|
record.environment_id = task.params.get("environment_id") or task.params.get("source_env_id")
|
|
record.started_at = task.started_at
|
|
record.finished_at = task.finished_at
|
|
|
|
# Ensure params and result are JSON serializable
|
|
def json_serializable(obj):
|
|
if isinstance(obj, dict):
|
|
return {k: json_serializable(v) for k, v in obj.items()}
|
|
elif isinstance(obj, list):
|
|
return [json_serializable(v) for v in obj]
|
|
elif isinstance(obj, datetime):
|
|
return obj.isoformat()
|
|
return obj
|
|
|
|
record.params = json_serializable(task.params)
|
|
record.result = json_serializable(task.result)
|
|
|
|
# Store logs as JSON, converting datetime to string
|
|
record.logs = []
|
|
for log in task.logs:
|
|
log_dict = log.dict()
|
|
if isinstance(log_dict.get('timestamp'), datetime):
|
|
log_dict['timestamp'] = log_dict['timestamp'].isoformat()
|
|
# Also clean up any datetimes in context
|
|
if log_dict.get('context'):
|
|
log_dict['context'] = json_serializable(log_dict['context'])
|
|
record.logs.append(log_dict)
|
|
|
|
# Extract error if failed
|
|
if task.status == TaskStatus.FAILED:
|
|
for log in reversed(task.logs):
|
|
if log.level == "ERROR":
|
|
record.error = log.message
|
|
break
|
|
|
|
session.commit()
|
|
except Exception as e:
|
|
session.rollback()
|
|
logger.error(f"Failed to persist task {task.id}: {e}")
|
|
finally:
|
|
session.close()
|
|
# [/DEF:persist_task:Function]
|
|
|
|
# [DEF:persist_tasks:Function]
|
|
# @PURPOSE: Persists multiple tasks.
|
|
# @PRE: isinstance(tasks, list)
|
|
# @POST: All tasks in list are persisted.
|
|
# @PARAM: tasks (List[Task]) - The list of tasks to persist.
|
|
def persist_tasks(self, tasks: List[Task]) -> None:
|
|
with belief_scope("TaskPersistenceService.persist_tasks"):
|
|
for task in tasks:
|
|
self.persist_task(task)
|
|
# [/DEF:persist_tasks:Function]
|
|
|
|
# [DEF:load_tasks:Function]
|
|
# @PURPOSE: Loads tasks from the database.
|
|
# @PRE: limit is an integer.
|
|
# @POST: Returns list of Task objects.
|
|
# @PARAM: limit (int) - Max tasks to load.
|
|
# @PARAM: status (Optional[TaskStatus]) - Filter by status.
|
|
# @RETURN: List[Task] - The loaded tasks.
|
|
def load_tasks(self, limit: int = 100, status: Optional[TaskStatus] = None) -> List[Task]:
|
|
with belief_scope("TaskPersistenceService.load_tasks"):
|
|
session: Session = TasksSessionLocal()
|
|
try:
|
|
query = session.query(TaskRecord)
|
|
if status:
|
|
query = query.filter(TaskRecord.status == status.value)
|
|
|
|
records = query.order_by(TaskRecord.created_at.desc()).limit(limit).all()
|
|
|
|
loaded_tasks = []
|
|
for record in records:
|
|
try:
|
|
logs = []
|
|
if record.logs:
|
|
for log_data in record.logs:
|
|
# Handle timestamp conversion if it's a string
|
|
if isinstance(log_data.get('timestamp'), str):
|
|
log_data['timestamp'] = datetime.fromisoformat(log_data['timestamp'])
|
|
logs.append(LogEntry(**log_data))
|
|
|
|
task = Task(
|
|
id=record.id,
|
|
plugin_id=record.type,
|
|
status=TaskStatus(record.status),
|
|
started_at=record.started_at,
|
|
finished_at=record.finished_at,
|
|
params=record.params or {},
|
|
result=record.result,
|
|
logs=logs
|
|
)
|
|
loaded_tasks.append(task)
|
|
except Exception as e:
|
|
logger.error(f"Failed to reconstruct task {record.id}: {e}")
|
|
|
|
return loaded_tasks
|
|
finally:
|
|
session.close()
|
|
# [/DEF:load_tasks:Function]
|
|
|
|
# [DEF:delete_tasks:Function]
|
|
# @PURPOSE: Deletes specific tasks from the database.
|
|
# @PRE: task_ids is a list of strings.
|
|
# @POST: Specified task records deleted from database.
|
|
# @PARAM: task_ids (List[str]) - List of task IDs to delete.
|
|
def delete_tasks(self, task_ids: List[str]) -> None:
|
|
if not task_ids:
|
|
return
|
|
with belief_scope("TaskPersistenceService.delete_tasks"):
|
|
session: Session = TasksSessionLocal()
|
|
try:
|
|
session.query(TaskRecord).filter(TaskRecord.id.in_(task_ids)).delete(synchronize_session=False)
|
|
session.commit()
|
|
except Exception as e:
|
|
session.rollback()
|
|
logger.error(f"Failed to delete tasks: {e}")
|
|
finally:
|
|
session.close()
|
|
# [/DEF:delete_tasks:Function]
|
|
|
|
# [/DEF:TaskPersistenceService:Class]
|
|
|
|
# [DEF:TaskLogPersistenceService:Class]
|
|
# @SEMANTICS: persistence, service, database, log, sqlalchemy
|
|
# @PURPOSE: Provides methods to save and query task logs from the task_logs table.
|
|
# @TIER: CRITICAL
|
|
# @RELATION: DEPENDS_ON -> TaskLogRecord
|
|
# @INVARIANT: Log entries are batch-inserted for performance.
|
|
class TaskLogPersistenceService:
|
|
"""
|
|
Service for persisting and querying task logs.
|
|
Supports batch inserts, filtering, and statistics.
|
|
"""
|
|
|
|
# [DEF:__init__:Function]
|
|
# @PURPOSE: Initialize the log persistence service.
|
|
# @POST: Service is ready.
|
|
def __init__(self):
|
|
pass
|
|
# [/DEF:__init__:Function]
|
|
|
|
# [DEF:add_logs:Function]
|
|
# @PURPOSE: Batch insert log entries for a task.
|
|
# @PRE: logs is a list of LogEntry objects.
|
|
# @POST: All logs inserted into task_logs table.
|
|
# @PARAM: task_id (str) - The task ID.
|
|
# @PARAM: logs (List[LogEntry]) - Log entries to insert.
|
|
# @SIDE_EFFECT: Writes to task_logs table.
|
|
def add_logs(self, task_id: str, logs: List[LogEntry]) -> None:
|
|
if not logs:
|
|
return
|
|
with belief_scope("TaskLogPersistenceService.add_logs", f"task_id={task_id}"):
|
|
session: Session = TasksSessionLocal()
|
|
try:
|
|
for log in logs:
|
|
record = TaskLogRecord(
|
|
task_id=task_id,
|
|
timestamp=log.timestamp,
|
|
level=log.level,
|
|
source=log.source or "system",
|
|
message=log.message,
|
|
metadata_json=json.dumps(log.metadata) if log.metadata else None
|
|
)
|
|
session.add(record)
|
|
session.commit()
|
|
except Exception as e:
|
|
session.rollback()
|
|
logger.error(f"Failed to add logs for task {task_id}: {e}")
|
|
finally:
|
|
session.close()
|
|
# [/DEF:add_logs:Function]
|
|
|
|
# [DEF:get_logs:Function]
|
|
# @PURPOSE: Query logs for a task with filtering and pagination.
|
|
# @PRE: task_id is a valid task ID.
|
|
# @POST: Returns list of TaskLog objects matching filters.
|
|
# @PARAM: task_id (str) - The task ID.
|
|
# @PARAM: log_filter (LogFilter) - Filter parameters.
|
|
# @RETURN: List[TaskLog] - Filtered log entries.
|
|
def get_logs(self, task_id: str, log_filter: LogFilter) -> List[TaskLog]:
|
|
with belief_scope("TaskLogPersistenceService.get_logs", f"task_id={task_id}"):
|
|
session: Session = TasksSessionLocal()
|
|
try:
|
|
query = session.query(TaskLogRecord).filter(TaskLogRecord.task_id == task_id)
|
|
|
|
# Apply filters
|
|
if log_filter.level:
|
|
query = query.filter(TaskLogRecord.level == log_filter.level.upper())
|
|
if log_filter.source:
|
|
query = query.filter(TaskLogRecord.source == log_filter.source)
|
|
if log_filter.search:
|
|
search_pattern = f"%{log_filter.search}%"
|
|
query = query.filter(TaskLogRecord.message.ilike(search_pattern))
|
|
|
|
# Order by timestamp ascending (oldest first)
|
|
query = query.order_by(TaskLogRecord.timestamp.asc())
|
|
|
|
# Apply pagination
|
|
records = query.offset(log_filter.offset).limit(log_filter.limit).all()
|
|
|
|
logs = []
|
|
for record in records:
|
|
metadata = None
|
|
if record.metadata_json:
|
|
try:
|
|
metadata = json.loads(record.metadata_json)
|
|
except json.JSONDecodeError:
|
|
metadata = None
|
|
|
|
logs.append(TaskLog(
|
|
id=record.id,
|
|
task_id=record.task_id,
|
|
timestamp=record.timestamp,
|
|
level=record.level,
|
|
source=record.source,
|
|
message=record.message,
|
|
metadata=metadata
|
|
))
|
|
|
|
return logs
|
|
finally:
|
|
session.close()
|
|
# [/DEF:get_logs:Function]
|
|
|
|
# [DEF:get_log_stats:Function]
|
|
# @PURPOSE: Get statistics about logs for a task.
|
|
# @PRE: task_id is a valid task ID.
|
|
# @POST: Returns LogStats with counts by level and source.
|
|
# @PARAM: task_id (str) - The task ID.
|
|
# @RETURN: LogStats - Statistics about task logs.
|
|
def get_log_stats(self, task_id: str) -> LogStats:
|
|
with belief_scope("TaskLogPersistenceService.get_log_stats", f"task_id={task_id}"):
|
|
session: Session = TasksSessionLocal()
|
|
try:
|
|
# Get total count
|
|
total_count = session.query(TaskLogRecord).filter(
|
|
TaskLogRecord.task_id == task_id
|
|
).count()
|
|
|
|
# Get counts by level
|
|
from sqlalchemy import func
|
|
level_counts = session.query(
|
|
TaskLogRecord.level,
|
|
func.count(TaskLogRecord.id)
|
|
).filter(
|
|
TaskLogRecord.task_id == task_id
|
|
).group_by(TaskLogRecord.level).all()
|
|
|
|
by_level = {level: count for level, count in level_counts}
|
|
|
|
# Get counts by source
|
|
source_counts = session.query(
|
|
TaskLogRecord.source,
|
|
func.count(TaskLogRecord.id)
|
|
).filter(
|
|
TaskLogRecord.task_id == task_id
|
|
).group_by(TaskLogRecord.source).all()
|
|
|
|
by_source = {source: count for source, count in source_counts}
|
|
|
|
return LogStats(
|
|
total_count=total_count,
|
|
by_level=by_level,
|
|
by_source=by_source
|
|
)
|
|
finally:
|
|
session.close()
|
|
# [/DEF:get_log_stats:Function]
|
|
|
|
# [DEF:get_sources:Function]
|
|
# @PURPOSE: Get unique sources for a task's logs.
|
|
# @PRE: task_id is a valid task ID.
|
|
# @POST: Returns list of unique source strings.
|
|
# @PARAM: task_id (str) - The task ID.
|
|
# @RETURN: List[str] - Unique source names.
|
|
def get_sources(self, task_id: str) -> List[str]:
|
|
with belief_scope("TaskLogPersistenceService.get_sources", f"task_id={task_id}"):
|
|
session: Session = TasksSessionLocal()
|
|
try:
|
|
from sqlalchemy import distinct
|
|
sources = session.query(distinct(TaskLogRecord.source)).filter(
|
|
TaskLogRecord.task_id == task_id
|
|
).all()
|
|
return [s[0] for s in sources]
|
|
finally:
|
|
session.close()
|
|
# [/DEF:get_sources:Function]
|
|
|
|
# [DEF:delete_logs_for_task:Function]
|
|
# @PURPOSE: Delete all logs for a specific task.
|
|
# @PRE: task_id is a valid task ID.
|
|
# @POST: All logs for the task are deleted.
|
|
# @PARAM: task_id (str) - The task ID.
|
|
# @SIDE_EFFECT: Deletes from task_logs table.
|
|
def delete_logs_for_task(self, task_id: str) -> None:
|
|
with belief_scope("TaskLogPersistenceService.delete_logs_for_task", f"task_id={task_id}"):
|
|
session: Session = TasksSessionLocal()
|
|
try:
|
|
session.query(TaskLogRecord).filter(
|
|
TaskLogRecord.task_id == task_id
|
|
).delete(synchronize_session=False)
|
|
session.commit()
|
|
except Exception as e:
|
|
session.rollback()
|
|
logger.error(f"Failed to delete logs for task {task_id}: {e}")
|
|
finally:
|
|
session.close()
|
|
# [/DEF:delete_logs_for_task:Function]
|
|
|
|
# [DEF:delete_logs_for_tasks:Function]
|
|
# @PURPOSE: Delete all logs for multiple tasks.
|
|
# @PRE: task_ids is a list of task IDs.
|
|
# @POST: All logs for the tasks are deleted.
|
|
# @PARAM: task_ids (List[str]) - List of task IDs.
|
|
def delete_logs_for_tasks(self, task_ids: List[str]) -> None:
|
|
if not task_ids:
|
|
return
|
|
with belief_scope("TaskLogPersistenceService.delete_logs_for_tasks"):
|
|
session: Session = TasksSessionLocal()
|
|
try:
|
|
session.query(TaskLogRecord).filter(
|
|
TaskLogRecord.task_id.in_(task_ids)
|
|
).delete(synchronize_session=False)
|
|
session.commit()
|
|
except Exception as e:
|
|
session.rollback()
|
|
logger.error(f"Failed to delete logs for tasks: {e}")
|
|
finally:
|
|
session.close()
|
|
# [/DEF:delete_logs_for_tasks:Function]
|
|
|
|
# [/DEF:TaskLogPersistenceService:Class]
|
|
# [/DEF:TaskPersistenceModule:Module] |