"""
Security utilities: PII redaction, encryption, token management
"""
import re
import hashlib
import hmac
from typing import Dict, Any, Optional, List
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2
from cryptography.hazmat.backends import default_backend
import base64
import json
import logging

from config import settings

logger = logging.getLogger(__name__)


class PIIRedactor:
    """Redact PII from text and data structures"""
    
    # Regex patterns for common PII
    PATTERNS = {
        "email": r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b',
        "phone": r'\b(?:\+?1[-.]?)?\(?([0-9]{3})\)?[-.]?([0-9]{3})[-.]?([0-9]{4})\b',
        "ssn": r'\b\d{3}-\d{2}-\d{4}\b',
        "credit_card": r'\b\d{4}[-\s]?\d{4}[-\s]?\d{4}[-\s]?\d{4}\b',
        "ip_address": r'\b(?:\d{1,3}\.){3}\d{1,3}\b',
        "url": r'https?://(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b(?:[-a-zA-Z0-9()@:%_\+.~#?&/=]*)',
    }
    
    def __init__(self, patterns: Optional[List[str]] = None):
        """Initialize with specific patterns to redact"""
        self.patterns = patterns or settings.pii_patterns
        self.enabled = settings.pii_redaction_enabled
    
    def redact_text(self, text: str, replacement: str = "[REDACTED]") -> str:
        """Redact PII from text"""
        if not self.enabled or not text:
            return text
        
        redacted = text
        for pattern_name in self.patterns:
            if pattern_name in self.PATTERNS:
                pattern = self.PATTERNS[pattern_name]
                redacted = re.sub(pattern, replacement, redacted, flags=re.IGNORECASE)
        
        return redacted
    
    def redact_dict(self, data: Dict[str, Any], replacement: str = "[REDACTED]") -> Dict[str, Any]:
        """Recursively redact PII from dictionary"""
        if not self.enabled:
            return data
        
        redacted = {}
        for key, value in data.items():
            if isinstance(value, str):
                redacted[key] = self.redact_text(value, replacement)
            elif isinstance(value, dict):
                redacted[key] = self.redact_dict(value, replacement)
            elif isinstance(value, list):
                redacted[key] = [
                    self.redact_dict(item, replacement) if isinstance(item, dict)
                    else self.redact_text(item, replacement) if isinstance(item, str)
                    else item
                    for item in value
                ]
            else:
                redacted[key] = value
        
        return redacted
    
    def has_pii(self, text: str) -> bool:
        """Check if text contains PII"""
        if not text:
            return False
        
        for pattern_name in self.patterns:
            if pattern_name in self.PATTERNS:
                pattern = self.PATTERNS[pattern_name]
                if re.search(pattern, text, flags=re.IGNORECASE):
                    return True
        
        return False


class Encryptor:
    """Encrypt and decrypt sensitive data"""
    
    def __init__(self, key: Optional[str] = None):
        """Initialize with encryption key"""
        self.key = key or settings.encryption_key
        
        # Derive a proper Fernet key from the provided key
        if len(self.key.encode()) != 32:
            kdf = PBKDF2(
                algorithm=hashes.SHA256(),
                length=32,
                salt=b'ai-microservice-salt',
                iterations=100000,
                backend=default_backend()
            )
            derived_key = base64.urlsafe_b64encode(kdf.derive(self.key.encode()))
        else:
            derived_key = base64.urlsafe_b64encode(self.key.encode())
        
        self.cipher = Fernet(derived_key)
    
    def encrypt(self, data: str) -> str:
        """Encrypt string data"""
        try:
            encrypted = self.cipher.encrypt(data.encode())
            return base64.urlsafe_b64encode(encrypted).decode()
        except Exception as e:
            logger.error(f"Encryption error: {e}")
            raise
    
    def decrypt(self, encrypted_data: str) -> str:
        """Decrypt string data"""
        try:
            decoded = base64.urlsafe_b64decode(encrypted_data.encode())
            decrypted = self.cipher.decrypt(decoded)
            return decrypted.decode()
        except Exception as e:
            logger.error(f"Decryption error: {e}")
            raise
    
    def encrypt_dict(self, data: Dict[str, Any]) -> str:
        """Encrypt dictionary as JSON"""
        json_str = json.dumps(data)
        return self.encrypt(json_str)
    
    def decrypt_dict(self, encrypted_data: str) -> Dict[str, Any]:
        """Decrypt JSON dictionary"""
        json_str = self.decrypt(encrypted_data)
        return json.loads(json_str)


class TokenManager:
    """Manage API tokens securely"""
    
    def __init__(self):
        self.encryptor = Encryptor()
    
    def store_token(self, token: str, token_type: str = "api") -> str:
        """Store token encrypted, return reference ID"""
        encrypted = self.encryptor.encrypt(token)
        # In production, store in secure vault (e.g., HashiCorp Vault)
        # For now, return encrypted token
        return encrypted
    
    def retrieve_token(self, token_ref: str) -> str:
        """Retrieve and decrypt token"""
        return self.encryptor.decrypt(token_ref)
    
    def mask_token(self, token: str, visible_chars: int = 4) -> str:
        """Mask token for logging (show only last N chars)"""
        if len(token) <= visible_chars:
            return "*" * len(token)
        return "*" * (len(token) - visible_chars) + token[-visible_chars:]


class SignatureValidator:
    """Validate webhook signatures"""
    
    @staticmethod
    def generate_signature(payload: str, secret: str) -> str:
        """Generate HMAC signature for payload"""
        signature = hmac.new(
            secret.encode(),
            payload.encode(),
            hashlib.sha256
        ).hexdigest()
        return signature
    
    @staticmethod
    def validate_signature(payload: str, signature: str, secret: str) -> bool:
        """Validate HMAC signature"""
        expected = SignatureValidator.generate_signature(payload, secret)
        return hmac.compare_digest(expected, signature)
    
    @staticmethod
    def validate_laravel_callback(payload: Dict[str, Any], signature: str) -> bool:
        """Validate callback from Laravel"""
        payload_str = json.dumps(payload, sort_keys=True)
        return SignatureValidator.validate_signature(
            payload_str,
            signature,
            settings.laravel_callback_secret
        )


# Global instances
pii_redactor = PIIRedactor()
encryptor = Encryptor()
token_manager = TokenManager()


def redact_for_logging(data: Any) -> Any:
    """Redact sensitive data before logging"""
    if isinstance(data, str):
        return pii_redactor.redact_text(data)
    elif isinstance(data, dict):
        return pii_redactor.redact_dict(data)
    return data


def sanitize_error_message(error: Exception) -> str:
    """Sanitize error message to remove sensitive info"""
    error_str = str(error)
    return pii_redactor.redact_text(error_str)