added cargo files

This commit is contained in:
2026-03-03 10:57:43 -05:00
parent 478a90e01b
commit 169df46bc2
813 changed files with 227273 additions and 9 deletions

View File

@@ -0,0 +1,137 @@
#!/usr/bin/env python3
"""
Database Migration Runner for PinePods
This script can be run standalone to apply database migrations.
Useful for updating existing installations.
"""
import os
import sys
import logging
import argparse
from pathlib import Path
# Set up logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Add pinepods to path
pinepods_path = Path(__file__).parent.parent
sys.path.insert(0, str(pinepods_path))
def run_migrations(target_version=None, validate_only=False):
"""Run database migrations"""
try:
# Import migration system
import database_functions.migration_definitions
from database_functions.migrations import get_migration_manager, run_all_migrations
# Register all migrations
database_functions.migration_definitions.register_all_migrations()
# Get migration manager
manager = get_migration_manager()
if validate_only:
logger.info("Validating existing migrations...")
success = manager.validate_migrations()
if success:
logger.info("All migrations validated successfully")
else:
logger.error("Migration validation failed")
return success
# Show current state
applied = manager.get_applied_migrations()
logger.info(f"Currently applied migrations: {len(applied)}")
for version in applied:
logger.info(f" - {version}")
# Run migrations
logger.info("Starting migration process...")
success = run_all_migrations()
if success:
logger.info("All migrations completed successfully")
else:
logger.error("Migration process failed")
return success
except Exception as e:
logger.error(f"Migration failed: {e}")
return False
def list_migrations():
"""List all available migrations"""
try:
import database_functions.migration_definitions
from database_functions.migrations import get_migration_manager
# Register migrations
database_functions.migration_definitions.register_all_migrations()
# Get manager and list migrations
manager = get_migration_manager()
applied = set(manager.get_applied_migrations())
logger.info("Available migrations:")
for version, migration in sorted(manager.migrations.items()):
status = "APPLIED" if version in applied else "PENDING"
logger.info(f" {version} - {migration.name} [{status}]")
logger.info(f" {migration.description}")
if migration.requires:
logger.info(f" Requires: {', '.join(migration.requires)}")
return True
except Exception as e:
logger.error(f"Failed to list migrations: {e}")
return False
def main():
"""Main CLI interface"""
parser = argparse.ArgumentParser(description="PinePods Database Migration Tool")
parser.add_argument(
"command",
choices=["migrate", "list", "validate"],
help="Command to execute"
)
parser.add_argument(
"--target",
help="Target migration version (migrate only)"
)
parser.add_argument(
"--verbose", "-v",
action="store_true",
help="Enable verbose logging"
)
args = parser.parse_args()
if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
# Execute command
if args.command == "migrate":
success = run_migrations(args.target)
elif args.command == "list":
success = list_migrations()
elif args.command == "validate":
success = run_migrations(validate_only=True)
else:
logger.error(f"Unknown command: {args.command}")
success = False
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,530 @@
"""
Database Migration System for PinePods
This module provides a robust, idempotent migration framework that tracks
applied migrations and ensures database schema changes are applied safely.
"""
import logging
import os
import sys
from typing import Dict, List, Optional, Callable, Any
from dataclasses import dataclass
from datetime import datetime
import hashlib
# Add pinepods to path for imports
sys.path.append('/pinepods')
# Database imports
try:
import psycopg
POSTGRES_AVAILABLE = True
except ImportError:
POSTGRES_AVAILABLE = False
try:
import mariadb as mysql_connector
MYSQL_AVAILABLE = True
except ImportError:
try:
import mysql.connector
MYSQL_AVAILABLE = True
except ImportError:
MYSQL_AVAILABLE = False
logger = logging.getLogger(__name__)
@dataclass
class Migration:
"""Represents a single database migration"""
version: str
name: str
description: str
postgres_sql: Optional[str] = None
mysql_sql: Optional[str] = None
python_func: Optional[Callable] = None
requires: List[str] = None # List of migration versions this depends on
def __post_init__(self):
if self.requires is None:
self.requires = []
class DatabaseMigrationManager:
"""Manages database migrations with support for PostgreSQL and MySQL/MariaDB"""
def __init__(self, db_type: str, connection_params: Dict[str, Any]):
self.db_type = db_type.lower()
self.connection_params = connection_params
self.migrations: Dict[str, Migration] = {}
self._connection = None
# Validate database type
if self.db_type not in ['postgresql', 'postgres', 'mariadb', 'mysql']:
raise ValueError(f"Unsupported database type: {db_type}")
# Normalize database type
if self.db_type in ['postgres', 'postgresql']:
self.db_type = 'postgresql'
elif self.db_type in ['mysql', 'mariadb']:
self.db_type = 'mysql'
def get_connection(self):
"""Get database connection based on type"""
if self._connection:
return self._connection
if self.db_type == 'postgresql':
if not POSTGRES_AVAILABLE:
raise ImportError("psycopg not available for PostgreSQL connections")
self._connection = psycopg.connect(**self.connection_params)
elif self.db_type == 'mysql':
if not MYSQL_AVAILABLE:
raise ImportError("MariaDB/MySQL connector not available for MySQL connections")
# Use MariaDB connector parameters
mysql_params = self.connection_params.copy()
# Convert mysql.connector parameter names to mariadb parameter names
if 'connection_timeout' in mysql_params:
mysql_params['connect_timeout'] = mysql_params.pop('connection_timeout')
if 'charset' in mysql_params:
mysql_params.pop('charset') # MariaDB connector doesn't use charset parameter
if 'collation' in mysql_params:
mysql_params.pop('collation') # MariaDB connector doesn't use collation parameter
self._connection = mysql_connector.connect(**mysql_params)
return self._connection
def close_connection(self):
"""Close database connection"""
if self._connection:
self._connection.close()
self._connection = None
def register_migration(self, migration: Migration):
"""Register a migration to be tracked"""
self.migrations[migration.version] = migration
logger.info(f"Registered migration {migration.version}: {migration.name}")
def create_migration_table(self):
"""Create the migrations tracking table if it doesn't exist"""
conn = self.get_connection()
cursor = conn.cursor()
try:
if self.db_type == 'postgresql':
cursor.execute("""
CREATE TABLE IF NOT EXISTS "schema_migrations" (
version VARCHAR(255) PRIMARY KEY,
name VARCHAR(255) NOT NULL,
description TEXT,
checksum VARCHAR(64) NOT NULL,
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
execution_time_ms INTEGER
)
""")
else: # mysql
cursor.execute("""
CREATE TABLE IF NOT EXISTS schema_migrations (
version VARCHAR(255) PRIMARY KEY,
name VARCHAR(255) NOT NULL,
description TEXT,
checksum VARCHAR(64) NOT NULL,
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
execution_time_ms INTEGER
)
""")
conn.commit()
logger.info("Migration tracking table created/verified")
except Exception as e:
logger.error(f"Failed to create migration table: {e}")
conn.rollback()
raise
finally:
cursor.close()
def get_applied_migrations(self) -> List[str]:
"""Get list of applied migration versions"""
conn = self.get_connection()
cursor = conn.cursor()
try:
table_name = '"schema_migrations"' if self.db_type == 'postgresql' else 'schema_migrations'
cursor.execute(f"SELECT version FROM {table_name} ORDER BY applied_at")
return [row[0] for row in cursor.fetchall()]
except Exception as e:
# If table doesn't exist, return empty list
logger.warning(f"Could not get applied migrations: {e}")
return []
finally:
cursor.close()
def calculate_migration_checksum(self, migration: Migration) -> str:
"""Calculate checksum for migration content"""
content = ""
if migration.postgres_sql and self.db_type == 'postgresql':
content += migration.postgres_sql
elif migration.mysql_sql and self.db_type == 'mysql':
content += migration.mysql_sql
if migration.python_func:
content += migration.python_func.__code__.co_code.hex()
return hashlib.sha256(content.encode()).hexdigest()
def record_migration(self, migration: Migration, execution_time_ms: int):
"""Record a migration as applied"""
conn = self.get_connection()
cursor = conn.cursor()
try:
checksum = self.calculate_migration_checksum(migration)
table_name = '"schema_migrations"' if self.db_type == 'postgresql' else 'schema_migrations'
if self.db_type == 'postgresql':
cursor.execute(f"""
INSERT INTO {table_name} (version, name, description, checksum, execution_time_ms)
VALUES (%s, %s, %s, %s, %s)
ON CONFLICT (version) DO NOTHING
""", (migration.version, migration.name, migration.description, checksum, execution_time_ms))
else: # mysql
cursor.execute(f"""
INSERT IGNORE INTO {table_name} (version, name, description, checksum, execution_time_ms)
VALUES (%s, %s, %s, %s, %s)
""", (migration.version, migration.name, migration.description, checksum, execution_time_ms))
conn.commit()
logger.info(f"Recorded migration {migration.version} as applied")
except Exception as e:
logger.error(f"Failed to record migration {migration.version}: {e}")
conn.rollback()
raise
finally:
cursor.close()
def check_dependencies(self, migration: Migration, applied_migrations: List[str]) -> bool:
"""Check if migration dependencies are satisfied"""
for required_version in migration.requires:
if required_version not in applied_migrations:
logger.error(f"Migration {migration.version} requires {required_version} but it's not applied")
return False
return True
def execute_migration(self, migration: Migration) -> bool:
"""Execute a single migration"""
start_time = datetime.now()
conn = self.get_connection()
try:
# Choose appropriate SQL based on database type
sql = None
if self.db_type == 'postgresql' and migration.postgres_sql:
sql = migration.postgres_sql
elif self.db_type == 'mysql' and migration.mysql_sql:
sql = migration.mysql_sql
# Execute SQL if available
if sql:
cursor = conn.cursor()
try:
# Split and execute multiple statements
statements = [stmt.strip() for stmt in sql.split(';') if stmt.strip()]
for statement in statements:
cursor.execute(statement)
conn.commit()
logger.info(f"Executed SQL for migration {migration.version}")
except Exception as e:
conn.rollback()
raise
finally:
cursor.close()
# Execute Python function if available (this is the main path for our migrations)
if migration.python_func:
try:
migration.python_func(conn, self.db_type)
conn.commit()
logger.info(f"Executed Python function for migration {migration.version}")
except Exception as e:
conn.rollback()
raise
# Record successful migration
execution_time = int((datetime.now() - start_time).total_seconds() * 1000)
self.record_migration(migration, execution_time)
logger.info(f"Successfully applied migration {migration.version}: {migration.name}")
return True
except Exception as e:
logger.error(f"Failed to execute migration {migration.version}: {e}")
try:
conn.rollback()
except:
pass # Connection might already be closed
return False
def detect_existing_schema(self) -> List[str]:
"""Detect which migrations have already been applied based on existing schema"""
conn = self.get_connection()
cursor = conn.cursor()
applied = []
try:
# Check for tables that indicate migrations have been applied
checks = {
"001": ['"Users"', '"OIDCProviders"', '"APIKeys"', '"RssKeys"'],
"002": ['"AppSettings"', '"EmailSettings"'],
"003": ['"UserStats"', '"UserSettings"'],
"005": ['"Podcasts"', '"Episodes"', '"YouTubeVideos"'],
"006": ['"UserEpisodeHistory"', '"UserVideoHistory"'],
"007": ['"EpisodeQueue"', '"SavedEpisodes"', '"DownloadedEpisodes"']
}
for version, tables in checks.items():
all_exist = True
for table in tables:
if self.db_type == 'postgresql':
cursor.execute("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_schema = 'public' AND table_name = %s
)
""", (table.strip('"'),))
else: # mysql
cursor.execute("""
SELECT COUNT(*)
FROM information_schema.tables
WHERE table_schema = DATABASE() AND table_name = %s
""", (table,))
exists = cursor.fetchone()[0]
if not exists:
all_exist = False
break
if all_exist:
applied.append(version)
logger.info(f"Detected existing schema for migration {version}")
# Migration 004 is harder to detect, assume it's applied if 001-003 are
if "001" in applied and "003" in applied and "004" not in applied:
# Check if background_tasks user exists
table_name = '"Users"' if self.db_type == 'postgresql' else 'Users'
cursor.execute(f"SELECT COUNT(*) FROM {table_name} WHERE Username = %s", ('background_tasks',))
if cursor.fetchone()[0] > 0:
applied.append("004")
logger.info("Detected existing schema for migration 004")
# Check for gpodder tables - if ANY exist, ALL gpodder migrations are applied
# (since they were created by the Go gpodder-api service and haven't changed)
gpodder_indicator_tables = ['"GpodderSyncMigrations"', '"GpodderSyncDeviceState"',
'"GpodderSyncSubscriptions"', '"GpodderSyncSettings"',
'"GpodderSessions"', '"GpodderSyncState"']
gpodder_migration_versions = ["100", "101", "102", "103", "104"]
gpodder_tables_exist = False
for table in gpodder_indicator_tables:
table_name = table.strip('"')
if self.db_type == 'postgresql':
cursor.execute("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_schema = 'public' AND table_name = %s
)
""", (table_name,))
else: # mysql
cursor.execute("""
SELECT COUNT(*)
FROM information_schema.tables
WHERE table_schema = DATABASE() AND table_name = %s
""", (table_name,))
if cursor.fetchone()[0]:
gpodder_tables_exist = True
break
if gpodder_tables_exist:
for version in gpodder_migration_versions:
if version not in applied:
applied.append(version)
logger.info(f"Detected existing gpodder tables, marking migration {version} as applied")
# Check for PeopleEpisodes_backup table separately (migration 104)
backup_table = "PeopleEpisodes_backup"
if self.db_type == 'postgresql':
cursor.execute("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_schema = 'public' AND table_name = %s
)
""", (backup_table,))
else: # mysql
cursor.execute("""
SELECT COUNT(*)
FROM information_schema.tables
WHERE table_schema = DATABASE() AND table_name = %s
""", (backup_table,))
if cursor.fetchone()[0] and "104" not in applied:
applied.append("104")
logger.info("Detected existing PeopleEpisodes_backup table, marking migration 104 as applied")
return applied
except Exception as e:
logger.warning(f"Error detecting existing schema: {e}")
return []
finally:
cursor.close()
def run_migrations(self, target_version: Optional[str] = None) -> bool:
"""Run all pending migrations up to target version"""
try:
# Create migration table
self.create_migration_table()
# Get applied migrations
applied_migrations = self.get_applied_migrations()
logger.info(f"Found {len(applied_migrations)} applied migrations")
# If no migrations are recorded but we have existing schema, detect what's there
if not applied_migrations:
detected_migrations = self.detect_existing_schema()
if detected_migrations:
logger.info(f"Detected existing schema, marking {len(detected_migrations)} migrations as applied")
# Record detected migrations without executing them
for version in detected_migrations:
if version in self.migrations:
migration = self.migrations[version]
self.record_migration(migration, 0) # 0ms execution time for pre-existing
# Refresh applied migrations list
applied_migrations = self.get_applied_migrations()
# Sort migrations by version
pending_migrations = []
for version, migration in sorted(self.migrations.items()):
if version not in applied_migrations:
if target_version and version > target_version:
continue
pending_migrations.append(migration)
if not pending_migrations:
logger.info("No pending migrations to apply")
return True
logger.info(f"Found {len(pending_migrations)} pending migrations")
# Execute pending migrations
for migration in pending_migrations:
# Check dependencies
if not self.check_dependencies(migration, applied_migrations):
logger.error(f"Dependency check failed for migration {migration.version}")
return False
# Execute migration
if not self.execute_migration(migration):
logger.error(f"Failed to execute migration {migration.version}")
return False
# Add to applied list for dependency checking
applied_migrations.append(migration.version)
logger.info("All migrations applied successfully")
return True
except Exception as e:
logger.error(f"Migration run failed: {e}")
return False
finally:
self.close_connection()
def validate_migrations(self) -> bool:
"""Validate that applied migrations haven't changed"""
try:
conn = self.get_connection()
cursor = conn.cursor()
table_name = '"schema_migrations"' if self.db_type == 'postgresql' else 'schema_migrations'
cursor.execute(f"SELECT version, checksum FROM {table_name}")
applied_checksums = dict(cursor.fetchall())
validation_errors = []
for version, stored_checksum in applied_checksums.items():
if version in self.migrations:
current_checksum = self.calculate_migration_checksum(self.migrations[version])
if current_checksum != stored_checksum:
validation_errors.append(f"Migration {version} checksum mismatch")
if validation_errors:
for error in validation_errors:
logger.error(error)
return False
logger.info("All migration checksums validated successfully")
return True
except Exception as e:
logger.error(f"Migration validation failed: {e}")
return False
finally:
cursor.close()
# Migration manager instance (singleton pattern)
_migration_manager: Optional[DatabaseMigrationManager] = None
def get_migration_manager() -> DatabaseMigrationManager:
"""Get the global migration manager instance"""
global _migration_manager
if _migration_manager is None:
# Get database configuration from environment
db_type = os.environ.get("DB_TYPE", "postgresql")
if db_type.lower() in ['postgresql', 'postgres']:
connection_params = {
'host': os.environ.get("DB_HOST", "127.0.0.1"),
'port': int(os.environ.get("DB_PORT", "5432")),
'user': os.environ.get("DB_USER", "postgres"),
'password': os.environ.get("DB_PASSWORD", "password"),
'dbname': os.environ.get("DB_NAME", "pinepods_database")
}
else: # mysql/mariadb
connection_params = {
'host': os.environ.get("DB_HOST", "127.0.0.1"),
'port': int(os.environ.get("DB_PORT", "3306")),
'user': os.environ.get("DB_USER", "root"),
'password': os.environ.get("DB_PASSWORD", "password"),
'database': os.environ.get("DB_NAME", "pinepods_database"),
'charset': 'utf8mb4',
'collation': 'utf8mb4_general_ci'
}
_migration_manager = DatabaseMigrationManager(db_type, connection_params)
return _migration_manager
def register_migration(version: str, name: str, description: str, **kwargs):
"""Decorator to register a migration"""
def decorator(func):
migration = Migration(
version=version,
name=name,
description=description,
python_func=func,
**kwargs
)
get_migration_manager().register_migration(migration)
return func
return decorator
def run_all_migrations() -> bool:
"""Run all registered migrations"""
manager = get_migration_manager()
return manager.run_migrations()

