from fastapi import FastAPI, UploadFile, File, Form, HTTPException
import logging
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import io
import json
import os
import time
from typing import List, Optional, Tuple
import re
import math
import warnings

import torch
# Suppress torch.load FutureWarning - we control our own model files
warnings.filterwarnings("ignore", message=".*weights_only=False.*", category=FutureWarning)

try:
    import open_clip
    _HAS_OPEN_CLIP = True
except Exception:
    open_clip = None
    _HAS_OPEN_CLIP = False

try:
    from transformers import AutoProcessor, Blip2ForConditionalGeneration
    _HAS_BLIP2 = True
except Exception:
    AutoProcessor = None
    Blip2ForConditionalGeneration = None
    _HAS_BLIP2 = False

try:
    import clip
    _HAS_OPENAI_CLIP = hasattr(clip, "load")
except Exception:
    clip = None
    _HAS_OPENAI_CLIP = False

APP_DIR = os.path.dirname(os.path.abspath(__file__))
TAG_DB_PATH = os.path.join(APP_DIR, "tag_db.json")
TAG_CLASSIFIER_PATH = os.path.join(APP_DIR, "tag_classifier.pt")
TAG_FINETUNE_PATH = os.path.join(APP_DIR, "finetuned_clip.pt")
TAG_EMBEDDINGS_PATH = os.path.join(APP_DIR, "tag_embeddings.pt")
TRAINING_QUEUE_DIR = os.path.join(APP_DIR, "training_queue")

# Auto-create training queue directory
os.makedirs(TRAINING_QUEUE_DIR, exist_ok=True)

DEFAULT_TAGS = []

LOG_FILE_PATH = os.path.join(APP_DIR, "logs.txt")

# Configure logging: console + rotating file (keeps last 5 MB)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("ai-tag")

from logging.handlers import RotatingFileHandler
_file_handler = RotatingFileHandler(
    LOG_FILE_PATH, maxBytes=5 * 1024 * 1024, backupCount=2, encoding="utf-8"
)
_file_handler.setLevel(logging.INFO)
_file_handler.setFormatter(logging.Formatter(
    "%(asctime)s [%(levelname)s] %(name)s: %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
))
# Attach to root logger so ALL log output (uvicorn, asyncio, ai-tag) goes to the file
_root_logger = logging.getLogger()
_root_logger.addHandler(_file_handler)
_root_logger.setLevel(logging.INFO)

# Uvicorn creates its own loggers with propagate=False, so they skip the root handler.
# Explicitly attach the file handler to every uvicorn logger, and disable propagation
# to prevent duplicate entries (root logger already has the same handler).
for _uv_name in ("uvicorn", "uvicorn.access", "uvicorn.error"):
    _uv_logger = logging.getLogger(_uv_name)
    _uv_logger.addHandler(_file_handler)
    _uv_logger.propagate = False

class _IgnoreConnectionReset(logging.Filter):
    def filter(self, record: logging.LogRecord) -> bool:
        msg = record.getMessage()
        if (
            "ConnectionResetError" in msg
            or "WinError 10054" in msg
            or "ProactorBasePipeTransport" in msg
            or "_call_connection_lost" in msg
        ):
            return False
        return True

logging.getLogger("asyncio").addFilter(_IgnoreConnectionReset())

# Training status tracker
class TrainingStatus:
    def __init__(self):
        self.active = False
        self.trained = 0
        self.total = 0
        self.last_update = 0
        self.session_id = None
        self.mode = "browser"  # "browser" or "background"
        self.cancelled = False
        self.error = None
        self.last_loss = None
        self.loss_embed = None
        self.loss_classifier = None
        self.loss_clip = None
        self.loss_clip_pre = None
        self.loss_vlm = None
    
    def start(self, total: int, session_id: str = None, mode: str = "browser"):
        import time
        self.active = True
        self.trained = 0
        self.total = total
        self.last_update = time.time()
        self.session_id = session_id
        self.mode = mode
        self.cancelled = False
        self.error = None
        self.last_loss = None
        self.loss_embed = None
        self.loss_classifier = None
        self.loss_clip = None
        self.loss_clip_pre = None
        self.loss_vlm = None
    
    def update(self, trained: int, loss: float = None, loss_embed: float = None, loss_classifier: float = None, loss_clip: float = None, loss_clip_pre: float = None, loss_vlm: float = None):
        import time
        self.trained = trained
        self.last_update = time.time()
        if loss is not None:
            try:
                self.last_loss = float(loss)
            except Exception:
                self.last_loss = None
        if loss_embed is not None:
            try:
                self.loss_embed = float(loss_embed)
            except Exception:
                pass
        if loss_classifier is not None:
            try:
                self.loss_classifier = float(loss_classifier)
            except Exception:
                pass
        if loss_clip is not None:
            try:
                self.loss_clip = float(loss_clip)
            except Exception:
                pass
        if loss_clip_pre is not None:
            try:
                self.loss_clip_pre = float(loss_clip_pre)
            except Exception:
                pass
        if loss_vlm is not None:
            try:
                self.loss_vlm = float(loss_vlm)
            except Exception:
                pass
    
    def finish(self, error: str = None):
        self.active = False
        self.error = error
    
    def cancel(self):
        self.cancelled = True
    
    def is_stale(self, timeout: int = 60) -> bool:
        """Check if training appears stuck (no update in timeout seconds)"""
        import time
        if not self.active:
            return False
        # Background mode has longer timeout
        actual_timeout = 300 if self.mode == "background" else timeout
        return (time.time() - self.last_update) > actual_timeout
    
    @staticmethod
    def _safe_float(v):
        """Return None for NaN/Inf — they are not JSON-compliant."""
        import math
        if v is None:
            return None
        try:
            f = float(v)
            return f if math.isfinite(f) else None
        except Exception:
            return None

    def to_dict(self):
        import time
        sf = self._safe_float
        return {
            "active": self.active,
            "trained": self.trained,
            "total": self.total,
            "progress": round((self.trained / self.total * 100) if self.total > 0 else 0, 1),
            "last_update": self.last_update,
            "stale": self.is_stale(),
            "session_id": self.session_id,
            "mode": self.mode,
            "cancelled": self.cancelled,
            "error": self.error,
            "loss": sf(self.last_loss),
            "loss_embed": sf(self.loss_embed),
            "loss_classifier": sf(self.loss_classifier),
            "loss_clip": sf(self.loss_clip),
            "loss_clip_pre": sf(self.loss_clip_pre),
            "loss_vlm": sf(self.loss_vlm)
        }

_training_status = TrainingStatus()
_background_training_task = None

app = FastAPI(title="AI Tag Server")

app.add_middleware(
    CORSMiddleware,
    allow_origins=[
        "https://www.celestia-dominance.click",
        "https://celestiadominance.com",
        "https://www.celestiadominance.com",
        "https://localhost:8443",
        "https://127.0.0.1:8443",
        "http://localhost:8000",
        "http://127.0.0.1:8000",
        "http://localhost",
        "http://127.0.0.1",
        "http://localhost:80",
        "http://localhost:8080",
        "http://127.0.0.1:8080",
    ],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
    expose_headers=["*"]
)

# Middleware to handle Private Network Access (Chrome/Edge requirement for localhost access from public sites)
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request

class PrivateNetworkAccessMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        # Handle preflight requests for Private Network Access
        if request.method == "OPTIONS":
            response = await call_next(request)
            response.headers["Access-Control-Allow-Private-Network"] = "true"
            return response
        response = await call_next(request)
        response.headers["Access-Control-Allow-Private-Network"] = "true"
        return response

app.add_middleware(PrivateNetworkAccessMiddleware)

_device = "cuda" if torch.cuda.is_available() else "cpu"
REQUIRE_CUDA = os.getenv("REQUIRE_CUDA", "0") == "1"
if REQUIRE_CUDA and _device != "cuda":
    raise RuntimeError("CUDA required but not available. Install a CUDA-enabled PyTorch build.")

_model = None
_preprocess = None
_tokenizer = None
_fallback_labels = None
_backend = "resnet"
_caption_model = None
_caption_processor = None
_classifier = None
_classifier_tags = []
_classifier_dim = None

# Tag Embedding Network - real neural network for learned tag representations
_tag_embed_net = None
_tag_embed_tags = []
_tag_embed_dim = None
TAG_EMBED_HIDDEN = 256  # Hidden layer size
TAG_EMBED_LR = 5e-3  # Learning rate for tag embedding training (increased)
TAG_EMBED_STEPS = 20  # Training steps per image (increased)

# Persistent optimizers — reused across images for better training continuity
_tag_embed_optimizer = None
_classifier_optimizer = None
_finetune_optimizer = None
_finetune_scaler = None  # Persistent GradScaler for fine-tuning


