# [DEF:TaskManagerModule:Module] # @TIER: CRITICAL # @SEMANTICS: task, manager, lifecycle, execution, state # @PURPOSE: Manages the lifecycle of tasks, including their creation, execution, and state tracking. It uses a thread pool to run plugins asynchronously. # @LAYER: Core # @RELATION: Depends on PluginLoader to get plugin instances. It is used by the API layer to create and query tasks. # @INVARIANT: Task IDs are unique. # @CONSTRAINT: Must use belief_scope for logging. # [SECTION: IMPORTS] import asyncio import threading import inspect from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timezone from typing import Dict, Any, List, Optional from .models import Task, TaskStatus, LogEntry, LogFilter, LogStats from .persistence import TaskPersistenceService, TaskLogPersistenceService from .context import TaskContext from ..logger import logger, belief_scope, should_log_task_level # [/SECTION] # [DEF:TaskManager:Class] # @SEMANTICS: task, manager, lifecycle, execution, state # @PURPOSE: Manages the lifecycle of tasks, including their creation, execution, and state tracking. # @TIER: CRITICAL # @INVARIANT: Task IDs are unique within the registry. # @INVARIANT: Each task has exactly one status at any time. # @INVARIANT: Log entries are never deleted after being added to a task. class TaskManager: """ Manages the lifecycle of tasks, including their creation, execution, and state tracking. """ # Log flush interval in seconds LOG_FLUSH_INTERVAL = 2.0 # [DEF:__init__:Function] # @PURPOSE: Initialize the TaskManager with dependencies. # @PRE: plugin_loader is initialized. # @POST: TaskManager is ready to accept tasks. # @PARAM: plugin_loader - The plugin loader instance. def __init__(self, plugin_loader): with belief_scope("TaskManager.__init__"): self.plugin_loader = plugin_loader self.tasks: Dict[str, Task] = {} self.subscribers: Dict[str, List[asyncio.Queue]] = {} self.executor = ThreadPoolExecutor(max_workers=5) # For CPU-bound plugin execution self.persistence_service = TaskPersistenceService() self.log_persistence_service = TaskLogPersistenceService() # Log buffer: task_id -> List[LogEntry] self._log_buffer: Dict[str, List[LogEntry]] = {} self._log_buffer_lock = threading.Lock() # Flusher thread for batch writing logs self._flusher_stop_event = threading.Event() self._flusher_thread = threading.Thread(target=self._flusher_loop, daemon=True) self._flusher_thread.start() try: self.loop = asyncio.get_running_loop() except RuntimeError: self.loop = asyncio.get_event_loop() self.task_futures: Dict[str, asyncio.Future] = {} # Load persisted tasks on startup self.load_persisted_tasks() # [/DEF:__init__:Function] # [DEF:_flusher_loop:Function] # @PURPOSE: Background thread that periodically flushes log buffer to database. # @PRE: TaskManager is initialized. # @POST: Logs are batch-written to database every LOG_FLUSH_INTERVAL seconds. def _flusher_loop(self): """Background thread that flushes log buffer to database.""" with belief_scope("_flusher_loop"): while not self._flusher_stop_event.is_set(): self._flush_logs() self._flusher_stop_event.wait(self.LOG_FLUSH_INTERVAL) # [/DEF:_flusher_loop:Function] # [DEF:_flush_logs:Function] # @PURPOSE: Flush all buffered logs to the database. # @PRE: None. # @POST: All buffered logs are written to task_logs table. def _flush_logs(self): """Flush all buffered logs to the database.""" with belief_scope("_flush_logs"): with self._log_buffer_lock: task_ids = list(self._log_buffer.keys()) for task_id in task_ids: with self._log_buffer_lock: logs = self._log_buffer.pop(task_id, []) if logs: try: self.log_persistence_service.add_logs(task_id, logs) except Exception as e: logger.error(f"Failed to flush logs for task {task_id}: {e}") # Re-add logs to buffer on failure with self._log_buffer_lock: if task_id not in self._log_buffer: self._log_buffer[task_id] = [] self._log_buffer[task_id].extend(logs) # [/DEF:_flush_logs:Function] # [DEF:_flush_task_logs:Function] # @PURPOSE: Flush logs for a specific task immediately. # @PRE: task_id exists. # @POST: Task's buffered logs are written to database. # @PARAM: task_id (str) - The task ID. def _flush_task_logs(self, task_id: str): """Flush logs for a specific task immediately.""" with belief_scope("_flush_task_logs"): with self._log_buffer_lock: logs = self._log_buffer.pop(task_id, []) if logs: try: self.log_persistence_service.add_logs(task_id, logs) except Exception as e: logger.error(f"Failed to flush logs for task {task_id}: {e}") # [/DEF:_flush_task_logs:Function] # [DEF:create_task:Function] # @PURPOSE: Creates and queues a new task for execution. # @PRE: Plugin with plugin_id exists. Params are valid. # @POST: Task is created, added to registry, and scheduled for execution. # @PARAM: plugin_id (str) - The ID of the plugin to run. # @PARAM: params (Dict[str, Any]) - Parameters for the plugin. # @PARAM: user_id (Optional[str]) - ID of the user requesting the task. # @RETURN: Task - The created task instance. # @THROWS: ValueError if plugin not found or params invalid. async def create_task(self, plugin_id: str, params: Dict[str, Any], user_id: Optional[str] = None) -> Task: with belief_scope("TaskManager.create_task", f"plugin_id={plugin_id}"): if not self.plugin_loader.has_plugin(plugin_id): logger.error(f"Plugin with ID '{plugin_id}' not found.") raise ValueError(f"Plugin with ID '{plugin_id}' not found.") self.plugin_loader.get_plugin(plugin_id) if not isinstance(params, dict): logger.error("Task parameters must be a dictionary.") raise ValueError("Task parameters must be a dictionary.") task = Task(plugin_id=plugin_id, params=params, user_id=user_id) self.tasks[task.id] = task self.persistence_service.persist_task(task) logger.info(f"Task {task.id} created and scheduled for execution") self.loop.create_task(self._run_task(task.id)) # Schedule task for execution return task # [/DEF:create_task:Function] # [DEF:_run_task:Function] # @PURPOSE: Internal method to execute a task with TaskContext support. # @PRE: Task exists in registry. # @POST: Task is executed, status updated to SUCCESS or FAILED. # @PARAM: task_id (str) - The ID of the task to run. async def _run_task(self, task_id: str): with belief_scope("TaskManager._run_task", f"task_id={task_id}"): task = self.tasks[task_id] plugin = self.plugin_loader.get_plugin(task.plugin_id) logger.info(f"Starting execution of task {task_id} for plugin '{plugin.name}'") task.status = TaskStatus.RUNNING task.started_at = datetime.utcnow() self.persistence_service.persist_task(task) self._add_log(task_id, "INFO", f"Task started for plugin '{plugin.name}'", source="system") try: # Prepare params and check if plugin supports new TaskContext params = {**task.params, "_task_id": task_id} # Check if plugin's execute method accepts 'context' parameter sig = inspect.signature(plugin.execute) accepts_context = 'context' in sig.parameters if accepts_context: # Create TaskContext for new-style plugins context = TaskContext( task_id=task_id, add_log_fn=self._add_log, params=params, default_source="plugin" ) if asyncio.iscoroutinefunction(plugin.execute): task.result = await plugin.execute(params, context=context) else: task.result = await self.loop.run_in_executor( self.executor, lambda: plugin.execute(params, context=context) ) else: # Backward compatibility: old-style plugins without context if asyncio.iscoroutinefunction(plugin.execute): task.result = await plugin.execute(params) else: task.result = await self.loop.run_in_executor( self.executor, plugin.execute, params ) logger.info(f"Task {task_id} completed successfully") task.status = TaskStatus.SUCCESS self._add_log(task_id, "INFO", f"Task completed successfully for plugin '{plugin.name}'", source="system") except Exception as e: logger.error(f"Task {task_id} failed: {e}") task.status = TaskStatus.FAILED self._add_log(task_id, "ERROR", f"Task failed: {e}", source="system", metadata={"error_type": type(e).__name__}) finally: task.finished_at = datetime.utcnow() # Flush any remaining buffered logs before persisting task self._flush_task_logs(task_id) self.persistence_service.persist_task(task) logger.info(f"Task {task_id} execution finished with status: {task.status}") # [/DEF:_run_task:Function] # [DEF:resolve_task:Function] # @PURPOSE: Resumes a task that is awaiting mapping. # @PRE: Task exists and is in AWAITING_MAPPING state. # @POST: Task status updated to RUNNING, params updated, execution resumed. # @PARAM: task_id (str) - The ID of the task. # @PARAM: resolution_params (Dict[str, Any]) - Params to resolve the wait. # @THROWS: ValueError if task not found or not awaiting mapping. async def resolve_task(self, task_id: str, resolution_params: Dict[str, Any]): with belief_scope("TaskManager.resolve_task", f"task_id={task_id}"): task = self.tasks.get(task_id) if not task or task.status != TaskStatus.AWAITING_MAPPING: raise ValueError("Task is not awaiting mapping.") # Update task params with resolution task.params.update(resolution_params) task.status = TaskStatus.RUNNING self.persistence_service.persist_task(task) self._add_log(task_id, "INFO", "Task resumed after mapping resolution.") # Signal the future to continue if task_id in self.task_futures: self.task_futures[task_id].set_result(True) # [/DEF:resolve_task:Function] # [DEF:wait_for_resolution:Function] # @PURPOSE: Pauses execution and waits for a resolution signal. # @PRE: Task exists. # @POST: Execution pauses until future is set. # @PARAM: task_id (str) - The ID of the task. async def wait_for_resolution(self, task_id: str): with belief_scope("TaskManager.wait_for_resolution", f"task_id={task_id}"): task = self.tasks.get(task_id) if not task: return task.status = TaskStatus.AWAITING_MAPPING self.persistence_service.persist_task(task) self.task_futures[task_id] = self.loop.create_future() try: await self.task_futures[task_id] finally: if task_id in self.task_futures: del self.task_futures[task_id] # [/DEF:wait_for_resolution:Function] # [DEF:wait_for_input:Function] # @PURPOSE: Pauses execution and waits for user input. # @PRE: Task exists. # @POST: Execution pauses until future is set via resume_task_with_password. # @PARAM: task_id (str) - The ID of the task. async def wait_for_input(self, task_id: str): with belief_scope("TaskManager.wait_for_input", f"task_id={task_id}"): task = self.tasks.get(task_id) if not task: return # Status is already set to AWAITING_INPUT by await_input() self.task_futures[task_id] = self.loop.create_future() try: await self.task_futures[task_id] finally: if task_id in self.task_futures: del self.task_futures[task_id] # [/DEF:wait_for_input:Function] # [DEF:get_task:Function] # @PURPOSE: Retrieves a task by its ID. # @PRE: task_id is a string. # @POST: Returns Task object or None. # @PARAM: task_id (str) - ID of the task. # @RETURN: Optional[Task] - The task or None. def get_task(self, task_id: str) -> Optional[Task]: with belief_scope("TaskManager.get_task", f"task_id={task_id}"): return self.tasks.get(task_id) # [/DEF:get_task:Function] # [DEF:get_all_tasks:Function] # @PURPOSE: Retrieves all registered tasks. # @PRE: None. # @POST: Returns list of all Task objects. # @RETURN: List[Task] - All tasks. def get_all_tasks(self) -> List[Task]: with belief_scope("TaskManager.get_all_tasks"): return list(self.tasks.values()) # [/DEF:get_all_tasks:Function] # [DEF:get_tasks:Function] # @PURPOSE: Retrieves tasks with pagination and optional status filter. # @PRE: limit and offset are non-negative integers. # @POST: Returns a list of tasks sorted by start_time descending. # @PARAM: limit (int) - Maximum number of tasks to return. # @PARAM: offset (int) - Number of tasks to skip. # @PARAM: status (Optional[TaskStatus]) - Filter by task status. # @RETURN: List[Task] - List of tasks matching criteria. def get_tasks( self, limit: int = 10, offset: int = 0, status: Optional[TaskStatus] = None, plugin_ids: Optional[List[str]] = None, completed_only: bool = False ) -> List[Task]: with belief_scope("TaskManager.get_tasks"): tasks = list(self.tasks.values()) if status: tasks = [t for t in tasks if t.status == status] if plugin_ids: plugin_id_set = set(plugin_ids) tasks = [t for t in tasks if t.plugin_id in plugin_id_set] if completed_only: tasks = [t for t in tasks if t.status in [TaskStatus.SUCCESS, TaskStatus.FAILED]] # Sort by started_at descending with tolerant handling of mixed tz-aware/naive values. def sort_key(task: Task) -> float: started_at = task.started_at if started_at is None: return float("-inf") if not isinstance(started_at, datetime): return float("-inf") if started_at.tzinfo is None: return started_at.replace(tzinfo=timezone.utc).timestamp() return started_at.timestamp() tasks.sort(key=sort_key, reverse=True) return tasks[offset:offset + limit] # [/DEF:get_tasks:Function] # [DEF:get_task_logs:Function] # @PURPOSE: Retrieves logs for a specific task (from memory for running, persistence for completed). # @PRE: task_id is a string. # @POST: Returns list of LogEntry or TaskLog objects. # @PARAM: task_id (str) - ID of the task. # @PARAM: log_filter (Optional[LogFilter]) - Filter parameters. # @RETURN: List[LogEntry] - List of log entries. def get_task_logs(self, task_id: str, log_filter: Optional[LogFilter] = None) -> List[LogEntry]: with belief_scope("TaskManager.get_task_logs", f"task_id={task_id}"): task = self.tasks.get(task_id) # For completed tasks, fetch from persistence if task and task.status in [TaskStatus.SUCCESS, TaskStatus.FAILED]: if log_filter is None: log_filter = LogFilter() task_logs = self.log_persistence_service.get_logs(task_id, log_filter) # Convert TaskLog to LogEntry for backward compatibility return [ LogEntry( timestamp=log.timestamp, level=log.level, message=log.message, source=log.source, metadata=log.metadata ) for log in task_logs ] # For running/pending tasks, return from memory return task.logs if task else [] # [/DEF:get_task_logs:Function] # [DEF:get_task_log_stats:Function] # @PURPOSE: Get statistics about logs for a task. # @PRE: task_id is a valid task ID. # @POST: Returns LogStats with counts by level and source. # @PARAM: task_id (str) - The task ID. # @RETURN: LogStats - Statistics about task logs. def get_task_log_stats(self, task_id: str) -> LogStats: with belief_scope("TaskManager.get_task_log_stats", f"task_id={task_id}"): return self.log_persistence_service.get_log_stats(task_id) # [/DEF:get_task_log_stats:Function] # [DEF:get_task_log_sources:Function] # @PURPOSE: Get unique sources for a task's logs. # @PRE: task_id is a valid task ID. # @POST: Returns list of unique source strings. # @PARAM: task_id (str) - The task ID. # @RETURN: List[str] - Unique source names. def get_task_log_sources(self, task_id: str) -> List[str]: with belief_scope("TaskManager.get_task_log_sources", f"task_id={task_id}"): return self.log_persistence_service.get_sources(task_id) # [/DEF:get_task_log_sources:Function] # [DEF:_add_log:Function] # @PURPOSE: Adds a log entry to a task buffer and notifies subscribers. # @PRE: Task exists. # @POST: Log added to buffer and pushed to queues (if level meets task_log_level filter). # @PARAM: task_id (str) - ID of the task. # @PARAM: level (str) - Log level. # @PARAM: message (str) - Log message. # @PARAM: source (str) - Source component (default: "system"). # @PARAM: metadata (Optional[Dict]) - Additional structured data. # @PARAM: context (Optional[Dict]) - Legacy context (for backward compatibility). def _add_log( self, task_id: str, level: str, message: str, source: str = "system", metadata: Optional[Dict[str, Any]] = None, context: Optional[Dict[str, Any]] = None ): with belief_scope("TaskManager._add_log", f"task_id={task_id}"): task = self.tasks.get(task_id) if not task: return # Filter logs based on task_log_level configuration if not should_log_task_level(level): return # Create log entry with new fields log_entry = LogEntry( level=level, message=message, source=source, metadata=metadata, context=context # Keep for backward compatibility ) # Add to in-memory logs (for backward compatibility with legacy JSON field) task.logs.append(log_entry) # Add to buffer for batch persistence with self._log_buffer_lock: if task_id not in self._log_buffer: self._log_buffer[task_id] = [] self._log_buffer[task_id].append(log_entry) # Notify subscribers (for real-time WebSocket updates) if task_id in self.subscribers: for queue in self.subscribers[task_id]: self.loop.call_soon_threadsafe(queue.put_nowait, log_entry) # [/DEF:_add_log:Function] # [DEF:subscribe_logs:Function] # @PURPOSE: Subscribes to real-time logs for a task. # @PRE: task_id is a string. # @POST: Returns an asyncio.Queue for log entries. # @PARAM: task_id (str) - ID of the task. # @RETURN: asyncio.Queue - Queue for log entries. async def subscribe_logs(self, task_id: str) -> asyncio.Queue: with belief_scope("TaskManager.subscribe_logs", f"task_id={task_id}"): queue = asyncio.Queue() if task_id not in self.subscribers: self.subscribers[task_id] = [] self.subscribers[task_id].append(queue) return queue # [/DEF:subscribe_logs:Function] # [DEF:unsubscribe_logs:Function] # @PURPOSE: Unsubscribes from real-time logs for a task. # @PRE: task_id is a string, queue is asyncio.Queue. # @POST: Queue removed from subscribers. # @PARAM: task_id (str) - ID of the task. # @PARAM: queue (asyncio.Queue) - Queue to remove. def unsubscribe_logs(self, task_id: str, queue: asyncio.Queue): with belief_scope("TaskManager.unsubscribe_logs", f"task_id={task_id}"): if task_id in self.subscribers: if queue in self.subscribers[task_id]: self.subscribers[task_id].remove(queue) if not self.subscribers[task_id]: del self.subscribers[task_id] # [/DEF:unsubscribe_logs:Function] # [DEF:load_persisted_tasks:Function] # @PURPOSE: Load persisted tasks using persistence service. # @PRE: None. # @POST: Persisted tasks loaded into self.tasks. def load_persisted_tasks(self) -> None: with belief_scope("TaskManager.load_persisted_tasks"): loaded_tasks = self.persistence_service.load_tasks(limit=100) for task in loaded_tasks: if task.id not in self.tasks: self.tasks[task.id] = task # [/DEF:load_persisted_tasks:Function] # [DEF:await_input:Function] # @PURPOSE: Transition a task to AWAITING_INPUT state with input request. # @PRE: Task exists and is in RUNNING state. # @POST: Task status changed to AWAITING_INPUT, input_request set, persisted. # @PARAM: task_id (str) - ID of the task. # @PARAM: input_request (Dict) - Details about required input. # @THROWS: ValueError if task not found or not RUNNING. def await_input(self, task_id: str, input_request: Dict[str, Any]) -> None: with belief_scope("TaskManager.await_input", f"task_id={task_id}"): task = self.tasks.get(task_id) if not task: raise ValueError(f"Task {task_id} not found") if task.status != TaskStatus.RUNNING: raise ValueError(f"Task {task_id} is not RUNNING (current: {task.status})") task.status = TaskStatus.AWAITING_INPUT task.input_required = True task.input_request = input_request self.persistence_service.persist_task(task) self._add_log(task_id, "INFO", "Task paused for user input", {"input_request": input_request}) # [/DEF:await_input:Function] # [DEF:resume_task_with_password:Function] # @PURPOSE: Resume a task that is awaiting input with provided passwords. # @PRE: Task exists and is in AWAITING_INPUT state. # @POST: Task status changed to RUNNING, passwords injected, task resumed. # @PARAM: task_id (str) - ID of the task. # @PARAM: passwords (Dict[str, str]) - Mapping of database name to password. # @THROWS: ValueError if task not found, not awaiting input, or passwords invalid. def resume_task_with_password(self, task_id: str, passwords: Dict[str, str]) -> None: with belief_scope("TaskManager.resume_task_with_password", f"task_id={task_id}"): task = self.tasks.get(task_id) if not task: raise ValueError(f"Task {task_id} not found") if task.status != TaskStatus.AWAITING_INPUT: raise ValueError(f"Task {task_id} is not AWAITING_INPUT (current: {task.status})") if not isinstance(passwords, dict) or not passwords: raise ValueError("Passwords must be a non-empty dictionary") task.params["passwords"] = passwords task.input_required = False task.input_request = None task.status = TaskStatus.RUNNING self.persistence_service.persist_task(task) self._add_log(task_id, "INFO", "Task resumed with passwords", {"databases": list(passwords.keys())}) if task_id in self.task_futures: self.task_futures[task_id].set_result(True) # [/DEF:resume_task_with_password:Function] # [DEF:clear_tasks:Function] # @PURPOSE: Clears tasks based on status filter (also deletes associated logs). # @PRE: status is Optional[TaskStatus]. # @POST: Tasks matching filter (or all non-active) cleared from registry and database. # @PARAM: status (Optional[TaskStatus]) - Filter by task status. # @RETURN: int - Number of tasks cleared. def clear_tasks(self, status: Optional[TaskStatus] = None) -> int: with belief_scope("TaskManager.clear_tasks"): tasks_to_remove = [] for task_id, task in list(self.tasks.items()): # If status is provided, match it. # If status is None, match everything EXCEPT RUNNING (unless they are awaiting input/mapping which are technically running but paused?) # Actually, AWAITING_INPUT and AWAITING_MAPPING are distinct statuses in TaskStatus enum. # RUNNING is active execution. should_remove = False if status: if task.status == status: should_remove = True else: # Clear all non-active tasks (keep RUNNING, AWAITING_INPUT, AWAITING_MAPPING) if task.status not in [TaskStatus.RUNNING, TaskStatus.AWAITING_INPUT, TaskStatus.AWAITING_MAPPING]: should_remove = True if should_remove: tasks_to_remove.append(task_id) for tid in tasks_to_remove: # Cancel future if exists (e.g. for AWAITING_INPUT/MAPPING) if tid in self.task_futures: self.task_futures[tid].cancel() del self.task_futures[tid] del self.tasks[tid] # Remove from persistence (task_records and task_logs via CASCADE) self.persistence_service.delete_tasks(tasks_to_remove) # Also explicitly delete logs (in case CASCADE is not set up) if tasks_to_remove: self.log_persistence_service.delete_logs_for_tasks(tasks_to_remove) logger.info(f"Cleared {len(tasks_to_remove)} tasks.") return len(tasks_to_remove) # [/DEF:clear_tasks:Function] # [/DEF:TaskManager:Class] # [/DEF:TaskManagerModule:Module]