added cargo files

This commit is contained in:
2026-03-03 10:57:43 -05:00
parent 478a90e01b
commit 169df46bc2
813 changed files with 227273 additions and 9 deletions

View File

@@ -0,0 +1,795 @@
# tasks.py - Define Celery tasks with Valkey as broker
from celery import Celery
import time
import os
import asyncio
import datetime
import requests
from threading import Thread
import json
import sys
import logging
from typing import Dict, Any, Optional, List
# Make sure pinepods is in the Python path
sys.path.append('/pinepods')
database_type = str(os.getenv('DB_TYPE', 'mariadb'))
class Web_Key:
def __init__(self):
self.web_key = None
def get_web_key(self, cnx):
# Import only when needed to avoid circular imports
from database_functions.functions import get_web_key as get_key
self.web_key = get_key(cnx, database_type)
return self.web_key
base_webkey = Web_Key()
# Set up logging
logger = logging.getLogger("celery_tasks")
# Import the WebSocket manager directly from clientapi
try:
from clients.clientapi import manager as websocket_manager
print("Successfully imported WebSocket manager from clientapi")
except ImportError as e:
logger.error(f"Failed to import WebSocket manager: {e}")
websocket_manager = None
# Create a dedicated event loop thread for async operations
_event_loop = None
_event_loop_thread = None
def start_background_loop():
global _event_loop, _event_loop_thread
# Only start if not already running
if _event_loop is None:
# Function to run event loop in background thread
def run_event_loop():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
global _event_loop
_event_loop = loop
loop.run_forever()
# Start the background thread
_event_loop_thread = Thread(target=run_event_loop, daemon=True)
_event_loop_thread.start()
# Wait a moment for the loop to start
time.sleep(0.1)
print("Started background event loop for WebSocket broadcasts")
# Start the event loop when this module is imported
start_background_loop()
# Use the existing Valkey connection for Celery
valkey_host = os.environ.get("VALKEY_HOST", "localhost")
valkey_port = os.environ.get("VALKEY_PORT", "6379")
broker_url = f"redis://{valkey_host}:{valkey_port}/0"
backend_url = f"redis://{valkey_host}:{valkey_port}/0"
# Initialize Celery with Valkey as broker and result backend
celery_app = Celery('pinepods',
broker=broker_url,
backend=backend_url)
# Configure Celery for best performance with downloads
celery_app.conf.update(
worker_concurrency=3, # Limit to 3 concurrent downloads per worker
task_acks_late=True, # Only acknowledge tasks after they're done
task_time_limit=1800, # 30 minutes time limit
task_soft_time_limit=1500, # 25 minutes soft time limit
worker_prefetch_multiplier=1, # Don't prefetch more tasks than workers
)
# Task status tracking in Valkey for all types of tasks
class TaskManager:
def __init__(self):
from database_functions.valkey_client import valkey_client
self.valkey_client = valkey_client
def register_task(self, task_id: str, task_type: str, user_id: int, item_id: Optional[int] = None,
details: Optional[Dict[str, Any]] = None):
"""Register any Celery task for tracking"""
task_data = {
"task_id": task_id,
"user_id": user_id,
"type": task_type,
"item_id": item_id,
"progress": 0.0,
"status": "PENDING",
"details": details or {},
"started_at": datetime.datetime.now().isoformat()
}
self.valkey_client.set(f"task:{task_id}", json.dumps(task_data))
# Set TTL for 24 hours
self.valkey_client.expire(f"task:{task_id}", 86400)
# Add to user's active tasks list
self._add_to_user_tasks(user_id, task_id)
# Try to broadcast the update if the WebSocket module is available
try:
self._broadcast_update(task_id)
except Exception as e:
logger.error(f"Error broadcasting task update: {e}")
def update_task(self, task_id: str, progress: float = None, status: str = None,
details: Dict[str, Any] = None):
"""Update any task's status and progress"""
task_json = self.valkey_client.get(f"task:{task_id}")
if task_json:
task = json.loads(task_json)
if progress is not None:
task["progress"] = progress
if status:
task["status"] = status
if details:
if "details" not in task:
task["details"] = {}
task["details"].update(details)
self.valkey_client.set(f"task:{task_id}", json.dumps(task))
# Try to broadcast the update
try:
self._broadcast_update(task_id)
except Exception as e:
logger.error(f"Error broadcasting task update: {e}")
def complete_task(self, task_id: str, success: bool = True, result: Any = None):
"""Mark any task as complete or failed"""
task_json = self.valkey_client.get(f"task:{task_id}")
if task_json:
task = json.loads(task_json)
task["progress"] = 100.0 if success else 0.0
task["status"] = "SUCCESS" if success else "FAILED"
task["completed_at"] = datetime.datetime.now().isoformat()
if result is not None:
task["result"] = result
self.valkey_client.set(f"task:{task_id}", json.dumps(task))
# Try to broadcast the final update
try:
self._broadcast_update(task_id)
except Exception as e:
logger.error(f"Error broadcasting task update: {e}")
# Keep completed tasks for 1 hour before expiring
self.valkey_client.expire(f"task:{task_id}", 3600)
# Remove from user's active tasks list after completion
if success:
self._remove_from_user_tasks(task.get("user_id"), task_id)
def get_task(self, task_id: str) -> Dict[str, Any]:
"""Get any task's details"""
task_json = self.valkey_client.get(f"task:{task_id}")
if task_json:
return json.loads(task_json)
return {}
def get_user_tasks(self, user_id: int) -> List[Dict[str, Any]]:
"""Get all active tasks for a user (all types)"""
tasks_list_json = self.valkey_client.get(f"user_tasks:{user_id}")
result = []
if tasks_list_json:
task_ids = json.loads(tasks_list_json)
for task_id in task_ids:
task_info = self.get_task(task_id)
if task_info:
result.append(task_info)
return result
def _add_to_user_tasks(self, user_id: int, task_id: str):
"""Add a task to the user's active tasks list"""
tasks_list_json = self.valkey_client.get(f"user_tasks:{user_id}")
if tasks_list_json:
tasks_list = json.loads(tasks_list_json)
if task_id not in tasks_list:
tasks_list.append(task_id)
else:
tasks_list = [task_id]
self.valkey_client.set(f"user_tasks:{user_id}", json.dumps(tasks_list))
# Set TTL for 7 days
self.valkey_client.expire(f"user_tasks:{user_id}", 604800)
def _remove_from_user_tasks(self, user_id: int, task_id: str):
"""Remove a task from the user's active tasks list"""
tasks_list_json = self.valkey_client.get(f"user_tasks:{user_id}")
if tasks_list_json:
tasks_list = json.loads(tasks_list_json)
if task_id in tasks_list:
tasks_list.remove(task_id)
self.valkey_client.set(f"user_tasks:{user_id}", json.dumps(tasks_list))
# Modified _broadcast_update method to avoid circular imports
def _broadcast_update(self, task_id: str):
"""Broadcast task update via HTTP endpoint"""
# Get task info
task_info = self.get_task(task_id)
if not task_info or "user_id" not in task_info:
return
user_id = task_info["user_id"]
cnx = None
try:
cnx = get_direct_db_connection()
# Import broadcaster - delay import to avoid circular dependency
sys.path.insert(0, '/pinepods/database_functions')
try:
from websocket_broadcaster import broadcaster
except ImportError:
try:
from database_functions.websocket_broadcaster import broadcaster
except ImportError as e:
print(f"Cannot import broadcaster from any location: {e}")
return
# Get web key
web_key = None
try:
# Get web key using class method to avoid direct import
if not base_webkey.web_key:
base_webkey.get_web_key(cnx)
web_key = base_webkey.web_key
except Exception as e:
print(f"Error getting web key: {str(e)}")
# Fallback to a direct approach if needed
try:
from database_functions.functions import get_web_key
web_key = get_web_key(cnx, database_type)
except Exception as e2:
print(f"Fallback web key retrieval failed: {str(e2)}")
return
# Progress and status details for better debugging
progress = task_info.get("progress", 0)
status = task_info.get("status", "unknown")
print(f"Broadcasting task update for user {user_id}, task {task_id}, progress: {progress}, status: {status}")
# Broadcast the update
result = broadcaster.broadcast_task_update(user_id, task_info, web_key)
if result:
print(f"Successfully broadcast task update for task {task_id}, progress: {progress}%")
else:
print(f"Failed to broadcast task update for task {task_id}, progress: {progress}%")
except Exception as e:
print(f"Error in task broadcast setup: {str(e)}")
finally:
if cnx:
# Close direct connection
close_direct_db_connection(cnx)
# Initialize a general task manager
task_manager = TaskManager()
# For backwards compatibility, keep the download_manager name too
download_manager = task_manager
# Function to get all active tasks including both downloads and other task types
def get_all_active_tasks(user_id: int) -> List[Dict[str, Any]]:
"""Get all active tasks for a user (all types)"""
return task_manager.get_user_tasks(user_id)
# ----------------------
# IMPROVED CONNECTION HANDLING
# ----------------------
def get_direct_db_connection():
"""
Create a direct database connection instead of using the pool
This is more reliable for Celery workers to avoid pool exhaustion
"""
db_host = os.environ.get("DB_HOST", "127.0.0.1")
db_port = os.environ.get("DB_PORT", "3306")
db_user = os.environ.get("DB_USER", "root")
db_password = os.environ.get("DB_PASSWORD", "password")
db_name = os.environ.get("DB_NAME", "pypods_database")
print(f"Creating direct database connection for task")
if database_type == "postgresql":
import psycopg
conninfo = f"host={db_host} port={db_port} user={db_user} password={db_password} dbname={db_name}"
return psycopg.connect(conninfo)
else: # Default to MariaDB/MySQL
try:
import mariadb as mysql_connector
except ImportError:
import mysql.connector
return mysql_connector.connect(
host=db_host,
port=db_port,
user=db_user,
password=db_password,
database=db_name
)
def close_direct_db_connection(cnx):
"""Close a direct database connection"""
if cnx:
try:
cnx.close()
print("Direct database connection closed")
except Exception as e:
print(f"Error closing direct connection: {str(e)}")
# Minimal changes to download_podcast_task that should work right away
@celery_app.task(bind=True, max_retries=3)
def download_podcast_task(self, episode_id: int, user_id: int, database_type: str):
"""
Celery task to download a podcast episode.
Uses retries with exponential backoff for handling transient failures.
"""
task_id = self.request.id
print(f"DOWNLOAD TASK STARTED: ID={task_id}, Episode={episode_id}, User={user_id}")
cnx = None
try:
# Get a direct connection to fetch the title first
cnx = get_direct_db_connection()
cursor = cnx.cursor()
# Get the episode title and podcast name
if database_type == "postgresql":
# First try to get both the episode title and podcast name
query = '''
SELECT e."episodetitle", p."podcastname"
FROM "Episodes" e
JOIN "Podcasts" p ON e."podcastid" = p."podcastid"
WHERE e."episodeid" = %s
'''
else:
query = '''
SELECT e.EpisodeTitle, p.PodcastName
FROM Episodes e
JOIN Podcasts p ON e.PodcastID = p.PodcastID
WHERE e.EpisodeID = %s
'''
cursor.execute(query, (episode_id,))
result = cursor.fetchone()
cursor.close()
# Extract episode title and podcast name
episode_title = None
podcast_name = None
if result:
if isinstance(result, dict):
# Dictionary result
if "episodetitle" in result: # PostgreSQL lowercase
episode_title = result["episodetitle"]
podcast_name = result.get("podcastname")
else: # MariaDB uppercase
episode_title = result["EpisodeTitle"]
podcast_name = result.get("PodcastName")
else:
# Tuple result
episode_title = result[0] if len(result) > 0 else None
podcast_name = result[1] if len(result) > 1 else None
# Format a nice display title
display_title = "Unknown Episode"
if episode_title and episode_title != "None" and episode_title.strip():
display_title = episode_title
elif podcast_name:
display_title = f"{podcast_name} - Episode"
else:
display_title = f"Episode #{episode_id}"
print(f"Using display title for episode {episode_id}: {display_title}")
# Register task with more details
task_manager.register_task(
task_id=task_id,
task_type="podcast_download",
user_id=user_id,
item_id=episode_id,
details={
"episode_id": episode_id,
"episode_title": display_title,
"status_text": f"Preparing to download {display_title}"
}
)
# Define a progress callback with the display title
def progress_callback(progress, status=None):
status_message = ""
if status == "DOWNLOADING":
status_message = f"Downloading {display_title}"
elif status == "PROCESSING":
status_message = f"Processing {display_title}"
elif status == "FINALIZING":
status_message = f"Finalizing {display_title}"
task_manager.update_task(task_id, progress, status, {
"episode_id": episode_id,
"episode_title": display_title,
"status_text": status_message
})
# Close the connection used for title lookup
close_direct_db_connection(cnx)
# Get a fresh connection for the download
cnx = get_direct_db_connection()
# Import the download function
from database_functions.functions import download_podcast
print(f"Starting download for episode: {episode_id} ({display_title}), user: {user_id}, task: {task_id}")
# Execute the download with progress reporting
success = download_podcast(
cnx,
database_type,
episode_id,
user_id,
task_id,
progress_callback=progress_callback
)
# Mark task as complete with a nice message
completion_message = f"{'Successfully downloaded' if success else 'Failed to download'} {display_title}"
task_manager.complete_task(
task_id,
success,
{
"episode_id": episode_id,
"episode_title": display_title,
"status_text": completion_message
}
)
return success
except Exception as exc:
print(f"Error downloading podcast {episode_id}: {str(exc)}")
# Mark task as failed
task_manager.complete_task(
task_id,
False,
{
"episode_id": episode_id,
"episode_title": f"Episode #{episode_id}",
"status_text": f"Download failed: {str(exc)}"
}
)
# Retry with exponential backoff (5s, 25s, 125s)
countdown = 5 * (2 ** self.request.retries)
self.retry(exc=exc, countdown=countdown)
finally:
# Always close the connection
if cnx:
close_direct_db_connection(cnx)
@celery_app.task(bind=True, max_retries=3)
def download_youtube_video_task(self, video_id: int, user_id: int, database_type: str):
"""
Celery task to download a YouTube video.
Uses retries with exponential backoff for handling transient failures.
"""
task_id = self.request.id
print(f"YOUTUBE DOWNLOAD TASK STARTED: ID={task_id}, Video={video_id}, User={user_id}")
cnx = None
try:
# Get a direct connection to fetch the title first
cnx = get_direct_db_connection()
cursor = cnx.cursor()
# Get the video title and channel name
if database_type == "postgresql":
# First try to get both the video title and channel name
query = '''
SELECT v."videotitle", p."podcastname"
FROM "YouTubeVideos" v
JOIN "Podcasts" p ON v."podcastid" = p."podcastid"
WHERE v."videoid" = %s
'''
else:
query = '''
SELECT v.VideoTitle, p.PodcastName
FROM YouTubeVideos v
JOIN Podcasts p ON v.PodcastID = p.PodcastID
WHERE v.VideoID = %s
'''
cursor.execute(query, (video_id,))
result = cursor.fetchone()
cursor.close()
# Extract video title and channel name
video_title = None
channel_name = None
if result:
if isinstance(result, dict):
# Dictionary result
if "videotitle" in result: # PostgreSQL lowercase
video_title = result["videotitle"]
channel_name = result.get("podcastname")
else: # MariaDB uppercase
video_title = result["VideoTitle"]
channel_name = result.get("PodcastName")
else:
# Tuple result
video_title = result[0] if len(result) > 0 else None
channel_name = result[1] if len(result) > 1 else None
# Format a nice display title
display_title = "Unknown Video"
if video_title and video_title != "None" and video_title.strip():
display_title = video_title
elif channel_name:
display_title = f"{channel_name} - Video"
else:
display_title = f"YouTube Video #{video_id}"
print(f"Using display title for video {video_id}: {display_title}")
# Register task with more details
task_manager.register_task(
task_id=task_id,
task_type="youtube_download",
user_id=user_id,
item_id=video_id,
details={
"item_id": video_id,
"item_title": display_title,
"status_text": f"Preparing to download {display_title}"
}
)
# Close the connection used for title lookup
close_direct_db_connection(cnx)
# Get a fresh connection for the download
cnx = get_direct_db_connection()
# Import the download function
from database_functions.functions import download_youtube_video
print(f"Starting download for YouTube video: {video_id} ({display_title}), user: {user_id}, task: {task_id}")
# Define a progress callback with the display title
def progress_callback(progress, status=None):
status_message = ""
if status == "DOWNLOADING":
status_message = f"Downloading {display_title}"
elif status == "PROCESSING":
status_message = f"Processing {display_title}"
elif status == "FINALIZING":
status_message = f"Finalizing {display_title}"
task_manager.update_task(task_id, progress, status, {
"item_id": video_id,
"item_title": display_title,
"status_text": status_message
})
# Check if the download_youtube_video function accepts progress_callback parameter
import inspect
try:
signature = inspect.signature(download_youtube_video)
has_progress_callback = 'progress_callback' in signature.parameters
except (TypeError, ValueError):
has_progress_callback = False
# Execute the download with progress callback if supported, otherwise without it
if has_progress_callback:
success = download_youtube_video(
cnx,
database_type,
video_id,
user_id,
task_id,
progress_callback=progress_callback
)
else:
# Call without the progress_callback parameter
success = download_youtube_video(
cnx,
database_type,
video_id,
user_id,
task_id
)
# Since we can't use progress callbacks directly, manually update progress after completion
task_manager.update_task(task_id, 100 if success else 0,
"SUCCESS" if success else "FAILED",
{
"item_id": video_id,
"item_title": display_title,
"status_text": f"{'Download complete' if success else 'Download failed'}"
})
# Mark task as complete with a nice message
completion_message = f"{'Successfully downloaded' if success else 'Failed to download'} {display_title}"
task_manager.complete_task(
task_id,
success,
{
"item_id": video_id,
"item_title": display_title,
"status_text": completion_message
}
)
return success
except Exception as exc:
print(f"Error downloading YouTube video {video_id}: {str(exc)}")
# Mark task as failed but include video title in the details
task_manager.complete_task(
task_id,
False,
{
"item_id": video_id,
"item_title": f"YouTube Video #{video_id}",
"status_text": f"Download failed: {str(exc)}"
}
)
# Retry with exponential backoff (5s, 25s, 125s)
countdown = 5 * (2 ** self.request.retries)
self.retry(exc=exc, countdown=countdown)
finally:
# Always close the connection
if cnx:
close_direct_db_connection(cnx)
@celery_app.task
def queue_podcast_downloads(podcast_id: int, user_id: int, database_type: str, is_youtube: bool = False):
"""
Task to queue individual download tasks for all episodes/videos in a podcast.
This adds downloads to the queue in small batches to prevent overwhelming the system.
"""
cnx = None
try:
# Get a direct connection
cnx = get_direct_db_connection()
from database_functions.functions import (
get_episode_ids_for_podcast,
get_video_ids_for_podcast,
check_downloaded
)
if is_youtube:
item_ids = get_video_ids_for_podcast(cnx, database_type, podcast_id)
print(f"Queueing {len(item_ids)} YouTube videos for download")
# Process YouTube items in batches
batch_size = 5
for i in range(0, len(item_ids), batch_size):
batch = item_ids[i:i+batch_size]
for item_id in batch:
if not check_downloaded(cnx, database_type, user_id, item_id, is_youtube):
download_youtube_video_task.delay(item_id, user_id, database_type)
# Add a small delay between batches
if i + batch_size < len(item_ids):
time.sleep(2)
else:
# Get episode IDs (should return dicts with id and title)
episodes = get_episode_ids_for_podcast(cnx, database_type, podcast_id)
print(f"Queueing {len(episodes)} podcast episodes for download")
# Process episodes in batches
batch_size = 5
for i in range(0, len(episodes), batch_size):
batch = episodes[i:i+batch_size]
for episode in batch:
# Handle both possible formats (dict or simple ID)
if isinstance(episode, dict) and "id" in episode:
episode_id = episode["id"]
else:
# Fall back to treating it as just an ID
episode_id = episode
if not check_downloaded(cnx, database_type, user_id, episode_id, is_youtube):
# Pass just the ID, the task will look up the title
download_podcast_task.delay(episode_id, user_id, database_type)
# Add a small delay between batches
if i + batch_size < len(episodes):
time.sleep(2)
return f"Queued {len(episodes if not is_youtube else item_ids)} items for download"
finally:
if cnx:
close_direct_db_connection(cnx)
# Helper task to clean up old download records
@celery_app.task
def cleanup_old_downloads():
"""
Periodic task to clean up old download records from Valkey
"""
from database_functions.valkey_client import valkey_client
# This would need to be implemented with a scan operation
# For simplicity, we rely on Redis/Valkey TTL mechanisms
print("Running download cleanup task")
# Example task for refreshing podcast feeds
@celery_app.task(bind=True, max_retries=2)
def refresh_feed_task(self, user_id: int, database_type: str):
"""
Celery task to refresh podcast feeds for a user.
"""
task_id = self.request.id
cnx = None
try:
# Register task
task_manager.register_task(
task_id=task_id,
task_type="feed_refresh",
user_id=user_id,
details={"description": "Refreshing podcast feeds"}
)
# Get a direct database connection
cnx = get_direct_db_connection()
# Get list of podcasts to refresh
# Then update progress as each one completes
try:
# Here you would have your actual feed refresh implementation
# with periodic progress updates
task_manager.update_task(task_id, 10, "PROGRESS", {"status_text": "Fetching podcast list"})
# Simulate feed refresh process with progress updates
# Replace with your actual implementation
total_podcasts = 10 # Example count
for i in range(total_podcasts):
# Update progress for each podcast
progress = (i + 1) / total_podcasts * 100
task_manager.update_task(
task_id,
progress,
"PROGRESS",
{"status_text": f"Refreshing podcast {i+1}/{total_podcasts}"}
)
# Simulated work - replace with actual refresh logic
time.sleep(0.5)
# Complete the task
task_manager.complete_task(task_id, True, {"refreshed_count": total_podcasts})
return True
except Exception as e:
raise e
except Exception as exc:
print(f"Error refreshing feeds for user {user_id}: {str(exc)}")
task_manager.complete_task(task_id, False, {"error": str(exc)})
self.retry(exc=exc, countdown=30)
finally:
# Always close the connection
if cnx:
close_direct_db_connection(cnx)
# Simple debug task
@celery_app.task
def debug_task(x, y):
"""Simple debug task that prints output"""
result = x + y
print(f"CELERY DEBUG TASK EXECUTED: {x} + {y} = {result}")
return result