class TagEmbeddingNet(torch.nn.Module):
    """Neural network that learns tag embeddings via contrastive learning.
    
    Unlike simple centroid averaging, this network:
    - Has learnable parameters trained with backpropagation
    - Uses contrastive loss to push positive pairs together, negatives apart
    - Learns non-linear transformations for better tag representations
    """
    
    def __init__(self, num_tags: int, feature_dim: int, hidden_dim: int = 256):
        super().__init__()
        self.num_tags = num_tags
        self.feature_dim = feature_dim
        
        # Learnable embedding for each tag - initialized with better spread
        # Use orthogonal-like initialization to prevent collapse
        init_embeddings = torch.randn(num_tags, feature_dim)
        # Normalize each embedding to unit sphere
        init_embeddings = init_embeddings / init_embeddings.norm(dim=-1, keepdim=True).clamp(min=1e-6)
        # Scale down to leave room for learning
        init_embeddings = init_embeddings * 0.5
        self.tag_embeddings = torch.nn.Parameter(init_embeddings)
        
        # Optional: Transform network for image features (projects to same space)
        self.image_transform = torch.nn.Sequential(
            torch.nn.Linear(feature_dim, hidden_dim),
            torch.nn.GELU(),
            torch.nn.Linear(hidden_dim, feature_dim),
        )
        
        # Temperature parameter for contrastive loss (learnable)
        self.log_temperature = torch.nn.Parameter(torch.tensor(0.07).log())
    
    @property
    def temperature(self):
        return self.log_temperature.exp().clamp(min=0.01, max=1.0)
    
    def get_tag_embeddings(self, normalize: bool = True):
        """Get all tag embeddings, optionally normalized."""
        emb = self.tag_embeddings
        if normalize:
            emb = emb / emb.norm(dim=-1, keepdim=True).clamp(min=1e-6)
        return emb
    
    def get_tag_embedding(self, tag_idx: int, normalize: bool = True):
        """Get embedding for a specific tag."""
        emb = self.tag_embeddings[tag_idx]
        if normalize:
            emb = emb / emb.norm().clamp(min=1e-6)
        return emb
    
    def transform_image(self, image_features: torch.Tensor, normalize: bool = True):
        """Transform image features through the network."""
        # Residual connection: original + stronger transformation
        transformed = image_features + 0.5 * self.image_transform(image_features)
        if normalize:
            transformed = transformed / transformed.norm(dim=-1, keepdim=True).clamp(min=1e-6)
        return transformed
    
    def forward(self, image_features: torch.Tensor, positive_tag_indices: List[int] = None):
        """Compute similarity scores between image and all tags.
        
        Args:
            image_features: [batch, feature_dim] or [feature_dim]
            positive_tag_indices: List of tag indices that are positive for this image
        
        Returns:
            similarity: [batch, num_tags] or [num_tags] similarity scores
        """
        if image_features.dim() == 1:
            image_features = image_features.unsqueeze(0)
        
        # Transform image features
        img_transformed = self.transform_image(image_features, normalize=True)
        
        # Get normalized tag embeddings
        tag_emb = self.get_tag_embeddings(normalize=True)
        
        # Compute cosine similarity
        similarity = (img_transformed @ tag_emb.T) / self.temperature
        
        return similarity.squeeze(0) if similarity.shape[0] == 1 else similarity
    
    def contrastive_loss(self, image_features: torch.Tensor, positive_indices: List[int]):
        """Compute contrastive loss: push positives up, negatives down.
        
        Uses InfoNCE-style loss with hard negative mining.
        Returns loss in range [0, ~2] for interpretability.
        """
        if not positive_indices:
            return torch.tensor(0.0, device=image_features.device)
        
        # Get RAW similarities (without temperature scaling for loss)
        if image_features.dim() == 1:
            image_features = image_features.unsqueeze(0)
        img_transformed = self.transform_image(image_features, normalize=True)
        tag_emb = self.get_tag_embeddings(normalize=True)
        
        # Raw cosine similarity [-1, 1]
        sims = (img_transformed @ tag_emb.T).squeeze(0)
        
        # Create mask for positive tags
        pos_mask = torch.zeros(self.num_tags, device=image_features.device, dtype=torch.bool)
        for idx in positive_indices:
            if 0 <= idx < self.num_tags:
                pos_mask[idx] = True
        
        if not pos_mask.any():
            return torch.tensor(0.0, device=image_features.device)
        
        # Positive loss: want similarity close to 1
        pos_sims = sims[pos_mask]
        pos_loss = (1.0 - pos_sims).mean()  # 0 when sim=1, 2 when sim=-1
        
        # Negative loss: want similarity close to 0 or negative
        neg_mask = ~pos_mask
        if neg_mask.any():
            neg_sims = sims[neg_mask]
            # Hard negative mining: penalize top-k most similar negatives
            k = min(10, neg_sims.shape[0])
            hard_neg_sims = torch.topk(neg_sims, k=k).values
            neg_loss = torch.relu(hard_neg_sims + 0.1).mean()  # Penalize if sim > -0.1
        else:
            neg_loss = torch.tensor(0.0, device=image_features.device)
        
        # Combined loss [0, ~2]
        total_loss = pos_loss + 0.5 * neg_loss
        
        # STRONG diversity regularization: prevent tag embeddings from collapsing
        # This is critical - without it, one tag dominates all predictions
        if self.num_tags > 1:
            tag_sims = tag_emb @ tag_emb.T  # [num_tags, num_tags]
            # Zero out diagonal (self-similarity = 1.0)
            mask = ~torch.eye(self.num_tags, dtype=torch.bool, device=tag_emb.device)
            off_diag = tag_sims[mask]
            
            # Penalize ANY positive similarity between different tags
            # Tags should be orthogonal or negatively correlated
            diversity_loss = torch.relu(off_diag).mean()  # Penalize if sim > 0
            
            # Also penalize high max similarity (the "dominant tag" problem)
            max_off_diag = off_diag.max()
            dominant_penalty = torch.relu(max_off_diag - 0.3) ** 2  # Strong penalty if any pair > 0.3
            
            total_loss = total_loss + 1.0 * diversity_loss + 2.0 * dominant_penalty
        
        return total_loss


def _load_tag_embed_net() -> None:
    """Load the tag embedding network from disk."""
    global _tag_embed_net, _tag_embed_tags, _tag_embed_dim
    if not os.path.exists(TAG_EMBEDDINGS_PATH):
        return
    try:
        data = torch.load(TAG_EMBEDDINGS_PATH, map_location=_device)
        tags = data.get("tags")
        state = data.get("state")
        dim = data.get("feature_dim")
        if not tags or not state or not dim:
            return
        _tag_embed_dim = int(dim)
        _tag_embed_tags = list(tags)
        _tag_embed_net = TagEmbeddingNet(
            num_tags=len(_tag_embed_tags),
            feature_dim=_tag_embed_dim,
            hidden_dim=TAG_EMBED_HIDDEN
        )
        _tag_embed_net.load_state_dict(state)
        _tag_embed_net.to(_device)
        _tag_embed_net.eval()
        logger.info("Loaded tag embedding network: %d tags, dim=%d", len(_tag_embed_tags), _tag_embed_dim)
    except Exception as exc:
        logger.warning("Failed to load tag embedding network: %s", exc)
        _tag_embed_net = None
        _tag_embed_tags = []
        _tag_embed_dim = None


def _save_tag_embed_net() -> None:
    """Save the tag embedding network to disk."""
    if _tag_embed_net is None or not _tag_embed_tags or _tag_embed_dim is None:
        logger.debug("Skipping tag embed save: network not initialized")
        return
    try:
        torch.save(
            {
                "tags": _tag_embed_tags,
                "state": _tag_embed_net.state_dict(),
                "feature_dim": _tag_embed_dim,
            },
            TAG_EMBEDDINGS_PATH
        )
        logger.info("✓ Saved tag_embeddings.pt (%d tags, dim=%d)", len(_tag_embed_tags), _tag_embed_dim)
    except Exception as exc:
        logger.warning("Failed to save tag embedding network: %s", exc)


def _ensure_tag_embed_net(all_tags: List[str], feature_dim: int) -> None:
    """Ensure the tag embedding network exists and has all tags."""
    global _tag_embed_net, _tag_embed_tags, _tag_embed_dim, _tag_embed_optimizer
    
    if _tag_embed_net is None or _tag_embed_dim != feature_dim:
        # Create new network — reset optimizer so it binds to new parameters
        _tag_embed_dim = feature_dim
        _tag_embed_tags = list(all_tags)
        _tag_embed_net = TagEmbeddingNet(
            num_tags=len(_tag_embed_tags),
            feature_dim=_tag_embed_dim,
            hidden_dim=TAG_EMBED_HIDDEN
        ).to(_device)
        _tag_embed_net.train()
        _tag_embed_optimizer = None
        logger.info("Created new tag embedding network: %d tags, dim=%d", len(_tag_embed_tags), _tag_embed_dim)
        return
    
    # Expand network if new tags were added
    if len(all_tags) > len(_tag_embed_tags):
        new_tags = [t for t in all_tags if t not in _tag_embed_tags]
        if new_tags:
            old_state = _tag_embed_net.state_dict()
            old_num_tags = len(_tag_embed_tags)
            _tag_embed_tags = _tag_embed_tags + new_tags
            
            # Create expanded network
            new_net = TagEmbeddingNet(
                num_tags=len(_tag_embed_tags),
                feature_dim=_tag_embed_dim,
                hidden_dim=TAG_EMBED_HIDDEN
            ).to(_device)
            
            # Copy old embeddings
            new_state = new_net.state_dict()
            new_state["tag_embeddings"][:old_num_tags] = old_state["tag_embeddings"]
            # Copy other weights
            for k in old_state:
                if k != "tag_embeddings" and k in new_state:
                    new_state[k] = old_state[k]
            
            new_net.load_state_dict(new_state)
            _tag_embed_net = new_net
            _tag_embed_net.train()
            _tag_embed_optimizer = None  # Reset — new parameters
            logger.info("Expanded tag embedding network: %d -> %d tags", old_num_tags, len(_tag_embed_tags))


def _train_tag_embed_net(image_features: torch.Tensor, tags: List[str]) -> float:
    """Train the tag embedding network on a single image with contrastive loss.
    
    This is REAL neural network training with backpropagation.
    
    Args:
        image_features: [1, feature_dim] normalized image features from CLIP
        tags: List of tag names that are positive for this image
    
    Returns:
        Final loss value
    """
    global _tag_embed_net
    
    if _tag_embed_net is None:
        return 0.0
    
    if not tags:
        return 0.0
    
    # Get indices for positive tags
    tag_to_idx = {t: i for i, t in enumerate(_tag_embed_tags)}
    positive_indices = [tag_to_idx[t] for t in tags if t in tag_to_idx]
    
    if not positive_indices:
        return 0.0
    
    # Ensure features are detached and on correct device
    features = image_features.detach().to(_device)
    if features.dim() == 2:
        features = features.squeeze(0)
    
    # Training loop with gradient descent — reuse persistent optimizer
    global _tag_embed_optimizer
    _tag_embed_net.train()
    if _tag_embed_optimizer is None:
        _tag_embed_optimizer = torch.optim.AdamW(_tag_embed_net.parameters(), lr=TAG_EMBED_LR)
    
    loss_value = 0.0
    for _ in range(TAG_EMBED_STEPS):
        _tag_embed_optimizer.zero_grad(set_to_none=True)
        
        loss = _tag_embed_net.contrastive_loss(features, positive_indices)
        
        if loss.requires_grad:
            loss.backward()
            _tag_embed_optimizer.step()
        
        loss_value = float(loss.detach().cpu().item())
    
    _tag_embed_net.eval()
    return loss_value


FINETUNE_ENABLED = os.getenv("FINETUNE_ENABLED", "1") == "1"
FINETUNE_STEPS = int(os.getenv("FINETUNE_STEPS", "2"))
FINETUNE_LR = float(os.getenv("FINETUNE_LR", "1e-6"))
FREEZE_VISUAL = os.getenv("FREEZE_VISUAL", "1") == "1"  # Freeze image encoder during fine-tune (safe for small data)
BATCH_SIZE = int(os.getenv("BATCH_SIZE", "4"))

