kai-api-gateway / provider_sessions.py
KiWA001's picture
fix: use lowercase table names for PostgreSQL compatibility
ba9d716
"""
Unified Provider Session Manager (Supabase)
--------------------------------------------
Manages persistent browser sessions for all providers via Supabase.
This ensures sessions survive redeploys and restarts.
Providers supported:
- huggingchat
- zai
- gemini
- (add more as needed)
"""
import json
import logging
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List
from supabase import create_client, Client
from config import SUPABASE_URL, SUPABASE_KEY
logger = logging.getLogger("kai_api.provider_sessions")
# Session limits per provider
DEFAULT_MAX_CONVERSATIONS = {
"huggingchat": 50,
"zai": 100,
"gemini": 100,
"copilot": 999999, # Unlimited for Copilot
}
# Session duration per provider (hours)
DEFAULT_SESSION_DURATION = {
"huggingchat": 24,
"zai": 48,
"gemini": 48,
"copilot": 720, # 30 days for Copilot
}
class ProviderSessionManager:
"""Manages provider sessions via Supabase."""
def __init__(self):
self.supabase: Optional[Client] = None
try:
self.supabase = create_client(SUPABASE_URL, SUPABASE_KEY)
logger.info("✅ ProviderSessionManager: Connected to Supabase")
except Exception as e:
logger.error(f"❌ ProviderSessionManager: Failed to connect to Supabase: {e}")
def is_available(self) -> bool:
"""Check if Supabase connection is available."""
return self.supabase is not None
def get_session(self, provider: str) -> Optional[Dict[str, Any]]:
"""
Get session data for a provider.
Returns None if no session or session expired.
"""
if not self.supabase:
return None
try:
response = self.supabase.table("kaiapi_provider_sessions").select("*").eq("provider", provider).execute()
if not response.data:
return None
session = response.data[0]
# Check if expired
expires_at = session.get("expires_at")
if expires_at:
expires_dt = datetime.fromisoformat(expires_at.replace('Z', '+00:00'))
if datetime.now().astimezone() > expires_dt:
logger.info(f"Session for {provider} expired")
self.delete_session(provider)
return None
# Check if exceeded max conversations
conv_count = session.get("conversation_count", 0)
max_conv = session.get("max_conversations", DEFAULT_MAX_CONVERSATIONS.get(provider, 50))
if conv_count >= max_conv:
logger.info(f"Session for {provider} reached {max_conv} conversations")
self.delete_session(provider)
return None
logger.info(f"✅ Found valid session for {provider} ({conv_count}/{max_conv} conversations)")
return session
except Exception as e:
logger.error(f"Failed to get session for {provider}: {e}")
return None
def save_session(
self,
provider: str,
cookies: List[Dict],
conversation_count: int = 0,
extra_data: Optional[Dict] = None
) -> bool:
"""
Save session data for a provider.
"""
if not self.supabase:
logger.warning("Supabase not available, cannot save session")
return False
try:
# Build session data
session_data = {
"cookies": cookies,
}
if extra_data:
session_data.update(extra_data)
# Calculate expiration
duration_hours = DEFAULT_SESSION_DURATION.get(provider, 24)
expires_at = datetime.now().astimezone() + timedelta(hours=duration_hours)
# Get max conversations
max_conv = DEFAULT_MAX_CONVERSATIONS.get(provider, 50)
# Upsert using the stored function
result = self.supabase.rpc(
"upsert_provider_session",
{
"p_provider": provider,
"p_session_data": session_data,
"p_conversation_count": conversation_count,
"p_max_conversations": max_conv,
"p_expires_at": expires_at.isoformat()
}
).execute()
logger.info(f"✅ Saved session for {provider} (expires: {expires_at})")
return True
except Exception as e:
logger.error(f"Failed to save session for {provider}: {e}")
return False
def increment_conversation(self, provider: str) -> bool:
"""
Increment conversation count for a provider.
"""
if not self.supabase:
return False
try:
self.supabase.rpc(
"increment_conversation_count",
{"p_provider": provider}
).execute()
return True
except Exception as e:
logger.error(f"Failed to increment conversation for {provider}: {e}")
return False
def delete_session(self, provider: str) -> bool:
"""
Delete session for a provider.
"""
if not self.supabase:
return False
try:
self.supabase.table("kaiapi_provider_sessions").delete().eq("provider", provider).execute()
logger.info(f"Deleted session for {provider}")
return True
except Exception as e:
logger.error(f"Failed to delete session for {provider}: {e}")
return False
def clear_all_sessions(self) -> bool:
"""
Clear all provider sessions.
"""
if not self.supabase:
return False
try:
self.supabase.table("kaiapi_provider_sessions").delete().neq("id", "00000000-0000-0000-0000-000000000000").execute()
logger.info("Cleared all provider sessions")
return True
except Exception as e:
logger.error(f"Failed to clear sessions: {e}")
return False
def get_all_sessions(self) -> List[Dict[str, Any]]:
"""
Get all sessions.
"""
if not self.supabase:
return []
try:
response = self.supabase.table("kaiapi_provider_sessions").select("*").execute()
return response.data
except Exception as e:
logger.error(f"Failed to get all sessions: {e}")
return []
def needs_login(self, provider: str) -> bool:
"""
Check if provider needs login (no valid session).
"""
session = self.get_session(provider)
return session is None
# Global instance
_session_manager: Optional[ProviderSessionManager] = None
def get_provider_session_manager() -> ProviderSessionManager:
"""Get the global provider session manager instance."""
global _session_manager
if _session_manager is None:
_session_manager = ProviderSessionManager()
return _session_manager