"""
Admin endpoints for monitoring and management
"""
from fastapi import APIRouter, Depends, HTTPException, status, Header
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, func
from typing import Optional
from datetime import datetime, timedelta

from app.database import get_db
from app.models import AIJob, PublishJob, CostTracking
from app.schemas import MetricsResponse
from app.utils.observability import logger, metrics_collector
from config import settings

router = APIRouter()


async def verify_admin_key(x_admin_key: Optional[str] = Header(None)):
    """Verify admin API key"""
    # In production, use a separate admin key
    if not x_admin_key or x_admin_key != settings.secret_key:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid or missing admin key"
        )
    return x_admin_key


@router.get("/metrics", response_model=MetricsResponse)
async def get_metrics(
    db: AsyncSession = Depends(get_db),
    admin_key: str = Depends(verify_admin_key)
):
    """
    Get system metrics
    
    Returns aggregated metrics about jobs, costs, and performance.
    """
    try:
        # Total jobs
        total_jobs_stmt = select(func.count(AIJob.id))
        total_jobs_result = await db.execute(total_jobs_stmt)
        total_jobs = total_jobs_result.scalar()
        
        # Jobs by status
        jobs_by_status = {}
        for status_val in ["pending", "processing", "completed", "failed", "cancelled"]:
            stmt = select(func.count(AIJob.id)).where(AIJob.status == status_val)
            result = await db.execute(stmt)
            jobs_by_status[status_val] = result.scalar()
        
        # Jobs by type
        jobs_by_type = {}
        for job_type in ["proposal", "content_ideas", "auto_reply", "ads_generate", 
                        "forecast", "matchmaking", "translate", "contract_draft", 
                        "support_rag", "custom"]:
            stmt = select(func.count(AIJob.id)).where(AIJob.job_type == job_type)
            result = await db.execute(stmt)
            count = result.scalar()
            if count > 0:
                jobs_by_type[job_type] = count
        
        # Total cost
        cost_stmt = select(func.sum(CostTracking.cost_usd))
        cost_result = await db.execute(cost_stmt)
        total_cost = cost_result.scalar() or 0.0
        
        # Average processing time
        avg_time_stmt = select(
            func.avg(
                func.extract('epoch', AIJob.completed_at - AIJob.started_at)
            )
        ).where(
            AIJob.status == "completed",
            AIJob.started_at.isnot(None),
            AIJob.completed_at.isnot(None)
        )
        avg_time_result = await db.execute(avg_time_stmt)
        avg_processing_time = avg_time_result.scalar() or 0.0
        
        # Active workers (from Celery)
        active_workers = 0
        try:
            from app.celery_app import celery_app
            inspect = celery_app.control.inspect()
            stats = inspect.stats()
            if stats:
                active_workers = len(stats)
        except Exception as e:
            logger.warning("Failed to get worker count", error=str(e))
        
        return MetricsResponse(
            total_jobs=total_jobs,
            jobs_by_status=jobs_by_status,
            jobs_by_type=jobs_by_type,
            total_cost_usd=round(total_cost, 2),
            avg_processing_time_seconds=round(avg_processing_time, 2),
            active_workers=active_workers
        )
        
    except Exception as e:
        logger.error("Failed to get metrics", error=str(e))
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Failed to get metrics: {str(e)}"
        )