USE_BLIP2 = os.getenv("USE_BLIP2", "1") == "1"
BLIP2_MODEL_ID = os.getenv("BLIP2_MODEL_ID", "")
HF_TOKEN = os.getenv("HF_TOKEN", "")
OPENCLIP_MODEL = os.getenv("OPENCLIP_MODEL", "ViT-B-32")
OPENCLIP_PRETRAINED = os.getenv("OPENCLIP_PRETRAINED", "openai")

def _is_scratch_pretrained(value: str) -> bool:
    return str(value or "").strip().lower() in {"", "none", "scratch", "random"}


def _select_blip2_model() -> str:
    if BLIP2_MODEL_ID:
        return BLIP2_MODEL_ID
    if _device != "cuda":
        return "Salesforce/blip2-flan-t5-base"
    try:
        props = torch.cuda.get_device_properties(0)
        vram_gb = props.total_memory / (1024 ** 3)
    except Exception:
        vram_gb = 0.0

    if vram_gb >= 48:
        return "Salesforce/blip2-flan-t5-xxl"
    if vram_gb >= 24:
        return "Salesforce/blip2-flan-t5-xl"
    return "Salesforce/blip2-flan-t5-base"

if _HAS_OPEN_CLIP:
    _backend = "open_clip"
    pretrained_arg = None if _is_scratch_pretrained(OPENCLIP_PRETRAINED) else OPENCLIP_PRETRAINED
    _model, _, _preprocess = open_clip.create_model_and_transforms(OPENCLIP_MODEL, pretrained=pretrained_arg)
    _tokenizer = open_clip.get_tokenizer(OPENCLIP_MODEL)
    _model.to(_device)
    _model.eval()
    if pretrained_arg is None:
        logger.warning("OpenCLIP started from scratch (no pretrained weights)")
elif _HAS_OPENAI_CLIP:
    _backend = "clip"
    _model, _preprocess = clip.load("ViT-B/32", device=_device)
    _model.eval()
else:
    # Fallback: use torchvision weights if openai/clip isn't available
    from torchvision import transforms
    from torchvision.models import resnet50, ResNet50_Weights

    weights = ResNet50_Weights.DEFAULT
    _fallback_labels = weights.meta.get("categories") if weights and weights.meta else None
    _model = resnet50(weights=weights)
    _model.to(_device)
    _model.eval()
    _preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

logger.info("AI tag server device=%s backend=%s", _device, _backend)

if _HAS_BLIP2 and USE_BLIP2:
    try:
        dtype = torch.float16 if _device == "cuda" else torch.float32
        selected_model_id = _select_blip2_model()
        if "itm" in selected_model_id.lower():
            logger.warning("BLIP2 ITM models are not suitable for captioning. Use blip2-flan-t5-* or blip2-opt-* for tagging.")
        token_arg = {"token": HF_TOKEN} if HF_TOKEN else {}
        _caption_processor = AutoProcessor.from_pretrained(selected_model_id, **token_arg)
        use_device_map = _device == "cuda"
        _caption_model = Blip2ForConditionalGeneration.from_pretrained(
            selected_model_id,
            **token_arg,
            torch_dtype=dtype,
            low_cpu_mem_usage=True,
            device_map="auto" if use_device_map else None
        )
        if not use_device_map:
            _caption_model.to(_device)
        _caption_model.eval()
        logger.info("BLIP2 enabled model=%s", selected_model_id)
    except Exception as exc:
        _caption_model = None
        _caption_processor = None
        logger.warning("BLIP2 disabled: %s", exc)


# ============================================================================
#  VISION LANGUAGE MODEL (VLM) — True AI Vision
#  Understands images like GPT-4V/Claude Vision and picks tags from YOUR vocabulary.
# ============================================================================
TAGGING_MODE = os.getenv("TAGGING_MODE", "clip")           # clip | vlm | hybrid
VLM_ENABLED = os.getenv("VLM_ENABLED", "0") == "1"
VLM_MODEL_ID = os.getenv("VLM_MODEL_ID", "vikhyatk/moondream2")
VLM_MAX_TAGS = int(os.getenv("VLM_MAX_TAGS", "10"))
VLM_LORA_ENABLED = os.getenv("VLM_LORA_ENABLED", "0") == "1"
VLM_LORA_LR = float(os.getenv("VLM_LORA_LR", "1e-4"))
VLM_LORA_RANK = int(os.getenv("VLM_LORA_RANK", "8"))
VLM_LORA_PATH = os.path.join(APP_DIR, "vlm_lora_adapter")

_vlm_model = None
_vlm_tokenizer = None
_vlm_processor = None      # For models that use a processor (LLaVA, Qwen, Florence)
_vlm_ready = False
_vlm_is_moondream = False   # moondream2 has a special API
_vlm_is_florence = False    # Florence-2 uses task tokens, not free-form prompts

def _load_vlm() -> None:
    """Load the Vision Language Model (and LoRA adapter if available)."""
    global _vlm_model, _vlm_tokenizer, _vlm_processor, _vlm_ready, _vlm_is_moondream, _vlm_is_florence

    if not VLM_ENABLED and TAGGING_MODE == "clip":
        logger.info("VLM disabled (VLM_ENABLED=0, TAGGING_MODE=clip)")
        return

    try:
        from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor as AP
    except ImportError:
        logger.warning("VLM disabled: transformers library not available")
        return

    model_id = VLM_MODEL_ID
    _vlm_is_moondream = "moondream" in model_id.lower()
    _vlm_is_florence = "florence" in model_id.lower()
    token_arg = {"token": HF_TOKEN} if HF_TOKEN else {}
    dtype = torch.float16 if _device == "cuda" else torch.float32

    logger.info("Loading VLM model: %s (dtype=%s, device=%s)", model_id, dtype, _device)

    try:
        if _vlm_is_moondream:
            # moondream2: uses AutoModelForCausalLM with trust_remote_code
            try:
                from transformers.dynamic_module_utils import get_class_from_dynamic_module
                from huggingface_hub import hf_hub_download as _hf_dl
                import json as _json

                _cfg_path = _hf_dl(model_id, "config.json", **token_arg)
                with open(_cfg_path, "r", encoding="utf-8") as _f:
                    _raw_cfg = _json.load(_f)
                _model_ref = _raw_cfg.get("auto_map", {}).get("AutoModelForCausalLM", "")
                if _model_ref:
                    _MdCls = get_class_from_dynamic_module(_model_ref, model_id, **token_arg)
                    if not hasattr(_MdCls, "all_tied_weights_keys"):
                        @property
                        def all_tied_weights_keys(self):
                            return getattr(self, "_tied_weights_keys", {}) or {}

                        _MdCls.all_tied_weights_keys = all_tied_weights_keys
                        logger.info("Patched moondream all_tied_weights_keys")
            except Exception as exc:
                logger.warning("Failed to patch moondream tied-weights keys: %s", exc)

            _vlm_model = AutoModelForCausalLM.from_pretrained(
                model_id,
                trust_remote_code=True,
                torch_dtype=dtype,
                low_cpu_mem_usage=True,
                **token_arg
            )
            _vlm_tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, **token_arg)
            _vlm_processor = None
        elif _vlm_is_florence:
            # Florence-2 workaround: the custom config class + processor
            # are broken with newer transformers (forced_bos_token_id and
            # additional_special_tokens AttributeErrors).
            # Strategy:
            #   1) Load slow tokenizer + image processor separately
            #   2) Import Florence2Processor class from the HF module cache
            #   3) Construct processor manually with the slow tokenizer
            #   4) Patch Florence2LanguageConfig before model loading
            from transformers import CLIPImageProcessor
            from transformers.dynamic_module_utils import get_class_from_dynamic_module
            from huggingface_hub import hf_hub_download as _hf_dl
            import json as _json
            import importlib as _imp

            _fl_tok = AutoTokenizer.from_pretrained(
                model_id, use_fast=False, trust_remote_code=True, **token_arg
            )
            # Newer transformers raises AttributeError for
            # additional_special_tokens when the list is empty.
            # Florence2Processor.__init__ reads this property, so we
            # must ensure it exists before constructing the processor.
            try:
                _ = _fl_tok.additional_special_tokens
            except AttributeError:
                _fl_tok.additional_special_tokens = []

            _fl_img_proc = CLIPImageProcessor.from_pretrained(model_id, **token_arg)

            # Import Florence2Processor directly (not in auto_map)
            _proc_ref = "processing_florence2.Florence2Processor"
            _Fl2Cls = get_class_from_dynamic_module(_proc_ref, model_id, **token_arg)
            _vlm_processor = _Fl2Cls(image_processor=_fl_img_proc, tokenizer=_fl_tok)

            # Patch Florence2LanguageConfig before model instantiation:
            # newer transformers' PretrainedConfig __getattribute__ raises
            # AttributeError for 'forced_bos_token_id' before __init__ sets it.
            _cfg_path = _hf_dl(model_id, "config.json", **token_arg)
            with open(_cfg_path, "r", encoding="utf-8") as _f:
                _raw_cfg = _json.load(_f)
            _cfg_ref = _raw_cfg.get("auto_map", {}).get("AutoConfig", "")
            if _cfg_ref:
                _Fl2CfgCls = get_class_from_dynamic_module(_cfg_ref, model_id, **token_arg)
                _cfg_mod = _imp.import_module(_Fl2CfgCls.__module__)
                for _attr_name in ("Florence2LanguageConfig",):
                    _sub = getattr(_cfg_mod, _attr_name, None)
                    if _sub is not None and not hasattr(_sub, "forced_bos_token_id"):
                        _sub.forced_bos_token_id = None
                        logger.info("Patched %s.forced_bos_token_id", _attr_name)

            # Transformers 5 initializes models on meta by default via
            # get_init_context(). Florence-2 calls .item() in __init__,
            # which crashes on meta tensors, so override the init context
            # to avoid meta for this model.
            _model_ref = _raw_cfg.get("auto_map", {}).get("AutoModelForCausalLM", "")
            if _model_ref:
                _Fl2ModelCls = get_class_from_dynamic_module(_model_ref, model_id, **token_arg)
                try:
                    from transformers.modeling_utils import local_torch_dtype as _local_torch_dtype
                    from transformers import initialization as _init

                    @classmethod
                    def _fl2_get_init_context(cls, dtype, is_quantized, _is_ds_init_called):
                        return [_local_torch_dtype(dtype, cls.__name__), _init.no_tie_weights()]

                    _Fl2ModelCls.get_init_context = _fl2_get_init_context
                    logger.info("Patched Florence2 init context to avoid meta tensors")
                except Exception as exc:
                    logger.warning("Failed to patch Florence2 init context: %s", exc)

            _vlm_model = AutoModelForCausalLM.from_pretrained(
                model_id,
                trust_remote_code=True,
                torch_dtype=dtype,
                low_cpu_mem_usage=False,
                device_map=None,
                attn_implementation="eager",   # Florence-2 lacks _supports_sdpa
                **token_arg
            )
            _vlm_tokenizer = None
        else:
            # Generic VLM: LLaVA, Qwen-VL, etc.
            _vlm_processor = AP.from_pretrained(model_id, trust_remote_code=True, **token_arg)
            _vlm_model = AutoModelForCausalLM.from_pretrained(
                model_id,
                trust_remote_code=True,
                torch_dtype=dtype,
                low_cpu_mem_usage=True,
                device_map="auto" if _device == "cuda" else None,
                **token_arg
            )
            _vlm_tokenizer = None  # Processor handles tokenization

        # Move to device (if no device_map was used)
        if not hasattr(_vlm_model, 'hf_device_map'):
            _vlm_model.to(_device)

        if _vlm_is_florence:
            # Florence-2 checkpoints store the shared embedding only; tie
            # encoder/decoder + lm_head to the shared weight explicitly.
            try:
                _fl_lm = getattr(_vlm_model, "language_model", None)
                _fl_core = getattr(_fl_lm, "model", None)
                _fl_shared = getattr(_fl_core, "shared", None)
                if _fl_shared is not None:
                    _enc = getattr(_fl_core, "encoder", None)
                    _dec = getattr(_fl_core, "decoder", None)
                    if _enc is not None and hasattr(_enc, "embed_tokens"):
                        _enc.embed_tokens.weight = _fl_shared.weight
                    if _dec is not None and hasattr(_dec, "embed_tokens"):
                        _dec.embed_tokens.weight = _fl_shared.weight
                    if _fl_lm is not None and hasattr(_fl_lm, "lm_head"):
                        _fl_lm.lm_head.weight = _fl_shared.weight
                    if hasattr(_vlm_model, "tie_weights"):
                        _vlm_model.tie_weights()
                    logger.info("Tied Florence-2 shared embeddings")
            except Exception as exc:
                logger.warning("Failed to tie Florence-2 embeddings: %s", exc)
        _vlm_model.eval()

        # Load LoRA adapter if it exists
        if os.path.isdir(VLM_LORA_PATH):
            try:
                from peft import PeftModel
                _vlm_model = PeftModel.from_pretrained(_vlm_model, VLM_LORA_PATH)
                _vlm_model.eval()
                logger.info("VLM LoRA adapter loaded from %s", VLM_LORA_PATH)
            except Exception as exc:
                logger.warning("Failed to load VLM LoRA adapter: %s", exc)

        _vlm_ready = True
        logger.info("VLM ready: %s (moondream=%s, florence=%s)", model_id, _vlm_is_moondream, _vlm_is_florence)

    except Exception as exc:
        _vlm_model = None
        _vlm_tokenizer = None
        _vlm_processor = None
        _vlm_ready = False
        logger.error("Failed to load VLM: %s", exc, exc_info=True)


