added cargo files
This commit is contained in:
0
PinePods-0.8.2/database_functions/__init__.py
Normal file
0
PinePods-0.8.2/database_functions/__init__.py
Normal file
137
PinePods-0.8.2/database_functions/migrate.py
Executable file
137
PinePods-0.8.2/database_functions/migrate.py
Executable file
@@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Database Migration Runner for PinePods
|
||||
|
||||
This script can be run standalone to apply database migrations.
|
||||
Useful for updating existing installations.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add pinepods to path
|
||||
pinepods_path = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(pinepods_path))
|
||||
|
||||
|
||||
def run_migrations(target_version=None, validate_only=False):
|
||||
"""Run database migrations"""
|
||||
try:
|
||||
# Import migration system
|
||||
import database_functions.migration_definitions
|
||||
from database_functions.migrations import get_migration_manager, run_all_migrations
|
||||
|
||||
# Register all migrations
|
||||
database_functions.migration_definitions.register_all_migrations()
|
||||
|
||||
# Get migration manager
|
||||
manager = get_migration_manager()
|
||||
|
||||
if validate_only:
|
||||
logger.info("Validating existing migrations...")
|
||||
success = manager.validate_migrations()
|
||||
if success:
|
||||
logger.info("All migrations validated successfully")
|
||||
else:
|
||||
logger.error("Migration validation failed")
|
||||
return success
|
||||
|
||||
# Show current state
|
||||
applied = manager.get_applied_migrations()
|
||||
logger.info(f"Currently applied migrations: {len(applied)}")
|
||||
for version in applied:
|
||||
logger.info(f" - {version}")
|
||||
|
||||
# Run migrations
|
||||
logger.info("Starting migration process...")
|
||||
success = run_all_migrations()
|
||||
|
||||
if success:
|
||||
logger.info("All migrations completed successfully")
|
||||
else:
|
||||
logger.error("Migration process failed")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Migration failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def list_migrations():
|
||||
"""List all available migrations"""
|
||||
try:
|
||||
import database_functions.migration_definitions
|
||||
from database_functions.migrations import get_migration_manager
|
||||
|
||||
# Register migrations
|
||||
database_functions.migration_definitions.register_all_migrations()
|
||||
|
||||
# Get manager and list migrations
|
||||
manager = get_migration_manager()
|
||||
applied = set(manager.get_applied_migrations())
|
||||
|
||||
logger.info("Available migrations:")
|
||||
for version, migration in sorted(manager.migrations.items()):
|
||||
status = "APPLIED" if version in applied else "PENDING"
|
||||
logger.info(f" {version} - {migration.name} [{status}]")
|
||||
logger.info(f" {migration.description}")
|
||||
if migration.requires:
|
||||
logger.info(f" Requires: {', '.join(migration.requires)}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list migrations: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Main CLI interface"""
|
||||
parser = argparse.ArgumentParser(description="PinePods Database Migration Tool")
|
||||
parser.add_argument(
|
||||
"command",
|
||||
choices=["migrate", "list", "validate"],
|
||||
help="Command to execute"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target",
|
||||
help="Target migration version (migrate only)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", "-v",
|
||||
action="store_true",
|
||||
help="Enable verbose logging"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.verbose:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
# Execute command
|
||||
if args.command == "migrate":
|
||||
success = run_migrations(args.target)
|
||||
elif args.command == "list":
|
||||
success = list_migrations()
|
||||
elif args.command == "validate":
|
||||
success = run_migrations(validate_only=True)
|
||||
else:
|
||||
logger.error(f"Unknown command: {args.command}")
|
||||
success = False
|
||||
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
3789
PinePods-0.8.2/database_functions/migration_definitions.py
Normal file
3789
PinePods-0.8.2/database_functions/migration_definitions.py
Normal file
File diff suppressed because it is too large
Load Diff
530
PinePods-0.8.2/database_functions/migrations.py
Normal file
530
PinePods-0.8.2/database_functions/migrations.py
Normal file
@@ -0,0 +1,530 @@
|
||||
"""
|
||||
Database Migration System for PinePods
|
||||
|
||||
This module provides a robust, idempotent migration framework that tracks
|
||||
applied migrations and ensures database schema changes are applied safely.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, List, Optional, Callable, Any
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
|
||||
# Add pinepods to path for imports
|
||||
sys.path.append('/pinepods')
|
||||
|
||||
# Database imports
|
||||
try:
|
||||
import psycopg
|
||||
POSTGRES_AVAILABLE = True
|
||||
except ImportError:
|
||||
POSTGRES_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import mariadb as mysql_connector
|
||||
MYSQL_AVAILABLE = True
|
||||
except ImportError:
|
||||
try:
|
||||
import mysql.connector
|
||||
MYSQL_AVAILABLE = True
|
||||
except ImportError:
|
||||
MYSQL_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Migration:
|
||||
"""Represents a single database migration"""
|
||||
version: str
|
||||
name: str
|
||||
description: str
|
||||
postgres_sql: Optional[str] = None
|
||||
mysql_sql: Optional[str] = None
|
||||
python_func: Optional[Callable] = None
|
||||
requires: List[str] = None # List of migration versions this depends on
|
||||
|
||||
def __post_init__(self):
|
||||
if self.requires is None:
|
||||
self.requires = []
|
||||
|
||||
|
||||
class DatabaseMigrationManager:
|
||||
"""Manages database migrations with support for PostgreSQL and MySQL/MariaDB"""
|
||||
|
||||
def __init__(self, db_type: str, connection_params: Dict[str, Any]):
|
||||
self.db_type = db_type.lower()
|
||||
self.connection_params = connection_params
|
||||
self.migrations: Dict[str, Migration] = {}
|
||||
self._connection = None
|
||||
|
||||
# Validate database type
|
||||
if self.db_type not in ['postgresql', 'postgres', 'mariadb', 'mysql']:
|
||||
raise ValueError(f"Unsupported database type: {db_type}")
|
||||
|
||||
# Normalize database type
|
||||
if self.db_type in ['postgres', 'postgresql']:
|
||||
self.db_type = 'postgresql'
|
||||
elif self.db_type in ['mysql', 'mariadb']:
|
||||
self.db_type = 'mysql'
|
||||
|
||||
def get_connection(self):
|
||||
"""Get database connection based on type"""
|
||||
if self._connection:
|
||||
return self._connection
|
||||
|
||||
if self.db_type == 'postgresql':
|
||||
if not POSTGRES_AVAILABLE:
|
||||
raise ImportError("psycopg not available for PostgreSQL connections")
|
||||
self._connection = psycopg.connect(**self.connection_params)
|
||||
elif self.db_type == 'mysql':
|
||||
if not MYSQL_AVAILABLE:
|
||||
raise ImportError("MariaDB/MySQL connector not available for MySQL connections")
|
||||
# Use MariaDB connector parameters
|
||||
mysql_params = self.connection_params.copy()
|
||||
# Convert mysql.connector parameter names to mariadb parameter names
|
||||
if 'connection_timeout' in mysql_params:
|
||||
mysql_params['connect_timeout'] = mysql_params.pop('connection_timeout')
|
||||
if 'charset' in mysql_params:
|
||||
mysql_params.pop('charset') # MariaDB connector doesn't use charset parameter
|
||||
if 'collation' in mysql_params:
|
||||
mysql_params.pop('collation') # MariaDB connector doesn't use collation parameter
|
||||
self._connection = mysql_connector.connect(**mysql_params)
|
||||
|
||||
return self._connection
|
||||
|
||||
def close_connection(self):
|
||||
"""Close database connection"""
|
||||
if self._connection:
|
||||
self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
def register_migration(self, migration: Migration):
|
||||
"""Register a migration to be tracked"""
|
||||
self.migrations[migration.version] = migration
|
||||
logger.info(f"Registered migration {migration.version}: {migration.name}")
|
||||
|
||||
def create_migration_table(self):
|
||||
"""Create the migrations tracking table if it doesn't exist"""
|
||||
conn = self.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
if self.db_type == 'postgresql':
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS "schema_migrations" (
|
||||
version VARCHAR(255) PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
checksum VARCHAR(64) NOT NULL,
|
||||
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
execution_time_ms INTEGER
|
||||
)
|
||||
""")
|
||||
else: # mysql
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
version VARCHAR(255) PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
description TEXT,
|
||||
checksum VARCHAR(64) NOT NULL,
|
||||
applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
execution_time_ms INTEGER
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
logger.info("Migration tracking table created/verified")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create migration table: {e}")
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def get_applied_migrations(self) -> List[str]:
|
||||
"""Get list of applied migration versions"""
|
||||
conn = self.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
table_name = '"schema_migrations"' if self.db_type == 'postgresql' else 'schema_migrations'
|
||||
cursor.execute(f"SELECT version FROM {table_name} ORDER BY applied_at")
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
except Exception as e:
|
||||
# If table doesn't exist, return empty list
|
||||
logger.warning(f"Could not get applied migrations: {e}")
|
||||
return []
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def calculate_migration_checksum(self, migration: Migration) -> str:
|
||||
"""Calculate checksum for migration content"""
|
||||
content = ""
|
||||
if migration.postgres_sql and self.db_type == 'postgresql':
|
||||
content += migration.postgres_sql
|
||||
elif migration.mysql_sql and self.db_type == 'mysql':
|
||||
content += migration.mysql_sql
|
||||
|
||||
if migration.python_func:
|
||||
content += migration.python_func.__code__.co_code.hex()
|
||||
|
||||
return hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
def record_migration(self, migration: Migration, execution_time_ms: int):
|
||||
"""Record a migration as applied"""
|
||||
conn = self.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
checksum = self.calculate_migration_checksum(migration)
|
||||
table_name = '"schema_migrations"' if self.db_type == 'postgresql' else 'schema_migrations'
|
||||
|
||||
if self.db_type == 'postgresql':
|
||||
cursor.execute(f"""
|
||||
INSERT INTO {table_name} (version, name, description, checksum, execution_time_ms)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
ON CONFLICT (version) DO NOTHING
|
||||
""", (migration.version, migration.name, migration.description, checksum, execution_time_ms))
|
||||
else: # mysql
|
||||
cursor.execute(f"""
|
||||
INSERT IGNORE INTO {table_name} (version, name, description, checksum, execution_time_ms)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
""", (migration.version, migration.name, migration.description, checksum, execution_time_ms))
|
||||
|
||||
conn.commit()
|
||||
logger.info(f"Recorded migration {migration.version} as applied")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to record migration {migration.version}: {e}")
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def check_dependencies(self, migration: Migration, applied_migrations: List[str]) -> bool:
|
||||
"""Check if migration dependencies are satisfied"""
|
||||
for required_version in migration.requires:
|
||||
if required_version not in applied_migrations:
|
||||
logger.error(f"Migration {migration.version} requires {required_version} but it's not applied")
|
||||
return False
|
||||
return True
|
||||
|
||||
def execute_migration(self, migration: Migration) -> bool:
|
||||
"""Execute a single migration"""
|
||||
start_time = datetime.now()
|
||||
conn = self.get_connection()
|
||||
|
||||
try:
|
||||
# Choose appropriate SQL based on database type
|
||||
sql = None
|
||||
if self.db_type == 'postgresql' and migration.postgres_sql:
|
||||
sql = migration.postgres_sql
|
||||
elif self.db_type == 'mysql' and migration.mysql_sql:
|
||||
sql = migration.mysql_sql
|
||||
|
||||
# Execute SQL if available
|
||||
if sql:
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
# Split and execute multiple statements
|
||||
statements = [stmt.strip() for stmt in sql.split(';') if stmt.strip()]
|
||||
for statement in statements:
|
||||
cursor.execute(statement)
|
||||
conn.commit()
|
||||
logger.info(f"Executed SQL for migration {migration.version}")
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
# Execute Python function if available (this is the main path for our migrations)
|
||||
if migration.python_func:
|
||||
try:
|
||||
migration.python_func(conn, self.db_type)
|
||||
conn.commit()
|
||||
logger.info(f"Executed Python function for migration {migration.version}")
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
raise
|
||||
|
||||
# Record successful migration
|
||||
execution_time = int((datetime.now() - start_time).total_seconds() * 1000)
|
||||
self.record_migration(migration, execution_time)
|
||||
|
||||
logger.info(f"Successfully applied migration {migration.version}: {migration.name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute migration {migration.version}: {e}")
|
||||
try:
|
||||
conn.rollback()
|
||||
except:
|
||||
pass # Connection might already be closed
|
||||
return False
|
||||
|
||||
def detect_existing_schema(self) -> List[str]:
|
||||
"""Detect which migrations have already been applied based on existing schema"""
|
||||
conn = self.get_connection()
|
||||
cursor = conn.cursor()
|
||||
applied = []
|
||||
|
||||
try:
|
||||
# Check for tables that indicate migrations have been applied
|
||||
checks = {
|
||||
"001": ['"Users"', '"OIDCProviders"', '"APIKeys"', '"RssKeys"'],
|
||||
"002": ['"AppSettings"', '"EmailSettings"'],
|
||||
"003": ['"UserStats"', '"UserSettings"'],
|
||||
"005": ['"Podcasts"', '"Episodes"', '"YouTubeVideos"'],
|
||||
"006": ['"UserEpisodeHistory"', '"UserVideoHistory"'],
|
||||
"007": ['"EpisodeQueue"', '"SavedEpisodes"', '"DownloadedEpisodes"']
|
||||
}
|
||||
|
||||
for version, tables in checks.items():
|
||||
all_exist = True
|
||||
for table in tables:
|
||||
if self.db_type == 'postgresql':
|
||||
cursor.execute("""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_schema = 'public' AND table_name = %s
|
||||
)
|
||||
""", (table.strip('"'),))
|
||||
else: # mysql
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE() AND table_name = %s
|
||||
""", (table,))
|
||||
|
||||
exists = cursor.fetchone()[0]
|
||||
if not exists:
|
||||
all_exist = False
|
||||
break
|
||||
|
||||
if all_exist:
|
||||
applied.append(version)
|
||||
logger.info(f"Detected existing schema for migration {version}")
|
||||
|
||||
# Migration 004 is harder to detect, assume it's applied if 001-003 are
|
||||
if "001" in applied and "003" in applied and "004" not in applied:
|
||||
# Check if background_tasks user exists
|
||||
table_name = '"Users"' if self.db_type == 'postgresql' else 'Users'
|
||||
cursor.execute(f"SELECT COUNT(*) FROM {table_name} WHERE Username = %s", ('background_tasks',))
|
||||
if cursor.fetchone()[0] > 0:
|
||||
applied.append("004")
|
||||
logger.info("Detected existing schema for migration 004")
|
||||
|
||||
# Check for gpodder tables - if ANY exist, ALL gpodder migrations are applied
|
||||
# (since they were created by the Go gpodder-api service and haven't changed)
|
||||
gpodder_indicator_tables = ['"GpodderSyncMigrations"', '"GpodderSyncDeviceState"',
|
||||
'"GpodderSyncSubscriptions"', '"GpodderSyncSettings"',
|
||||
'"GpodderSessions"', '"GpodderSyncState"']
|
||||
gpodder_migration_versions = ["100", "101", "102", "103", "104"]
|
||||
|
||||
gpodder_tables_exist = False
|
||||
for table in gpodder_indicator_tables:
|
||||
table_name = table.strip('"')
|
||||
if self.db_type == 'postgresql':
|
||||
cursor.execute("""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_schema = 'public' AND table_name = %s
|
||||
)
|
||||
""", (table_name,))
|
||||
else: # mysql
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE() AND table_name = %s
|
||||
""", (table_name,))
|
||||
|
||||
if cursor.fetchone()[0]:
|
||||
gpodder_tables_exist = True
|
||||
break
|
||||
|
||||
if gpodder_tables_exist:
|
||||
for version in gpodder_migration_versions:
|
||||
if version not in applied:
|
||||
applied.append(version)
|
||||
logger.info(f"Detected existing gpodder tables, marking migration {version} as applied")
|
||||
|
||||
# Check for PeopleEpisodes_backup table separately (migration 104)
|
||||
backup_table = "PeopleEpisodes_backup"
|
||||
if self.db_type == 'postgresql':
|
||||
cursor.execute("""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_schema = 'public' AND table_name = %s
|
||||
)
|
||||
""", (backup_table,))
|
||||
else: # mysql
|
||||
cursor.execute("""
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE() AND table_name = %s
|
||||
""", (backup_table,))
|
||||
|
||||
if cursor.fetchone()[0] and "104" not in applied:
|
||||
applied.append("104")
|
||||
logger.info("Detected existing PeopleEpisodes_backup table, marking migration 104 as applied")
|
||||
|
||||
return applied
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error detecting existing schema: {e}")
|
||||
return []
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def run_migrations(self, target_version: Optional[str] = None) -> bool:
|
||||
"""Run all pending migrations up to target version"""
|
||||
try:
|
||||
# Create migration table
|
||||
self.create_migration_table()
|
||||
|
||||
# Get applied migrations
|
||||
applied_migrations = self.get_applied_migrations()
|
||||
logger.info(f"Found {len(applied_migrations)} applied migrations")
|
||||
|
||||
# If no migrations are recorded but we have existing schema, detect what's there
|
||||
if not applied_migrations:
|
||||
detected_migrations = self.detect_existing_schema()
|
||||
if detected_migrations:
|
||||
logger.info(f"Detected existing schema, marking {len(detected_migrations)} migrations as applied")
|
||||
# Record detected migrations without executing them
|
||||
for version in detected_migrations:
|
||||
if version in self.migrations:
|
||||
migration = self.migrations[version]
|
||||
self.record_migration(migration, 0) # 0ms execution time for pre-existing
|
||||
|
||||
# Refresh applied migrations list
|
||||
applied_migrations = self.get_applied_migrations()
|
||||
|
||||
# Sort migrations by version
|
||||
pending_migrations = []
|
||||
for version, migration in sorted(self.migrations.items()):
|
||||
if version not in applied_migrations:
|
||||
if target_version and version > target_version:
|
||||
continue
|
||||
pending_migrations.append(migration)
|
||||
|
||||
if not pending_migrations:
|
||||
logger.info("No pending migrations to apply")
|
||||
return True
|
||||
|
||||
logger.info(f"Found {len(pending_migrations)} pending migrations")
|
||||
|
||||
# Execute pending migrations
|
||||
for migration in pending_migrations:
|
||||
# Check dependencies
|
||||
if not self.check_dependencies(migration, applied_migrations):
|
||||
logger.error(f"Dependency check failed for migration {migration.version}")
|
||||
return False
|
||||
|
||||
# Execute migration
|
||||
if not self.execute_migration(migration):
|
||||
logger.error(f"Failed to execute migration {migration.version}")
|
||||
return False
|
||||
|
||||
# Add to applied list for dependency checking
|
||||
applied_migrations.append(migration.version)
|
||||
|
||||
logger.info("All migrations applied successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Migration run failed: {e}")
|
||||
return False
|
||||
finally:
|
||||
self.close_connection()
|
||||
|
||||
def validate_migrations(self) -> bool:
|
||||
"""Validate that applied migrations haven't changed"""
|
||||
try:
|
||||
conn = self.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
table_name = '"schema_migrations"' if self.db_type == 'postgresql' else 'schema_migrations'
|
||||
cursor.execute(f"SELECT version, checksum FROM {table_name}")
|
||||
applied_checksums = dict(cursor.fetchall())
|
||||
|
||||
validation_errors = []
|
||||
for version, stored_checksum in applied_checksums.items():
|
||||
if version in self.migrations:
|
||||
current_checksum = self.calculate_migration_checksum(self.migrations[version])
|
||||
if current_checksum != stored_checksum:
|
||||
validation_errors.append(f"Migration {version} checksum mismatch")
|
||||
|
||||
if validation_errors:
|
||||
for error in validation_errors:
|
||||
logger.error(error)
|
||||
return False
|
||||
|
||||
logger.info("All migration checksums validated successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Migration validation failed: {e}")
|
||||
return False
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
# Migration manager instance (singleton pattern)
|
||||
_migration_manager: Optional[DatabaseMigrationManager] = None
|
||||
|
||||
|
||||
def get_migration_manager() -> DatabaseMigrationManager:
|
||||
"""Get the global migration manager instance"""
|
||||
global _migration_manager
|
||||
|
||||
if _migration_manager is None:
|
||||
# Get database configuration from environment
|
||||
db_type = os.environ.get("DB_TYPE", "postgresql")
|
||||
|
||||
if db_type.lower() in ['postgresql', 'postgres']:
|
||||
connection_params = {
|
||||
'host': os.environ.get("DB_HOST", "127.0.0.1"),
|
||||
'port': int(os.environ.get("DB_PORT", "5432")),
|
||||
'user': os.environ.get("DB_USER", "postgres"),
|
||||
'password': os.environ.get("DB_PASSWORD", "password"),
|
||||
'dbname': os.environ.get("DB_NAME", "pinepods_database")
|
||||
}
|
||||
else: # mysql/mariadb
|
||||
connection_params = {
|
||||
'host': os.environ.get("DB_HOST", "127.0.0.1"),
|
||||
'port': int(os.environ.get("DB_PORT", "3306")),
|
||||
'user': os.environ.get("DB_USER", "root"),
|
||||
'password': os.environ.get("DB_PASSWORD", "password"),
|
||||
'database': os.environ.get("DB_NAME", "pinepods_database"),
|
||||
'charset': 'utf8mb4',
|
||||
'collation': 'utf8mb4_general_ci'
|
||||
}
|
||||
|
||||
_migration_manager = DatabaseMigrationManager(db_type, connection_params)
|
||||
|
||||
return _migration_manager
|
||||
|
||||
|
||||
def register_migration(version: str, name: str, description: str, **kwargs):
|
||||
"""Decorator to register a migration"""
|
||||
def decorator(func):
|
||||
migration = Migration(
|
||||
version=version,
|
||||
name=name,
|
||||
description=description,
|
||||
python_func=func,
|
||||
**kwargs
|
||||
)
|
||||
get_migration_manager().register_migration(migration)
|
||||
return func
|
||||
return decorator
|
||||
|
||||
|
||||
def run_all_migrations() -> bool:
|
||||
"""Run all registered migrations"""
|
||||
manager = get_migration_manager()
|
||||
return manager.run_migrations()
|
||||
795
PinePods-0.8.2/database_functions/tasks.py
Normal file
795
PinePods-0.8.2/database_functions/tasks.py
Normal file
@@ -0,0 +1,795 @@
|
||||
# tasks.py - Define Celery tasks with Valkey as broker
|
||||
from celery import Celery
|
||||
import time
|
||||
import os
|
||||
import asyncio
|
||||
import datetime
|
||||
import requests
|
||||
from threading import Thread
|
||||
import json
|
||||
import sys
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
# Make sure pinepods is in the Python path
|
||||
sys.path.append('/pinepods')
|
||||
|
||||
database_type = str(os.getenv('DB_TYPE', 'mariadb'))
|
||||
|
||||
class Web_Key:
|
||||
def __init__(self):
|
||||
self.web_key = None
|
||||
|
||||
def get_web_key(self, cnx):
|
||||
# Import only when needed to avoid circular imports
|
||||
from database_functions.functions import get_web_key as get_key
|
||||
self.web_key = get_key(cnx, database_type)
|
||||
return self.web_key
|
||||
|
||||
base_webkey = Web_Key()
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger("celery_tasks")
|
||||
|
||||
# Import the WebSocket manager directly from clientapi
|
||||
try:
|
||||
from clients.clientapi import manager as websocket_manager
|
||||
print("Successfully imported WebSocket manager from clientapi")
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import WebSocket manager: {e}")
|
||||
websocket_manager = None
|
||||
|
||||
# Create a dedicated event loop thread for async operations
|
||||
_event_loop = None
|
||||
_event_loop_thread = None
|
||||
|
||||
def start_background_loop():
|
||||
global _event_loop, _event_loop_thread
|
||||
|
||||
# Only start if not already running
|
||||
if _event_loop is None:
|
||||
# Function to run event loop in background thread
|
||||
def run_event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
global _event_loop
|
||||
_event_loop = loop
|
||||
loop.run_forever()
|
||||
|
||||
# Start the background thread
|
||||
_event_loop_thread = Thread(target=run_event_loop, daemon=True)
|
||||
_event_loop_thread.start()
|
||||
|
||||
# Wait a moment for the loop to start
|
||||
time.sleep(0.1)
|
||||
print("Started background event loop for WebSocket broadcasts")
|
||||
|
||||
# Start the event loop when this module is imported
|
||||
start_background_loop()
|
||||
|
||||
# Use the existing Valkey connection for Celery
|
||||
valkey_host = os.environ.get("VALKEY_HOST", "localhost")
|
||||
valkey_port = os.environ.get("VALKEY_PORT", "6379")
|
||||
broker_url = f"redis://{valkey_host}:{valkey_port}/0"
|
||||
backend_url = f"redis://{valkey_host}:{valkey_port}/0"
|
||||
|
||||
# Initialize Celery with Valkey as broker and result backend
|
||||
celery_app = Celery('pinepods',
|
||||
broker=broker_url,
|
||||
backend=backend_url)
|
||||
|
||||
# Configure Celery for best performance with downloads
|
||||
celery_app.conf.update(
|
||||
worker_concurrency=3, # Limit to 3 concurrent downloads per worker
|
||||
task_acks_late=True, # Only acknowledge tasks after they're done
|
||||
task_time_limit=1800, # 30 minutes time limit
|
||||
task_soft_time_limit=1500, # 25 minutes soft time limit
|
||||
worker_prefetch_multiplier=1, # Don't prefetch more tasks than workers
|
||||
)
|
||||
|
||||
# Task status tracking in Valkey for all types of tasks
|
||||
class TaskManager:
|
||||
def __init__(self):
|
||||
from database_functions.valkey_client import valkey_client
|
||||
self.valkey_client = valkey_client
|
||||
|
||||
def register_task(self, task_id: str, task_type: str, user_id: int, item_id: Optional[int] = None,
|
||||
details: Optional[Dict[str, Any]] = None):
|
||||
"""Register any Celery task for tracking"""
|
||||
task_data = {
|
||||
"task_id": task_id,
|
||||
"user_id": user_id,
|
||||
"type": task_type,
|
||||
"item_id": item_id,
|
||||
"progress": 0.0,
|
||||
"status": "PENDING",
|
||||
"details": details or {},
|
||||
"started_at": datetime.datetime.now().isoformat()
|
||||
}
|
||||
|
||||
self.valkey_client.set(f"task:{task_id}", json.dumps(task_data))
|
||||
# Set TTL for 24 hours
|
||||
self.valkey_client.expire(f"task:{task_id}", 86400)
|
||||
|
||||
# Add to user's active tasks list
|
||||
self._add_to_user_tasks(user_id, task_id)
|
||||
|
||||
# Try to broadcast the update if the WebSocket module is available
|
||||
try:
|
||||
self._broadcast_update(task_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error broadcasting task update: {e}")
|
||||
|
||||
def update_task(self, task_id: str, progress: float = None, status: str = None,
|
||||
details: Dict[str, Any] = None):
|
||||
"""Update any task's status and progress"""
|
||||
task_json = self.valkey_client.get(f"task:{task_id}")
|
||||
if task_json:
|
||||
task = json.loads(task_json)
|
||||
if progress is not None:
|
||||
task["progress"] = progress
|
||||
if status:
|
||||
task["status"] = status
|
||||
if details:
|
||||
if "details" not in task:
|
||||
task["details"] = {}
|
||||
task["details"].update(details)
|
||||
|
||||
self.valkey_client.set(f"task:{task_id}", json.dumps(task))
|
||||
|
||||
# Try to broadcast the update
|
||||
try:
|
||||
self._broadcast_update(task_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error broadcasting task update: {e}")
|
||||
|
||||
def complete_task(self, task_id: str, success: bool = True, result: Any = None):
|
||||
"""Mark any task as complete or failed"""
|
||||
task_json = self.valkey_client.get(f"task:{task_id}")
|
||||
if task_json:
|
||||
task = json.loads(task_json)
|
||||
task["progress"] = 100.0 if success else 0.0
|
||||
task["status"] = "SUCCESS" if success else "FAILED"
|
||||
task["completed_at"] = datetime.datetime.now().isoformat()
|
||||
if result is not None:
|
||||
task["result"] = result
|
||||
|
||||
self.valkey_client.set(f"task:{task_id}", json.dumps(task))
|
||||
|
||||
# Try to broadcast the final update
|
||||
try:
|
||||
self._broadcast_update(task_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error broadcasting task update: {e}")
|
||||
|
||||
# Keep completed tasks for 1 hour before expiring
|
||||
self.valkey_client.expire(f"task:{task_id}", 3600)
|
||||
|
||||
# Remove from user's active tasks list after completion
|
||||
if success:
|
||||
self._remove_from_user_tasks(task.get("user_id"), task_id)
|
||||
|
||||
def get_task(self, task_id: str) -> Dict[str, Any]:
|
||||
"""Get any task's details"""
|
||||
task_json = self.valkey_client.get(f"task:{task_id}")
|
||||
if task_json:
|
||||
return json.loads(task_json)
|
||||
return {}
|
||||
|
||||
def get_user_tasks(self, user_id: int) -> List[Dict[str, Any]]:
|
||||
"""Get all active tasks for a user (all types)"""
|
||||
tasks_list_json = self.valkey_client.get(f"user_tasks:{user_id}")
|
||||
result = []
|
||||
|
||||
if tasks_list_json:
|
||||
task_ids = json.loads(tasks_list_json)
|
||||
for task_id in task_ids:
|
||||
task_info = self.get_task(task_id)
|
||||
if task_info:
|
||||
result.append(task_info)
|
||||
|
||||
return result
|
||||
|
||||
def _add_to_user_tasks(self, user_id: int, task_id: str):
|
||||
"""Add a task to the user's active tasks list"""
|
||||
tasks_list_json = self.valkey_client.get(f"user_tasks:{user_id}")
|
||||
if tasks_list_json:
|
||||
tasks_list = json.loads(tasks_list_json)
|
||||
if task_id not in tasks_list:
|
||||
tasks_list.append(task_id)
|
||||
else:
|
||||
tasks_list = [task_id]
|
||||
|
||||
self.valkey_client.set(f"user_tasks:{user_id}", json.dumps(tasks_list))
|
||||
# Set TTL for 7 days
|
||||
self.valkey_client.expire(f"user_tasks:{user_id}", 604800)
|
||||
|
||||
def _remove_from_user_tasks(self, user_id: int, task_id: str):
|
||||
"""Remove a task from the user's active tasks list"""
|
||||
tasks_list_json = self.valkey_client.get(f"user_tasks:{user_id}")
|
||||
if tasks_list_json:
|
||||
tasks_list = json.loads(tasks_list_json)
|
||||
if task_id in tasks_list:
|
||||
tasks_list.remove(task_id)
|
||||
self.valkey_client.set(f"user_tasks:{user_id}", json.dumps(tasks_list))
|
||||
|
||||
# Modified _broadcast_update method to avoid circular imports
|
||||
def _broadcast_update(self, task_id: str):
|
||||
"""Broadcast task update via HTTP endpoint"""
|
||||
# Get task info
|
||||
task_info = self.get_task(task_id)
|
||||
if not task_info or "user_id" not in task_info:
|
||||
return
|
||||
|
||||
user_id = task_info["user_id"]
|
||||
cnx = None
|
||||
|
||||
try:
|
||||
cnx = get_direct_db_connection()
|
||||
|
||||
# Import broadcaster - delay import to avoid circular dependency
|
||||
sys.path.insert(0, '/pinepods/database_functions')
|
||||
try:
|
||||
from websocket_broadcaster import broadcaster
|
||||
except ImportError:
|
||||
try:
|
||||
from database_functions.websocket_broadcaster import broadcaster
|
||||
except ImportError as e:
|
||||
print(f"Cannot import broadcaster from any location: {e}")
|
||||
return
|
||||
|
||||
# Get web key
|
||||
web_key = None
|
||||
try:
|
||||
# Get web key using class method to avoid direct import
|
||||
if not base_webkey.web_key:
|
||||
base_webkey.get_web_key(cnx)
|
||||
web_key = base_webkey.web_key
|
||||
except Exception as e:
|
||||
print(f"Error getting web key: {str(e)}")
|
||||
# Fallback to a direct approach if needed
|
||||
try:
|
||||
from database_functions.functions import get_web_key
|
||||
web_key = get_web_key(cnx, database_type)
|
||||
except Exception as e2:
|
||||
print(f"Fallback web key retrieval failed: {str(e2)}")
|
||||
return
|
||||
|
||||
# Progress and status details for better debugging
|
||||
progress = task_info.get("progress", 0)
|
||||
status = task_info.get("status", "unknown")
|
||||
print(f"Broadcasting task update for user {user_id}, task {task_id}, progress: {progress}, status: {status}")
|
||||
|
||||
# Broadcast the update
|
||||
result = broadcaster.broadcast_task_update(user_id, task_info, web_key)
|
||||
if result:
|
||||
print(f"Successfully broadcast task update for task {task_id}, progress: {progress}%")
|
||||
else:
|
||||
print(f"Failed to broadcast task update for task {task_id}, progress: {progress}%")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in task broadcast setup: {str(e)}")
|
||||
finally:
|
||||
if cnx:
|
||||
# Close direct connection
|
||||
close_direct_db_connection(cnx)
|
||||
|
||||
# Initialize a general task manager
|
||||
task_manager = TaskManager()
|
||||
|
||||
# For backwards compatibility, keep the download_manager name too
|
||||
download_manager = task_manager
|
||||
|
||||
# Function to get all active tasks including both downloads and other task types
|
||||
def get_all_active_tasks(user_id: int) -> List[Dict[str, Any]]:
|
||||
"""Get all active tasks for a user (all types)"""
|
||||
return task_manager.get_user_tasks(user_id)
|
||||
|
||||
# ----------------------
|
||||
# IMPROVED CONNECTION HANDLING
|
||||
# ----------------------
|
||||
|
||||
def get_direct_db_connection():
|
||||
"""
|
||||
Create a direct database connection instead of using the pool
|
||||
This is more reliable for Celery workers to avoid pool exhaustion
|
||||
"""
|
||||
db_host = os.environ.get("DB_HOST", "127.0.0.1")
|
||||
db_port = os.environ.get("DB_PORT", "3306")
|
||||
db_user = os.environ.get("DB_USER", "root")
|
||||
db_password = os.environ.get("DB_PASSWORD", "password")
|
||||
db_name = os.environ.get("DB_NAME", "pypods_database")
|
||||
|
||||
print(f"Creating direct database connection for task")
|
||||
|
||||
if database_type == "postgresql":
|
||||
import psycopg
|
||||
conninfo = f"host={db_host} port={db_port} user={db_user} password={db_password} dbname={db_name}"
|
||||
return psycopg.connect(conninfo)
|
||||
else: # Default to MariaDB/MySQL
|
||||
try:
|
||||
import mariadb as mysql_connector
|
||||
except ImportError:
|
||||
import mysql.connector
|
||||
return mysql_connector.connect(
|
||||
host=db_host,
|
||||
port=db_port,
|
||||
user=db_user,
|
||||
password=db_password,
|
||||
database=db_name
|
||||
)
|
||||
|
||||
def close_direct_db_connection(cnx):
|
||||
"""Close a direct database connection"""
|
||||
if cnx:
|
||||
try:
|
||||
cnx.close()
|
||||
print("Direct database connection closed")
|
||||
except Exception as e:
|
||||
print(f"Error closing direct connection: {str(e)}")
|
||||
|
||||
# Minimal changes to download_podcast_task that should work right away
|
||||
@celery_app.task(bind=True, max_retries=3)
|
||||
def download_podcast_task(self, episode_id: int, user_id: int, database_type: str):
|
||||
"""
|
||||
Celery task to download a podcast episode.
|
||||
Uses retries with exponential backoff for handling transient failures.
|
||||
"""
|
||||
task_id = self.request.id
|
||||
print(f"DOWNLOAD TASK STARTED: ID={task_id}, Episode={episode_id}, User={user_id}")
|
||||
cnx = None
|
||||
|
||||
try:
|
||||
# Get a direct connection to fetch the title first
|
||||
cnx = get_direct_db_connection()
|
||||
cursor = cnx.cursor()
|
||||
|
||||
# Get the episode title and podcast name
|
||||
if database_type == "postgresql":
|
||||
# First try to get both the episode title and podcast name
|
||||
query = '''
|
||||
SELECT e."episodetitle", p."podcastname"
|
||||
FROM "Episodes" e
|
||||
JOIN "Podcasts" p ON e."podcastid" = p."podcastid"
|
||||
WHERE e."episodeid" = %s
|
||||
'''
|
||||
else:
|
||||
query = '''
|
||||
SELECT e.EpisodeTitle, p.PodcastName
|
||||
FROM Episodes e
|
||||
JOIN Podcasts p ON e.PodcastID = p.PodcastID
|
||||
WHERE e.EpisodeID = %s
|
||||
'''
|
||||
|
||||
cursor.execute(query, (episode_id,))
|
||||
result = cursor.fetchone()
|
||||
cursor.close()
|
||||
|
||||
# Extract episode title and podcast name
|
||||
episode_title = None
|
||||
podcast_name = None
|
||||
if result:
|
||||
if isinstance(result, dict):
|
||||
# Dictionary result
|
||||
if "episodetitle" in result: # PostgreSQL lowercase
|
||||
episode_title = result["episodetitle"]
|
||||
podcast_name = result.get("podcastname")
|
||||
else: # MariaDB uppercase
|
||||
episode_title = result["EpisodeTitle"]
|
||||
podcast_name = result.get("PodcastName")
|
||||
else:
|
||||
# Tuple result
|
||||
episode_title = result[0] if len(result) > 0 else None
|
||||
podcast_name = result[1] if len(result) > 1 else None
|
||||
|
||||
# Format a nice display title
|
||||
display_title = "Unknown Episode"
|
||||
if episode_title and episode_title != "None" and episode_title.strip():
|
||||
display_title = episode_title
|
||||
elif podcast_name:
|
||||
display_title = f"{podcast_name} - Episode"
|
||||
else:
|
||||
display_title = f"Episode #{episode_id}"
|
||||
|
||||
print(f"Using display title for episode {episode_id}: {display_title}")
|
||||
|
||||
# Register task with more details
|
||||
task_manager.register_task(
|
||||
task_id=task_id,
|
||||
task_type="podcast_download",
|
||||
user_id=user_id,
|
||||
item_id=episode_id,
|
||||
details={
|
||||
"episode_id": episode_id,
|
||||
"episode_title": display_title,
|
||||
"status_text": f"Preparing to download {display_title}"
|
||||
}
|
||||
)
|
||||
|
||||
# Define a progress callback with the display title
|
||||
def progress_callback(progress, status=None):
|
||||
status_message = ""
|
||||
if status == "DOWNLOADING":
|
||||
status_message = f"Downloading {display_title}"
|
||||
elif status == "PROCESSING":
|
||||
status_message = f"Processing {display_title}"
|
||||
elif status == "FINALIZING":
|
||||
status_message = f"Finalizing {display_title}"
|
||||
|
||||
task_manager.update_task(task_id, progress, status, {
|
||||
"episode_id": episode_id,
|
||||
"episode_title": display_title,
|
||||
"status_text": status_message
|
||||
})
|
||||
|
||||
# Close the connection used for title lookup
|
||||
close_direct_db_connection(cnx)
|
||||
|
||||
# Get a fresh connection for the download
|
||||
cnx = get_direct_db_connection()
|
||||
|
||||
# Import the download function
|
||||
from database_functions.functions import download_podcast
|
||||
|
||||
print(f"Starting download for episode: {episode_id} ({display_title}), user: {user_id}, task: {task_id}")
|
||||
|
||||
# Execute the download with progress reporting
|
||||
success = download_podcast(
|
||||
cnx,
|
||||
database_type,
|
||||
episode_id,
|
||||
user_id,
|
||||
task_id,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
|
||||
# Mark task as complete with a nice message
|
||||
completion_message = f"{'Successfully downloaded' if success else 'Failed to download'} {display_title}"
|
||||
task_manager.complete_task(
|
||||
task_id,
|
||||
success,
|
||||
{
|
||||
"episode_id": episode_id,
|
||||
"episode_title": display_title,
|
||||
"status_text": completion_message
|
||||
}
|
||||
)
|
||||
|
||||
return success
|
||||
except Exception as exc:
|
||||
print(f"Error downloading podcast {episode_id}: {str(exc)}")
|
||||
# Mark task as failed
|
||||
task_manager.complete_task(
|
||||
task_id,
|
||||
False,
|
||||
{
|
||||
"episode_id": episode_id,
|
||||
"episode_title": f"Episode #{episode_id}",
|
||||
"status_text": f"Download failed: {str(exc)}"
|
||||
}
|
||||
)
|
||||
# Retry with exponential backoff (5s, 25s, 125s)
|
||||
countdown = 5 * (2 ** self.request.retries)
|
||||
self.retry(exc=exc, countdown=countdown)
|
||||
finally:
|
||||
# Always close the connection
|
||||
if cnx:
|
||||
close_direct_db_connection(cnx)
|
||||
|
||||
@celery_app.task(bind=True, max_retries=3)
|
||||
def download_youtube_video_task(self, video_id: int, user_id: int, database_type: str):
|
||||
"""
|
||||
Celery task to download a YouTube video.
|
||||
Uses retries with exponential backoff for handling transient failures.
|
||||
"""
|
||||
task_id = self.request.id
|
||||
print(f"YOUTUBE DOWNLOAD TASK STARTED: ID={task_id}, Video={video_id}, User={user_id}")
|
||||
cnx = None
|
||||
|
||||
try:
|
||||
# Get a direct connection to fetch the title first
|
||||
cnx = get_direct_db_connection()
|
||||
cursor = cnx.cursor()
|
||||
|
||||
# Get the video title and channel name
|
||||
if database_type == "postgresql":
|
||||
# First try to get both the video title and channel name
|
||||
query = '''
|
||||
SELECT v."videotitle", p."podcastname"
|
||||
FROM "YouTubeVideos" v
|
||||
JOIN "Podcasts" p ON v."podcastid" = p."podcastid"
|
||||
WHERE v."videoid" = %s
|
||||
'''
|
||||
else:
|
||||
query = '''
|
||||
SELECT v.VideoTitle, p.PodcastName
|
||||
FROM YouTubeVideos v
|
||||
JOIN Podcasts p ON v.PodcastID = p.PodcastID
|
||||
WHERE v.VideoID = %s
|
||||
'''
|
||||
|
||||
cursor.execute(query, (video_id,))
|
||||
result = cursor.fetchone()
|
||||
cursor.close()
|
||||
|
||||
# Extract video title and channel name
|
||||
video_title = None
|
||||
channel_name = None
|
||||
if result:
|
||||
if isinstance(result, dict):
|
||||
# Dictionary result
|
||||
if "videotitle" in result: # PostgreSQL lowercase
|
||||
video_title = result["videotitle"]
|
||||
channel_name = result.get("podcastname")
|
||||
else: # MariaDB uppercase
|
||||
video_title = result["VideoTitle"]
|
||||
channel_name = result.get("PodcastName")
|
||||
else:
|
||||
# Tuple result
|
||||
video_title = result[0] if len(result) > 0 else None
|
||||
channel_name = result[1] if len(result) > 1 else None
|
||||
|
||||
# Format a nice display title
|
||||
display_title = "Unknown Video"
|
||||
if video_title and video_title != "None" and video_title.strip():
|
||||
display_title = video_title
|
||||
elif channel_name:
|
||||
display_title = f"{channel_name} - Video"
|
||||
else:
|
||||
display_title = f"YouTube Video #{video_id}"
|
||||
|
||||
print(f"Using display title for video {video_id}: {display_title}")
|
||||
|
||||
# Register task with more details
|
||||
task_manager.register_task(
|
||||
task_id=task_id,
|
||||
task_type="youtube_download",
|
||||
user_id=user_id,
|
||||
item_id=video_id,
|
||||
details={
|
||||
"item_id": video_id,
|
||||
"item_title": display_title,
|
||||
"status_text": f"Preparing to download {display_title}"
|
||||
}
|
||||
)
|
||||
|
||||
# Close the connection used for title lookup
|
||||
close_direct_db_connection(cnx)
|
||||
|
||||
# Get a fresh connection for the download
|
||||
cnx = get_direct_db_connection()
|
||||
|
||||
# Import the download function
|
||||
from database_functions.functions import download_youtube_video
|
||||
|
||||
print(f"Starting download for YouTube video: {video_id} ({display_title}), user: {user_id}, task: {task_id}")
|
||||
|
||||
# Define a progress callback with the display title
|
||||
def progress_callback(progress, status=None):
|
||||
status_message = ""
|
||||
if status == "DOWNLOADING":
|
||||
status_message = f"Downloading {display_title}"
|
||||
elif status == "PROCESSING":
|
||||
status_message = f"Processing {display_title}"
|
||||
elif status == "FINALIZING":
|
||||
status_message = f"Finalizing {display_title}"
|
||||
|
||||
task_manager.update_task(task_id, progress, status, {
|
||||
"item_id": video_id,
|
||||
"item_title": display_title,
|
||||
"status_text": status_message
|
||||
})
|
||||
|
||||
# Check if the download_youtube_video function accepts progress_callback parameter
|
||||
import inspect
|
||||
try:
|
||||
signature = inspect.signature(download_youtube_video)
|
||||
has_progress_callback = 'progress_callback' in signature.parameters
|
||||
except (TypeError, ValueError):
|
||||
has_progress_callback = False
|
||||
|
||||
# Execute the download with progress callback if supported, otherwise without it
|
||||
if has_progress_callback:
|
||||
success = download_youtube_video(
|
||||
cnx,
|
||||
database_type,
|
||||
video_id,
|
||||
user_id,
|
||||
task_id,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
else:
|
||||
# Call without the progress_callback parameter
|
||||
success = download_youtube_video(
|
||||
cnx,
|
||||
database_type,
|
||||
video_id,
|
||||
user_id,
|
||||
task_id
|
||||
)
|
||||
|
||||
# Since we can't use progress callbacks directly, manually update progress after completion
|
||||
task_manager.update_task(task_id, 100 if success else 0,
|
||||
"SUCCESS" if success else "FAILED",
|
||||
{
|
||||
"item_id": video_id,
|
||||
"item_title": display_title,
|
||||
"status_text": f"{'Download complete' if success else 'Download failed'}"
|
||||
})
|
||||
|
||||
# Mark task as complete with a nice message
|
||||
completion_message = f"{'Successfully downloaded' if success else 'Failed to download'} {display_title}"
|
||||
task_manager.complete_task(
|
||||
task_id,
|
||||
success,
|
||||
{
|
||||
"item_id": video_id,
|
||||
"item_title": display_title,
|
||||
"status_text": completion_message
|
||||
}
|
||||
)
|
||||
|
||||
return success
|
||||
except Exception as exc:
|
||||
print(f"Error downloading YouTube video {video_id}: {str(exc)}")
|
||||
# Mark task as failed but include video title in the details
|
||||
task_manager.complete_task(
|
||||
task_id,
|
||||
False,
|
||||
{
|
||||
"item_id": video_id,
|
||||
"item_title": f"YouTube Video #{video_id}",
|
||||
"status_text": f"Download failed: {str(exc)}"
|
||||
}
|
||||
)
|
||||
# Retry with exponential backoff (5s, 25s, 125s)
|
||||
countdown = 5 * (2 ** self.request.retries)
|
||||
self.retry(exc=exc, countdown=countdown)
|
||||
finally:
|
||||
# Always close the connection
|
||||
if cnx:
|
||||
close_direct_db_connection(cnx)
|
||||
|
||||
@celery_app.task
|
||||
def queue_podcast_downloads(podcast_id: int, user_id: int, database_type: str, is_youtube: bool = False):
|
||||
"""
|
||||
Task to queue individual download tasks for all episodes/videos in a podcast.
|
||||
This adds downloads to the queue in small batches to prevent overwhelming the system.
|
||||
"""
|
||||
cnx = None
|
||||
|
||||
try:
|
||||
# Get a direct connection
|
||||
cnx = get_direct_db_connection()
|
||||
|
||||
from database_functions.functions import (
|
||||
get_episode_ids_for_podcast,
|
||||
get_video_ids_for_podcast,
|
||||
check_downloaded
|
||||
)
|
||||
|
||||
if is_youtube:
|
||||
item_ids = get_video_ids_for_podcast(cnx, database_type, podcast_id)
|
||||
print(f"Queueing {len(item_ids)} YouTube videos for download")
|
||||
|
||||
# Process YouTube items in batches
|
||||
batch_size = 5
|
||||
for i in range(0, len(item_ids), batch_size):
|
||||
batch = item_ids[i:i+batch_size]
|
||||
for item_id in batch:
|
||||
if not check_downloaded(cnx, database_type, user_id, item_id, is_youtube):
|
||||
download_youtube_video_task.delay(item_id, user_id, database_type)
|
||||
|
||||
# Add a small delay between batches
|
||||
if i + batch_size < len(item_ids):
|
||||
time.sleep(2)
|
||||
else:
|
||||
# Get episode IDs (should return dicts with id and title)
|
||||
episodes = get_episode_ids_for_podcast(cnx, database_type, podcast_id)
|
||||
print(f"Queueing {len(episodes)} podcast episodes for download")
|
||||
|
||||
# Process episodes in batches
|
||||
batch_size = 5
|
||||
for i in range(0, len(episodes), batch_size):
|
||||
batch = episodes[i:i+batch_size]
|
||||
|
||||
for episode in batch:
|
||||
# Handle both possible formats (dict or simple ID)
|
||||
if isinstance(episode, dict) and "id" in episode:
|
||||
episode_id = episode["id"]
|
||||
else:
|
||||
# Fall back to treating it as just an ID
|
||||
episode_id = episode
|
||||
|
||||
if not check_downloaded(cnx, database_type, user_id, episode_id, is_youtube):
|
||||
# Pass just the ID, the task will look up the title
|
||||
download_podcast_task.delay(episode_id, user_id, database_type)
|
||||
|
||||
# Add a small delay between batches
|
||||
if i + batch_size < len(episodes):
|
||||
time.sleep(2)
|
||||
|
||||
return f"Queued {len(episodes if not is_youtube else item_ids)} items for download"
|
||||
finally:
|
||||
if cnx:
|
||||
close_direct_db_connection(cnx)
|
||||
|
||||
# Helper task to clean up old download records
|
||||
@celery_app.task
|
||||
def cleanup_old_downloads():
|
||||
"""
|
||||
Periodic task to clean up old download records from Valkey
|
||||
"""
|
||||
from database_functions.valkey_client import valkey_client
|
||||
|
||||
# This would need to be implemented with a scan operation
|
||||
# For simplicity, we rely on Redis/Valkey TTL mechanisms
|
||||
print("Running download cleanup task")
|
||||
|
||||
# Example task for refreshing podcast feeds
|
||||
@celery_app.task(bind=True, max_retries=2)
|
||||
def refresh_feed_task(self, user_id: int, database_type: str):
|
||||
"""
|
||||
Celery task to refresh podcast feeds for a user.
|
||||
"""
|
||||
task_id = self.request.id
|
||||
cnx = None
|
||||
|
||||
try:
|
||||
# Register task
|
||||
task_manager.register_task(
|
||||
task_id=task_id,
|
||||
task_type="feed_refresh",
|
||||
user_id=user_id,
|
||||
details={"description": "Refreshing podcast feeds"}
|
||||
)
|
||||
|
||||
# Get a direct database connection
|
||||
cnx = get_direct_db_connection()
|
||||
|
||||
# Get list of podcasts to refresh
|
||||
# Then update progress as each one completes
|
||||
try:
|
||||
# Here you would have your actual feed refresh implementation
|
||||
# with periodic progress updates
|
||||
task_manager.update_task(task_id, 10, "PROGRESS", {"status_text": "Fetching podcast list"})
|
||||
|
||||
# Simulate feed refresh process with progress updates
|
||||
# Replace with your actual implementation
|
||||
total_podcasts = 10 # Example count
|
||||
for i in range(total_podcasts):
|
||||
# Update progress for each podcast
|
||||
progress = (i + 1) / total_podcasts * 100
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
progress,
|
||||
"PROGRESS",
|
||||
{"status_text": f"Refreshing podcast {i+1}/{total_podcasts}"}
|
||||
)
|
||||
|
||||
# Simulated work - replace with actual refresh logic
|
||||
time.sleep(0.5)
|
||||
|
||||
# Complete the task
|
||||
task_manager.complete_task(task_id, True, {"refreshed_count": total_podcasts})
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
except Exception as exc:
|
||||
print(f"Error refreshing feeds for user {user_id}: {str(exc)}")
|
||||
task_manager.complete_task(task_id, False, {"error": str(exc)})
|
||||
self.retry(exc=exc, countdown=30)
|
||||
finally:
|
||||
# Always close the connection
|
||||
if cnx:
|
||||
close_direct_db_connection(cnx)
|
||||
|
||||
# Simple debug task
|
||||
@celery_app.task
|
||||
def debug_task(x, y):
|
||||
"""Simple debug task that prints output"""
|
||||
result = x + y
|
||||
print(f"CELERY DEBUG TASK EXECUTED: {x} + {y} = {result}")
|
||||
return result
|
||||
778
PinePods-0.8.2/database_functions/validate_database.py
Normal file
778
PinePods-0.8.2/database_functions/validate_database.py
Normal file
@@ -0,0 +1,778 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Database Validator for PinePods
|
||||
|
||||
This script validates that an existing database matches the expected schema
|
||||
by using the migration system as the source of truth.
|
||||
|
||||
Usage:
|
||||
python validate_database.py --db-type mysql --db-host localhost --db-port 3306 --db-user root --db-password pass --db-name pinepods_database
|
||||
python validate_database.py --db-type postgresql --db-host localhost --db-port 5432 --db-user postgres --db-password pass --db-name pinepods_database
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import tempfile
|
||||
import logging
|
||||
from typing import Dict, List, Set, Tuple, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
import importlib.util
|
||||
|
||||
# Add the parent directory to path so we can import database_functions
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, parent_dir)
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
try:
|
||||
import mysql.connector
|
||||
MYSQL_AVAILABLE = True
|
||||
except ImportError:
|
||||
MYSQL_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import psycopg
|
||||
POSTGRESQL_AVAILABLE = True
|
||||
except ImportError:
|
||||
POSTGRESQL_AVAILABLE = False
|
||||
|
||||
from database_functions.migrations import get_migration_manager
|
||||
|
||||
|
||||
@dataclass
|
||||
class TableInfo:
|
||||
"""Information about a database table"""
|
||||
name: str
|
||||
columns: Dict[str, Dict[str, Any]]
|
||||
indexes: Dict[str, Dict[str, Any]]
|
||||
constraints: Dict[str, Dict[str, Any]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of database validation"""
|
||||
is_valid: bool
|
||||
missing_tables: List[str]
|
||||
extra_tables: List[str]
|
||||
table_differences: Dict[str, Dict[str, Any]]
|
||||
missing_indexes: List[Tuple[str, str]] # (table, index)
|
||||
extra_indexes: List[Tuple[str, str]]
|
||||
missing_constraints: List[Tuple[str, str]] # (table, constraint)
|
||||
extra_constraints: List[Tuple[str, str]]
|
||||
column_differences: Dict[str, Dict[str, Dict[str, Any]]] # table -> column -> differences
|
||||
|
||||
|
||||
class DatabaseInspector:
|
||||
"""Base class for database inspection"""
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def get_tables(self) -> Set[str]:
|
||||
"""Get all table names"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_table_info(self, table_name: str) -> TableInfo:
|
||||
"""Get detailed information about a table"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_all_table_info(self) -> Dict[str, TableInfo]:
|
||||
"""Get information about all tables"""
|
||||
tables = {}
|
||||
for table_name in self.get_tables():
|
||||
tables[table_name] = self.get_table_info(table_name)
|
||||
return tables
|
||||
|
||||
|
||||
class MySQLInspector(DatabaseInspector):
|
||||
"""MySQL database inspector"""
|
||||
|
||||
def get_tables(self) -> Set[str]:
|
||||
cursor = self.connection.cursor()
|
||||
cursor.execute("SHOW TABLES")
|
||||
tables = {row[0] for row in cursor.fetchall()}
|
||||
cursor.close()
|
||||
return tables
|
||||
|
||||
def get_table_info(self, table_name: str) -> TableInfo:
|
||||
cursor = self.connection.cursor(dictionary=True)
|
||||
|
||||
# Get column information
|
||||
cursor.execute(f"DESCRIBE `{table_name}`")
|
||||
columns = {}
|
||||
for row in cursor.fetchall():
|
||||
columns[row['Field']] = {
|
||||
'type': row['Type'],
|
||||
'null': row['Null'],
|
||||
'key': row['Key'],
|
||||
'default': row['Default'],
|
||||
'extra': row['Extra']
|
||||
}
|
||||
|
||||
# Get index information
|
||||
cursor.execute(f"SHOW INDEX FROM `{table_name}`")
|
||||
indexes = {}
|
||||
for row in cursor.fetchall():
|
||||
index_name = row['Key_name']
|
||||
if index_name not in indexes:
|
||||
indexes[index_name] = {
|
||||
'columns': [],
|
||||
'unique': not row['Non_unique'],
|
||||
'type': row['Index_type']
|
||||
}
|
||||
indexes[index_name]['columns'].append(row['Column_name'])
|
||||
|
||||
# Get constraint information (foreign keys, etc.)
|
||||
cursor.execute(f"""
|
||||
SELECT kcu.CONSTRAINT_NAME, tc.CONSTRAINT_TYPE, kcu.COLUMN_NAME,
|
||||
kcu.REFERENCED_TABLE_NAME, kcu.REFERENCED_COLUMN_NAME
|
||||
FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu
|
||||
JOIN INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc
|
||||
ON kcu.CONSTRAINT_NAME = tc.CONSTRAINT_NAME
|
||||
AND kcu.TABLE_SCHEMA = tc.TABLE_SCHEMA
|
||||
WHERE kcu.TABLE_SCHEMA = DATABASE() AND kcu.TABLE_NAME = %s
|
||||
AND kcu.REFERENCED_TABLE_NAME IS NOT NULL
|
||||
""", (table_name,))
|
||||
|
||||
constraints = {}
|
||||
for row in cursor.fetchall():
|
||||
constraint_name = row['CONSTRAINT_NAME']
|
||||
constraints[constraint_name] = {
|
||||
'type': 'FOREIGN KEY',
|
||||
'column': row['COLUMN_NAME'],
|
||||
'referenced_table': row['REFERENCED_TABLE_NAME'],
|
||||
'referenced_column': row['REFERENCED_COLUMN_NAME']
|
||||
}
|
||||
|
||||
cursor.close()
|
||||
return TableInfo(table_name, columns, indexes, constraints)
|
||||
|
||||
|
||||
class PostgreSQLInspector(DatabaseInspector):
|
||||
"""PostgreSQL database inspector"""
|
||||
|
||||
def get_tables(self) -> Set[str]:
|
||||
cursor = self.connection.cursor()
|
||||
cursor.execute("""
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public' AND table_type = 'BASE TABLE'
|
||||
""")
|
||||
tables = {row[0] for row in cursor.fetchall()}
|
||||
cursor.close()
|
||||
return tables
|
||||
|
||||
def get_table_info(self, table_name: str) -> TableInfo:
|
||||
cursor = self.connection.cursor()
|
||||
|
||||
# Get column information
|
||||
cursor.execute("""
|
||||
SELECT column_name, data_type, is_nullable, column_default,
|
||||
character_maximum_length, numeric_precision, numeric_scale
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public' AND table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
""", (table_name,))
|
||||
|
||||
columns = {}
|
||||
for row in cursor.fetchall():
|
||||
col_name, data_type, is_nullable, default, max_length, precision, scale = row
|
||||
type_str = data_type
|
||||
if max_length:
|
||||
type_str += f"({max_length})"
|
||||
elif precision:
|
||||
if scale:
|
||||
type_str += f"({precision},{scale})"
|
||||
else:
|
||||
type_str += f"({precision})"
|
||||
|
||||
columns[col_name] = {
|
||||
'type': type_str,
|
||||
'null': is_nullable,
|
||||
'default': default,
|
||||
'max_length': max_length,
|
||||
'precision': precision,
|
||||
'scale': scale
|
||||
}
|
||||
|
||||
# Get index information
|
||||
cursor.execute("""
|
||||
SELECT i.relname as index_name,
|
||||
array_agg(a.attname ORDER BY c.ordinality) as columns,
|
||||
ix.indisunique as is_unique,
|
||||
ix.indisprimary as is_primary
|
||||
FROM pg_class t
|
||||
JOIN pg_index ix ON t.oid = ix.indrelid
|
||||
JOIN pg_class i ON i.oid = ix.indexrelid
|
||||
JOIN unnest(ix.indkey) WITH ORDINALITY c(colnum, ordinality) ON true
|
||||
JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = c.colnum
|
||||
WHERE t.relname = %s AND t.relkind = 'r'
|
||||
GROUP BY i.relname, ix.indisunique, ix.indisprimary
|
||||
""", (table_name,))
|
||||
|
||||
indexes = {}
|
||||
for row in cursor.fetchall():
|
||||
index_name, columns_list, is_unique, is_primary = row
|
||||
indexes[index_name] = {
|
||||
'columns': columns_list,
|
||||
'unique': is_unique,
|
||||
'primary': is_primary
|
||||
}
|
||||
|
||||
# Get constraint information
|
||||
cursor.execute("""
|
||||
SELECT con.conname as constraint_name,
|
||||
con.contype as constraint_type,
|
||||
array_agg(att.attname) as columns,
|
||||
cl.relname as referenced_table,
|
||||
array_agg(att2.attname) as referenced_columns
|
||||
FROM pg_constraint con
|
||||
JOIN pg_class t ON con.conrelid = t.oid
|
||||
JOIN pg_attribute att ON att.attrelid = t.oid AND att.attnum = ANY(con.conkey)
|
||||
LEFT JOIN pg_class cl ON con.confrelid = cl.oid
|
||||
LEFT JOIN pg_attribute att2 ON att2.attrelid = cl.oid AND att2.attnum = ANY(con.confkey)
|
||||
WHERE t.relname = %s
|
||||
GROUP BY con.conname, con.contype, cl.relname
|
||||
""", (table_name,))
|
||||
|
||||
constraints = {}
|
||||
for row in cursor.fetchall():
|
||||
constraint_name, constraint_type, columns_list, ref_table, ref_columns = row
|
||||
constraints[constraint_name] = {
|
||||
'type': constraint_type,
|
||||
'columns': columns_list,
|
||||
'referenced_table': ref_table,
|
||||
'referenced_columns': ref_columns
|
||||
}
|
||||
|
||||
cursor.close()
|
||||
return TableInfo(table_name, columns, indexes, constraints)
|
||||
|
||||
|
||||
class DatabaseValidator:
|
||||
"""Main database validator class"""
|
||||
|
||||
def __init__(self, db_type: str, db_config: Dict[str, Any]):
|
||||
self.db_type = db_type.lower()
|
||||
# Normalize mariadb to mysql since they use the same connector
|
||||
if self.db_type == 'mariadb':
|
||||
self.db_type = 'mysql'
|
||||
self.db_config = db_config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def create_test_database(self) -> Tuple[Any, str]:
|
||||
"""Create a temporary database and run all migrations"""
|
||||
if self.db_type == 'mysql':
|
||||
return self._create_mysql_test_db()
|
||||
elif self.db_type == 'postgresql':
|
||||
return self._create_postgresql_test_db()
|
||||
else:
|
||||
raise ValueError(f"Unsupported database type: {self.db_type}")
|
||||
|
||||
def _create_mysql_test_db(self) -> Tuple[Any, str]:
|
||||
"""Create MySQL test database"""
|
||||
if not MYSQL_AVAILABLE:
|
||||
raise ImportError("mysql-connector-python is required for MySQL validation")
|
||||
|
||||
# Create temporary database name
|
||||
import uuid
|
||||
test_db_name = f"pinepods_test_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Connect to MySQL server
|
||||
config = self.db_config.copy()
|
||||
config.pop('database', None) # Remove database from config
|
||||
config['use_pure'] = True # Use pure Python implementation to avoid auth plugin issues
|
||||
|
||||
conn = mysql.connector.connect(**config)
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
# Create test database
|
||||
cursor.execute(f"CREATE DATABASE `{test_db_name}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci")
|
||||
cursor.execute(f"USE `{test_db_name}`")
|
||||
cursor.close()
|
||||
|
||||
# Run all migrations
|
||||
self._run_migrations(conn, 'mysql')
|
||||
|
||||
# Create a fresh connection to the test database for schema inspection
|
||||
config['database'] = test_db_name
|
||||
test_conn = mysql.connector.connect(**config)
|
||||
|
||||
# Close the migration connection
|
||||
conn.close()
|
||||
|
||||
return test_conn, test_db_name
|
||||
|
||||
except Exception as e:
|
||||
if cursor:
|
||||
cursor.close()
|
||||
if conn:
|
||||
conn.close()
|
||||
raise e
|
||||
|
||||
def _create_postgresql_test_db(self) -> Tuple[Any, str]:
|
||||
"""Create PostgreSQL test database"""
|
||||
if not POSTGRESQL_AVAILABLE:
|
||||
raise ImportError("psycopg is required for PostgreSQL validation")
|
||||
|
||||
# Create temporary database name
|
||||
import uuid
|
||||
test_db_name = f"pinepods_test_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Connect to PostgreSQL server
|
||||
config = self.db_config.copy()
|
||||
config.pop('dbname', None) # Remove database from config
|
||||
config['dbname'] = 'postgres' # Connect to default database
|
||||
|
||||
conn = psycopg.connect(**config)
|
||||
conn.autocommit = True
|
||||
cursor = conn.cursor()
|
||||
|
||||
try:
|
||||
# Create test database
|
||||
cursor.execute(f'CREATE DATABASE "{test_db_name}"')
|
||||
cursor.close()
|
||||
conn.close()
|
||||
|
||||
# Connect to the new test database
|
||||
config['dbname'] = test_db_name
|
||||
test_conn = psycopg.connect(**config)
|
||||
test_conn.autocommit = True
|
||||
|
||||
# Run all migrations
|
||||
self._run_migrations(test_conn, 'postgresql')
|
||||
|
||||
return test_conn, test_db_name
|
||||
|
||||
except Exception as e:
|
||||
cursor.close()
|
||||
conn.close()
|
||||
raise e
|
||||
|
||||
def _run_migrations(self, conn: Any, db_type: str):
|
||||
"""Run all migrations on the test database using existing migration system"""
|
||||
# Set environment variables for the migration manager
|
||||
import os
|
||||
original_env = {}
|
||||
|
||||
try:
|
||||
# Backup original environment
|
||||
for key in ['DB_TYPE', 'DB_HOST', 'DB_PORT', 'DB_USER', 'DB_PASSWORD', 'DB_NAME']:
|
||||
original_env[key] = os.environ.get(key)
|
||||
|
||||
# Set environment for test database
|
||||
if db_type == 'mysql':
|
||||
os.environ['DB_TYPE'] = 'mysql'
|
||||
os.environ['DB_HOST'] = 'localhost' # We'll override the connection
|
||||
os.environ['DB_PORT'] = '3306'
|
||||
os.environ['DB_USER'] = 'test'
|
||||
os.environ['DB_PASSWORD'] = 'test'
|
||||
os.environ['DB_NAME'] = 'test'
|
||||
else:
|
||||
os.environ['DB_TYPE'] = 'postgresql'
|
||||
os.environ['DB_HOST'] = 'localhost'
|
||||
os.environ['DB_PORT'] = '5432'
|
||||
os.environ['DB_USER'] = 'test'
|
||||
os.environ['DB_PASSWORD'] = 'test'
|
||||
os.environ['DB_NAME'] = 'test'
|
||||
|
||||
# Import and register migrations
|
||||
import database_functions.migration_definitions
|
||||
|
||||
# Get migration manager and override its connection
|
||||
manager = get_migration_manager()
|
||||
manager._connection = conn
|
||||
|
||||
# Run all migrations
|
||||
success = manager.run_migrations()
|
||||
if not success:
|
||||
raise RuntimeError("Failed to apply migrations")
|
||||
|
||||
finally:
|
||||
# Restore original environment
|
||||
for key, value in original_env.items():
|
||||
if value is not None:
|
||||
os.environ[key] = value
|
||||
elif key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
def validate_database(self) -> ValidationResult:
|
||||
"""Validate the actual database against the expected schema"""
|
||||
# Create test database with perfect schema
|
||||
test_conn, test_db_name = self.create_test_database()
|
||||
|
||||
try:
|
||||
# Connect to actual database
|
||||
actual_conn = self._connect_to_actual_database()
|
||||
|
||||
try:
|
||||
# Get schema information from both databases
|
||||
if self.db_type == 'mysql':
|
||||
expected_inspector = MySQLInspector(test_conn)
|
||||
actual_inspector = MySQLInspector(actual_conn)
|
||||
# Extract schemas
|
||||
expected_schema = expected_inspector.get_all_table_info()
|
||||
actual_schema = actual_inspector.get_all_table_info()
|
||||
else:
|
||||
# For PostgreSQL, create fresh connection for expected schema since migration manager closes it
|
||||
fresh_test_conn = psycopg.connect(
|
||||
host=self.db_config['host'],
|
||||
port=self.db_config['port'],
|
||||
user=self.db_config['user'],
|
||||
password=self.db_config['password'],
|
||||
dbname=test_db_name
|
||||
)
|
||||
fresh_test_conn.autocommit = True
|
||||
|
||||
try:
|
||||
expected_inspector = PostgreSQLInspector(fresh_test_conn)
|
||||
actual_inspector = PostgreSQLInspector(actual_conn)
|
||||
|
||||
# Extract schemas
|
||||
expected_schema = expected_inspector.get_all_table_info()
|
||||
actual_schema = actual_inspector.get_all_table_info()
|
||||
finally:
|
||||
fresh_test_conn.close()
|
||||
|
||||
# DEBUG: Print what we're actually comparing
|
||||
print(f"\n🔍 DEBUG: Expected schema has {len(expected_schema)} tables:")
|
||||
for table in sorted(expected_schema.keys()):
|
||||
cols = list(expected_schema[table].columns.keys())
|
||||
print(f" {table}: {len(cols)} columns - {', '.join(cols[:5])}{'...' if len(cols) > 5 else ''}")
|
||||
|
||||
print(f"\n🔍 DEBUG: Actual schema has {len(actual_schema)} tables:")
|
||||
for table in sorted(actual_schema.keys()):
|
||||
cols = list(actual_schema[table].columns.keys())
|
||||
print(f" {table}: {len(cols)} columns - {', '.join(cols[:5])}{'...' if len(cols) > 5 else ''}")
|
||||
|
||||
# Check specifically for Playlists table
|
||||
if 'Playlists' in expected_schema and 'Playlists' in actual_schema:
|
||||
exp_cols = set(expected_schema['Playlists'].columns.keys())
|
||||
act_cols = set(actual_schema['Playlists'].columns.keys())
|
||||
print(f"\n🔍 DEBUG: Playlists comparison:")
|
||||
print(f" Expected columns: {sorted(exp_cols)}")
|
||||
print(f" Actual columns: {sorted(act_cols)}")
|
||||
print(f" Missing from actual: {sorted(exp_cols - act_cols)}")
|
||||
print(f" Extra in actual: {sorted(act_cols - exp_cols)}")
|
||||
|
||||
# Compare schemas
|
||||
result = self._compare_schemas(expected_schema, actual_schema)
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
actual_conn.close()
|
||||
|
||||
finally:
|
||||
# Clean up test database - this will close test_conn
|
||||
self._cleanup_test_database(test_conn, test_db_name)
|
||||
|
||||
def _connect_to_actual_database(self) -> Any:
|
||||
"""Connect to the actual database"""
|
||||
if self.db_type == 'mysql':
|
||||
config = self.db_config.copy()
|
||||
# Ensure autocommit is enabled for MySQL
|
||||
config['autocommit'] = True
|
||||
config['use_pure'] = True # Use pure Python implementation to avoid auth plugin issues
|
||||
return mysql.connector.connect(**config)
|
||||
else:
|
||||
return psycopg.connect(**self.db_config)
|
||||
|
||||
def _cleanup_test_database(self, test_conn: Any, test_db_name: str):
|
||||
"""Clean up the test database"""
|
||||
try:
|
||||
# Close the test connection first
|
||||
if test_conn:
|
||||
test_conn.close()
|
||||
|
||||
if self.db_type == 'mysql':
|
||||
config = self.db_config.copy()
|
||||
config.pop('database', None)
|
||||
config['use_pure'] = True # Use pure Python implementation to avoid auth plugin issues
|
||||
cleanup_conn = mysql.connector.connect(**config)
|
||||
cursor = cleanup_conn.cursor()
|
||||
cursor.execute(f"DROP DATABASE IF EXISTS `{test_db_name}`")
|
||||
cursor.close()
|
||||
cleanup_conn.close()
|
||||
else:
|
||||
config = self.db_config.copy()
|
||||
config.pop('dbname', None)
|
||||
config['dbname'] = 'postgres'
|
||||
cleanup_conn = psycopg.connect(**config)
|
||||
cleanup_conn.autocommit = True
|
||||
cursor = cleanup_conn.cursor()
|
||||
cursor.execute(f'DROP DATABASE IF EXISTS "{test_db_name}"')
|
||||
cursor.close()
|
||||
cleanup_conn.close()
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to clean up test database {test_db_name}: {e}")
|
||||
|
||||
def _compare_schemas(self, expected: Dict[str, TableInfo], actual: Dict[str, TableInfo]) -> ValidationResult:
|
||||
"""Compare expected and actual database schemas"""
|
||||
expected_tables = set(expected.keys())
|
||||
actual_tables = set(actual.keys())
|
||||
|
||||
missing_tables = list(expected_tables - actual_tables)
|
||||
extra_tables = list(actual_tables - expected_tables)
|
||||
|
||||
table_differences = {}
|
||||
missing_indexes = []
|
||||
extra_indexes = []
|
||||
missing_constraints = []
|
||||
extra_constraints = []
|
||||
column_differences = {}
|
||||
|
||||
# Compare common tables
|
||||
common_tables = expected_tables & actual_tables
|
||||
for table_name in common_tables:
|
||||
expected_table = expected[table_name]
|
||||
actual_table = actual[table_name]
|
||||
|
||||
# Compare columns
|
||||
table_col_diffs = self._compare_columns(expected_table.columns, actual_table.columns)
|
||||
if table_col_diffs:
|
||||
column_differences[table_name] = table_col_diffs
|
||||
|
||||
# Compare indexes
|
||||
expected_indexes = set(expected_table.indexes.keys())
|
||||
actual_indexes = set(actual_table.indexes.keys())
|
||||
|
||||
for missing_idx in expected_indexes - actual_indexes:
|
||||
missing_indexes.append((table_name, missing_idx))
|
||||
for extra_idx in actual_indexes - expected_indexes:
|
||||
extra_indexes.append((table_name, extra_idx))
|
||||
|
||||
# Compare constraints
|
||||
expected_constraints = set(expected_table.constraints.keys())
|
||||
actual_constraints = set(actual_table.constraints.keys())
|
||||
|
||||
for missing_const in expected_constraints - actual_constraints:
|
||||
missing_constraints.append((table_name, missing_const))
|
||||
for extra_const in actual_constraints - expected_constraints:
|
||||
extra_constraints.append((table_name, extra_const))
|
||||
|
||||
# Only fail on critical issues:
|
||||
# - Missing tables (CRITICAL)
|
||||
# - Missing columns (CRITICAL)
|
||||
# Extra tables, extra columns, and type differences are warnings only
|
||||
critical_issues = []
|
||||
critical_issues.extend(missing_tables)
|
||||
|
||||
# Check for missing columns (critical) - but only in expected tables
|
||||
for table, col_diffs in column_differences.items():
|
||||
# Skip extra tables entirely - they shouldn't be validated
|
||||
if table in extra_tables:
|
||||
continue
|
||||
|
||||
for col, diff in col_diffs.items():
|
||||
if diff['status'] == 'missing':
|
||||
critical_issues.append(f"missing column {col} in table {table}")
|
||||
|
||||
is_valid = len(critical_issues) == 0
|
||||
|
||||
return ValidationResult(
|
||||
is_valid=is_valid,
|
||||
missing_tables=missing_tables,
|
||||
extra_tables=extra_tables,
|
||||
table_differences=table_differences,
|
||||
missing_indexes=missing_indexes,
|
||||
extra_indexes=extra_indexes,
|
||||
missing_constraints=missing_constraints,
|
||||
extra_constraints=extra_constraints,
|
||||
column_differences=column_differences
|
||||
)
|
||||
|
||||
def _compare_columns(self, expected: Dict[str, Dict[str, Any]], actual: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
||||
"""Compare column definitions between expected and actual"""
|
||||
differences = {}
|
||||
|
||||
expected_cols = set(expected.keys())
|
||||
actual_cols = set(actual.keys())
|
||||
|
||||
# Missing columns
|
||||
for missing_col in expected_cols - actual_cols:
|
||||
differences[missing_col] = {'status': 'missing', 'expected': expected[missing_col]}
|
||||
|
||||
# Extra columns
|
||||
for extra_col in actual_cols - expected_cols:
|
||||
differences[extra_col] = {'status': 'extra', 'actual': actual[extra_col]}
|
||||
|
||||
# Different columns
|
||||
for col_name in expected_cols & actual_cols:
|
||||
expected_col = expected[col_name]
|
||||
actual_col = actual[col_name]
|
||||
|
||||
col_diffs = {}
|
||||
for key in expected_col:
|
||||
if key in actual_col and expected_col[key] != actual_col[key]:
|
||||
col_diffs[key] = {'expected': expected_col[key], 'actual': actual_col[key]}
|
||||
|
||||
if col_diffs:
|
||||
differences[col_name] = {'status': 'different', 'differences': col_diffs}
|
||||
|
||||
return differences
|
||||
|
||||
|
||||
def print_validation_report(result: ValidationResult):
|
||||
"""Print a detailed validation report"""
|
||||
print("=" * 80)
|
||||
print("DATABASE VALIDATION REPORT")
|
||||
print("=" * 80)
|
||||
|
||||
# Count critical vs warning issues
|
||||
critical_issues = []
|
||||
warning_issues = []
|
||||
|
||||
# Missing tables are critical
|
||||
critical_issues.extend(result.missing_tables)
|
||||
|
||||
# Missing columns are critical, others are warnings
|
||||
for table, col_diffs in result.column_differences.items():
|
||||
for col, diff in col_diffs.items():
|
||||
if diff['status'] == 'missing':
|
||||
critical_issues.append(f"Missing column {col} in table {table}")
|
||||
else:
|
||||
warning_issues.append((table, col, diff))
|
||||
|
||||
# Extra tables are warnings
|
||||
warning_issues.extend([('EXTRA_TABLE', table, None) for table in result.extra_tables])
|
||||
|
||||
if result.is_valid:
|
||||
if warning_issues:
|
||||
print("✅ DATABASE IS VALID - No critical issues found!")
|
||||
print("⚠️ Some warnings exist but don't affect functionality")
|
||||
else:
|
||||
print("✅ DATABASE IS PERFECT - All checks passed!")
|
||||
else:
|
||||
print("❌ DATABASE VALIDATION FAILED - Critical issues found")
|
||||
|
||||
print()
|
||||
|
||||
# Show critical issues
|
||||
if critical_issues:
|
||||
print("🔴 CRITICAL ISSUES (MUST BE FIXED):")
|
||||
if result.missing_tables:
|
||||
print(" Missing Tables:")
|
||||
for table in result.missing_tables:
|
||||
print(f" - {table}")
|
||||
|
||||
# Show missing columns
|
||||
for table, col_diffs in result.column_differences.items():
|
||||
missing_cols = [col for col, diff in col_diffs.items() if diff['status'] == 'missing']
|
||||
if missing_cols:
|
||||
print(f" Missing Columns in {table}:")
|
||||
for col in missing_cols:
|
||||
print(f" - {col}")
|
||||
print()
|
||||
|
||||
# Show warnings
|
||||
if warning_issues:
|
||||
print("⚠️ WARNINGS (ACCEPTABLE DIFFERENCES):")
|
||||
|
||||
if result.extra_tables:
|
||||
print(" Extra Tables (ignored):")
|
||||
for table in result.extra_tables:
|
||||
print(f" - {table}")
|
||||
|
||||
# Show column warnings
|
||||
for table, col_diffs in result.column_differences.items():
|
||||
table_warnings = []
|
||||
for col, diff in col_diffs.items():
|
||||
if diff['status'] == 'extra':
|
||||
table_warnings.append(f"Extra column: {col}")
|
||||
elif diff['status'] == 'different':
|
||||
details = []
|
||||
for key, values in diff['differences'].items():
|
||||
details.append(f"{key}: {values}")
|
||||
table_warnings.append(f"Different column {col}: {', '.join(details)}")
|
||||
|
||||
if table_warnings:
|
||||
print(f" Table {table}:")
|
||||
for warning in table_warnings:
|
||||
print(f" - {warning}")
|
||||
print()
|
||||
|
||||
if result.missing_indexes:
|
||||
print("🟡 MISSING INDEXES:")
|
||||
for table, index in result.missing_indexes:
|
||||
print(f" - {table}.{index}")
|
||||
print()
|
||||
|
||||
if result.extra_indexes:
|
||||
print("🟡 EXTRA INDEXES:")
|
||||
for table, index in result.extra_indexes:
|
||||
print(f" - {table}.{index}")
|
||||
print()
|
||||
|
||||
if result.missing_constraints:
|
||||
print("🟡 MISSING CONSTRAINTS:")
|
||||
for table, constraint in result.missing_constraints:
|
||||
print(f" - {table}.{constraint}")
|
||||
print()
|
||||
|
||||
if result.extra_constraints:
|
||||
print("🟡 EXTRA CONSTRAINTS:")
|
||||
for table, constraint in result.extra_constraints:
|
||||
print(f" - {table}.{constraint}")
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
parser = argparse.ArgumentParser(description='Validate PinePods database schema')
|
||||
parser.add_argument('--db-type', required=True, choices=['mysql', 'mariadb', 'postgresql'], help='Database type')
|
||||
parser.add_argument('--db-host', required=True, help='Database host')
|
||||
parser.add_argument('--db-port', required=True, type=int, help='Database port')
|
||||
parser.add_argument('--db-user', required=True, help='Database user')
|
||||
parser.add_argument('--db-password', required=True, help='Database password')
|
||||
parser.add_argument('--db-name', required=True, help='Database name')
|
||||
parser.add_argument('--verbose', '-v', action='store_true', help='Enable verbose logging')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set up logging
|
||||
level = logging.DEBUG if args.verbose else logging.INFO
|
||||
logging.basicConfig(level=level, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# Build database config
|
||||
if args.db_type in ['mysql', 'mariadb']:
|
||||
db_config = {
|
||||
'host': args.db_host,
|
||||
'port': args.db_port,
|
||||
'user': args.db_user,
|
||||
'password': args.db_password,
|
||||
'database': args.db_name,
|
||||
'charset': 'utf8mb4',
|
||||
'collation': 'utf8mb4_unicode_ci'
|
||||
}
|
||||
else: # postgresql
|
||||
db_config = {
|
||||
'host': args.db_host,
|
||||
'port': args.db_port,
|
||||
'user': args.db_user,
|
||||
'password': args.db_password,
|
||||
'dbname': args.db_name
|
||||
}
|
||||
|
||||
try:
|
||||
# Create validator and run validation
|
||||
validator = DatabaseValidator(args.db_type, db_config)
|
||||
result = validator.validate_database()
|
||||
|
||||
# Print report
|
||||
print_validation_report(result)
|
||||
|
||||
# Exit with appropriate code
|
||||
sys.exit(0 if result.is_valid else 1)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Validation failed with error: {e}")
|
||||
if args.verbose:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user