""" Unified Provider Session Manager (Supabase) -------------------------------------------- Manages persistent browser sessions for all providers via Supabase. This ensures sessions survive redeploys and restarts. Providers supported: - huggingchat - zai - gemini - (add more as needed) """ import json import logging from datetime import datetime, timedelta from typing import Optional, Dict, Any, List from supabase import create_client, Client from config import SUPABASE_URL, SUPABASE_KEY logger = logging.getLogger("kai_api.provider_sessions") # Session limits per provider DEFAULT_MAX_CONVERSATIONS = { "huggingchat": 50, "zai": 100, "gemini": 100, "copilot": 999999, # Unlimited for Copilot } # Session duration per provider (hours) DEFAULT_SESSION_DURATION = { "huggingchat": 24, "zai": 48, "gemini": 48, "copilot": 720, # 30 days for Copilot } class ProviderSessionManager: """Manages provider sessions via Supabase.""" def __init__(self): self.supabase: Optional[Client] = None try: self.supabase = create_client(SUPABASE_URL, SUPABASE_KEY) logger.info("✅ ProviderSessionManager: Connected to Supabase") except Exception as e: logger.error(f"❌ ProviderSessionManager: Failed to connect to Supabase: {e}") def is_available(self) -> bool: """Check if Supabase connection is available.""" return self.supabase is not None def get_session(self, provider: str) -> Optional[Dict[str, Any]]: """ Get session data for a provider. Returns None if no session or session expired. """ if not self.supabase: return None try: response = self.supabase.table("kaiapi_provider_sessions").select("*").eq("provider", provider).execute() if not response.data: return None session = response.data[0] # Check if expired expires_at = session.get("expires_at") if expires_at: expires_dt = datetime.fromisoformat(expires_at.replace('Z', '+00:00')) if datetime.now().astimezone() > expires_dt: logger.info(f"Session for {provider} expired") self.delete_session(provider) return None # Check if exceeded max conversations conv_count = session.get("conversation_count", 0) max_conv = session.get("max_conversations", DEFAULT_MAX_CONVERSATIONS.get(provider, 50)) if conv_count >= max_conv: logger.info(f"Session for {provider} reached {max_conv} conversations") self.delete_session(provider) return None logger.info(f"✅ Found valid session for {provider} ({conv_count}/{max_conv} conversations)") return session except Exception as e: logger.error(f"Failed to get session for {provider}: {e}") return None def save_session( self, provider: str, cookies: List[Dict], conversation_count: int = 0, extra_data: Optional[Dict] = None ) -> bool: """ Save session data for a provider. """ if not self.supabase: logger.warning("Supabase not available, cannot save session") return False try: # Build session data session_data = { "cookies": cookies, } if extra_data: session_data.update(extra_data) # Calculate expiration duration_hours = DEFAULT_SESSION_DURATION.get(provider, 24) expires_at = datetime.now().astimezone() + timedelta(hours=duration_hours) # Get max conversations max_conv = DEFAULT_MAX_CONVERSATIONS.get(provider, 50) # Upsert using the stored function result = self.supabase.rpc( "upsert_provider_session", { "p_provider": provider, "p_session_data": session_data, "p_conversation_count": conversation_count, "p_max_conversations": max_conv, "p_expires_at": expires_at.isoformat() } ).execute() logger.info(f"✅ Saved session for {provider} (expires: {expires_at})") return True except Exception as e: logger.error(f"Failed to save session for {provider}: {e}") return False def increment_conversation(self, provider: str) -> bool: """ Increment conversation count for a provider. """ if not self.supabase: return False try: self.supabase.rpc( "increment_conversation_count", {"p_provider": provider} ).execute() return True except Exception as e: logger.error(f"Failed to increment conversation for {provider}: {e}") return False def delete_session(self, provider: str) -> bool: """ Delete session for a provider. """ if not self.supabase: return False try: self.supabase.table("kaiapi_provider_sessions").delete().eq("provider", provider).execute() logger.info(f"Deleted session for {provider}") return True except Exception as e: logger.error(f"Failed to delete session for {provider}: {e}") return False def clear_all_sessions(self) -> bool: """ Clear all provider sessions. """ if not self.supabase: return False try: self.supabase.table("kaiapi_provider_sessions").delete().neq("id", "00000000-0000-0000-0000-000000000000").execute() logger.info("Cleared all provider sessions") return True except Exception as e: logger.error(f"Failed to clear sessions: {e}") return False def get_all_sessions(self) -> List[Dict[str, Any]]: """ Get all sessions. """ if not self.supabase: return [] try: response = self.supabase.table("kaiapi_provider_sessions").select("*").execute() return response.data except Exception as e: logger.error(f"Failed to get all sessions: {e}") return [] def needs_login(self, provider: str) -> bool: """ Check if provider needs login (no valid session). """ session = self.get_session(provider) return session is None # Global instance _session_manager: Optional[ProviderSessionManager] = None def get_provider_session_manager() -> ProviderSessionManager: """Get the global provider session manager instance.""" global _session_manager if _session_manager is None: _session_manager = ProviderSessionManager() return _session_manager