def _vlm_build_prompt(tag_list: List[str], max_tags: int = 0) -> str:
    """Build the prompt that constrains VLM to only return tags from the user's vocabulary."""
    if max_tags <= 0:
        max_tags = VLM_MAX_TAGS
    tags_str = ", ".join(tag_list)
    return (
        f"You are an image tagger. Look at this image carefully and select ALL applicable tags from this list ONLY:\n"
        f"[{tags_str}]\n\n"
        f"Rules:\n"
        f"- Return ONLY tags from the list above, separated by commas.\n"
        f"- Return at most {max_tags} tags.\n"
        f"- If no tags match, return: none\n"
        f"- Do NOT add any tags not in the list.\n"
        f"- Do NOT add explanations, just the tag names."
    )


def _vlm_parse_response(response: str, valid_tags: List[str]) -> List[str]:
    """Parse VLM response and filter to only valid tags from the user's vocabulary."""
    response = response.strip().lower()
    if response in ("none", "n/a", ""):
        return []

    # Split by comma, newline, or semicolon
    raw_tags = re.split(r'[,;\n]+', response)
    raw_tags = [t.strip().strip('"').strip("'").strip('-').strip('•').strip() for t in raw_tags]

    # Build a lookup set for fuzzy matching
    valid_set = {t.lower(): t for t in valid_tags}
    valid_norm = {}
    for t in valid_tags:
        norm = re.sub(r'[-_\s]+', '-', t.strip().lower())
        if norm and norm not in valid_norm:
            valid_norm[norm] = t
    matched = []
    seen = set()
    seen_norm = set()
    for raw in raw_tags:
        raw_l = raw.lower().strip()
        if not raw_l:
            continue
        # Exact match
        if raw_l in valid_set and raw_l not in seen:
            matched.append(valid_set[raw_l])
            seen.add(raw_l)
            continue
        # Normalized match (treat hyphen/underscore/space the same)
        raw_norm = re.sub(r'[-_\s]+', '-', raw_l)
        if raw_norm in valid_norm and raw_norm not in seen_norm:
            matched.append(valid_norm[raw_norm])
            seen_norm.add(raw_norm)
            continue
        # Partial match: check if raw tag is contained in or contains a valid tag
        for vl, vt in valid_set.items():
            if vl not in seen and (vl in raw_l or raw_l in vl):
                matched.append(vt)
                seen.add(vl)
                break

    return matched[:VLM_MAX_TAGS]


def _vlm_prepare_image(image: Image.Image) -> Tuple[Image.Image, object, Optional[str]]:
    """Return an RGB PIL image, a processor-safe image input, and input format."""
    img = image
    if img.mode != "RGB":
        img = img.convert("RGB")

    try:
        import numpy as np
        arr = np.asarray(img)
        if arr.ndim == 2:
            arr = np.stack([arr, arr, arr], axis=-1)
        elif arr.shape[-1] == 4:
            arr = arr[:, :, :3]
        return img, arr, "channels_last"
    except Exception:
        return img, img, None


def _vlm_tag_image(image: Image.Image, tag_list: List[str], max_tags: int = 0) -> tuple[list, str]:
    """Use the VLM to tag an image, constrained to the user's tag vocabulary."""
    if not _vlm_ready or _vlm_model is None:
        return [], ""

    image_rgb, image_input, input_format = _vlm_prepare_image(image)

    prompt = _vlm_build_prompt(tag_list, max_tags)

    try:
        with torch.no_grad():
            if _vlm_is_moondream:
                # moondream2 API: encode_image + answer_question
                enc_image = _vlm_model.encode_image(image_rgb)
                response = _vlm_model.answer_question(enc_image, prompt, _vlm_tokenizer)
            elif _vlm_is_florence and _vlm_processor is not None:
                # Florence-2: generate a detailed caption, then match tags
                task_tok = "<MORE_DETAILED_CAPTION>"
                processor_kwargs = {
                    "text": task_tok,
                    "images": image_input,
                    "return_tensors": "pt",
                }
                if input_format:
                    processor_kwargs["input_data_format"] = input_format
                inputs = _vlm_processor(**processor_kwargs)
                inputs = {k: v.to(_device) if hasattr(v, 'to') else v for k, v in inputs.items()}
                if "pixel_values" in inputs and hasattr(inputs["pixel_values"], "to"):
                    model_dtype = next(_vlm_model.parameters()).dtype
                    inputs["pixel_values"] = inputs["pixel_values"].to(_device, dtype=model_dtype)
                gen_ids = _vlm_model.generate(
                    input_ids=inputs["input_ids"],
                    pixel_values=inputs["pixel_values"],
                    max_new_tokens=512,
                    do_sample=False,
                    use_cache=False,
                )
                gen_text = _vlm_processor.batch_decode(gen_ids, skip_special_tokens=False)[0]
                parsed = _vlm_processor.post_process_generation(
                    gen_text, task=task_tok,
                    image_size=(image_rgb.width, image_rgb.height),
                )
                caption = parsed.get(task_tok, "")
                # Match valid tags by word-boundary search in the caption
                cap_low = caption.lower()
                matched_fl = []
                for t in tag_list:
                    tl = t.lower().strip()
                    if len(tl) < 2:
                        continue
                    parts = [p for p in re.split(r'[-_\s]+', tl) if p]
                    if not parts:
                        continue
                    if len(parts) == 1:
                        pattern = r'\b' + re.escape(parts[0]) + r'\b'
                    else:
                        pattern = r'\b' + r'[\s_-]+'.join(re.escape(p) for p in parts) + r'\b'
                    if re.search(pattern, cap_low):
                        matched_fl.append(t)
                _mt = max_tags if max_tags > 0 else VLM_MAX_TAGS
                logger.info("VLM(Florence) tags: %s (caption: %s)", matched_fl[:_mt], caption[:200])
                return matched_fl[:_mt], caption
            elif _vlm_processor is not None:
                # Generic processor-based VLM (LLaVA, Qwen-VL)
                processor_kwargs = {
                    "images": image_input,
                    "text": prompt,
                    "return_tensors": "pt",
                }
                if input_format:
                    processor_kwargs["input_data_format"] = input_format
                inputs = _vlm_processor(**processor_kwargs)
                inputs = {k: v.to(_device) if hasattr(v, 'to') else v for k, v in inputs.items()}
                if "pixel_values" in inputs and hasattr(inputs["pixel_values"], "to"):
                    model_dtype = next(_vlm_model.parameters()).dtype
                    inputs["pixel_values"] = inputs["pixel_values"].to(_device, dtype=model_dtype)
                outputs = _vlm_model.generate(**inputs, max_new_tokens=200, do_sample=False)
                # Decode only the generated tokens (skip the prompt)
                input_len = inputs.get("input_ids", torch.tensor([])).shape[-1] if "input_ids" in inputs else 0
                response = _vlm_processor.decode(outputs[0][input_len:], skip_special_tokens=True)
            else:
                return [], ""

        tags = _vlm_parse_response(response, tag_list)
        logger.info("VLM tags: %s (raw response: %s)", tags, response[:200])
        return tags, ""

    except Exception as exc:
        logger.error("VLM inference failed: %s", exc, exc_info=True)
        return [], ""


