598 lines
28 KiB
Python
598 lines
28 KiB
Python
# [DEF:TaskManagerModule:Module]
|
|
# @TIER: CRITICAL
|
|
# @SEMANTICS: task, manager, lifecycle, execution, state
|
|
# @PURPOSE: Manages the lifecycle of tasks, including their creation, execution, and state tracking. It uses a thread pool to run plugins asynchronously.
|
|
# @LAYER: Core
|
|
# @RELATION: Depends on PluginLoader to get plugin instances. It is used by the API layer to create and query tasks.
|
|
# @INVARIANT: Task IDs are unique.
|
|
# @CONSTRAINT: Must use belief_scope for logging.
|
|
|
|
# [SECTION: IMPORTS]
|
|
import asyncio
|
|
import threading
|
|
import inspect
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from datetime import datetime, timezone
|
|
from typing import Dict, Any, List, Optional
|
|
|
|
from .models import Task, TaskStatus, LogEntry, LogFilter, LogStats
|
|
from .persistence import TaskPersistenceService, TaskLogPersistenceService
|
|
from .context import TaskContext
|
|
from ..logger import logger, belief_scope, should_log_task_level
|
|
# [/SECTION]
|
|
|
|
# [DEF:TaskManager:Class]
|
|
# @SEMANTICS: task, manager, lifecycle, execution, state
|
|
# @PURPOSE: Manages the lifecycle of tasks, including their creation, execution, and state tracking.
|
|
# @TIER: CRITICAL
|
|
# @INVARIANT: Task IDs are unique within the registry.
|
|
# @INVARIANT: Each task has exactly one status at any time.
|
|
# @INVARIANT: Log entries are never deleted after being added to a task.
|
|
class TaskManager:
|
|
"""
|
|
Manages the lifecycle of tasks, including their creation, execution, and state tracking.
|
|
"""
|
|
|
|
# Log flush interval in seconds
|
|
LOG_FLUSH_INTERVAL = 2.0
|
|
|
|
# [DEF:__init__:Function]
|
|
# @PURPOSE: Initialize the TaskManager with dependencies.
|
|
# @PRE: plugin_loader is initialized.
|
|
# @POST: TaskManager is ready to accept tasks.
|
|
# @PARAM: plugin_loader - The plugin loader instance.
|
|
def __init__(self, plugin_loader):
|
|
with belief_scope("TaskManager.__init__"):
|
|
self.plugin_loader = plugin_loader
|
|
self.tasks: Dict[str, Task] = {}
|
|
self.subscribers: Dict[str, List[asyncio.Queue]] = {}
|
|
self.executor = ThreadPoolExecutor(max_workers=5) # For CPU-bound plugin execution
|
|
self.persistence_service = TaskPersistenceService()
|
|
self.log_persistence_service = TaskLogPersistenceService()
|
|
|
|
# Log buffer: task_id -> List[LogEntry]
|
|
self._log_buffer: Dict[str, List[LogEntry]] = {}
|
|
self._log_buffer_lock = threading.Lock()
|
|
|
|
# Flusher thread for batch writing logs
|
|
self._flusher_stop_event = threading.Event()
|
|
self._flusher_thread = threading.Thread(target=self._flusher_loop, daemon=True)
|
|
self._flusher_thread.start()
|
|
|
|
try:
|
|
self.loop = asyncio.get_running_loop()
|
|
except RuntimeError:
|
|
self.loop = asyncio.get_event_loop()
|
|
self.task_futures: Dict[str, asyncio.Future] = {}
|
|
|
|
# Load persisted tasks on startup
|
|
self.load_persisted_tasks()
|
|
# [/DEF:__init__:Function]
|
|
|
|
# [DEF:_flusher_loop:Function]
|
|
# @PURPOSE: Background thread that periodically flushes log buffer to database.
|
|
# @PRE: TaskManager is initialized.
|
|
# @POST: Logs are batch-written to database every LOG_FLUSH_INTERVAL seconds.
|
|
def _flusher_loop(self):
|
|
"""Background thread that flushes log buffer to database."""
|
|
with belief_scope("_flusher_loop"):
|
|
while not self._flusher_stop_event.is_set():
|
|
self._flush_logs()
|
|
self._flusher_stop_event.wait(self.LOG_FLUSH_INTERVAL)
|
|
# [/DEF:_flusher_loop:Function]
|
|
|
|
# [DEF:_flush_logs:Function]
|
|
# @PURPOSE: Flush all buffered logs to the database.
|
|
# @PRE: None.
|
|
# @POST: All buffered logs are written to task_logs table.
|
|
def _flush_logs(self):
|
|
"""Flush all buffered logs to the database."""
|
|
with belief_scope("_flush_logs"):
|
|
with self._log_buffer_lock:
|
|
task_ids = list(self._log_buffer.keys())
|
|
|
|
for task_id in task_ids:
|
|
with self._log_buffer_lock:
|
|
logs = self._log_buffer.pop(task_id, [])
|
|
|
|
if logs:
|
|
try:
|
|
self.log_persistence_service.add_logs(task_id, logs)
|
|
except Exception as e:
|
|
logger.error(f"Failed to flush logs for task {task_id}: {e}")
|
|
# Re-add logs to buffer on failure
|
|
with self._log_buffer_lock:
|
|
if task_id not in self._log_buffer:
|
|
self._log_buffer[task_id] = []
|
|
self._log_buffer[task_id].extend(logs)
|
|
# [/DEF:_flush_logs:Function]
|
|
|
|
# [DEF:_flush_task_logs:Function]
|
|
# @PURPOSE: Flush logs for a specific task immediately.
|
|
# @PRE: task_id exists.
|
|
# @POST: Task's buffered logs are written to database.
|
|
# @PARAM: task_id (str) - The task ID.
|
|
def _flush_task_logs(self, task_id: str):
|
|
"""Flush logs for a specific task immediately."""
|
|
with belief_scope("_flush_task_logs"):
|
|
with self._log_buffer_lock:
|
|
logs = self._log_buffer.pop(task_id, [])
|
|
|
|
if logs:
|
|
try:
|
|
self.log_persistence_service.add_logs(task_id, logs)
|
|
except Exception as e:
|
|
logger.error(f"Failed to flush logs for task {task_id}: {e}")
|
|
# [/DEF:_flush_task_logs:Function]
|
|
|
|
# [DEF:create_task:Function]
|
|
# @PURPOSE: Creates and queues a new task for execution.
|
|
# @PRE: Plugin with plugin_id exists. Params are valid.
|
|
# @POST: Task is created, added to registry, and scheduled for execution.
|
|
# @PARAM: plugin_id (str) - The ID of the plugin to run.
|
|
# @PARAM: params (Dict[str, Any]) - Parameters for the plugin.
|
|
# @PARAM: user_id (Optional[str]) - ID of the user requesting the task.
|
|
# @RETURN: Task - The created task instance.
|
|
# @THROWS: ValueError if plugin not found or params invalid.
|
|
async def create_task(self, plugin_id: str, params: Dict[str, Any], user_id: Optional[str] = None) -> Task:
|
|
with belief_scope("TaskManager.create_task", f"plugin_id={plugin_id}"):
|
|
if not self.plugin_loader.has_plugin(plugin_id):
|
|
logger.error(f"Plugin with ID '{plugin_id}' not found.")
|
|
raise ValueError(f"Plugin with ID '{plugin_id}' not found.")
|
|
|
|
self.plugin_loader.get_plugin(plugin_id)
|
|
|
|
if not isinstance(params, dict):
|
|
logger.error("Task parameters must be a dictionary.")
|
|
raise ValueError("Task parameters must be a dictionary.")
|
|
|
|
task = Task(plugin_id=plugin_id, params=params, user_id=user_id)
|
|
self.tasks[task.id] = task
|
|
self.persistence_service.persist_task(task)
|
|
logger.info(f"Task {task.id} created and scheduled for execution")
|
|
self.loop.create_task(self._run_task(task.id)) # Schedule task for execution
|
|
return task
|
|
# [/DEF:create_task:Function]
|
|
|
|
# [DEF:_run_task:Function]
|
|
# @PURPOSE: Internal method to execute a task with TaskContext support.
|
|
# @PRE: Task exists in registry.
|
|
# @POST: Task is executed, status updated to SUCCESS or FAILED.
|
|
# @PARAM: task_id (str) - The ID of the task to run.
|
|
async def _run_task(self, task_id: str):
|
|
with belief_scope("TaskManager._run_task", f"task_id={task_id}"):
|
|
task = self.tasks[task_id]
|
|
plugin = self.plugin_loader.get_plugin(task.plugin_id)
|
|
|
|
logger.info(f"Starting execution of task {task_id} for plugin '{plugin.name}'")
|
|
task.status = TaskStatus.RUNNING
|
|
task.started_at = datetime.utcnow()
|
|
self.persistence_service.persist_task(task)
|
|
self._add_log(task_id, "INFO", f"Task started for plugin '{plugin.name}'", source="system")
|
|
|
|
try:
|
|
# Prepare params and check if plugin supports new TaskContext
|
|
params = {**task.params, "_task_id": task_id}
|
|
|
|
# Check if plugin's execute method accepts 'context' parameter
|
|
sig = inspect.signature(plugin.execute)
|
|
accepts_context = 'context' in sig.parameters
|
|
|
|
if accepts_context:
|
|
# Create TaskContext for new-style plugins
|
|
context = TaskContext(
|
|
task_id=task_id,
|
|
add_log_fn=self._add_log,
|
|
params=params,
|
|
default_source="plugin"
|
|
)
|
|
|
|
if asyncio.iscoroutinefunction(plugin.execute):
|
|
task.result = await plugin.execute(params, context=context)
|
|
else:
|
|
task.result = await self.loop.run_in_executor(
|
|
self.executor,
|
|
lambda: plugin.execute(params, context=context)
|
|
)
|
|
else:
|
|
# Backward compatibility: old-style plugins without context
|
|
if asyncio.iscoroutinefunction(plugin.execute):
|
|
task.result = await plugin.execute(params)
|
|
else:
|
|
task.result = await self.loop.run_in_executor(
|
|
self.executor,
|
|
plugin.execute,
|
|
params
|
|
)
|
|
|
|
logger.info(f"Task {task_id} completed successfully")
|
|
task.status = TaskStatus.SUCCESS
|
|
self._add_log(task_id, "INFO", f"Task completed successfully for plugin '{plugin.name}'", source="system")
|
|
except Exception as e:
|
|
logger.error(f"Task {task_id} failed: {e}")
|
|
task.status = TaskStatus.FAILED
|
|
self._add_log(task_id, "ERROR", f"Task failed: {e}", source="system", metadata={"error_type": type(e).__name__})
|
|
finally:
|
|
task.finished_at = datetime.utcnow()
|
|
# Flush any remaining buffered logs before persisting task
|
|
self._flush_task_logs(task_id)
|
|
self.persistence_service.persist_task(task)
|
|
logger.info(f"Task {task_id} execution finished with status: {task.status}")
|
|
# [/DEF:_run_task:Function]
|
|
|
|
# [DEF:resolve_task:Function]
|
|
# @PURPOSE: Resumes a task that is awaiting mapping.
|
|
# @PRE: Task exists and is in AWAITING_MAPPING state.
|
|
# @POST: Task status updated to RUNNING, params updated, execution resumed.
|
|
# @PARAM: task_id (str) - The ID of the task.
|
|
# @PARAM: resolution_params (Dict[str, Any]) - Params to resolve the wait.
|
|
# @THROWS: ValueError if task not found or not awaiting mapping.
|
|
async def resolve_task(self, task_id: str, resolution_params: Dict[str, Any]):
|
|
with belief_scope("TaskManager.resolve_task", f"task_id={task_id}"):
|
|
task = self.tasks.get(task_id)
|
|
if not task or task.status != TaskStatus.AWAITING_MAPPING:
|
|
raise ValueError("Task is not awaiting mapping.")
|
|
|
|
# Update task params with resolution
|
|
task.params.update(resolution_params)
|
|
task.status = TaskStatus.RUNNING
|
|
self.persistence_service.persist_task(task)
|
|
self._add_log(task_id, "INFO", "Task resumed after mapping resolution.")
|
|
|
|
# Signal the future to continue
|
|
if task_id in self.task_futures:
|
|
self.task_futures[task_id].set_result(True)
|
|
# [/DEF:resolve_task:Function]
|
|
|
|
# [DEF:wait_for_resolution:Function]
|
|
# @PURPOSE: Pauses execution and waits for a resolution signal.
|
|
# @PRE: Task exists.
|
|
# @POST: Execution pauses until future is set.
|
|
# @PARAM: task_id (str) - The ID of the task.
|
|
async def wait_for_resolution(self, task_id: str):
|
|
with belief_scope("TaskManager.wait_for_resolution", f"task_id={task_id}"):
|
|
task = self.tasks.get(task_id)
|
|
if not task:
|
|
return
|
|
|
|
task.status = TaskStatus.AWAITING_MAPPING
|
|
self.persistence_service.persist_task(task)
|
|
self.task_futures[task_id] = self.loop.create_future()
|
|
|
|
try:
|
|
await self.task_futures[task_id]
|
|
finally:
|
|
if task_id in self.task_futures:
|
|
del self.task_futures[task_id]
|
|
# [/DEF:wait_for_resolution:Function]
|
|
|
|
# [DEF:wait_for_input:Function]
|
|
# @PURPOSE: Pauses execution and waits for user input.
|
|
# @PRE: Task exists.
|
|
# @POST: Execution pauses until future is set via resume_task_with_password.
|
|
# @PARAM: task_id (str) - The ID of the task.
|
|
async def wait_for_input(self, task_id: str):
|
|
with belief_scope("TaskManager.wait_for_input", f"task_id={task_id}"):
|
|
task = self.tasks.get(task_id)
|
|
if not task:
|
|
return
|
|
|
|
# Status is already set to AWAITING_INPUT by await_input()
|
|
self.task_futures[task_id] = self.loop.create_future()
|
|
|
|
try:
|
|
await self.task_futures[task_id]
|
|
finally:
|
|
if task_id in self.task_futures:
|
|
del self.task_futures[task_id]
|
|
# [/DEF:wait_for_input:Function]
|
|
|
|
# [DEF:get_task:Function]
|
|
# @PURPOSE: Retrieves a task by its ID.
|
|
# @PRE: task_id is a string.
|
|
# @POST: Returns Task object or None.
|
|
# @PARAM: task_id (str) - ID of the task.
|
|
# @RETURN: Optional[Task] - The task or None.
|
|
def get_task(self, task_id: str) -> Optional[Task]:
|
|
with belief_scope("TaskManager.get_task", f"task_id={task_id}"):
|
|
return self.tasks.get(task_id)
|
|
# [/DEF:get_task:Function]
|
|
|
|
# [DEF:get_all_tasks:Function]
|
|
# @PURPOSE: Retrieves all registered tasks.
|
|
# @PRE: None.
|
|
# @POST: Returns list of all Task objects.
|
|
# @RETURN: List[Task] - All tasks.
|
|
def get_all_tasks(self) -> List[Task]:
|
|
with belief_scope("TaskManager.get_all_tasks"):
|
|
return list(self.tasks.values())
|
|
# [/DEF:get_all_tasks:Function]
|
|
|
|
# [DEF:get_tasks:Function]
|
|
# @PURPOSE: Retrieves tasks with pagination and optional status filter.
|
|
# @PRE: limit and offset are non-negative integers.
|
|
# @POST: Returns a list of tasks sorted by start_time descending.
|
|
# @PARAM: limit (int) - Maximum number of tasks to return.
|
|
# @PARAM: offset (int) - Number of tasks to skip.
|
|
# @PARAM: status (Optional[TaskStatus]) - Filter by task status.
|
|
# @RETURN: List[Task] - List of tasks matching criteria.
|
|
def get_tasks(
|
|
self,
|
|
limit: int = 10,
|
|
offset: int = 0,
|
|
status: Optional[TaskStatus] = None,
|
|
plugin_ids: Optional[List[str]] = None,
|
|
completed_only: bool = False
|
|
) -> List[Task]:
|
|
with belief_scope("TaskManager.get_tasks"):
|
|
tasks = list(self.tasks.values())
|
|
if status:
|
|
tasks = [t for t in tasks if t.status == status]
|
|
if plugin_ids:
|
|
plugin_id_set = set(plugin_ids)
|
|
tasks = [t for t in tasks if t.plugin_id in plugin_id_set]
|
|
if completed_only:
|
|
tasks = [t for t in tasks if t.status in [TaskStatus.SUCCESS, TaskStatus.FAILED]]
|
|
# Sort by started_at descending with tolerant handling of mixed tz-aware/naive values.
|
|
def sort_key(task: Task) -> float:
|
|
started_at = task.started_at
|
|
if started_at is None:
|
|
return float("-inf")
|
|
if not isinstance(started_at, datetime):
|
|
return float("-inf")
|
|
if started_at.tzinfo is None:
|
|
return started_at.replace(tzinfo=timezone.utc).timestamp()
|
|
return started_at.timestamp()
|
|
|
|
tasks.sort(key=sort_key, reverse=True)
|
|
return tasks[offset:offset + limit]
|
|
# [/DEF:get_tasks:Function]
|
|
|
|
# [DEF:get_task_logs:Function]
|
|
# @PURPOSE: Retrieves logs for a specific task (from memory for running, persistence for completed).
|
|
# @PRE: task_id is a string.
|
|
# @POST: Returns list of LogEntry or TaskLog objects.
|
|
# @PARAM: task_id (str) - ID of the task.
|
|
# @PARAM: log_filter (Optional[LogFilter]) - Filter parameters.
|
|
# @RETURN: List[LogEntry] - List of log entries.
|
|
def get_task_logs(self, task_id: str, log_filter: Optional[LogFilter] = None) -> List[LogEntry]:
|
|
with belief_scope("TaskManager.get_task_logs", f"task_id={task_id}"):
|
|
task = self.tasks.get(task_id)
|
|
|
|
# For completed tasks, fetch from persistence
|
|
if task and task.status in [TaskStatus.SUCCESS, TaskStatus.FAILED]:
|
|
if log_filter is None:
|
|
log_filter = LogFilter()
|
|
task_logs = self.log_persistence_service.get_logs(task_id, log_filter)
|
|
# Convert TaskLog to LogEntry for backward compatibility
|
|
return [
|
|
LogEntry(
|
|
timestamp=log.timestamp,
|
|
level=log.level,
|
|
message=log.message,
|
|
source=log.source,
|
|
metadata=log.metadata
|
|
)
|
|
for log in task_logs
|
|
]
|
|
|
|
# For running/pending tasks, return from memory
|
|
return task.logs if task else []
|
|
# [/DEF:get_task_logs:Function]
|
|
|
|
# [DEF:get_task_log_stats:Function]
|
|
# @PURPOSE: Get statistics about logs for a task.
|
|
# @PRE: task_id is a valid task ID.
|
|
# @POST: Returns LogStats with counts by level and source.
|
|
# @PARAM: task_id (str) - The task ID.
|
|
# @RETURN: LogStats - Statistics about task logs.
|
|
def get_task_log_stats(self, task_id: str) -> LogStats:
|
|
with belief_scope("TaskManager.get_task_log_stats", f"task_id={task_id}"):
|
|
return self.log_persistence_service.get_log_stats(task_id)
|
|
# [/DEF:get_task_log_stats:Function]
|
|
|
|
# [DEF:get_task_log_sources:Function]
|
|
# @PURPOSE: Get unique sources for a task's logs.
|
|
# @PRE: task_id is a valid task ID.
|
|
# @POST: Returns list of unique source strings.
|
|
# @PARAM: task_id (str) - The task ID.
|
|
# @RETURN: List[str] - Unique source names.
|
|
def get_task_log_sources(self, task_id: str) -> List[str]:
|
|
with belief_scope("TaskManager.get_task_log_sources", f"task_id={task_id}"):
|
|
return self.log_persistence_service.get_sources(task_id)
|
|
# [/DEF:get_task_log_sources:Function]
|
|
|
|
# [DEF:_add_log:Function]
|
|
# @PURPOSE: Adds a log entry to a task buffer and notifies subscribers.
|
|
# @PRE: Task exists.
|
|
# @POST: Log added to buffer and pushed to queues (if level meets task_log_level filter).
|
|
# @PARAM: task_id (str) - ID of the task.
|
|
# @PARAM: level (str) - Log level.
|
|
# @PARAM: message (str) - Log message.
|
|
# @PARAM: source (str) - Source component (default: "system").
|
|
# @PARAM: metadata (Optional[Dict]) - Additional structured data.
|
|
# @PARAM: context (Optional[Dict]) - Legacy context (for backward compatibility).
|
|
def _add_log(
|
|
self,
|
|
task_id: str,
|
|
level: str,
|
|
message: str,
|
|
source: str = "system",
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
context: Optional[Dict[str, Any]] = None
|
|
):
|
|
with belief_scope("TaskManager._add_log", f"task_id={task_id}"):
|
|
task = self.tasks.get(task_id)
|
|
if not task:
|
|
return
|
|
|
|
# Filter logs based on task_log_level configuration
|
|
if not should_log_task_level(level):
|
|
return
|
|
|
|
# Create log entry with new fields
|
|
log_entry = LogEntry(
|
|
level=level,
|
|
message=message,
|
|
source=source,
|
|
metadata=metadata,
|
|
context=context # Keep for backward compatibility
|
|
)
|
|
|
|
# Add to in-memory logs (for backward compatibility with legacy JSON field)
|
|
task.logs.append(log_entry)
|
|
|
|
# Add to buffer for batch persistence
|
|
with self._log_buffer_lock:
|
|
if task_id not in self._log_buffer:
|
|
self._log_buffer[task_id] = []
|
|
self._log_buffer[task_id].append(log_entry)
|
|
|
|
# Notify subscribers (for real-time WebSocket updates)
|
|
if task_id in self.subscribers:
|
|
for queue in self.subscribers[task_id]:
|
|
self.loop.call_soon_threadsafe(queue.put_nowait, log_entry)
|
|
# [/DEF:_add_log:Function]
|
|
|
|
# [DEF:subscribe_logs:Function]
|
|
# @PURPOSE: Subscribes to real-time logs for a task.
|
|
# @PRE: task_id is a string.
|
|
# @POST: Returns an asyncio.Queue for log entries.
|
|
# @PARAM: task_id (str) - ID of the task.
|
|
# @RETURN: asyncio.Queue - Queue for log entries.
|
|
async def subscribe_logs(self, task_id: str) -> asyncio.Queue:
|
|
with belief_scope("TaskManager.subscribe_logs", f"task_id={task_id}"):
|
|
queue = asyncio.Queue()
|
|
if task_id not in self.subscribers:
|
|
self.subscribers[task_id] = []
|
|
self.subscribers[task_id].append(queue)
|
|
return queue
|
|
# [/DEF:subscribe_logs:Function]
|
|
|
|
# [DEF:unsubscribe_logs:Function]
|
|
# @PURPOSE: Unsubscribes from real-time logs for a task.
|
|
# @PRE: task_id is a string, queue is asyncio.Queue.
|
|
# @POST: Queue removed from subscribers.
|
|
# @PARAM: task_id (str) - ID of the task.
|
|
# @PARAM: queue (asyncio.Queue) - Queue to remove.
|
|
def unsubscribe_logs(self, task_id: str, queue: asyncio.Queue):
|
|
with belief_scope("TaskManager.unsubscribe_logs", f"task_id={task_id}"):
|
|
if task_id in self.subscribers:
|
|
if queue in self.subscribers[task_id]:
|
|
self.subscribers[task_id].remove(queue)
|
|
if not self.subscribers[task_id]:
|
|
del self.subscribers[task_id]
|
|
# [/DEF:unsubscribe_logs:Function]
|
|
|
|
# [DEF:load_persisted_tasks:Function]
|
|
# @PURPOSE: Load persisted tasks using persistence service.
|
|
# @PRE: None.
|
|
# @POST: Persisted tasks loaded into self.tasks.
|
|
def load_persisted_tasks(self) -> None:
|
|
with belief_scope("TaskManager.load_persisted_tasks"):
|
|
loaded_tasks = self.persistence_service.load_tasks(limit=100)
|
|
for task in loaded_tasks:
|
|
if task.id not in self.tasks:
|
|
self.tasks[task.id] = task
|
|
# [/DEF:load_persisted_tasks:Function]
|
|
|
|
# [DEF:await_input:Function]
|
|
# @PURPOSE: Transition a task to AWAITING_INPUT state with input request.
|
|
# @PRE: Task exists and is in RUNNING state.
|
|
# @POST: Task status changed to AWAITING_INPUT, input_request set, persisted.
|
|
# @PARAM: task_id (str) - ID of the task.
|
|
# @PARAM: input_request (Dict) - Details about required input.
|
|
# @THROWS: ValueError if task not found or not RUNNING.
|
|
def await_input(self, task_id: str, input_request: Dict[str, Any]) -> None:
|
|
with belief_scope("TaskManager.await_input", f"task_id={task_id}"):
|
|
task = self.tasks.get(task_id)
|
|
if not task:
|
|
raise ValueError(f"Task {task_id} not found")
|
|
if task.status != TaskStatus.RUNNING:
|
|
raise ValueError(f"Task {task_id} is not RUNNING (current: {task.status})")
|
|
|
|
task.status = TaskStatus.AWAITING_INPUT
|
|
task.input_required = True
|
|
task.input_request = input_request
|
|
self.persistence_service.persist_task(task)
|
|
self._add_log(task_id, "INFO", "Task paused for user input", {"input_request": input_request})
|
|
# [/DEF:await_input:Function]
|
|
|
|
# [DEF:resume_task_with_password:Function]
|
|
# @PURPOSE: Resume a task that is awaiting input with provided passwords.
|
|
# @PRE: Task exists and is in AWAITING_INPUT state.
|
|
# @POST: Task status changed to RUNNING, passwords injected, task resumed.
|
|
# @PARAM: task_id (str) - ID of the task.
|
|
# @PARAM: passwords (Dict[str, str]) - Mapping of database name to password.
|
|
# @THROWS: ValueError if task not found, not awaiting input, or passwords invalid.
|
|
def resume_task_with_password(self, task_id: str, passwords: Dict[str, str]) -> None:
|
|
with belief_scope("TaskManager.resume_task_with_password", f"task_id={task_id}"):
|
|
task = self.tasks.get(task_id)
|
|
if not task:
|
|
raise ValueError(f"Task {task_id} not found")
|
|
if task.status != TaskStatus.AWAITING_INPUT:
|
|
raise ValueError(f"Task {task_id} is not AWAITING_INPUT (current: {task.status})")
|
|
|
|
if not isinstance(passwords, dict) or not passwords:
|
|
raise ValueError("Passwords must be a non-empty dictionary")
|
|
|
|
task.params["passwords"] = passwords
|
|
task.input_required = False
|
|
task.input_request = None
|
|
task.status = TaskStatus.RUNNING
|
|
self.persistence_service.persist_task(task)
|
|
self._add_log(task_id, "INFO", "Task resumed with passwords", {"databases": list(passwords.keys())})
|
|
|
|
if task_id in self.task_futures:
|
|
self.task_futures[task_id].set_result(True)
|
|
# [/DEF:resume_task_with_password:Function]
|
|
|
|
# [DEF:clear_tasks:Function]
|
|
# @PURPOSE: Clears tasks based on status filter (also deletes associated logs).
|
|
# @PRE: status is Optional[TaskStatus].
|
|
# @POST: Tasks matching filter (or all non-active) cleared from registry and database.
|
|
# @PARAM: status (Optional[TaskStatus]) - Filter by task status.
|
|
# @RETURN: int - Number of tasks cleared.
|
|
def clear_tasks(self, status: Optional[TaskStatus] = None) -> int:
|
|
with belief_scope("TaskManager.clear_tasks"):
|
|
tasks_to_remove = []
|
|
for task_id, task in list(self.tasks.items()):
|
|
# If status is provided, match it.
|
|
# If status is None, match everything EXCEPT RUNNING (unless they are awaiting input/mapping which are technically running but paused?)
|
|
# Actually, AWAITING_INPUT and AWAITING_MAPPING are distinct statuses in TaskStatus enum.
|
|
# RUNNING is active execution.
|
|
|
|
should_remove = False
|
|
if status:
|
|
if task.status == status:
|
|
should_remove = True
|
|
else:
|
|
# Clear all non-active tasks (keep RUNNING, AWAITING_INPUT, AWAITING_MAPPING)
|
|
if task.status not in [TaskStatus.RUNNING, TaskStatus.AWAITING_INPUT, TaskStatus.AWAITING_MAPPING]:
|
|
should_remove = True
|
|
|
|
if should_remove:
|
|
tasks_to_remove.append(task_id)
|
|
|
|
for tid in tasks_to_remove:
|
|
# Cancel future if exists (e.g. for AWAITING_INPUT/MAPPING)
|
|
if tid in self.task_futures:
|
|
self.task_futures[tid].cancel()
|
|
del self.task_futures[tid]
|
|
|
|
del self.tasks[tid]
|
|
|
|
# Remove from persistence (task_records and task_logs via CASCADE)
|
|
self.persistence_service.delete_tasks(tasks_to_remove)
|
|
|
|
# Also explicitly delete logs (in case CASCADE is not set up)
|
|
if tasks_to_remove:
|
|
self.log_persistence_service.delete_logs_for_tasks(tasks_to_remove)
|
|
|
|
logger.info(f"Cleared {len(tasks_to_remove)} tasks.")
|
|
return len(tasks_to_remove)
|
|
# [/DEF:clear_tasks:Function]
|
|
|
|
# [/DEF:TaskManager:Class]
|
|
# [/DEF:TaskManagerModule:Module]
|