# tasks.py - Define Celery tasks with Valkey as broker from celery import Celery import time import os import asyncio import datetime import requests from threading import Thread import json import sys import logging from typing import Dict, Any, Optional, List # Make sure pinepods is in the Python path sys.path.append('/pinepods') database_type = str(os.getenv('DB_TYPE', 'mariadb')) class Web_Key: def __init__(self): self.web_key = None def get_web_key(self, cnx): # Import only when needed to avoid circular imports from database_functions.functions import get_web_key as get_key self.web_key = get_key(cnx, database_type) return self.web_key base_webkey = Web_Key() # Set up logging logger = logging.getLogger("celery_tasks") # Import the WebSocket manager directly from clientapi try: from clients.clientapi import manager as websocket_manager print("Successfully imported WebSocket manager from clientapi") except ImportError as e: logger.error(f"Failed to import WebSocket manager: {e}") websocket_manager = None # Create a dedicated event loop thread for async operations _event_loop = None _event_loop_thread = None def start_background_loop(): global _event_loop, _event_loop_thread # Only start if not already running if _event_loop is None: # Function to run event loop in background thread def run_event_loop(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) global _event_loop _event_loop = loop loop.run_forever() # Start the background thread _event_loop_thread = Thread(target=run_event_loop, daemon=True) _event_loop_thread.start() # Wait a moment for the loop to start time.sleep(0.1) print("Started background event loop for WebSocket broadcasts") # Start the event loop when this module is imported start_background_loop() # Use the existing Valkey connection for Celery valkey_host = os.environ.get("VALKEY_HOST", "localhost") valkey_port = os.environ.get("VALKEY_PORT", "6379") broker_url = f"redis://{valkey_host}:{valkey_port}/0" backend_url = f"redis://{valkey_host}:{valkey_port}/0" # Initialize Celery with Valkey as broker and result backend celery_app = Celery('pinepods', broker=broker_url, backend=backend_url) # Configure Celery for best performance with downloads celery_app.conf.update( worker_concurrency=3, # Limit to 3 concurrent downloads per worker task_acks_late=True, # Only acknowledge tasks after they're done task_time_limit=1800, # 30 minutes time limit task_soft_time_limit=1500, # 25 minutes soft time limit worker_prefetch_multiplier=1, # Don't prefetch more tasks than workers ) # Task status tracking in Valkey for all types of tasks class TaskManager: def __init__(self): from database_functions.valkey_client import valkey_client self.valkey_client = valkey_client def register_task(self, task_id: str, task_type: str, user_id: int, item_id: Optional[int] = None, details: Optional[Dict[str, Any]] = None): """Register any Celery task for tracking""" task_data = { "task_id": task_id, "user_id": user_id, "type": task_type, "item_id": item_id, "progress": 0.0, "status": "PENDING", "details": details or {}, "started_at": datetime.datetime.now().isoformat() } self.valkey_client.set(f"task:{task_id}", json.dumps(task_data)) # Set TTL for 24 hours self.valkey_client.expire(f"task:{task_id}", 86400) # Add to user's active tasks list self._add_to_user_tasks(user_id, task_id) # Try to broadcast the update if the WebSocket module is available try: self._broadcast_update(task_id) except Exception as e: logger.error(f"Error broadcasting task update: {e}") def update_task(self, task_id: str, progress: float = None, status: str = None, details: Dict[str, Any] = None): """Update any task's status and progress""" task_json = self.valkey_client.get(f"task:{task_id}") if task_json: task = json.loads(task_json) if progress is not None: task["progress"] = progress if status: task["status"] = status if details: if "details" not in task: task["details"] = {} task["details"].update(details) self.valkey_client.set(f"task:{task_id}", json.dumps(task)) # Try to broadcast the update try: self._broadcast_update(task_id) except Exception as e: logger.error(f"Error broadcasting task update: {e}") def complete_task(self, task_id: str, success: bool = True, result: Any = None): """Mark any task as complete or failed""" task_json = self.valkey_client.get(f"task:{task_id}") if task_json: task = json.loads(task_json) task["progress"] = 100.0 if success else 0.0 task["status"] = "SUCCESS" if success else "FAILED" task["completed_at"] = datetime.datetime.now().isoformat() if result is not None: task["result"] = result self.valkey_client.set(f"task:{task_id}", json.dumps(task)) # Try to broadcast the final update try: self._broadcast_update(task_id) except Exception as e: logger.error(f"Error broadcasting task update: {e}") # Keep completed tasks for 1 hour before expiring self.valkey_client.expire(f"task:{task_id}", 3600) # Remove from user's active tasks list after completion if success: self._remove_from_user_tasks(task.get("user_id"), task_id) def get_task(self, task_id: str) -> Dict[str, Any]: """Get any task's details""" task_json = self.valkey_client.get(f"task:{task_id}") if task_json: return json.loads(task_json) return {} def get_user_tasks(self, user_id: int) -> List[Dict[str, Any]]: """Get all active tasks for a user (all types)""" tasks_list_json = self.valkey_client.get(f"user_tasks:{user_id}") result = [] if tasks_list_json: task_ids = json.loads(tasks_list_json) for task_id in task_ids: task_info = self.get_task(task_id) if task_info: result.append(task_info) return result def _add_to_user_tasks(self, user_id: int, task_id: str): """Add a task to the user's active tasks list""" tasks_list_json = self.valkey_client.get(f"user_tasks:{user_id}") if tasks_list_json: tasks_list = json.loads(tasks_list_json) if task_id not in tasks_list: tasks_list.append(task_id) else: tasks_list = [task_id] self.valkey_client.set(f"user_tasks:{user_id}", json.dumps(tasks_list)) # Set TTL for 7 days self.valkey_client.expire(f"user_tasks:{user_id}", 604800) def _remove_from_user_tasks(self, user_id: int, task_id: str): """Remove a task from the user's active tasks list""" tasks_list_json = self.valkey_client.get(f"user_tasks:{user_id}") if tasks_list_json: tasks_list = json.loads(tasks_list_json) if task_id in tasks_list: tasks_list.remove(task_id) self.valkey_client.set(f"user_tasks:{user_id}", json.dumps(tasks_list)) # Modified _broadcast_update method to avoid circular imports def _broadcast_update(self, task_id: str): """Broadcast task update via HTTP endpoint""" # Get task info task_info = self.get_task(task_id) if not task_info or "user_id" not in task_info: return user_id = task_info["user_id"] cnx = None try: cnx = get_direct_db_connection() # Import broadcaster - delay import to avoid circular dependency sys.path.insert(0, '/pinepods/database_functions') try: from websocket_broadcaster import broadcaster except ImportError: try: from database_functions.websocket_broadcaster import broadcaster except ImportError as e: print(f"Cannot import broadcaster from any location: {e}") return # Get web key web_key = None try: # Get web key using class method to avoid direct import if not base_webkey.web_key: base_webkey.get_web_key(cnx) web_key = base_webkey.web_key except Exception as e: print(f"Error getting web key: {str(e)}") # Fallback to a direct approach if needed try: from database_functions.functions import get_web_key web_key = get_web_key(cnx, database_type) except Exception as e2: print(f"Fallback web key retrieval failed: {str(e2)}") return # Progress and status details for better debugging progress = task_info.get("progress", 0) status = task_info.get("status", "unknown") print(f"Broadcasting task update for user {user_id}, task {task_id}, progress: {progress}, status: {status}") # Broadcast the update result = broadcaster.broadcast_task_update(user_id, task_info, web_key) if result: print(f"Successfully broadcast task update for task {task_id}, progress: {progress}%") else: print(f"Failed to broadcast task update for task {task_id}, progress: {progress}%") except Exception as e: print(f"Error in task broadcast setup: {str(e)}") finally: if cnx: # Close direct connection close_direct_db_connection(cnx) # Initialize a general task manager task_manager = TaskManager() # For backwards compatibility, keep the download_manager name too download_manager = task_manager # Function to get all active tasks including both downloads and other task types def get_all_active_tasks(user_id: int) -> List[Dict[str, Any]]: """Get all active tasks for a user (all types)""" return task_manager.get_user_tasks(user_id) # ---------------------- # IMPROVED CONNECTION HANDLING # ---------------------- def get_direct_db_connection(): """ Create a direct database connection instead of using the pool This is more reliable for Celery workers to avoid pool exhaustion """ db_host = os.environ.get("DB_HOST", "127.0.0.1") db_port = os.environ.get("DB_PORT", "3306") db_user = os.environ.get("DB_USER", "root") db_password = os.environ.get("DB_PASSWORD", "password") db_name = os.environ.get("DB_NAME", "pypods_database") print(f"Creating direct database connection for task") if database_type == "postgresql": import psycopg conninfo = f"host={db_host} port={db_port} user={db_user} password={db_password} dbname={db_name}" return psycopg.connect(conninfo) else: # Default to MariaDB/MySQL try: import mariadb as mysql_connector except ImportError: import mysql.connector return mysql_connector.connect( host=db_host, port=db_port, user=db_user, password=db_password, database=db_name ) def close_direct_db_connection(cnx): """Close a direct database connection""" if cnx: try: cnx.close() print("Direct database connection closed") except Exception as e: print(f"Error closing direct connection: {str(e)}") # Minimal changes to download_podcast_task that should work right away @celery_app.task(bind=True, max_retries=3) def download_podcast_task(self, episode_id: int, user_id: int, database_type: str): """ Celery task to download a podcast episode. Uses retries with exponential backoff for handling transient failures. """ task_id = self.request.id print(f"DOWNLOAD TASK STARTED: ID={task_id}, Episode={episode_id}, User={user_id}") cnx = None try: # Get a direct connection to fetch the title first cnx = get_direct_db_connection() cursor = cnx.cursor() # Get the episode title and podcast name if database_type == "postgresql": # First try to get both the episode title and podcast name query = ''' SELECT e."episodetitle", p."podcastname" FROM "Episodes" e JOIN "Podcasts" p ON e."podcastid" = p."podcastid" WHERE e."episodeid" = %s ''' else: query = ''' SELECT e.EpisodeTitle, p.PodcastName FROM Episodes e JOIN Podcasts p ON e.PodcastID = p.PodcastID WHERE e.EpisodeID = %s ''' cursor.execute(query, (episode_id,)) result = cursor.fetchone() cursor.close() # Extract episode title and podcast name episode_title = None podcast_name = None if result: if isinstance(result, dict): # Dictionary result if "episodetitle" in result: # PostgreSQL lowercase episode_title = result["episodetitle"] podcast_name = result.get("podcastname") else: # MariaDB uppercase episode_title = result["EpisodeTitle"] podcast_name = result.get("PodcastName") else: # Tuple result episode_title = result[0] if len(result) > 0 else None podcast_name = result[1] if len(result) > 1 else None # Format a nice display title display_title = "Unknown Episode" if episode_title and episode_title != "None" and episode_title.strip(): display_title = episode_title elif podcast_name: display_title = f"{podcast_name} - Episode" else: display_title = f"Episode #{episode_id}" print(f"Using display title for episode {episode_id}: {display_title}") # Register task with more details task_manager.register_task( task_id=task_id, task_type="podcast_download", user_id=user_id, item_id=episode_id, details={ "episode_id": episode_id, "episode_title": display_title, "status_text": f"Preparing to download {display_title}" } ) # Define a progress callback with the display title def progress_callback(progress, status=None): status_message = "" if status == "DOWNLOADING": status_message = f"Downloading {display_title}" elif status == "PROCESSING": status_message = f"Processing {display_title}" elif status == "FINALIZING": status_message = f"Finalizing {display_title}" task_manager.update_task(task_id, progress, status, { "episode_id": episode_id, "episode_title": display_title, "status_text": status_message }) # Close the connection used for title lookup close_direct_db_connection(cnx) # Get a fresh connection for the download cnx = get_direct_db_connection() # Import the download function from database_functions.functions import download_podcast print(f"Starting download for episode: {episode_id} ({display_title}), user: {user_id}, task: {task_id}") # Execute the download with progress reporting success = download_podcast( cnx, database_type, episode_id, user_id, task_id, progress_callback=progress_callback ) # Mark task as complete with a nice message completion_message = f"{'Successfully downloaded' if success else 'Failed to download'} {display_title}" task_manager.complete_task( task_id, success, { "episode_id": episode_id, "episode_title": display_title, "status_text": completion_message } ) return success except Exception as exc: print(f"Error downloading podcast {episode_id}: {str(exc)}") # Mark task as failed task_manager.complete_task( task_id, False, { "episode_id": episode_id, "episode_title": f"Episode #{episode_id}", "status_text": f"Download failed: {str(exc)}" } ) # Retry with exponential backoff (5s, 25s, 125s) countdown = 5 * (2 ** self.request.retries) self.retry(exc=exc, countdown=countdown) finally: # Always close the connection if cnx: close_direct_db_connection(cnx) @celery_app.task(bind=True, max_retries=3) def download_youtube_video_task(self, video_id: int, user_id: int, database_type: str): """ Celery task to download a YouTube video. Uses retries with exponential backoff for handling transient failures. """ task_id = self.request.id print(f"YOUTUBE DOWNLOAD TASK STARTED: ID={task_id}, Video={video_id}, User={user_id}") cnx = None try: # Get a direct connection to fetch the title first cnx = get_direct_db_connection() cursor = cnx.cursor() # Get the video title and channel name if database_type == "postgresql": # First try to get both the video title and channel name query = ''' SELECT v."videotitle", p."podcastname" FROM "YouTubeVideos" v JOIN "Podcasts" p ON v."podcastid" = p."podcastid" WHERE v."videoid" = %s ''' else: query = ''' SELECT v.VideoTitle, p.PodcastName FROM YouTubeVideos v JOIN Podcasts p ON v.PodcastID = p.PodcastID WHERE v.VideoID = %s ''' cursor.execute(query, (video_id,)) result = cursor.fetchone() cursor.close() # Extract video title and channel name video_title = None channel_name = None if result: if isinstance(result, dict): # Dictionary result if "videotitle" in result: # PostgreSQL lowercase video_title = result["videotitle"] channel_name = result.get("podcastname") else: # MariaDB uppercase video_title = result["VideoTitle"] channel_name = result.get("PodcastName") else: # Tuple result video_title = result[0] if len(result) > 0 else None channel_name = result[1] if len(result) > 1 else None # Format a nice display title display_title = "Unknown Video" if video_title and video_title != "None" and video_title.strip(): display_title = video_title elif channel_name: display_title = f"{channel_name} - Video" else: display_title = f"YouTube Video #{video_id}" print(f"Using display title for video {video_id}: {display_title}") # Register task with more details task_manager.register_task( task_id=task_id, task_type="youtube_download", user_id=user_id, item_id=video_id, details={ "item_id": video_id, "item_title": display_title, "status_text": f"Preparing to download {display_title}" } ) # Close the connection used for title lookup close_direct_db_connection(cnx) # Get a fresh connection for the download cnx = get_direct_db_connection() # Import the download function from database_functions.functions import download_youtube_video print(f"Starting download for YouTube video: {video_id} ({display_title}), user: {user_id}, task: {task_id}") # Define a progress callback with the display title def progress_callback(progress, status=None): status_message = "" if status == "DOWNLOADING": status_message = f"Downloading {display_title}" elif status == "PROCESSING": status_message = f"Processing {display_title}" elif status == "FINALIZING": status_message = f"Finalizing {display_title}" task_manager.update_task(task_id, progress, status, { "item_id": video_id, "item_title": display_title, "status_text": status_message }) # Check if the download_youtube_video function accepts progress_callback parameter import inspect try: signature = inspect.signature(download_youtube_video) has_progress_callback = 'progress_callback' in signature.parameters except (TypeError, ValueError): has_progress_callback = False # Execute the download with progress callback if supported, otherwise without it if has_progress_callback: success = download_youtube_video( cnx, database_type, video_id, user_id, task_id, progress_callback=progress_callback ) else: # Call without the progress_callback parameter success = download_youtube_video( cnx, database_type, video_id, user_id, task_id ) # Since we can't use progress callbacks directly, manually update progress after completion task_manager.update_task(task_id, 100 if success else 0, "SUCCESS" if success else "FAILED", { "item_id": video_id, "item_title": display_title, "status_text": f"{'Download complete' if success else 'Download failed'}" }) # Mark task as complete with a nice message completion_message = f"{'Successfully downloaded' if success else 'Failed to download'} {display_title}" task_manager.complete_task( task_id, success, { "item_id": video_id, "item_title": display_title, "status_text": completion_message } ) return success except Exception as exc: print(f"Error downloading YouTube video {video_id}: {str(exc)}") # Mark task as failed but include video title in the details task_manager.complete_task( task_id, False, { "item_id": video_id, "item_title": f"YouTube Video #{video_id}", "status_text": f"Download failed: {str(exc)}" } ) # Retry with exponential backoff (5s, 25s, 125s) countdown = 5 * (2 ** self.request.retries) self.retry(exc=exc, countdown=countdown) finally: # Always close the connection if cnx: close_direct_db_connection(cnx) @celery_app.task def queue_podcast_downloads(podcast_id: int, user_id: int, database_type: str, is_youtube: bool = False): """ Task to queue individual download tasks for all episodes/videos in a podcast. This adds downloads to the queue in small batches to prevent overwhelming the system. """ cnx = None try: # Get a direct connection cnx = get_direct_db_connection() from database_functions.functions import ( get_episode_ids_for_podcast, get_video_ids_for_podcast, check_downloaded ) if is_youtube: item_ids = get_video_ids_for_podcast(cnx, database_type, podcast_id) print(f"Queueing {len(item_ids)} YouTube videos for download") # Process YouTube items in batches batch_size = 5 for i in range(0, len(item_ids), batch_size): batch = item_ids[i:i+batch_size] for item_id in batch: if not check_downloaded(cnx, database_type, user_id, item_id, is_youtube): download_youtube_video_task.delay(item_id, user_id, database_type) # Add a small delay between batches if i + batch_size < len(item_ids): time.sleep(2) else: # Get episode IDs (should return dicts with id and title) episodes = get_episode_ids_for_podcast(cnx, database_type, podcast_id) print(f"Queueing {len(episodes)} podcast episodes for download") # Process episodes in batches batch_size = 5 for i in range(0, len(episodes), batch_size): batch = episodes[i:i+batch_size] for episode in batch: # Handle both possible formats (dict or simple ID) if isinstance(episode, dict) and "id" in episode: episode_id = episode["id"] else: # Fall back to treating it as just an ID episode_id = episode if not check_downloaded(cnx, database_type, user_id, episode_id, is_youtube): # Pass just the ID, the task will look up the title download_podcast_task.delay(episode_id, user_id, database_type) # Add a small delay between batches if i + batch_size < len(episodes): time.sleep(2) return f"Queued {len(episodes if not is_youtube else item_ids)} items for download" finally: if cnx: close_direct_db_connection(cnx) # Helper task to clean up old download records @celery_app.task def cleanup_old_downloads(): """ Periodic task to clean up old download records from Valkey """ from database_functions.valkey_client import valkey_client # This would need to be implemented with a scan operation # For simplicity, we rely on Redis/Valkey TTL mechanisms print("Running download cleanup task") # Example task for refreshing podcast feeds @celery_app.task(bind=True, max_retries=2) def refresh_feed_task(self, user_id: int, database_type: str): """ Celery task to refresh podcast feeds for a user. """ task_id = self.request.id cnx = None try: # Register task task_manager.register_task( task_id=task_id, task_type="feed_refresh", user_id=user_id, details={"description": "Refreshing podcast feeds"} ) # Get a direct database connection cnx = get_direct_db_connection() # Get list of podcasts to refresh # Then update progress as each one completes try: # Here you would have your actual feed refresh implementation # with periodic progress updates task_manager.update_task(task_id, 10, "PROGRESS", {"status_text": "Fetching podcast list"}) # Simulate feed refresh process with progress updates # Replace with your actual implementation total_podcasts = 10 # Example count for i in range(total_podcasts): # Update progress for each podcast progress = (i + 1) / total_podcasts * 100 task_manager.update_task( task_id, progress, "PROGRESS", {"status_text": f"Refreshing podcast {i+1}/{total_podcasts}"} ) # Simulated work - replace with actual refresh logic time.sleep(0.5) # Complete the task task_manager.complete_task(task_id, True, {"refreshed_count": total_podcasts}) return True except Exception as e: raise e except Exception as exc: print(f"Error refreshing feeds for user {user_id}: {str(exc)}") task_manager.complete_task(task_id, False, {"error": str(exc)}) self.retry(exc=exc, countdown=30) finally: # Always close the connection if cnx: close_direct_db_connection(cnx) # Simple debug task @celery_app.task def debug_task(x, y): """Simple debug task that prints output""" result = x + y print(f"CELERY DEBUG TASK EXECUTED: {x} + {y} = {result}") return result