def _vlm_train_on_tags(image: Image.Image, tag_list: List[str], all_tags: List[str]) -> float:
    """Fine-tune the VLM with LoRA on a single training example.

    Returns the training loss (float). If LoRA is disabled or the VLM is not loaded,
    returns 0.0 without doing anything.
    """
    global _vlm_model

    if not VLM_LORA_ENABLED or not _vlm_ready or _vlm_model is None:
        return 0.0
    if not tag_list:
        return 0.0

    image_rgb, image_input, input_format = _vlm_prepare_image(image)

    try:
        from peft import LoraConfig, get_peft_model, TaskType
    except ImportError:
        logger.warning("VLM LoRA training skipped: peft library not available")
        return 0.0

    # Apply LoRA if not already applied
    if not hasattr(_vlm_model, 'peft_config'):
        try:
            lora_config = LoraConfig(
                r=VLM_LORA_RANK,
                lora_alpha=VLM_LORA_RANK * 2,
                target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
                lora_dropout=0.05,
                bias="none",
                task_type=TaskType.CAUSAL_LM
            )
            _vlm_model = get_peft_model(_vlm_model, lora_config)
            logger.info("VLM LoRA applied: rank=%d, alpha=%d", VLM_LORA_RANK, VLM_LORA_RANK * 2)
        except Exception as exc:
            logger.error("Failed to apply LoRA to VLM: %s", exc)
            return 0.0

    # Build training example: prompt + expected answer
    prompt = _vlm_build_prompt(all_tags)
    expected_answer = ", ".join(tag_list)

    try:
        _vlm_model.train()

        if _vlm_is_moondream and _vlm_tokenizer is not None:
            # moondream2: training with tokenizer
            full_text = f"Question: {prompt}\nAnswer: {expected_answer}"
            tokens = _vlm_tokenizer(full_text, return_tensors="pt", truncation=True, max_length=512)
            tokens = {k: v.to(_device) for k, v in tokens.items()}
            tokens["labels"] = tokens["input_ids"].clone()
            outputs = _vlm_model(**tokens)
            loss = outputs.loss
        elif _vlm_is_florence and _vlm_processor is not None:
            # Florence-2: train on caption generation task
            task_tok = "<MORE_DETAILED_CAPTION>"
            processor_kwargs = {
                "text": task_tok,
                "images": image_input,
                "return_tensors": "pt",
            }
            if input_format:
                processor_kwargs["input_data_format"] = input_format
            inputs = _vlm_processor(**processor_kwargs)
            inputs = {k: v.to(_device) if hasattr(v, 'to') else v for k, v in inputs.items()}
            if "pixel_values" in inputs and hasattr(inputs["pixel_values"], "to"):
                model_dtype = next(_vlm_model.parameters()).dtype
                inputs["pixel_values"] = inputs["pixel_values"].to(_device, dtype=model_dtype)
            labels = _vlm_processor.tokenizer(
                expected_answer, return_tensors="pt",
                padding=True, truncation=True, max_length=512,
            ).input_ids.to(_device)
            inputs["labels"] = labels
            outputs = _vlm_model(**inputs)
            loss = outputs.loss
        elif _vlm_processor is not None:
            # Generic VLM training
            full_text = f"{prompt}\nAnswer: {expected_answer}"
            processor_kwargs = {
                "images": image_input,
                "text": full_text,
                "return_tensors": "pt",
            }
            if input_format:
                processor_kwargs["input_data_format"] = input_format
            inputs = _vlm_processor(**processor_kwargs)
            inputs = {k: v.to(_device) if hasattr(v, 'to') else v for k, v in inputs.items()}
            if "pixel_values" in inputs and hasattr(inputs["pixel_values"], "to"):
                model_dtype = next(_vlm_model.parameters()).dtype
                inputs["pixel_values"] = inputs["pixel_values"].to(_device, dtype=model_dtype)
            inputs["labels"] = inputs["input_ids"].clone() if "input_ids" in inputs else None
            if inputs.get("labels") is None:
                _vlm_model.eval()
                return 0.0
            outputs = _vlm_model(**inputs)
            loss = outputs.loss
        else:
            _vlm_model.eval()
            return 0.0

        if loss is None or not math.isfinite(loss.item()):
            _vlm_model.eval()
            return 0.0

        loss.backward()

        # Simple optimizer step (create per-call to keep it stateless for now)
        from torch.optim import AdamW
        trainable = [p for p in _vlm_model.parameters() if p.requires_grad]
        if trainable:
            opt = AdamW(trainable, lr=VLM_LORA_LR)
            opt.step()
            opt.zero_grad()

        loss_val = float(loss.detach().cpu().item())
        _vlm_model.eval()
        return loss_val

    except Exception as exc:
        logger.error("VLM LoRA training step failed: %s", exc, exc_info=True)
        _vlm_model.eval()
        return 0.0


def _vlm_save_lora() -> None:
    """Save the VLM LoRA adapter weights to disk."""
    if _vlm_model is None or not hasattr(_vlm_model, 'peft_config'):
        return
    try:
        _vlm_model.save_pretrained(VLM_LORA_PATH)
        logger.info("VLM LoRA adapter saved to %s", VLM_LORA_PATH)
    except Exception as exc:
        logger.error("Failed to save VLM LoRA adapter: %s", exc)


# Load VLM at startup if enabled
_load_vlm()


def _sanitize_tag(tag: str) -> str:
    if not isinstance(tag, str):
        return ""
    cleaned = tag.strip().lower()
    if not cleaned:
        return ""
    if _is_noise_tag(cleaned):
        return ""
    return cleaned


def _is_noise_tag(tag: str) -> bool:
    if len(tag) > 30:
        return True
    if "unsplash" in tag:
        return True
    if re.fullmatch(r"[a-f0-9]{8,}", tag):
        return True
    digit_count = sum(ch.isdigit() for ch in tag)
    letter_count = sum(ch.isalpha() for ch in tag)
    if digit_count >= 3 and letter_count >= 3 and len(tag) >= 10:
        return True
    return False


def _load_tag_db() -> List[str]:
    if not os.path.exists(TAG_DB_PATH):
        return []
    try:
        with open(TAG_DB_PATH, "r", encoding="utf-8") as f:
            data = json.load(f)
        if isinstance(data, list):
            return [s for s in (_sanitize_tag(t) for t in data if isinstance(t, str) and t.strip()) if s]
    except Exception:
        pass
    return []


def _save_tag_db(tags: List[str]) -> None:
    try:
        with open(TAG_DB_PATH, "w", encoding="utf-8") as f:
            json.dump(sorted(set(tags)), f, ensure_ascii=False, indent=2)
    except Exception:
        pass


def _load_classifier() -> None:
    global _classifier, _classifier_tags, _classifier_dim
    if not os.path.exists(TAG_CLASSIFIER_PATH):
        return
    try:
        data = torch.load(TAG_CLASSIFIER_PATH, map_location=_device)
        tags = data.get("tags")
        state = data.get("state")
        dim = data.get("feature_dim")
        if not tags or not state or not dim:
            return
        _classifier_dim = int(dim)
        _classifier_tags = list(tags)
        _classifier = torch.nn.Linear(_classifier_dim, len(_classifier_tags))
        _classifier.load_state_dict(state)
        _classifier.to(_device)
        _classifier.eval()
    except Exception:
        _classifier = None
        _classifier_tags = []
        _classifier_dim = None


def _save_classifier() -> None:
    if _classifier is None or not _classifier_tags or _classifier_dim is None:
        return
    try:
        torch.save(
            {
                "tags": _classifier_tags,
                "state": _classifier.state_dict(),
                "feature_dim": _classifier_dim,
            },
            TAG_CLASSIFIER_PATH
        )
    except Exception:
        pass


def _ensure_classifier(all_tags: List[str], feature_dim: int) -> None:
    global _classifier, _classifier_tags, _classifier_dim, _classifier_optimizer
    if _classifier is None or _classifier_dim != feature_dim:
        _classifier_dim = feature_dim
        _classifier_tags = list(all_tags)
        _classifier = torch.nn.Linear(feature_dim, len(_classifier_tags)).to(_device)
        _classifier.train()
        _classifier_optimizer = None  # Reset — new parameters
        return

    # Expand classifier if new tags were added
    if len(all_tags) > len(_classifier_tags):
        new_classifier = torch.nn.Linear(feature_dim, len(all_tags)).to(_device)
        with torch.no_grad():
            new_classifier.weight[:len(_classifier_tags)] = _classifier.weight
            new_classifier.bias[:len(_classifier_tags)] = _classifier.bias
        _classifier = new_classifier
        _classifier_tags = list(all_tags)
        _classifier.train()
        _classifier_optimizer = None  # Reset — new parameters


def _load_finetune_weights() -> None:
    if not os.path.exists(TAG_FINETUNE_PATH):
        if _is_scratch_pretrained(OPENCLIP_PRETRAINED):
            logger.info("No finetuned weights found - starting from scratch")
        return
    if _backend not in {"open_clip", "clip"}:
        return
    try:
        state = torch.load(TAG_FINETUNE_PATH, map_location=_device)
        _model.load_state_dict(state, strict=False)
        logger.info("✓ Loaded finetuned CLIP weights from: %s", TAG_FINETUNE_PATH)
    except Exception as exc:
        logger.warning("Failed to load finetuned weights: %s", exc)


def _save_finetune_weights() -> None:
    if _backend not in {"open_clip", "clip"}:
        return
    try:
        torch.save(_model.state_dict(), TAG_FINETUNE_PATH)
    except Exception as exc:
        logger.warning("Failed to save finetuned weights: %s", exc)