@router.get("/costs")
async def get_cost_breakdown(
    days: int = 7,
    db: AsyncSession = Depends(get_db),
    admin_key: str = Depends(verify_admin_key)
):
    """
    Get cost breakdown by provider and model
    
    Returns cost analysis for the specified time period.
    """
    try:
        since = datetime.utcnow() - timedelta(days=days)
        
        # Cost by provider
        provider_stmt = select(
            CostTracking.provider,
            func.sum(CostTracking.cost_usd).label('total_cost'),
            func.sum(CostTracking.tokens_total).label('total_tokens')
        ).where(
            CostTracking.created_at >= since
        ).group_by(CostTracking.provider)
        
        provider_result = await db.execute(provider_stmt)
        cost_by_provider = [
            {
                "provider": row.provider,
                "total_cost_usd": round(row.total_cost, 2),
                "total_tokens": row.total_tokens
            }
            for row in provider_result
        ]
        
        # Cost by model
        model_stmt = select(
            CostTracking.provider,
            CostTracking.model,
            func.sum(CostTracking.cost_usd).label('total_cost'),
            func.sum(CostTracking.tokens_total).label('total_tokens')
        ).where(
            CostTracking.created_at >= since
        ).group_by(CostTracking.provider, CostTracking.model)
        
        model_result = await db.execute(model_stmt)
        cost_by_model = [
            {
                "provider": row.provider,
                "model": row.model,
                "total_cost_usd": round(row.total_cost, 2),
                "total_tokens": row.total_tokens
            }
            for row in model_result
        ]
        
        return {
            "period_days": days,
            "cost_by_provider": cost_by_provider,
            "cost_by_model": cost_by_model
        }
        
    except Exception as e:
        logger.error("Failed to get cost breakdown", error=str(e))
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Failed to get cost breakdown: {str(e)}"
        )


@router.get("/workers")
async def get_worker_status(admin_key: str = Depends(verify_admin_key)):
    """
    Get Celery worker status
    
    Returns information about active workers and queues.
    """
    try:
        from app.celery_app import celery_app
        
        inspect = celery_app.control.inspect()
        
        # Get active workers
        stats = inspect.stats()
        active_tasks = inspect.active()
        registered_tasks = inspect.registered()
        
        workers = []
        if stats:
            for worker_name, worker_stats in stats.items():
                workers.append({
                    "name": worker_name,
                    "status": "active",
                    "pool": worker_stats.get("pool", {}).get("implementation"),
                    "max_concurrency": worker_stats.get("pool", {}).get("max-concurrency"),
                    "active_tasks": len(active_tasks.get(worker_name, [])) if active_tasks else 0,
                    "registered_tasks": len(registered_tasks.get(worker_name, [])) if registered_tasks else 0
                })
        
        return {
            "workers": workers,
            "total_workers": len(workers)
        }
        
    except Exception as e:
        logger.error("Failed to get worker status", error=str(e))
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Failed to get worker status: {str(e)}"
        )


@router.post("/workers/scale")
async def scale_workers(
    count: int,
    admin_key: str = Depends(verify_admin_key)
):
    """
    Scale Celery workers
    
    Note: This is a placeholder. Actual scaling depends on deployment method.
    """
    if count < 1 or count > 10:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Worker count must be between 1 and 10"
        )
    
    logger.info("Worker scaling requested", count=count)
    
    return {
        "message": f"Worker scaling to {count} requested",
        "note": "Actual scaling depends on deployment configuration"
    }


@router.post("/cache/clear")
async def clear_cache(admin_key: str = Depends(verify_admin_key)):
    """
    Clear Redis cache
    
    Clears all cached data.
    """
    try:
        import redis.asyncio as redis
        r = redis.from_url(settings.redis_url)
        await r.flushdb()
        await r.close()
        
        logger.info("Cache cleared")
        
        return {"message": "Cache cleared successfully"}
        
    except Exception as e:
        logger.error("Failed to clear cache", error=str(e))
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"Failed to clear cache: {str(e)}"
        )


@router.get("/config")
async def get_config(admin_key: str = Depends(verify_admin_key)):
    """
    Get current configuration (sanitized)
    
    Returns non-sensitive configuration values.
    """
    return {
        "app_env": settings.app_env,
        "app_debug": settings.app_debug,
        "vector_store_type": settings.vector_store_type,
        "embedding_model": settings.embedding_model,
        "openai_default_model": settings.openai_default_model,
        "anthropic_default_model": settings.anthropic_default_model,
        "local_llm_enabled": settings.local_llm_enabled,
        "pii_redaction_enabled": settings.pii_redaction_enabled,
        "track_costs": settings.track_costs,
        "enable_metrics": settings.enable_metrics,
        "test_mode": settings.test_mode,
        "mock_providers": settings.mock_providers
    }