"""
Vector store service for RAG
Supports FAISS, Weaviate, and PGVector
"""
from typing import List, Dict, Any, Optional
from abc import ABC, abstractmethod
import os
import pickle

from config import settings
from app.utils.observability import logger
from app.services.ai_providers import get_ai_provider


class VectorStore(ABC):
    """Base class for vector stores"""
    
    @abstractmethod
    async def add_document(
        self,
        tenant_id: str,
        document_id: str,
        content: str,
        metadata: Dict[str, Any]
    ) -> str:
        """Add document to vector store"""
        pass
    
    @abstractmethod
    async def search(
        self,
        tenant_id: str,
        query: str,
        top_k: int = 5,
        document_type: Optional[str] = None
    ) -> List[Dict[str, Any]]:
        """Search for similar documents"""
        pass
    
    @abstractmethod
    async def delete_document(self, vector_id: str) -> bool:
        """Delete document from vector store"""
        pass
    
    async def generate_embedding(self, text: str) -> List[float]:
        """Generate embedding for text"""
        provider = get_ai_provider("openai")
        return await provider.generate_embedding(text)


class FAISSVectorStore(VectorStore):
    """FAISS-based vector store"""
    
    def __init__(self):
        self.index_path = settings.faiss_index_path
        self.dimension = settings.embedding_dimension
        self.index = None
        self.documents = {}  # Store document metadata
        self._load_index()
    
    def _load_index(self):
        """Load FAISS index from disk"""
        import faiss
        
        os.makedirs(os.path.dirname(self.index_path), exist_ok=True)
        
        index_file = f"{self.index_path}/index.faiss"
        docs_file = f"{self.index_path}/documents.pkl"
        
        if os.path.exists(index_file):
            self.index = faiss.read_index(index_file)
            logger.info("FAISS index loaded", path=index_file)
        else:
            self.index = faiss.IndexFlatL2(self.dimension)
            logger.info("New FAISS index created", dimension=self.dimension)
        
        if os.path.exists(docs_file):
            with open(docs_file, 'rb') as f:
                self.documents = pickle.load(f)
            logger.info("Document metadata loaded", count=len(self.documents))
    
    def _save_index(self):
        """Save FAISS index to disk"""
        import faiss
        
        os.makedirs(os.path.dirname(self.index_path), exist_ok=True)
        
        index_file = f"{self.index_path}/index.faiss"
        docs_file = f"{self.index_path}/documents.pkl"
        
        faiss.write_index(self.index, index_file)
        
        with open(docs_file, 'wb') as f:
            pickle.dump(self.documents, f)
        
        logger.info("FAISS index saved", path=index_file)
    
    async def add_document(
        self,
        tenant_id: str,
        document_id: str,
        content: str,
        metadata: Dict[str, Any]
    ) -> str:
        """Add document to FAISS index"""
        import faiss
        import numpy as np
        
        # Generate embedding
        embedding = await self.generate_embedding(content)
        embedding_array = np.array([embedding], dtype=np.float32)
        
        # Add to index
        vector_id = self.index.ntotal
        self.index.add(embedding_array)
        
        # Store metadata
        self.documents[vector_id] = {
            "tenant_id": tenant_id,
            "document_id": document_id,
            "content": content,
            "metadata": metadata
        }
        
        # Save index
        self._save_index()
        
        logger.info(
            "Document added to FAISS",
            vector_id=vector_id,
            document_id=document_id
        )
        
        return str(vector_id)
    
    async def search(
        self,
        tenant_id: str,
        query: str,
        top_k: int = 5,
        document_type: Optional[str] = None
    ) -> List[Dict[str, Any]]:
        """Search FAISS index"""
        import numpy as np
        
        # Generate query embedding
        query_embedding = await self.generate_embedding(query)
        query_array = np.array([query_embedding], dtype=np.float32)
        
        # Search index
        distances, indices = self.index.search(query_array, top_k * 2)  # Get more for filtering
        
        # Filter and format results
        results = []
        for i, idx in enumerate(indices[0]):
            if idx == -1:  # No more results
                break
            
            doc = self.documents.get(idx)
            if not doc:
                continue
            
            # Filter by tenant
            if doc["tenant_id"] != tenant_id:
                continue
            
            # Filter by document type if specified
            if document_type and doc["metadata"].get("document_type") != document_type:
                continue
            
            results.append({
                "document_id": doc["document_id"],
                "title": doc["metadata"].get("title", ""),
                "content": doc["content"],
                "score": float(1.0 / (1.0 + distances[0][i])),  # Convert distance to similarity
                "metadata": doc["metadata"]
            })
            
            if len(results) >= top_k:
                break
        
        logger.info("FAISS search completed", query_length=len(query), results=len(results))
        
        return results
    
    async def delete_document(self, vector_id: str) -> bool:
        """Delete document from FAISS"""
        # FAISS doesn't support deletion, mark as deleted in metadata
        vector_id_int = int(vector_id)
        if vector_id_int in self.documents:
            del self.documents[vector_id_int]
            self._save_index()
            logger.info("Document marked as deleted", vector_id=vector_id)
            return True
        return False