def _finetune_on_tags(image: Image.Image, tag_list: List[str]) -> tuple:
    """Fine-tune CLIP model on image-tag pairs.
    
    Returns (pre_loss, post_loss) tuple:
      - pre_loss:  negative mean cosine similarity BEFORE training on this image
      - post_loss: negative mean cosine similarity AFTER training on this image
    Comparing the two shows whether the model is actually learning beyond
    just memorising the current sample.
    """
    if not FINETUNE_ENABLED:
        return (0.0, 0.0)
    if _backend not in {"open_clip", "clip"}:
        return (0.0, 0.0)
    if not tag_list:
        return (0.0, 0.0)

    prompts = [f"a photo of {t}" for t in tag_list if isinstance(t, str) and t.strip()]
    if not prompts:
        return (0.0, 0.0)

    image_input = _preprocess(image).unsqueeze(0).to(_device)
    if _backend == "open_clip":
        text_input = _tokenizer(prompts)
        if not isinstance(text_input, torch.Tensor):
            text_input = torch.tensor(text_input)
        text_input = text_input.to(_device)
    else:
        text_input = clip.tokenize(prompts).to(_device)

    global _finetune_optimizer, _finetune_scaler

    # ── Measure PRE-training similarity (no gradients) ──────────────
    with torch.no_grad():
        with torch.amp.autocast("cuda", enabled=(_device == "cuda")):
            img_f = _model.encode_image(image_input)
            txt_f = _model.encode_text(text_input)
            img_f = img_f / img_f.norm(dim=-1, keepdim=True)
            txt_f = txt_f / txt_f.norm(dim=-1, keepdim=True)
            pre_sims = (img_f @ txt_f.T).squeeze(0)
    pre_loss = float((-pre_sims.mean()).cpu().item())

    # ── Optionally freeze image encoder to prevent collapse with small datasets ──
    # When FREEZE_VISUAL=1 (default), only text encoder + projections are trainable.
    # Set FREEZE_VISUAL=0 for large-scale training where the encoder needs to learn.
    if FREEZE_VISUAL:
        for name, param in _model.named_parameters():
            param.requires_grad = not name.startswith("visual.")
    else:
        for param in _model.parameters():
            param.requires_grad = True

    # ── Optimisation loop ───────────────────────────────────────────
    _model.train()
    trainable_params = [p for p in _model.parameters() if p.requires_grad]
    if _finetune_optimizer is None or len(trainable_params) != len(_finetune_optimizer.param_groups[0]["params"]):
        # (Re)build optimizer for text-only params
        _finetune_optimizer = torch.optim.AdamW(trainable_params, lr=FINETUNE_LR)
    if _finetune_scaler is None:
        _finetune_scaler = torch.amp.GradScaler("cuda", enabled=(_device == "cuda"))

    import math
    for step in range(max(1, FINETUNE_STEPS)):
        _finetune_optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast("cuda", enabled=(_device == "cuda")):
            image_features = _model.encode_image(image_input)
            text_features = _model.encode_text(text_input)
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            sims = (image_features @ text_features.T).squeeze(0)
            loss = -sims.mean()

        # Guard: if loss is NaN/Inf the gradients are garbage — abort early
        loss_val = float(loss.detach().cpu().item())
        if not math.isfinite(loss_val):
            logger.warning("CLIP finetune NaN/Inf detected at step %d — skipping remaining steps", step)
            # Return pre_loss for both values to signal no useful training happened
            _model.eval()
            for param in _model.parameters():
                param.requires_grad = True
            return (pre_loss, pre_loss)

        _finetune_scaler.scale(loss).backward()
        _finetune_scaler.step(_finetune_optimizer)
        _finetune_scaler.update()

    post_loss = loss_val
    _model.eval()
    # Restore requires_grad on all params (frozen only during optimisation)
    for param in _model.parameters():
        param.requires_grad = True
    return (pre_loss, post_loss)


_load_classifier()
_load_finetune_weights()
_load_tag_embed_net()

# For scratch mode: save initial random weights if no file exists yet
# This ensures consistent weights across restarts even without fine-tuning
if _is_scratch_pretrained(OPENCLIP_PRETRAINED) and not os.path.exists(TAG_FINETUNE_PATH):
    logger.info("Saving initial scratch weights for persistence...")
    _save_finetune_weights()


def _get_candidate_tags(existing_tags: List[str]) -> List[str]:
    db_tags = _load_tag_db()
    merged = list(dict.fromkeys([*_sanitize_tags(existing_tags), *db_tags, *DEFAULT_TAGS]))
    return [t for t in merged if t]


def _sanitize_tags(tags) -> List[str]:
    """Sanitize a list of tags. Also handles [{filename, tags}] dicts gracefully."""
    if not isinstance(tags, list):
        return []
    result = []
    for t in tags:
        if isinstance(t, dict):
            inner = t.get("tags", [])
            if isinstance(inner, list):
                result.extend(_sanitize_tag(s) for s in inner if isinstance(s, str) and s.strip())
        elif isinstance(t, str) and t.strip():
            result.append(_sanitize_tag(t))
    return [r for r in result if r]


_STOPWORDS = {
    "a", "an", "the", "and", "or", "but", "if", "then", "than", "of", "to", "in", "on", "at",
    "for", "with", "without", "over", "under", "near", "by", "from", "is", "are", "was", "were",
    "be", "been", "being", "this", "that", "these", "those", "it", "its", "as", "into", "through",
    "around", "up", "down", "left", "right", "front", "back", "new", "old", "young", "man", "woman",
    "person", "people", "group", "photo", "picture", "image", "video", "clip", "scene", "view"
}


def _extract_caption_tags(caption: str) -> List[str]:
    if not caption:
        return []
    tokens = re.findall(r"[a-zA-Z]{3,}", caption.lower())
    tags = []
    for t in tokens:
        if t in _STOPWORDS:
            continue
        tags.append(t)
    # de-duplicate while keeping order
    return list(dict.fromkeys(tags))


def _caption_image(image: Image.Image) -> str:
    if _caption_model is None or _caption_processor is None:
        return ""
    try:
        inputs = _caption_processor(images=image, return_tensors="pt")
        inputs = {k: v.to(_device) for k, v in inputs.items()}
        if _device == "cuda" and "pixel_values" in inputs:
            inputs["pixel_values"] = inputs["pixel_values"].half()
        output = _caption_model.generate(**inputs, max_new_tokens=24)
        return _caption_processor.decode(output[0], skip_special_tokens=True)
    except Exception:
        return ""


@app.get("/health")
def health():
    return {"status": "ok"}


@app.get("/status")
def status():
    """Get server status including training progress and settings"""
    return {
        "status": "ok",
        "training": _training_status.to_dict(),
        "batch_size": BATCH_SIZE,
        "finetune_enabled": FINETUNE_ENABLED,
        "finetune_steps": FINETUNE_STEPS,
        "tagging_mode": TAGGING_MODE,
        "tag_embed_net": {
            "enabled": _tag_embed_net is not None,
            "num_tags": len(_tag_embed_tags) if _tag_embed_tags else 0,
            "feature_dim": _tag_embed_dim
        },
        "vlm": {
            "enabled": VLM_ENABLED,
            "ready": _vlm_ready,
            "model_id": VLM_MODEL_ID if VLM_ENABLED else None,
            "lora_enabled": VLM_LORA_ENABLED,
            "lora_loaded": os.path.isdir(VLM_LORA_PATH),
            "max_tags": VLM_MAX_TAGS
        }
    }


@app.post("/ai-train-start")
async def ai_train_start(total: int = Form(0), session_id: str = Form(None)):
    """Signal training session start (for multi-tab sync)"""
    _training_status.start(total, session_id)
    return {"status": "ok"}


@app.post("/ai-train-end")
async def ai_train_end():
    """Signal training session end"""
    _training_status.finish()
    return {"status": "ok"}


@app.post("/ai-train-cancel")
async def ai_train_cancel():
    """Cancel background training"""
    global _background_training_task
    _training_status.cancel()
    if _background_training_task and not _background_training_task.done():
        _background_training_task.cancel()
    # Clean up remaining queue files
    _cleanup_training_queue()
    return {"status": "ok", "message": "Cancellation requested"}


def _cleanup_training_queue():
    """Remove all files in the training queue folder"""
    import glob
    try:
        for f in glob.glob(os.path.join(TRAINING_QUEUE_DIR, "*")):
            try:
                os.remove(f)
            except Exception as e:
                logger.warning("Failed to remove queue file %s: %s", f, e)
        logger.info("Training queue cleaned")
    except Exception as e:
        logger.warning("Failed to clean training queue: %s", e)