View File

@@ -0,0 +1,795 @@
# tasks.py - Define Celery tasks with Valkey as broker
from celery import Celery
import time
import os
import asyncio
import datetime
import requests
from threading import Thread
import json
import sys
import logging
from typing import Dict, Any, Optional, List
# Make sure pinepods is in the Python path
sys.path.append('/pinepods')
database_type = str(os.getenv('DB_TYPE', 'mariadb'))
class Web_Key:
def __init__(self):
self.web_key = None
def get_web_key(self, cnx):
# Import only when needed to avoid circular imports
from database_functions.functions import get_web_key as get_key
self.web_key = get_key(cnx, database_type)
return self.web_key
base_webkey = Web_Key()
# Set up logging
logger = logging.getLogger("celery_tasks")
# Import the WebSocket manager directly from clientapi
try:
from clients.clientapi import manager as websocket_manager
print("Successfully imported WebSocket manager from clientapi")
except ImportError as e:
logger.error(f"Failed to import WebSocket manager: {e}")
websocket_manager = None
# Create a dedicated event loop thread for async operations
_event_loop = None
_event_loop_thread = None
def start_background_loop():
global _event_loop, _event_loop_thread
# Only start if not already running
if _event_loop is None:
# Function to run event loop in background thread
def run_event_loop():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
global _event_loop
_event_loop = loop
loop.run_forever()
# Start the background thread
_event_loop_thread = Thread(target=run_event_loop, daemon=True)
_event_loop_thread.start()
# Wait a moment for the loop to start
time.sleep(0.1)
print("Started background event loop for WebSocket broadcasts")
# Start the event loop when this module is imported
start_background_loop()
# Use the existing Valkey connection for Celery
valkey_host = os.environ.get("VALKEY_HOST", "localhost")
valkey_port = os.environ.get("VALKEY_PORT", "6379")
broker_url = f"redis://{valkey_host}:{valkey_port}/0"
backend_url = f"redis://{valkey_host}:{valkey_port}/0"
# Initialize Celery with Valkey as broker and result backend
celery_app = Celery('pinepods',
broker=broker_url,
backend=backend_url)
# Configure Celery for best performance with downloads
celery_app.conf.update(
worker_concurrency=3, # Limit to 3 concurrent downloads per worker
task_acks_late=True, # Only acknowledge tasks after they're done
task_time_limit=1800, # 30 minutes time limit
task_soft_time_limit=1500, # 25 minutes soft time limit
worker_prefetch_multiplier=1, # Don't prefetch more tasks than workers
)
# Task status tracking in Valkey for all types of tasks
class TaskManager:
def __init__(self):
from database_functions.valkey_client import valkey_client
self.valkey_client = valkey_client
def register_task(self, task_id: str, task_type: str, user_id: int, item_id: Optional[int] = None,
details: Optional[Dict[str, Any]] = None):
"""Register any Celery task for tracking"""
task_data = {
"task_id": task_id,
"user_id": user_id,
"type": task_type,
"item_id": item_id,
"progress": 0.0,
"status": "PENDING",
"details": details or {},
"started_at": datetime.datetime.now().isoformat()
}
self.valkey_client.set(f"task:{task_id}", json.dumps(task_data))
# Set TTL for 24 hours
self.valkey_client.expire(f"task:{task_id}", 86400)
# Add to user's active tasks list
self._add_to_user_tasks(user_id, task_id)
# Try to broadcast the update if the WebSocket module is available
try:
self._broadcast_update(task_id)
except Exception as e:
logger.error(f"Error broadcasting task update: {e}")
def update_task(self, task_id: str, progress: float = None, status: str = None,
details: Dict[str, Any] = None):
"""Update any task's status and progress"""
task_json = self.valkey_client.get(f"task:{task_id}")
if task_json:
task = json.loads(task_json)
if progress is not None:
task["progress"] = progress
if status:
task["status"] = status
if details:
if "details" not in task:
task["details"] = {}
task["details"].update(details)
self.valkey_client.set(f"task:{task_id}", json.dumps(task))
# Try to broadcast the update
try:
self._broadcast_update(task_id)
except Exception as e:
logger.error(f"Error broadcasting task update: {e}")
def complete_task(self, task_id: str, success: bool = True, result: Any = None):
"""Mark any task as complete or failed"""
task_json = self.valkey_client.get(f"task:{task_id}")
if task_json:
task = json.loads(task_json)
task["progress"] = 100.0 if success else 0.0
task["status"] = "SUCCESS" if success else "FAILED"
task["completed_at"] = datetime.datetime.now().isoformat()
if result is not None:
task["result"] = result
self.valkey_client.set(f"task:{task_id}", json.dumps(task))
# Try to broadcast the final update
try:
self._broadcast_update(task_id)
except Exception as e:
logger.error(f"Error broadcasting task update: {e}")
# Keep completed tasks for 1 hour before expiring
self.valkey_client.expire(f"task:{task_id}", 3600)
# Remove from user's active tasks list after completion
if success:
self._remove_from_user_tasks(task.get("user_id"), task_id)
def get_task(self, task_id: str) -> Dict[str, Any]:
"""Get any task's details"""
task_json = self.valkey_client.get(f"task:{task_id}")
if task_json:
return json.loads(task_json)
return {}
def get_user_tasks(self, user_id: int) -> List[Dict[str, Any]]:
"""Get all active tasks for a user (all types)"""
tasks_list_json = self.valkey_client.get(f"user_tasks:{user_id}")
result = []
if tasks_list_json:
task_ids = json.loads(tasks_list_json)
for task_id in task_ids:
task_info = self.get_task(task_id)
if task_info:
result.append(task_info)
return result
def _add_to_user_tasks(self, user_id: int, task_id: str):
"""Add a task to the user's active tasks list"""
tasks_list_json = self.valkey_client.get(f"user_tasks:{user_id}")
if tasks_list_json:
tasks_list = json.loads(tasks_list_json)
if task_id not in tasks_list:
tasks_list.append(task_id)
else:
tasks_list = [task_id]
self.valkey_client.set(f"user_tasks:{user_id}", json.dumps(tasks_list))
# Set TTL for 7 days
self.valkey_client.expire(f"user_tasks:{user_id}", 604800)
def _remove_from_user_tasks(self, user_id: int, task_id: str):
"""Remove a task from the user's active tasks list"""
tasks_list_json = self.valkey_client.get(f"user_tasks:{user_id}")
if tasks_list_json:
tasks_list = json.loads(tasks_list_json)
if task_id in tasks_list:
tasks_list.remove(task_id)
self.valkey_client.set(f"user_tasks:{user_id}", json.dumps(tasks_list))
# Modified _broadcast_update method to avoid circular imports
def _broadcast_update(self, task_id: str):
"""Broadcast task update via HTTP endpoint"""
# Get task info
task_info = self.get_task(task_id)
if not task_info or "user_id" not in task_info:
return
user_id = task_info["user_id"]
cnx = None
try:
cnx = get_direct_db_connection()
# Import broadcaster - delay import to avoid circular dependency
sys.path.insert(0, '/pinepods/database_functions')
try:
from websocket_broadcaster import broadcaster
except ImportError:
try:
from database_functions.websocket_broadcaster import broadcaster
except ImportError as e:
print(f"Cannot import broadcaster from any location: {e}")
return
# Get web key
web_key = None
try:
# Get web key using class method to avoid direct import
if not base_webkey.web_key:
base_webkey.get_web_key(cnx)
web_key = base_webkey.web_key
except Exception as e:
print(f"Error getting web key: {str(e)}")
# Fallback to a direct approach if needed
try:
from database_functions.functions import get_web_key
web_key = get_web_key(cnx, database_type)
except Exception as e2:
print(f"Fallback web key retrieval failed: {str(e2)}")
return
# Progress and status details for better debugging
progress = task_info.get("progress", 0)
status = task_info.get("status", "unknown")
print(f"Broadcasting task update for user {user_id}, task {task_id}, progress: {progress}, status: {status}")
# Broadcast the update
result = broadcaster.broadcast_task_update(user_id, task_info, web_key)
if result:
print(f"Successfully broadcast task update for task {task_id}, progress: {progress}%")
else:
print(f"Failed to broadcast task update for task {task_id}, progress: {progress}%")
except Exception as e:
print(f"Error in task broadcast setup: {str(e)}")
finally:
if cnx:
# Close direct connection
close_direct_db_connection(cnx)
# Initialize a general task manager
task_manager = TaskManager()
# For backwards compatibility, keep the download_manager name too
download_manager = task_manager
# Function to get all active tasks including both downloads and other task types
def get_all_active_tasks(user_id: int) -> List[Dict[str, Any]]:
"""Get all active tasks for a user (all types)"""
return task_manager.get_user_tasks(user_id)
# ----------------------
# IMPROVED CONNECTION HANDLING
# ----------------------
def get_direct_db_connection():
"""
Create a direct database connection instead of using the pool
This is more reliable for Celery workers to avoid pool exhaustion
"""
db_host = os.environ.get("DB_HOST", "127.0.0.1")
db_port = os.environ.get("DB_PORT", "3306")
db_user = os.environ.get("DB_USER", "root")
db_password = os.environ.get("DB_PASSWORD", "password")
db_name = os.environ.get("DB_NAME", "pypods_database")
print(f"Creating direct database connection for task")
if database_type == "postgresql":
import psycopg
conninfo = f"host={db_host} port={db_port} user={db_user} password={db_password} dbname={db_name}"
return psycopg.connect(conninfo)
else: # Default to MariaDB/MySQL
try:
import mariadb as mysql_connector
except ImportError:
import mysql.connector
return mysql_connector.connect(
host=db_host,
port=db_port,
user=db_user,
password=db_password,
database=db_name
)
def close_direct_db_connection(cnx):
"""Close a direct database connection"""
if cnx:
try:
cnx.close()
print("Direct database connection closed")
except Exception as e:
print(f"Error closing direct connection: {str(e)}")
# Minimal changes to download_podcast_task that should work right away
@celery_app.task(bind=True, max_retries=3)
def download_podcast_task(self, episode_id: int, user_id: int, database_type: str):
"""
Celery task to download a podcast episode.
Uses retries with exponential backoff for handling transient failures.
"""
task_id = self.request.id
print(f"DOWNLOAD TASK STARTED: ID={task_id}, Episode={episode_id}, User={user_id}")
cnx = None
try:
# Get a direct connection to fetch the title first
cnx = get_direct_db_connection()
cursor = cnx.cursor()
# Get the episode title and podcast name
if database_type == "postgresql":
# First try to get both the episode title and podcast name
query = '''
SELECT e."episodetitle", p."podcastname"
FROM "Episodes" e
JOIN "Podcasts" p ON e."podcastid" = p."podcastid"
WHERE e."episodeid" = %s
'''
else:
query = '''
SELECT e.EpisodeTitle, p.PodcastName
FROM Episodes e
JOIN Podcasts p ON e.PodcastID = p.PodcastID
WHERE e.EpisodeID = %s
'''
cursor.execute(query, (episode_id,))
result = cursor.fetchone()
cursor.close()
# Extract episode title and podcast name
episode_title = None
podcast_name = None
if result:
if isinstance(result, dict):
# Dictionary result
if "episodetitle" in result: # PostgreSQL lowercase
episode_title = result["episodetitle"]
podcast_name = result.get("podcastname")
else: # MariaDB uppercase
episode_title = result["EpisodeTitle"]
podcast_name = result.get("PodcastName")
else:
# Tuple result
episode_title = result[0] if len(result) > 0 else None
podcast_name = result[1] if len(result) > 1 else None
# Format a nice display title
display_title = "Unknown Episode"
if episode_title and episode_title != "None" and episode_title.strip():
display_title = episode_title
elif podcast_name:
display_title = f"{podcast_name} - Episode"
else:
display_title = f"Episode #{episode_id}"
print(f"Using display title for episode {episode_id}: {display_title}")
# Register task with more details
task_manager.register_task(
task_id=task_id,
task_type="podcast_download",
user_id=user_id,
item_id=episode_id,
details={
"episode_id": episode_id,
"episode_title": display_title,
"status_text": f"Preparing to download {display_title}"
}
)
# Define a progress callback with the display title
def progress_callback(progress, status=None):
status_message = ""
if status == "DOWNLOADING":
status_message = f"Downloading {display_title}"
elif status == "PROCESSING":
status_message = f"Processing {display_title}"
elif status == "FINALIZING":
status_message = f"Finalizing {display_title}"
task_manager.update_task(task_id, progress, status, {
"episode_id": episode_id,
"episode_title": display_title,
"status_text": status_message
})
# Close the connection used for title lookup
close_direct_db_connection(cnx)
# Get a fresh connection for the download
cnx = get_direct_db_connection()
# Import the download function
from database_functions.functions import download_podcast
print(f"Starting download for episode: {episode_id} ({display_title}), user: {user_id}, task: {task_id}")
# Execute the download with progress reporting
success = download_podcast(
cnx,
database_type,
episode_id,
user_id,
task_id,
progress_callback=progress_callback
)
# Mark task as complete with a nice message
completion_message = f"{'Successfully downloaded' if success else 'Failed to download'} {display_title}"
task_manager.complete_task(
task_id,
success,
{
"episode_id": episode_id,
"episode_title": display_title,
"status_text": completion_message
}
)
return success
except Exception as exc:
print(f"Error downloading podcast {episode_id}: {str(exc)}")
# Mark task as failed
task_manager.complete_task(
task_id,
False,
{
"episode_id": episode_id,
"episode_title": f"Episode #{episode_id}",
"status_text": f"Download failed: {str(exc)}"
}
)
# Retry with exponential backoff (5s, 25s, 125s)
countdown = 5 * (2 ** self.request.retries)
self.retry(exc=exc, countdown=countdown)
finally:
# Always close the connection
if cnx:
close_direct_db_connection(cnx)
@celery_app.task(bind=True, max_retries=3)
def download_youtube_video_task(self, video_id: int, user_id: int, database_type: str):
"""
Celery task to download a YouTube video.
Uses retries with exponential backoff for handling transient failures.
"""
task_id = self.request.id
print(f"YOUTUBE DOWNLOAD TASK STARTED: ID={task_id}, Video={video_id}, User={user_id}")
cnx = None
try:
# Get a direct connection to fetch the title first
cnx = get_direct_db_connection()
cursor = cnx.cursor()
# Get the video title and channel name
if database_type == "postgresql":
# First try to get both the video title and channel name
query = '''
SELECT v."videotitle", p."podcastname"
FROM "YouTubeVideos" v
JOIN "Podcasts" p ON v."podcastid" = p."podcastid"
WHERE v."videoid" = %s
'''
else:
query = '''
SELECT v.VideoTitle, p.PodcastName
FROM YouTubeVideos v
JOIN Podcasts p ON v.PodcastID = p.PodcastID
WHERE v.VideoID = %s
'''
cursor.execute(query, (video_id,))
result = cursor.fetchone()
cursor.close()
# Extract video title and channel name
video_title = None
channel_name = None
if result:
if isinstance(result, dict):
# Dictionary result
if "videotitle" in result: # PostgreSQL lowercase
video_title = result["videotitle"]
channel_name = result.get("podcastname")
else: # MariaDB uppercase
video_title = result["VideoTitle"]
channel_name = result.get("PodcastName")
else:
# Tuple result
video_title = result[0] if len(result) > 0 else None
channel_name = result[1] if len(result) > 1 else None
# Format a nice display title
display_title = "Unknown Video"
if video_title and video_title != "None" and video_title.strip():
display_title = video_title
elif channel_name:
display_title = f"{channel_name} - Video"
else:
display_title = f"YouTube Video #{video_id}"
print(f"Using display title for video {video_id}: {display_title}")
# Register task with more details
task_manager.register_task(
task_id=task_id,
task_type="youtube_download",
user_id=user_id,
item_id=video_id,
details={
"item_id": video_id,
"item_title": display_title,
"status_text": f"Preparing to download {display_title}"
}
)
# Close the connection used for title lookup
close_direct_db_connection(cnx)
# Get a fresh connection for the download
cnx = get_direct_db_connection()
# Import the download function
from database_functions.functions import download_youtube_video
print(f"Starting download for YouTube video: {video_id} ({display_title}), user: {user_id}, task: {task_id}")
# Define a progress callback with the display title
def progress_callback(progress, status=None):
status_message = ""
if status == "DOWNLOADING":
status_message = f"Downloading {display_title}"
elif status == "PROCESSING":
status_message = f"Processing {display_title}"
elif status == "FINALIZING":
status_message = f"Finalizing {display_title}"
task_manager.update_task(task_id, progress, status, {
"item_id": video_id,
"item_title": display_title,
"status_text": status_message
})
# Check if the download_youtube_video function accepts progress_callback parameter
import inspect
try:
signature = inspect.signature(download_youtube_video)
has_progress_callback = 'progress_callback' in signature.parameters
except (TypeError, ValueError):
has_progress_callback = False
# Execute the download with progress callback if supported, otherwise without it
if has_progress_callback:
success = download_youtube_video(
cnx,
database_type,
video_id,
user_id,
task_id,
progress_callback=progress_callback
)
else:
# Call without the progress_callback parameter
success = download_youtube_video(
cnx,
database_type,
video_id,
user_id,
task_id
)
# Since we can't use progress callbacks directly, manually update progress after completion
task_manager.update_task(task_id, 100 if success else 0,
"SUCCESS" if success else "FAILED",
{
"item_id": video_id,
"item_title": display_title,
"status_text": f"{'Download complete' if success else 'Download failed'}"
})
# Mark task as complete with a nice message
completion_message = f"{'Successfully downloaded' if success else 'Failed to download'} {display_title}"
task_manager.complete_task(
task_id,
success,
{
"item_id": video_id,
"item_title": display_title,
"status_text": completion_message
}
)
return success
except Exception as exc:
print(f"Error downloading YouTube video {video_id}: {str(exc)}")
# Mark task as failed but include video title in the details
task_manager.complete_task(
task_id,
False,
{
"item_id": video_id,
"item_title": f"YouTube Video #{video_id}",
"status_text": f"Download failed: {str(exc)}"
}
)
# Retry with exponential backoff (5s, 25s, 125s)
countdown = 5 * (2 ** self.request.retries)
self.retry(exc=exc, countdown=countdown)
finally:
# Always close the connection
if cnx:
close_direct_db_connection(cnx)
@celery_app.task
def queue_podcast_downloads(podcast_id: int, user_id: int, database_type: str, is_youtube: bool = False):
"""
Task to queue individual download tasks for all episodes/videos in a podcast.
This adds downloads to the queue in small batches to prevent overwhelming the system.
"""
cnx = None
try:
# Get a direct connection
cnx = get_direct_db_connection()
from database_functions.functions import (
get_episode_ids_for_podcast,
get_video_ids_for_podcast,
check_downloaded
)
if is_youtube:
item_ids = get_video_ids_for_podcast(cnx, database_type, podcast_id)
print(f"Queueing {len(item_ids)} YouTube videos for download")
# Process YouTube items in batches
batch_size = 5
for i in range(0, len(item_ids), batch_size):
batch = item_ids[i:i+batch_size]
for item_id in batch:
if not check_downloaded(cnx, database_type, user_id, item_id, is_youtube):
download_youtube_video_task.delay(item_id, user_id, database_type)
# Add a small delay between batches
if i + batch_size < len(item_ids):
time.sleep(2)
else:
# Get episode IDs (should return dicts with id and title)
episodes = get_episode_ids_for_podcast(cnx, database_type, podcast_id)
print(f"Queueing {len(episodes)} podcast episodes for download")
# Process episodes in batches
batch_size = 5
for i in range(0, len(episodes), batch_size):
batch = episodes[i:i+batch_size]
for episode in batch:
# Handle both possible formats (dict or simple ID)
if isinstance(episode, dict) and "id" in episode:
episode_id = episode["id"]
else:
# Fall back to treating it as just an ID
episode_id = episode
if not check_downloaded(cnx, database_type, user_id, episode_id, is_youtube):
# Pass just the ID, the task will look up the title
download_podcast_task.delay(episode_id, user_id, database_type)
# Add a small delay between batches
if i + batch_size < len(episodes):
time.sleep(2)
return f"Queued {len(episodes if not is_youtube else item_ids)} items for download"
finally:
if cnx:
close_direct_db_connection(cnx)
# Helper task to clean up old download records
@celery_app.task
def cleanup_old_downloads():
"""
Periodic task to clean up old download records from Valkey
"""
from database_functions.valkey_client import valkey_client
# This would need to be implemented with a scan operation
# For simplicity, we rely on Redis/Valkey TTL mechanisms
print("Running download cleanup task")
# Example task for refreshing podcast feeds
@celery_app.task(bind=True, max_retries=2)
def refresh_feed_task(self, user_id: int, database_type: str):
"""
Celery task to refresh podcast feeds for a user.
"""
task_id = self.request.id
cnx = None
try:
# Register task
task_manager.register_task(
task_id=task_id,
task_type="feed_refresh",
user_id=user_id,
details={"description": "Refreshing podcast feeds"}
)
# Get a direct database connection
cnx = get_direct_db_connection()
# Get list of podcasts to refresh
# Then update progress as each one completes
try:
# Here you would have your actual feed refresh implementation
# with periodic progress updates
task_manager.update_task(task_id, 10, "PROGRESS", {"status_text": "Fetching podcast list"})
# Simulate feed refresh process with progress updates
# Replace with your actual implementation
total_podcasts = 10 # Example count
for i in range(total_podcasts):
# Update progress for each podcast
progress = (i + 1) / total_podcasts * 100
task_manager.update_task(
task_id,
progress,
"PROGRESS",
{"status_text": f"Refreshing podcast {i+1}/{total_podcasts}"}
)
# Simulated work - replace with actual refresh logic
time.sleep(0.5)
# Complete the task
task_manager.complete_task(task_id, True, {"refreshed_count": total_podcasts})
return True
except Exception as e:
raise e
except Exception as exc:
print(f"Error refreshing feeds for user {user_id}: {str(exc)}")
task_manager.complete_task(task_id, False, {"error": str(exc)})
self.retry(exc=exc, countdown=30)
finally:
# Always close the connection
if cnx:
close_direct_db_connection(cnx)
# Simple debug task
@celery_app.task
def debug_task(x, y):
"""Simple debug task that prints output"""
result = x + y
print(f"CELERY DEBUG TASK EXECUTED: {x} + {y} = {result}")
return result