class WeaviateVectorStore(VectorStore):
    """Weaviate-based vector store"""
    
    def __init__(self):
        import weaviate
        
        self.client = weaviate.Client(
            url=settings.weaviate_url,
            auth_client_secret=weaviate.AuthApiKey(api_key=settings.weaviate_api_key) if settings.weaviate_api_key else None
        )
        self._ensure_schema()
    
    def _ensure_schema(self):
        """Ensure Weaviate schema exists"""
        schema = {
            "class": "Document",
            "vectorizer": "none",  # We provide our own vectors
            "properties": [
                {"name": "tenant_id", "dataType": ["string"]},
                {"name": "document_id", "dataType": ["string"]},
                {"name": "content", "dataType": ["text"]},
                {"name": "title", "dataType": ["string"]},
                {"name": "document_type", "dataType": ["string"]},
                {"name": "metadata", "dataType": ["object"]},
            ]
        }
        
        try:
            self.client.schema.create_class(schema)
            logger.info("Weaviate schema created")
        except Exception as e:
            logger.info("Weaviate schema already exists or error", error=str(e))
    
    async def add_document(
        self,
        tenant_id: str,
        document_id: str,
        content: str,
        metadata: Dict[str, Any]
    ) -> str:
        """Add document to Weaviate"""
        # Generate embedding
        embedding = await self.generate_embedding(content)
        
        # Add to Weaviate
        result = self.client.data_object.create(
            data_object={
                "tenant_id": tenant_id,
                "document_id": document_id,
                "content": content,
                "title": metadata.get("title", ""),
                "document_type": metadata.get("document_type", ""),
                "metadata": metadata
            },
            class_name="Document",
            vector=embedding
        )
        
        vector_id = result
        
        logger.info(
            "Document added to Weaviate",
            vector_id=vector_id,
            document_id=document_id
        )
        
        return vector_id
    
    async def search(
        self,
        tenant_id: str,
        query: str,
        top_k: int = 5,
        document_type: Optional[str] = None
    ) -> List[Dict[str, Any]]:
        """Search Weaviate"""
        # Generate query embedding
        query_embedding = await self.generate_embedding(query)
        
        # Build where filter
        where_filter = {
            "path": ["tenant_id"],
            "operator": "Equal",
            "valueString": tenant_id
        }
        
        if document_type:
            where_filter = {
                "operator": "And",
                "operands": [
                    where_filter,
                    {
                        "path": ["document_type"],
                        "operator": "Equal",
                        "valueString": document_type
                    }
                ]
            }
        
        # Search
        result = (
            self.client.query
            .get("Document", ["document_id", "title", "content", "metadata"])
            .with_near_vector({"vector": query_embedding})
            .with_where(where_filter)
            .with_limit(top_k)
            .with_additional(["distance"])
            .do()
        )
        
        # Format results
        results = []
        for item in result.get("data", {}).get("Get", {}).get("Document", []):
            results.append({
                "document_id": item["document_id"],
                "title": item["title"],
                "content": item["content"],
                "score": 1.0 - float(item["_additional"]["distance"]),
                "metadata": item.get("metadata", {})
            })
        
        logger.info("Weaviate search completed", results=len(results))
        
        return results
    
    async def delete_document(self, vector_id: str) -> bool:
        """Delete document from Weaviate"""
        try:
            self.client.data_object.delete(vector_id, class_name="Document")
            logger.info("Document deleted from Weaviate", vector_id=vector_id)
            return True
        except Exception as e:
            logger.error("Failed to delete from Weaviate", vector_id=vector_id, error=str(e))
            return False


class MockVectorStore(VectorStore):
    """Mock vector store for testing"""
    
    def __init__(self):
        self.documents = {}
    
    async def add_document(
        self,
        tenant_id: str,
        document_id: str,
        content: str,
        metadata: Dict[str, Any]
    ) -> str:
        """Mock add document"""
        import uuid
        vector_id = str(uuid.uuid4())
        
        self.documents[vector_id] = {
            "tenant_id": tenant_id,
            "document_id": document_id,
            "content": content,
            "metadata": metadata
        }
        
        logger.info("Mock document added", vector_id=vector_id)
        return vector_id
    
    async def search(
        self,
        tenant_id: str,
        query: str,
        top_k: int = 5,
        document_type: Optional[str] = None
    ) -> List[Dict[str, Any]]:
        """Mock search"""
        # Return mock results
        results = []
        for vector_id, doc in list(self.documents.items())[:top_k]:
            if doc["tenant_id"] == tenant_id:
                results.append({
                    "document_id": doc["document_id"],
                    "title": doc["metadata"].get("title", ""),
                    "content": doc["content"],
                    "score": 0.9,
                    "metadata": doc["metadata"]
                })
        
        return results
    
    async def delete_document(self, vector_id: str) -> bool:
        """Mock delete"""
        if vector_id in self.documents:
            del self.documents[vector_id]
            return True
        return False


def get_vector_store() -> VectorStore:
    """Get vector store instance based on configuration"""
    if settings.mock_providers or settings.test_mode:
        return MockVectorStore()
    
    store_type = settings.vector_store_type.lower()
    
    if store_type == "faiss":
        return FAISSVectorStore()
    elif store_type == "weaviate":
        return WeaviateVectorStore()
    else:
        logger.warning(f"Unknown vector store type: {store_type}, using FAISS")
        return FAISSVectorStore()


# Global vector store instance
vector_store_service = get_vector_store()