async def _process_training_queue():
    """Background task to process queued training samples in batches"""
    global _classifier_optimizer
    import asyncio
    import glob
    
    try:
        queue_files = sorted(glob.glob(os.path.join(TRAINING_QUEUE_DIR, "*.json")))
        total = len(queue_files)
        if total == 0:
            _training_status.finish()
            return
        
        _training_status.start(total, mode="background")
        trained = 0
        batch_size = max(1, min(BATCH_SIZE, 16))  # Clamp to 1-16
        
        # Load tag DB once — accumulate in memory, save once per batch
        accumulated_tags = set(_load_tag_db())
        
        # Process in batches
        for batch_start in range(0, len(queue_files), batch_size):
            # Check for cancellation
            if _training_status.cancelled:
                logger.info("Background training cancelled at %d/%d", trained, total)
                break
            
            batch_files = queue_files[batch_start:batch_start + batch_size]
            batch_images = []
            batch_tags = []
            batch_paths = []
            batch_meta_paths = []
            
            # Load batch
            for queue_file in batch_files:
                try:
                    with open(queue_file, 'r') as f:
                        meta = json.load(f)
                    
                    image_path = meta.get("image_path")
                    tags = meta.get("tags", [])
                    
                    if not image_path or not os.path.exists(image_path):
                        logger.warning("Queue item missing image: %s", queue_file)
                        os.remove(queue_file)
                        continue
                    
                    image = Image.open(image_path).convert("RGB")
                    batch_images.append(image)
                    batch_tags.append(tags)
                    batch_paths.append(image_path)
                    batch_meta_paths.append(queue_file)
                except Exception as exc:
                    logger.warning("Failed to load queue item %s: %s", queue_file, exc)
                    if os.path.exists(queue_file):
                        os.remove(queue_file)
            
            if not batch_images:
                continue
            
            # Encode batch together for efficiency
            try:
                with torch.no_grad():
                    # Stack preprocessed images into a batch tensor
                    batch_tensors = torch.stack([_preprocess(img) for img in batch_images]).to(_device)
                    
                    if _backend == "open_clip":
                        batch_features = _model.encode_image(batch_tensors)
                    elif _backend == "clip":
                        batch_features = _model.encode_image(batch_tensors)
                    else:
                        logits = _model(batch_tensors)
                        batch_features = torch.softmax(logits, dim=-1)
                    
                    batch_features = batch_features / batch_features.norm(dim=-1, keepdim=True)
            except Exception as exc:
                logger.exception("Batch encoding failed: %s", exc)
                # Fallback: remove batch items and continue
                for qf in batch_meta_paths:
                    if os.path.exists(qf):
                        os.remove(qf)
                for ip in batch_paths:
                    if os.path.exists(ip):
                        os.remove(ip)
                continue
            
            # Process each item in batch
            loss_value = None
            for i, (image, tags) in enumerate(zip(batch_images, batch_tags)):
                embed_loss = 0.0
                clip_loss = 0.0
                try:
                    image_features = batch_features[i:i+1]  # Keep batch dimension
                    feature_dim = image_features.shape[-1]
                    
                    # Accumulate tags in memory (saved once per batch)
                    accumulated_tags.update(tags)
                    all_tags = sorted(accumulated_tags)
                    
                    # ============================================
                    # REAL NEURAL NETWORK TRAINING: Tag Embedding Network
                    # Uses contrastive learning with backpropagation
                    # ============================================
                    try:
                        _ensure_tag_embed_net(all_tags, feature_dim)
                        embed_loss = _train_tag_embed_net(image_features, tags)
                        # Note: Save happens once per batch below, not per image
                    except Exception as exc:
                        logger.warning("Tag embedding training skipped for item %d: %s", i, exc)
                    
                    # Fine-tune classifier (real training)
                    try:
                        _ensure_classifier(all_tags, feature_dim)
                        if _classifier is not None:
                            _classifier.train()
                            target = torch.zeros(len(_classifier_tags), device=_device)
                            tag_index = {t: i for i, t in enumerate(_classifier_tags)}
                            for t in tags:
                                idx = tag_index.get(t)
                                if idx is not None:
                                    target[idx] = 1.0
                            if _classifier_optimizer is None:
                                _classifier_optimizer = torch.optim.AdamW(_classifier.parameters(), lr=1e-3)
                            loss_fn = torch.nn.BCEWithLogitsLoss()
                            feats = image_features.detach()
                            for _ in range(5):
                                _classifier_optimizer.zero_grad(set_to_none=True)
                                logits = _classifier(feats).squeeze(0)
                                loss = loss_fn(logits, target)
                                loss.backward()
                                _classifier_optimizer.step()
                            try:
                                loss_value = float(loss.detach().cpu().item())
                            except Exception:
                                loss_value = None
                            # Note: Save happens once per batch below, not per image
                    except Exception as exc:
                        logger.warning("Classifier fine-tune skipped for item %d: %s", i, exc)
                    
                    # End-to-end fine-tune (per image)
                    clip_pre = 0.0
                    clip_post = 0.0
                    try:
                        clip_pre, clip_post = _finetune_on_tags(image, tags)
                    except Exception as exc:
                        logger.warning("Finetune skipped for item %d: %s", i, exc)

                    # VLM LoRA training (per image, if enabled)
                    vlm_loss = 0.0
                    if VLM_LORA_ENABLED and _vlm_ready:
                        try:
                            vlm_loss = _vlm_train_on_tags(image, tags, all_tags)
                        except Exception as exc:
                            logger.warning("VLM LoRA training skipped for item %d: %s", i, exc)
                    
                    # Log all losses to console (CLIP shows pre→post to track generalisation)
                    logger.info(
                        "[Training %d/%d] Tags: %s | Losses: Embed=%.4f, Classifier=%.4f, CLIP=%.4f→%.4f, VLM=%.4f",
                        trained + 1, total, tags[:3],
                        embed_loss,
                        loss_value if loss_value else 0.0,
                        clip_pre, clip_post,
                        vlm_loss
                    )
                    
                    trained += 1
                    _training_status.update(
                        trained, 
                        loss_value,
                        loss_embed=embed_loss,
                        loss_classifier=loss_value,
                        loss_clip=clip_post,
                        loss_clip_pre=clip_pre,
                        loss_vlm=vlm_loss
                    )
                    
                except Exception as exc:
                    logger.exception("Training failed for batch item %d: %s", i, exc)
            
            # Cleanup batch files
            for qf, ip in zip(batch_meta_paths, batch_paths):
                try:
                    if os.path.exists(qf):
                        os.remove(qf)
                    if os.path.exists(ip):
                        os.remove(ip)
                except Exception:
                    pass
            
            # Save models and tags once per batch (not per image)
            _save_tag_db(list(accumulated_tags))
            _save_tag_embed_net()
            _save_classifier()
            _save_finetune_weights()
            if VLM_LORA_ENABLED and _vlm_ready:
                _vlm_save_lora()
            
            # Small delay between batches to prevent overload
            await asyncio.sleep(0.05)
        
        if _training_status.cancelled:
            _training_status.finish()
            _cleanup_training_queue()  # Clean remaining files
            logger.info("Background training cancelled: %d/%d completed", trained, total)
        else:
            _training_status.finish()
            _cleanup_training_queue()  # Ensure queue is empty
            logger.info("Background training complete: %d/%d", trained, total)
            
    except Exception as exc:
        logger.exception("Background training error: %s", exc)
        _training_status.finish(error=str(exc))


@app.post("/ai-train-batch")
async def ai_train_batch(
    files: List[UploadFile] = File(...),
    tags_json: str = Form("[]")
):
    """
    Queue multiple files for background training.
    tags_json is a JSON array where each element corresponds to a file's tags.
    Example: [["tag1", "tag2"], ["tag3"], ...]
    """
    global _background_training_task
    import asyncio
    import uuid
    
    # Check if already training
    if _training_status.active and not _training_status.cancelled:
        raise HTTPException(status_code=409, detail="Training already in progress")
    
    try:
        all_tags = json.loads(tags_json) if tags_json else []
    except Exception:
        all_tags = []
    
    if len(files) != len(all_tags):
        raise HTTPException(
            status_code=400, 
            detail=f"Mismatch: {len(files)} files but {len(all_tags)} tag arrays"
        )
    
    queued = 0
    for i, file in enumerate(files):
        try:
            raw = await file.read()
            # Validate it's an image
            image = Image.open(io.BytesIO(raw)).convert("RGB")
            
            # Save to queue
            item_id = f"{int(time.time() * 1000)}_{i}_{uuid.uuid4().hex[:8]}"
            image_path = os.path.join(TRAINING_QUEUE_DIR, f"{item_id}.jpg")
            meta_path = os.path.join(TRAINING_QUEUE_DIR, f"{item_id}.json")
            
            image.save(image_path, "JPEG", quality=90)
            
            raw_tags = all_tags[i] if i < len(all_tags) else []
            # Support both formats: [{filename, tags}, ...] and [[tag1, tag2], ...]
            if isinstance(raw_tags, dict):
                raw_tags = raw_tags.get("tags", [])
            file_tags = _sanitize_tags(raw_tags) if isinstance(raw_tags, list) else []
            with open(meta_path, 'w') as f:
                json.dump({"image_path": image_path, "tags": file_tags}, f)
            
            queued += 1
        except Exception as exc:
            logger.warning("Failed to queue file %s: %s", file.filename, exc)
    
    if queued == 0:
        raise HTTPException(status_code=400, detail="No valid files to queue")
    
    # Start background training
    _background_training_task = asyncio.create_task(_process_training_queue())
    
    logger.info("Queued %d files for background training", queued)
    return {"status": "ok", "queued": queued, "mode": "background"}


@app.get("/ai-train-queue")
def ai_train_queue():
    """Get current training queue status"""
    import glob
    queue_files = glob.glob(os.path.join(TRAINING_QUEUE_DIR, "*.json"))
    return {
        "status": "ok",
        "queued": len(queue_files),
        "training": _training_status.to_dict()
    }