View File

@@ -0,0 +1,778 @@
#!/usr/bin/env python3
"""
Database Validator for PinePods
This script validates that an existing database matches the expected schema
by using the migration system as the source of truth.
Usage:
python validate_database.py --db-type mysql --db-host localhost --db-port 3306 --db-user root --db-password pass --db-name pinepods_database
python validate_database.py --db-type postgresql --db-host localhost --db-port 5432 --db-user postgres --db-password pass --db-name pinepods_database
"""
import argparse
import sys
import os
import tempfile
import logging
from typing import Dict, List, Set, Tuple, Any, Optional
from dataclasses import dataclass
import importlib.util
# Add the parent directory to path so we can import database_functions
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, parent_dir)
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
try:
import mysql.connector
MYSQL_AVAILABLE = True
except ImportError:
MYSQL_AVAILABLE = False
try:
import psycopg
POSTGRESQL_AVAILABLE = True
except ImportError:
POSTGRESQL_AVAILABLE = False
from database_functions.migrations import get_migration_manager
@dataclass
class TableInfo:
"""Information about a database table"""
name: str
columns: Dict[str, Dict[str, Any]]
indexes: Dict[str, Dict[str, Any]]
constraints: Dict[str, Dict[str, Any]]
@dataclass
class ValidationResult:
"""Result of database validation"""
is_valid: bool
missing_tables: List[str]
extra_tables: List[str]
table_differences: Dict[str, Dict[str, Any]]
missing_indexes: List[Tuple[str, str]] # (table, index)
extra_indexes: List[Tuple[str, str]]
missing_constraints: List[Tuple[str, str]] # (table, constraint)
extra_constraints: List[Tuple[str, str]]
column_differences: Dict[str, Dict[str, Dict[str, Any]]] # table -> column -> differences
class DatabaseInspector:
"""Base class for database inspection"""
def __init__(self, connection):
self.connection = connection
def get_tables(self) -> Set[str]:
"""Get all table names"""
raise NotImplementedError
def get_table_info(self, table_name: str) -> TableInfo:
"""Get detailed information about a table"""
raise NotImplementedError
def get_all_table_info(self) -> Dict[str, TableInfo]:
"""Get information about all tables"""
tables = {}
for table_name in self.get_tables():
tables[table_name] = self.get_table_info(table_name)
return tables
class MySQLInspector(DatabaseInspector):
"""MySQL database inspector"""
def get_tables(self) -> Set[str]:
cursor = self.connection.cursor()
cursor.execute("SHOW TABLES")
tables = {row[0] for row in cursor.fetchall()}
cursor.close()
return tables
def get_table_info(self, table_name: str) -> TableInfo:
cursor = self.connection.cursor(dictionary=True)
# Get column information
cursor.execute(f"DESCRIBE `{table_name}`")
columns = {}
for row in cursor.fetchall():
columns[row['Field']] = {
'type': row['Type'],
'null': row['Null'],
'key': row['Key'],
'default': row['Default'],
'extra': row['Extra']
}
# Get index information
cursor.execute(f"SHOW INDEX FROM `{table_name}`")
indexes = {}
for row in cursor.fetchall():
index_name = row['Key_name']
if index_name not in indexes:
indexes[index_name] = {
'columns': [],
'unique': not row['Non_unique'],
'type': row['Index_type']
}
indexes[index_name]['columns'].append(row['Column_name'])
# Get constraint information (foreign keys, etc.)
cursor.execute(f"""
SELECT kcu.CONSTRAINT_NAME, tc.CONSTRAINT_TYPE, kcu.COLUMN_NAME,
kcu.REFERENCED_TABLE_NAME, kcu.REFERENCED_COLUMN_NAME
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu
JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc
ON kcu.CONSTRAINT_NAME = tc.CONSTRAINT_NAME
AND kcu.TABLE_SCHEMA = tc.TABLE_SCHEMA
WHERE kcu.TABLE_SCHEMA = DATABASE() AND kcu.TABLE_NAME = %s
AND kcu.REFERENCED_TABLE_NAME IS NOT NULL
""", (table_name,))
constraints = {}
for row in cursor.fetchall():
constraint_name = row['CONSTRAINT_NAME']
constraints[constraint_name] = {
'type': 'FOREIGN KEY',
'column': row['COLUMN_NAME'],
'referenced_table': row['REFERENCED_TABLE_NAME'],
'referenced_column': row['REFERENCED_COLUMN_NAME']
}
cursor.close()
return TableInfo(table_name, columns, indexes, constraints)
class PostgreSQLInspector(DatabaseInspector):
"""PostgreSQL database inspector"""
def get_tables(self) -> Set[str]:
cursor = self.connection.cursor()
cursor.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public' AND table_type = 'BASE TABLE'
""")
tables = {row[0] for row in cursor.fetchall()}
cursor.close()
return tables
def get_table_info(self, table_name: str) -> TableInfo:
cursor = self.connection.cursor()
# Get column information
cursor.execute("""
SELECT column_name, data_type, is_nullable, column_default,
character_maximum_length, numeric_precision, numeric_scale
FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = %s
ORDER BY ordinal_position
""", (table_name,))
columns = {}
for row in cursor.fetchall():
col_name, data_type, is_nullable, default, max_length, precision, scale = row
type_str = data_type
if max_length:
type_str += f"({max_length})"
elif precision:
if scale:
type_str += f"({precision},{scale})"
else:
type_str += f"({precision})"
columns[col_name] = {
'type': type_str,
'null': is_nullable,
'default': default,
'max_length': max_length,
'precision': precision,
'scale': scale
}
# Get index information
cursor.execute("""
SELECT i.relname as index_name,
array_agg(a.attname ORDER BY c.ordinality) as columns,
ix.indisunique as is_unique,
ix.indisprimary as is_primary
FROM pg_class t
JOIN pg_index ix ON t.oid = ix.indrelid
JOIN pg_class i ON i.oid = ix.indexrelid
JOIN unnest(ix.indkey) WITH ORDINALITY c(colnum, ordinality) ON true
JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = c.colnum
WHERE t.relname = %s AND t.relkind = 'r'
GROUP BY i.relname, ix.indisunique, ix.indisprimary
""", (table_name,))
indexes = {}
for row in cursor.fetchall():
index_name, columns_list, is_unique, is_primary = row
indexes[index_name] = {
'columns': columns_list,
'unique': is_unique,
'primary': is_primary
}
# Get constraint information
cursor.execute("""
SELECT con.conname as constraint_name,
con.contype as constraint_type,
array_agg(att.attname) as columns,
cl.relname as referenced_table,
array_agg(att2.attname) as referenced_columns
FROM pg_constraint con
JOIN pg_class t ON con.conrelid = t.oid
JOIN pg_attribute att ON att.attrelid = t.oid AND att.attnum = ANY(con.conkey)
LEFT JOIN pg_class cl ON con.confrelid = cl.oid
LEFT JOIN pg_attribute att2 ON att2.attrelid = cl.oid AND att2.attnum = ANY(con.confkey)
WHERE t.relname = %s
GROUP BY con.conname, con.contype, cl.relname
""", (table_name,))
constraints = {}
for row in cursor.fetchall():
constraint_name, constraint_type, columns_list, ref_table, ref_columns = row
constraints[constraint_name] = {
'type': constraint_type,
'columns': columns_list,
'referenced_table': ref_table,
'referenced_columns': ref_columns
}
cursor.close()
return TableInfo(table_name, columns, indexes, constraints)
class DatabaseValidator:
"""Main database validator class"""
def __init__(self, db_type: str, db_config: Dict[str, Any]):
self.db_type = db_type.lower()
# Normalize mariadb to mysql since they use the same connector
if self.db_type == 'mariadb':
self.db_type = 'mysql'
self.db_config = db_config
self.logger = logging.getLogger(__name__)
def create_test_database(self) -> Tuple[Any, str]:
"""Create a temporary database and run all migrations"""
if self.db_type == 'mysql':
return self._create_mysql_test_db()
elif self.db_type == 'postgresql':
return self._create_postgresql_test_db()
else:
raise ValueError(f"Unsupported database type: {self.db_type}")
def _create_mysql_test_db(self) -> Tuple[Any, str]:
"""Create MySQL test database"""
if not MYSQL_AVAILABLE:
raise ImportError("mysql-connector-python is required for MySQL validation")
# Create temporary database name
import uuid
test_db_name = f"pinepods_test_{uuid.uuid4().hex[:8]}"
# Connect to MySQL server
config = self.db_config.copy()
config.pop('database', None) # Remove database from config
config['use_pure'] = True # Use pure Python implementation to avoid auth plugin issues
conn = mysql.connector.connect(**config)
cursor = conn.cursor()
try:
# Create test database
cursor.execute(f"CREATE DATABASE `{test_db_name}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci")
cursor.execute(f"USE `{test_db_name}`")
cursor.close()
# Run all migrations
self._run_migrations(conn, 'mysql')
# Create a fresh connection to the test database for schema inspection
config['database'] = test_db_name
test_conn = mysql.connector.connect(**config)
# Close the migration connection
conn.close()
return test_conn, test_db_name
except Exception as e:
if cursor:
cursor.close()
if conn:
conn.close()
raise e
def _create_postgresql_test_db(self) -> Tuple[Any, str]:
"""Create PostgreSQL test database"""
if not POSTGRESQL_AVAILABLE:
raise ImportError("psycopg is required for PostgreSQL validation")
# Create temporary database name
import uuid
test_db_name = f"pinepods_test_{uuid.uuid4().hex[:8]}"
# Connect to PostgreSQL server
config = self.db_config.copy()
config.pop('dbname', None) # Remove database from config
config['dbname'] = 'postgres' # Connect to default database
conn = psycopg.connect(**config)
conn.autocommit = True
cursor = conn.cursor()
try:
# Create test database
cursor.execute(f'CREATE DATABASE "{test_db_name}"')
cursor.close()
conn.close()
# Connect to the new test database
config['dbname'] = test_db_name
test_conn = psycopg.connect(**config)
test_conn.autocommit = True
# Run all migrations
self._run_migrations(test_conn, 'postgresql')
return test_conn, test_db_name
except Exception as e:
cursor.close()
conn.close()
raise e
def _run_migrations(self, conn: Any, db_type: str):
"""Run all migrations on the test database using existing migration system"""
# Set environment variables for the migration manager
import os
original_env = {}
try:
# Backup original environment
for key in ['DB_TYPE', 'DB_HOST', 'DB_PORT', 'DB_USER', 'DB_PASSWORD', 'DB_NAME']:
original_env[key] = os.environ.get(key)
# Set environment for test database
if db_type == 'mysql':
os.environ['DB_TYPE'] = 'mysql'
os.environ['DB_HOST'] = 'localhost' # We'll override the connection
os.environ['DB_PORT'] = '3306'
os.environ['DB_USER'] = 'test'
os.environ['DB_PASSWORD'] = 'test'
os.environ['DB_NAME'] = 'test'
else:
os.environ['DB_TYPE'] = 'postgresql'
os.environ['DB_HOST'] = 'localhost'
os.environ['DB_PORT'] = '5432'
os.environ['DB_USER'] = 'test'
os.environ['DB_PASSWORD'] = 'test'
os.environ['DB_NAME'] = 'test'
# Import and register migrations
import database_functions.migration_definitions
# Get migration manager and override its connection
manager = get_migration_manager()
manager._connection = conn
# Run all migrations
success = manager.run_migrations()
if not success:
raise RuntimeError("Failed to apply migrations")
finally:
# Restore original environment
for key, value in original_env.items():
if value is not None:
os.environ[key] = value
elif key in os.environ:
del os.environ[key]
def validate_database(self) -> ValidationResult:
"""Validate the actual database against the expected schema"""
# Create test database with perfect schema
test_conn, test_db_name = self.create_test_database()
try:
# Connect to actual database
actual_conn = self._connect_to_actual_database()
try:
# Get schema information from both databases
if self.db_type == 'mysql':
expected_inspector = MySQLInspector(test_conn)
actual_inspector = MySQLInspector(actual_conn)
# Extract schemas
expected_schema = expected_inspector.get_all_table_info()
actual_schema = actual_inspector.get_all_table_info()
else:
# For PostgreSQL, create fresh connection for expected schema since migration manager closes it
fresh_test_conn = psycopg.connect(
host=self.db_config['host'],
port=self.db_config['port'],
user=self.db_config['user'],
password=self.db_config['password'],
dbname=test_db_name
)
fresh_test_conn.autocommit = True
try:
expected_inspector = PostgreSQLInspector(fresh_test_conn)
actual_inspector = PostgreSQLInspector(actual_conn)
# Extract schemas
expected_schema = expected_inspector.get_all_table_info()
actual_schema = actual_inspector.get_all_table_info()
finally:
fresh_test_conn.close()
# DEBUG: Print what we're actually comparing
print(f"\n🔍 DEBUG: Expected schema has {len(expected_schema)} tables:")
for table in sorted(expected_schema.keys()):
cols = list(expected_schema[table].columns.keys())
print(f" {table}: {len(cols)} columns - {', '.join(cols[:5])}{'...' if len(cols) > 5 else ''}")
print(f"\n🔍 DEBUG: Actual schema has {len(actual_schema)} tables:")
for table in sorted(actual_schema.keys()):
cols = list(actual_schema[table].columns.keys())
print(f" {table}: {len(cols)} columns - {', '.join(cols[:5])}{'...' if len(cols) > 5 else ''}")
# Check specifically for Playlists table
if 'Playlists' in expected_schema and 'Playlists' in actual_schema:
exp_cols = set(expected_schema['Playlists'].columns.keys())
act_cols = set(actual_schema['Playlists'].columns.keys())
print(f"\n🔍 DEBUG: Playlists comparison:")
print(f" Expected columns: {sorted(exp_cols)}")
print(f" Actual columns: {sorted(act_cols)}")
print(f" Missing from actual: {sorted(exp_cols - act_cols)}")
print(f" Extra in actual: {sorted(act_cols - exp_cols)}")
# Compare schemas
result = self._compare_schemas(expected_schema, actual_schema)
return result
finally:
actual_conn.close()
finally:
# Clean up test database - this will close test_conn
self._cleanup_test_database(test_conn, test_db_name)
def _connect_to_actual_database(self) -> Any:
"""Connect to the actual database"""
if self.db_type == 'mysql':
config = self.db_config.copy()
# Ensure autocommit is enabled for MySQL
config['autocommit'] = True
config['use_pure'] = True # Use pure Python implementation to avoid auth plugin issues
return mysql.connector.connect(**config)
else:
return psycopg.connect(**self.db_config)
def _cleanup_test_database(self, test_conn: Any, test_db_name: str):
"""Clean up the test database"""
try:
# Close the test connection first
if test_conn:
test_conn.close()
if self.db_type == 'mysql':
config = self.db_config.copy()
config.pop('database', None)
config['use_pure'] = True # Use pure Python implementation to avoid auth plugin issues
cleanup_conn = mysql.connector.connect(**config)
cursor = cleanup_conn.cursor()
cursor.execute(f"DROP DATABASE IF EXISTS `{test_db_name}`")
cursor.close()
cleanup_conn.close()
else:
config = self.db_config.copy()
config.pop('dbname', None)
config['dbname'] = 'postgres'
cleanup_conn = psycopg.connect(**config)
cleanup_conn.autocommit = True
cursor = cleanup_conn.cursor()
cursor.execute(f'DROP DATABASE IF EXISTS "{test_db_name}"')
cursor.close()
cleanup_conn.close()
except Exception as e:
self.logger.warning(f"Failed to clean up test database {test_db_name}: {e}")
def _compare_schemas(self, expected: Dict[str, TableInfo], actual: Dict[str, TableInfo]) -> ValidationResult:
"""Compare expected and actual database schemas"""
expected_tables = set(expected.keys())
actual_tables = set(actual.keys())
missing_tables = list(expected_tables - actual_tables)
extra_tables = list(actual_tables - expected_tables)
table_differences = {}
missing_indexes = []
extra_indexes = []
missing_constraints = []
extra_constraints = []
column_differences = {}
# Compare common tables
common_tables = expected_tables & actual_tables
for table_name in common_tables:
expected_table = expected[table_name]
actual_table = actual[table_name]
# Compare columns
table_col_diffs = self._compare_columns(expected_table.columns, actual_table.columns)
if table_col_diffs:
column_differences[table_name] = table_col_diffs
# Compare indexes
expected_indexes = set(expected_table.indexes.keys())
actual_indexes = set(actual_table.indexes.keys())
for missing_idx in expected_indexes - actual_indexes:
missing_indexes.append((table_name, missing_idx))
for extra_idx in actual_indexes - expected_indexes:
extra_indexes.append((table_name, extra_idx))
# Compare constraints
expected_constraints = set(expected_table.constraints.keys())
actual_constraints = set(actual_table.constraints.keys())
for missing_const in expected_constraints - actual_constraints:
missing_constraints.append((table_name, missing_const))
for extra_const in actual_constraints - expected_constraints:
extra_constraints.append((table_name, extra_const))
# Only fail on critical issues:
# - Missing tables (CRITICAL)
# - Missing columns (CRITICAL)
# Extra tables, extra columns, and type differences are warnings only
critical_issues = []
critical_issues.extend(missing_tables)
# Check for missing columns (critical) - but only in expected tables
for table, col_diffs in column_differences.items():
# Skip extra tables entirely - they shouldn't be validated
if table in extra_tables:
continue
for col, diff in col_diffs.items():
if diff['status'] == 'missing':
critical_issues.append(f"missing column {col} in table {table}")
is_valid = len(critical_issues) == 0
return ValidationResult(
is_valid=is_valid,
missing_tables=missing_tables,
extra_tables=extra_tables,
table_differences=table_differences,
missing_indexes=missing_indexes,
extra_indexes=extra_indexes,
missing_constraints=missing_constraints,
extra_constraints=extra_constraints,
column_differences=column_differences
)
def _compare_columns(self, expected: Dict[str, Dict[str, Any]], actual: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
"""Compare column definitions between expected and actual"""
differences = {}
expected_cols = set(expected.keys())
actual_cols = set(actual.keys())
# Missing columns
for missing_col in expected_cols - actual_cols:
differences[missing_col] = {'status': 'missing', 'expected': expected[missing_col]}
# Extra columns
for extra_col in actual_cols - expected_cols:
differences[extra_col] = {'status': 'extra', 'actual': actual[extra_col]}
# Different columns
for col_name in expected_cols & actual_cols:
expected_col = expected[col_name]
actual_col = actual[col_name]
col_diffs = {}
for key in expected_col:
if key in actual_col and expected_col[key] != actual_col[key]:
col_diffs[key] = {'expected': expected_col[key], 'actual': actual_col[key]}
if col_diffs:
differences[col_name] = {'status': 'different', 'differences': col_diffs}
return differences
def print_validation_report(result: ValidationResult):
"""Print a detailed validation report"""
print("=" * 80)
print("DATABASE VALIDATION REPORT")
print("=" * 80)
# Count critical vs warning issues
critical_issues = []
warning_issues = []
# Missing tables are critical
critical_issues.extend(result.missing_tables)
# Missing columns are critical, others are warnings
for table, col_diffs in result.column_differences.items():
for col, diff in col_diffs.items():
if diff['status'] == 'missing':
critical_issues.append(f"Missing column {col} in table {table}")
else:
warning_issues.append((table, col, diff))
# Extra tables are warnings
warning_issues.extend([('EXTRA_TABLE', table, None) for table in result.extra_tables])
if result.is_valid:
if warning_issues:
print("✅ DATABASE IS VALID - No critical issues found!")
print("⚠️ Some warnings exist but don't affect functionality")
else:
print("✅ DATABASE IS PERFECT - All checks passed!")
else:
print("❌ DATABASE VALIDATION FAILED - Critical issues found")
print()
# Show critical issues
if critical_issues:
print("🔴 CRITICAL ISSUES (MUST BE FIXED):")
if result.missing_tables:
print(" Missing Tables:")
for table in result.missing_tables:
print(f" - {table}")
# Show missing columns
for table, col_diffs in result.column_differences.items():
missing_cols = [col for col, diff in col_diffs.items() if diff['status'] == 'missing']
if missing_cols:
print(f" Missing Columns in {table}:")
for col in missing_cols:
print(f" - {col}")
print()
# Show warnings
if warning_issues:
print("⚠️ WARNINGS (ACCEPTABLE DIFFERENCES):")
if result.extra_tables:
print(" Extra Tables (ignored):")
for table in result.extra_tables:
print(f" - {table}")
# Show column warnings
for table, col_diffs in result.column_differences.items():
table_warnings = []
for col, diff in col_diffs.items():
if diff['status'] == 'extra':
table_warnings.append(f"Extra column: {col}")
elif diff['status'] == 'different':
details = []
for key, values in diff['differences'].items():
details.append(f"{key}: {values}")
table_warnings.append(f"Different column {col}: {', '.join(details)}")
if table_warnings:
print(f" Table {table}:")
for warning in table_warnings:
print(f" - {warning}")
print()
if result.missing_indexes:
print("🟡 MISSING INDEXES:")
for table, index in result.missing_indexes:
print(f" - {table}.{index}")
print()
if result.extra_indexes:
print("🟡 EXTRA INDEXES:")
for table, index in result.extra_indexes:
print(f" - {table}.{index}")
print()
if result.missing_constraints:
print("🟡 MISSING CONSTRAINTS:")
for table, constraint in result.missing_constraints:
print(f" - {table}.{constraint}")
print()
if result.extra_constraints:
print("🟡 EXTRA CONSTRAINTS:")
for table, constraint in result.extra_constraints:
print(f" - {table}.{constraint}")
print()
def main():
"""Main function"""
parser = argparse.ArgumentParser(description='Validate PinePods database schema')
parser.add_argument('--db-type', required=True, choices=['mysql', 'mariadb', 'postgresql'], help='Database type')
parser.add_argument('--db-host', required=True, help='Database host')
parser.add_argument('--db-port', required=True, type=int, help='Database port')
parser.add_argument('--db-user', required=True, help='Database user')
parser.add_argument('--db-password', required=True, help='Database password')
parser.add_argument('--db-name', required=True, help='Database name')
parser.add_argument('--verbose', '-v', action='store_true', help='Enable verbose logging')
args = parser.parse_args()
# Set up logging
level = logging.DEBUG if args.verbose else logging.INFO
logging.basicConfig(level=level, format='%(asctime)s - %(levelname)s - %(message)s')
# Build database config
if args.db_type in ['mysql', 'mariadb']:
db_config = {
'host': args.db_host,
'port': args.db_port,
'user': args.db_user,
'password': args.db_password,
'database': args.db_name,
'charset': 'utf8mb4',
'collation': 'utf8mb4_unicode_ci'
}
else: # postgresql
db_config = {
'host': args.db_host,
'port': args.db_port,
'user': args.db_user,
'password': args.db_password,
'dbname': args.db_name
}
try:
# Create validator and run validation
validator = DatabaseValidator(args.db_type, db_config)
result = validator.validate_database()
# Print report
print_validation_report(result)
# Exit with appropriate code
sys.exit(0 if result.is_valid else 1)
except Exception as e:
logging.error(f"Validation failed with error: {e}")
if args.verbose:
import traceback
traceback.print_exc()
sys.exit(2)
if __name__ == '__main__':
main()