diff --git a/backend/src/api/routes/assistant.py b/backend/src/api/routes/assistant.py index 3fbf70c..b30d79c 100644 --- a/backend/src/api/routes/assistant.py +++ b/backend/src/api/routes/assistant.py @@ -1239,7 +1239,7 @@ async def _dispatch_intent( ) provider = LLMProviderService(db).get_provider(provider_id) provider_model = provider.default_model if provider else "" - if not is_multimodal_model(provider_model): + if not is_multimodal_model(provider_model, provider.provider_type if provider else None): raise HTTPException( status_code=422, detail=( diff --git a/backend/src/api/routes/tasks.py b/backend/src/api/routes/tasks.py index 31935b7..2ff59fd 100755 --- a/backend/src/api/routes/tasks.py +++ b/backend/src/api/routes/tasks.py @@ -83,9 +83,13 @@ async def create_task( db_provider = llm_service.get_provider(provider_id) if not db_provider: raise ValueError(f"LLM Provider {provider_id} not found") - if request.plugin_id == "llm_dashboard_validation" and not is_multimodal_model(db_provider.default_model): - raise ValueError( - "Selected provider model is not multimodal for dashboard validation" + if request.plugin_id == "llm_dashboard_validation" and not is_multimodal_model( + db_provider.default_model, + db_provider.provider_type, + ): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Selected provider model is not multimodal for dashboard validation", ) finally: db.close() diff --git a/backend/src/plugins/llm_analysis/plugin.py b/backend/src/plugins/llm_analysis/plugin.py index 29e11f9..d76fd5b 100644 --- a/backend/src/plugins/llm_analysis/plugin.py +++ b/backend/src/plugins/llm_analysis/plugin.py @@ -109,7 +109,7 @@ class DashboardValidationPlugin(PluginBase): llm_log.debug(f" Base URL: {db_provider.base_url}") llm_log.debug(f" Default Model: {db_provider.default_model}") llm_log.debug(f" Is Active: {db_provider.is_active}") - if not is_multimodal_model(db_provider.default_model): + if not is_multimodal_model(db_provider.default_model, db_provider.provider_type): raise ValueError( "Dashboard validation requires a multimodal model (image input support)." ) diff --git a/backend/src/services/__tests__/test_llm_prompt_templates.py b/backend/src/services/__tests__/test_llm_prompt_templates.py index 3fac778..4509cb4 100644 --- a/backend/src/services/__tests__/test_llm_prompt_templates.py +++ b/backend/src/services/__tests__/test_llm_prompt_templates.py @@ -74,6 +74,7 @@ def test_render_prompt_replaces_known_placeholders(): def test_is_multimodal_model_detects_known_vision_models(): assert is_multimodal_model("gpt-4o") is True assert is_multimodal_model("claude-3-5-sonnet") is True + assert is_multimodal_model("stepfun/step-3.5-flash:free", "openrouter") is True assert is_multimodal_model("text-only-model") is False # [/DEF:test_is_multimodal_model_detects_known_vision_models:Function] diff --git a/backend/src/services/llm_prompt_templates.py b/backend/src/services/llm_prompt_templates.py index 97cc746..aabbb6e 100644 --- a/backend/src/services/llm_prompt_templates.py +++ b/backend/src/services/llm_prompt_templates.py @@ -9,7 +9,7 @@ from __future__ import annotations from copy import deepcopy -from typing import Dict, Any +from typing import Dict, Any, Optional # [DEF:DEFAULT_LLM_PROMPTS:Constant] @@ -131,10 +131,21 @@ def normalize_llm_settings(llm_settings: Any) -> Dict[str, Any]: # @PURPOSE: Heuristically determine whether model supports image input required for dashboard validation. # @PRE: model_name may be empty or mixed-case. # @POST: Returns True when model likely supports multimodal input. -def is_multimodal_model(model_name: str) -> bool: +def is_multimodal_model(model_name: str, provider_type: Optional[str] = None) -> bool: token = (model_name or "").strip().lower() if not token: return False + provider = (provider_type or "").strip().lower() + text_only_markers = ( + "text-only", + "embedding", + "rerank", + "whisper", + "tts", + "transcribe", + ) + if any(marker in token for marker in text_only_markers): + return False multimodal_markers = ( "gpt-4o", "gpt-4.1", @@ -143,8 +154,21 @@ def is_multimodal_model(model_name: str) -> bool: "gemini", "claude-3", "claude-sonnet-4", + "omni", + "multimodal", + "pixtral", + "llava", + "internvl", + "qwen-vl", + "qwen2-vl", + "stepfun/step-3.5", ) - return any(marker in token for marker in multimodal_markers) + if any(marker in token for marker in multimodal_markers): + return True + # OpenRouter model ids are heterogeneous; keep permissive path for known StepFun family. + if provider == "openrouter" and token.startswith("stepfun/step-3.5"): + return True + return False # [/DEF:is_multimodal_model:Function]