@app.post("/ai-tags")
async def ai_tags(
    file: UploadFile = File(...),
    filename: str = Form(""),
    existingTags: str = Form("[]"),
    maxTags: int = Form(8)
):
    global _tag_embed_net, _tag_embed_tags, _tag_embed_dim
    global _classifier, _classifier_tags, _classifier_dim
    
    logger.info("AI tag request: filename=%s content_type=%s mode=%s", filename, file.content_type, TAGGING_MODE)
    # Merge existing tags into the tag DB (simple "self-training" vocabulary growth)
    try:
        existing_list = json.loads(existingTags) if existingTags else []
    except Exception:
        existing_list = []

    existing_list = _sanitize_tags(existing_list)
    if existing_list:
        _save_tag_db(_load_tag_db() + existing_list)

    # Only handle images here (videos require server-side frame extraction)
    try:
        raw = await file.read()
        image = Image.open(io.BytesIO(raw)).convert("RGB")
    except Exception as exc:
        logger.exception("Failed to read image")
        raise HTTPException(status_code=400, detail=f"Invalid image: {exc}")

    caption = _caption_image(image)
    caption_tags = _extract_caption_tags(caption)
    candidates = _get_candidate_tags(existing_list)
    if caption_tags:
        candidates = list(dict.fromkeys([*caption_tags, *candidates]))
        logger.info("BLIP2 caption: %s", caption)
    if not candidates:
        return {"tags": [], "caption": caption}

    vlm_tags = []
    vlm_caption = ""
    vlm_attempted = False

    # VLM-only mode: skip CLIP inference entirely for performance
    if TAGGING_MODE == "vlm" and _vlm_ready:
        vlm_attempted = True
        vlm_tags, vlm_caption = _vlm_tag_image(image, candidates, maxTags)
        if vlm_tags:
            logger.info("VLM-only tags (fast path): %s", vlm_tags)
            return {"tags": vlm_tags, "caption": vlm_caption or caption}
        logger.warning("VLM returned no tags, falling back to CLIP")

    try:
        with torch.no_grad():
            image_input = _preprocess(image).unsqueeze(0).to(_device)
            if _backend == "open_clip":
                text_input = _tokenizer(candidates)
                if not isinstance(text_input, torch.Tensor):
                    text_input = torch.tensor(text_input)
                text_input = text_input.to(_device)
                image_features = _model.encode_image(image_input)
                text_features = _model.encode_text(text_input)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                text_similarity = (image_features @ text_features.T).squeeze(0)
            elif _backend == "clip":
                text_input = clip.tokenize(candidates).to(_device)
                image_features = _model.encode_image(image_input)
                text_features = _model.encode_text(text_input)
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                text_features = text_features / text_features.norm(dim=-1, keepdim=True)
                text_similarity = (image_features @ text_features.T).squeeze(0)
            else:
                # Fallback: embed images only; use centroid-only ranking
                logits = _model(image_input)
                probs = torch.softmax(logits, dim=-1).squeeze(0)
                image_features = probs.unsqueeze(0)

                text_similarity = torch.zeros(len(candidates), device=probs.device)
                if _fallback_labels:
                    top_k_labels = min(50, len(_fallback_labels))
                    top_probs, top_indices = torch.topk(probs, k=top_k_labels)
                    top_labels = [_fallback_labels[i] for i in top_indices.tolist()]
                    top_scores = top_probs.tolist()

                    for i, tag in enumerate(candidates):
                        tag_l = tag.strip().lower()
                        if not tag_l:
                            continue
                        best = 0.0
                        for label, score in zip(top_labels, top_scores):
                            if tag_l in label.lower():
                                if score > best:
                                    best = score
                        text_similarity[i] = best

        # ============================================
        # NEURAL NETWORK INFERENCE: Tag Embedding Network
        # Uses learned embeddings from contrastive training
        # ============================================
        embed_net_similarity = None
        
        # Try to load the network if not already loaded (may have been trained since startup)
        if _tag_embed_net is None:
            _load_tag_embed_net()
        
        # Debug: Check image features are unique per image
        img_hash = image_features.sum().item()
        logger.info("Image feature hash: %.4f, shape: %s, dim match: %s (net=%s, feat=%s)",
                     img_hash, image_features.shape,
                     _tag_embed_dim == image_features.shape[-1] if _tag_embed_dim else "N/A",
                     _tag_embed_dim, image_features.shape[-1] if image_features.dim() > 0 else None)
        
        if _tag_embed_net is not None and _tag_embed_dim == image_features.shape[-1]:
            try:
                _tag_embed_net.eval()
                tag_to_idx = {t: i for i, t in enumerate(_tag_embed_tags)}
                
                # Get neural network scores for all known tags
                with torch.no_grad():
                    # Transform image features through the network
                    img_transformed = _tag_embed_net.transform_image(image_features, normalize=True)
                    tag_embeddings = _tag_embed_net.get_tag_embeddings(normalize=True)
                    all_sims = (img_transformed @ tag_embeddings.T).squeeze(0)
                
                # Map to candidate tags
                embed_net_similarity = torch.zeros(len(candidates), device=image_features.device)
                for i, tag in enumerate(candidates):
                    idx = tag_to_idx.get(tag)
                    if idx is not None:
                        embed_net_similarity[i] = all_sims[idx]
                logger.info("EmbedNet: matched %d/%d candidates, sim range [%.3f, %.3f]",
                            (embed_net_similarity != 0).sum().item(), len(candidates),
                            embed_net_similarity.min().item(), embed_net_similarity.max().item())
            except Exception as exc:
                logger.debug("Tag embedding network inference failed: %s", exc)
                embed_net_similarity = None

        # Combine similarity scores: neural network + CLIP text similarity
        if embed_net_similarity is not None:
            # Neural network (50%) + CLIP text (50%)
            similarity = (0.5 * text_similarity) + (0.5 * embed_net_similarity)
            logger.info("Using EmbedNet+Text blend")
        else:
            # Fallback to pure CLIP text similarity
            similarity = text_similarity
            logger.info("EmbedNet skipped: net=%s, dim_match=%s", 
                        _tag_embed_net is not None,
                        _tag_embed_dim == image_features.shape[-1] if _tag_embed_dim else "N/A")

        # Blend in classifier scores if available
        if _classifier is None:
            _load_classifier()  # Try to load if not already loaded
        
        if _classifier is not None and _classifier_dim == image_features.shape[-1]:
            try:
                _classifier.eval()
                logits = _classifier(image_features)
                probs = torch.sigmoid(logits).squeeze(0)
                if _classifier_tags:
                    tag_index = {t: i for i, t in enumerate(_classifier_tags)}
                    cls_scores = torch.zeros(len(candidates), device=probs.device)
                    for i, tag in enumerate(candidates):
                        idx = tag_index.get(tag)
                        if idx is not None:
                            cls_scores[i] = probs[idx]
                    similarity = (0.6 * similarity) + (0.4 * cls_scores)
                    logger.info("Classifier: matched %d/%d, score range [%.3f, %.3f]",
                                (cls_scores != 0).sum().item(), len(candidates),
                                cls_scores.min().item(), cls_scores.max().item())
            except Exception:
                pass
    except Exception as exc:
        logger.exception("Model inference failed")
        raise HTTPException(status_code=500, detail=f"Model error: {exc}")

    # Debug: Log final similarity stats
    logger.info("Final similarity range: [%.4f, %.4f], top-8 indices: %s",
                similarity.min().item(), similarity.max().item(),
                torch.topk(similarity, k=min(8, len(similarity))).indices.tolist())

    # Get top-k candidates with confidence threshold
    top_k = min(max(int(maxTags), 1), len(candidates))
    top_values, top_indices = torch.topk(similarity, k=top_k)
    
    # Filter by confidence threshold - only return tags with decent similarity
    MIN_CONFIDENCE = 0.15  # Minimum similarity score to include a tag
    confident_mask = top_values >= MIN_CONFIDENCE
    
    if confident_mask.any():
        # Return only tags above threshold (up to maxTags)
        confident_indices = top_indices[confident_mask].tolist()
        confident_scores = top_values[confident_mask].tolist()
        tags = [candidates[i] for i in confident_indices]
        logger.info("AI tags generated (filtered): %s (scores: %s)", 
                   tags, [f"{s:.3f}" for s in confident_scores])
    else:
        # Fallback: return top 3 if nothing passes threshold
        fallback_k = min(3, len(candidates))
        tags = [candidates[i] for i in top_indices[:fallback_k].tolist()]
        logger.info("AI tags generated (fallback, low confidence): %s", tags)

    clip_tags = tags  # Save CLIP results for potential hybrid merge

    # ======== VLM TAGGING MODE ========
    if TAGGING_MODE in ("vlm", "hybrid") and _vlm_ready:
        if not vlm_attempted:
            vlm_tags, vlm_caption = _vlm_tag_image(image, candidates, maxTags)
        if TAGGING_MODE == "vlm":
            # VLM only: use VLM results entirely
            tags = vlm_tags if vlm_tags else clip_tags  # Fallback to CLIP if VLM returns nothing
            logger.info("Using VLM-only tags: %s", tags)
        elif TAGGING_MODE == "hybrid":
            # Hybrid: merge CLIP + VLM results, prioritizing tags that appear in both
            clip_set = set(clip_tags)
            vlm_set = set(vlm_tags)
            # Tags in both systems get highest priority
            both = [t for t in clip_tags if t in vlm_set]
            # Tags only in VLM (VLM "saw" something CLIP missed)
            vlm_only = [t for t in vlm_tags if t not in clip_set]
            # Tags only in CLIP (CLIP scored it but VLM didn't mention it)
            clip_only = [t for t in clip_tags if t not in vlm_set]
            tags = both + vlm_only + clip_only
            tags = tags[:maxTags]
            logger.info("Hybrid tags: %s (both=%s, vlm_only=%s, clip_only=%s)",
                        tags, both, vlm_only, clip_only)
    elif TAGGING_MODE in ("vlm", "hybrid") and not _vlm_ready:
        logger.warning("VLM requested but not ready, falling back to CLIP tags")

    return {"tags": tags, "caption": vlm_caption or caption}


@app.post("/ai-train")
async def ai_train(
    file: UploadFile = File(...),
    tags: str = Form("[]")
):
    try:
        tag_list = json.loads(tags) if tags else []
    except Exception:
        tag_list = []

    tag_list = _sanitize_tags(tag_list)
    if not tag_list:
        raise HTTPException(status_code=400, detail="No tags provided")

    try:
        raw = await file.read()
        image = Image.open(io.BytesIO(raw)).convert("RGB")
    except Exception as exc:
        logger.exception("Failed to read image for training")
        raise HTTPException(status_code=400, detail=f"Invalid image: {exc}")

    with torch.no_grad():
        image_input = _preprocess(image).unsqueeze(0).to(_device)
        if _backend == "open_clip":
            image_features = _model.encode_image(image_input)
        elif _backend == "clip":
            image_features = _model.encode_image(image_input)
        else:
            logits = _model(image_input)
            image_features = torch.softmax(logits, dim=-1)
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)

    # Save tags to database
    _save_tag_db(_load_tag_db() + tag_list)
    all_tags = _load_tag_db()
    feature_dim = image_features.shape[-1]

    # ============================================
    # REAL NEURAL NETWORK TRAINING: Tag Embedding Network
    # ============================================
    embed_loss = 0.0
    try:
        _ensure_tag_embed_net(all_tags, feature_dim)
        embed_loss = _train_tag_embed_net(image_features, tag_list)
        _save_tag_embed_net()
        logger.debug("Tag embedding training loss: %.4f", embed_loss)
    except Exception as exc:
        logger.warning("Tag embedding training skipped: %s", exc)

    # Fine-tune classifier head (real training step)
    loss_value = None
    try:
        global _classifier_optimizer
        _ensure_classifier(all_tags, feature_dim)
        if _classifier is not None:
            _classifier.train()
            target = torch.zeros(len(_classifier_tags), device=_device)
            tag_index = {t: i for i, t in enumerate(_classifier_tags)}
            for t in tag_list:
                idx = tag_index.get(t)
                if idx is not None:
                    target[idx] = 1.0
            if _classifier_optimizer is None:
                _classifier_optimizer = torch.optim.AdamW(_classifier.parameters(), lr=1e-3)
            loss_fn = torch.nn.BCEWithLogitsLoss()
            feats = image_features.detach()
            for _ in range(5):
                _classifier_optimizer.zero_grad(set_to_none=True)
                logits = _classifier(feats).squeeze(0)
                loss = loss_fn(logits, target)
                loss.backward()
                _classifier_optimizer.step()
            try:
                loss_value = float(loss.detach().cpu().item())
            except Exception:
                loss_value = None
            _save_classifier()
    except Exception as exc:
        logger.warning("Classifier fine-tune skipped: %s", exc)

    # End-to-end fine-tuning (updates encoder weights)
    clip_pre = 0.0
    clip_post = 0.0
    try:
        clip_pre, clip_post = _finetune_on_tags(image, tag_list)
        _save_finetune_weights()
    except Exception as exc:
        logger.warning("End-to-end finetune skipped: %s", exc)

    # VLM LoRA training (if enabled and VLM is loaded)
    vlm_loss = 0.0
    if VLM_LORA_ENABLED and _vlm_ready:
        try:
            vlm_loss = _vlm_train_on_tags(image, tag_list, all_tags)
            _vlm_save_lora()
            logger.debug("VLM LoRA training loss: %.4f", vlm_loss)
        except Exception as exc:
            logger.warning("VLM LoRA training skipped: %s", exc)

    # Update training status tracker
    _training_status.update(
        _training_status.trained + 1,
        loss_value,
        loss_clip=clip_post,
        loss_clip_pre=clip_pre,
        loss_vlm=vlm_loss
    )
    
    logger.info("AI train updated tags: %s | CLIP %.4f→%.4f | VLM %.4f", tag_list, clip_pre, clip_post, vlm_loss)
    return {"status": "ok", "trained": tag_list, "loss": loss_value,
            "clip_pre": clip_pre, "clip_post": clip_post, "vlm_loss": vlm_loss}
