added cargo files

This commit is contained in:
2026-03-03 10:57:43 -05:00
parent 478a90e01b
commit 169df46bc2
813 changed files with 227273 additions and 9 deletions

5098
PinePods-0.8.2/rust-api/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,88 @@
[package]
name = "pinepods-api"
version = "0.1.0"
edition = "2021"
rust-version = "1.89"
[dependencies]
# Web Framework
axum = { version = "0.8.6", features = ["macros", "multipart", "ws"] }
tokio = { version = "1.48.0", features = ["full"] }
tower = { version = "0.5.2", features = ["util", "timeout", "load-shed", "limit"] }
tower-http = { version = "0.6.6", features = ["fs", "trace", "cors", "compression-gzip"] }
# Serialization
serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.145"
# Database
sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "postgres", "mysql", "uuid", "chrono", "json", "bigdecimal"] }
bigdecimal = "0.4.9"
# Redis/Valkey
redis = { version = "0.32.7", features = ["aio", "tokio-comp"] }
# HTTP Client
reqwest = { version = "0.12.24", features = ["json", "rustls-tls", "stream", "cookies"] }
# Configuration and Environment
config = "0.15.18"
dotenvy = "0.15.7"
# Logging
tracing = "0.1.41"
tracing-subscriber = { version = "0.3.20", features = ["env-filter"] }
# Utilities
uuid = { version = "1.18.1", features = ["v4", "serde"] }
chrono = { version = "0.4.42", features = ["serde"] }
chrono-tz = "0.10.0"
anyhow = "1.0.100"
thiserror = "2.0.17"
async-trait = "0.1.89"
base64 = "0.22.1"
lazy_static = "1.5.0"
urlencoding = "2.1.3"
# Authentication and Crypto
argon2 = "0.6.0-rc.1"
jsonwebtoken = { version = "10.1.0", features = ["aws_lc_rs"] }
rand = "0.9.2"
# MFA/TOTP Support
totp-rs = { version = "5.7.0", features = ["otpauth"] }
qrcode = "0.14.1"
image = "0.25.8"
# Encryption for sync credentials
fernet = "0.2.2"
# RSS/Feed Processing
feed-rs = "2.3.1"
url = "2.5.7"
regex = "1.12.2"
# Audio metadata tagging
id3 = "1.16.3"
mp3-metadata = "0.4.0"
quick-xml = "0.38.3"
# Email
lettre = { version = "0.11.18", default-features = false, features = ["tokio1-rustls-tls", "smtp-transport", "builder"] }
# CORS and Security
hyper = "1.7.0"
# Background Tasks and Task Management
tokio-cron-scheduler = "0.15.1"
tokio-stream = "0.1.17"
futures = "0.3.31"
# WebSocket Support (already in axum features)
# File handling
tokio-util = { version = "0.7.16", features = ["io"] }
mime_guess = "2.0.5"
[dev-dependencies]
tower-test = "0.4"

View File

@@ -0,0 +1,405 @@
use serde::{Deserialize, Serialize};
use std::env;
use crate::error::{AppError, AppResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub database: DatabaseConfig,
pub redis: RedisConfig,
pub server: ServerConfig,
pub security: SecurityConfig,
pub email: EmailConfig,
pub oidc: OIDCConfig,
pub api: ApiConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiConfig {
pub search_api_url: String,
pub people_api_url: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
pub db_type: String,
pub host: String,
pub port: u16,
pub username: String,
pub password: String,
pub name: String,
pub max_connections: u32,
pub min_connections: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RedisConfig {
pub host: String,
pub port: u16,
pub max_connections: u32,
pub password: Option<String>,
pub username: Option<String>,
pub database: Option<u8>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
pub port: u16,
pub host: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SecurityConfig {
pub api_key_header: String,
pub jwt_secret: String,
pub password_salt_rounds: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmailConfig {
pub smtp_server: Option<String>,
pub smtp_port: Option<u16>,
pub smtp_username: Option<String>,
pub smtp_password: Option<String>,
pub from_email: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OIDCConfig {
pub disable_standard_login: bool,
pub provider_name: Option<String>,
pub client_id: Option<String>,
pub client_secret: Option<String>,
pub authorization_url: Option<String>,
pub token_url: Option<String>,
pub user_info_url: Option<String>,
pub button_text: Option<String>,
pub scope: Option<String>,
pub button_color: Option<String>,
pub button_text_color: Option<String>,
pub icon_svg: Option<String>,
pub name_claim: Option<String>,
pub email_claim: Option<String>,
pub username_claim: Option<String>,
pub roles_claim: Option<String>,
pub user_role: Option<String>,
pub admin_role: Option<String>,
}
impl OIDCConfig {
pub fn is_configured(&self) -> bool {
self.provider_name.as_ref().map_or(false, |s| !s.trim().is_empty()) &&
self.client_id.as_ref().map_or(false, |s| !s.trim().is_empty()) &&
self.client_secret.as_ref().map_or(false, |s| !s.trim().is_empty()) &&
self.authorization_url.as_ref().map_or(false, |s| !s.trim().is_empty()) &&
self.token_url.as_ref().map_or(false, |s| !s.trim().is_empty()) &&
self.user_info_url.as_ref().map_or(false, |s| !s.trim().is_empty()) &&
self.button_text.as_ref().map_or(false, |s| !s.trim().is_empty()) &&
self.scope.as_ref().map_or(false, |s| !s.trim().is_empty()) &&
self.button_color.as_ref().map_or(false, |s| !s.trim().is_empty()) &&
self.button_text_color.as_ref().map_or(false, |s| !s.trim().is_empty())
}
pub fn validate(&self) -> Result<(), String> {
let required_fields = [
(&self.provider_name, "OIDC_PROVIDER_NAME"),
(&self.client_id, "OIDC_CLIENT_ID"),
(&self.client_secret, "OIDC_CLIENT_SECRET"),
(&self.authorization_url, "OIDC_AUTHORIZATION_URL"),
(&self.token_url, "OIDC_TOKEN_URL"),
(&self.user_info_url, "OIDC_USER_INFO_URL"),
(&self.button_text, "OIDC_BUTTON_TEXT"),
(&self.scope, "OIDC_SCOPE"),
(&self.button_color, "OIDC_BUTTON_COLOR"),
(&self.button_text_color, "OIDC_BUTTON_TEXT_COLOR"),
];
let missing_fields: Vec<&str> = required_fields
.iter()
.filter_map(|(field, name)| if field.is_none() { Some(*name) } else { None })
.collect();
// Check if any OIDC fields are set
let any_oidc_set = required_fields.iter().any(|(field, _)| field.is_some());
if any_oidc_set && !missing_fields.is_empty() {
return Err(format!(
"Incomplete OIDC configuration. When setting up OIDC, all required environment variables must be provided. Missing: {}",
missing_fields.join(", ")
));
}
if self.disable_standard_login && !self.is_configured() {
return Err("OIDC_DISABLE_STANDARD_LOGIN is set to true, but OIDC is not properly configured. All OIDC environment variables must be set when disabling standard login.".to_string());
}
Ok(())
}
}
impl Config {
pub fn new() -> AppResult<Self> {
// Load environment variables
dotenvy::dotenv().ok();
// Validate required database environment variables
let db_required_vars = [
("DB_TYPE", "Database type (e.g., postgresql, mariadb)"),
("DB_HOST", "Database host (e.g., localhost, db)"),
("DB_PORT", "Database port (e.g., 5432 for PostgreSQL, 3306 for MariaDB)"),
("DB_USER", "Database username"),
("DB_PASSWORD", "Database password"),
("DB_NAME", "Database name"),
];
let mut missing_db_vars = Vec::new();
for (var_name, description) in &db_required_vars {
if env::var(var_name).is_err() {
missing_db_vars.push(format!(" {} - {}", var_name, description));
}
}
if !missing_db_vars.is_empty() {
return Err(AppError::Config(format!(
"Missing required database environment variables:\n{}\n\nPlease set these variables in your docker-compose.yml or environment.",
missing_db_vars.join("\n")
)));
}
// Validate required API URLs
let api_required_vars = [
("SEARCH_API_URL", "Search API URL (e.g., https://search.pinepods.online/api/search)"),
("PEOPLE_API_URL", "People API URL (e.g., https://people.pinepods.online)"),
];
let mut missing_api_vars = Vec::new();
for (var_name, description) in &api_required_vars {
if env::var(var_name).is_err() {
missing_api_vars.push(format!(" {} - {}", var_name, description));
}
}
if !missing_api_vars.is_empty() {
return Err(AppError::Config(format!(
"Missing required API environment variables:\n{}\n\nPlease set these variables in your docker-compose.yml or environment.",
missing_api_vars.join("\n")
)));
}
// Validate Valkey/Redis configuration - either URL or individual variables (support both VALKEY_* and REDIS_* naming)
let has_valkey_url = env::var("VALKEY_URL").is_ok();
let has_redis_url = env::var("REDIS_URL").is_ok();
let has_valkey_vars = env::var("VALKEY_HOST").is_ok() && env::var("VALKEY_PORT").is_ok();
let has_redis_vars = env::var("REDIS_HOST").is_ok() && env::var("REDIS_PORT").is_ok();
if !has_valkey_url && !has_redis_url && !has_valkey_vars && !has_redis_vars {
return Err(AppError::Config(format!(
"Missing required Valkey/Redis configuration. Please provide either:\n Option 1: VALKEY_URL or REDIS_URL - Complete connection URL\n Option 2: VALKEY_HOST/VALKEY_PORT or REDIS_HOST/REDIS_PORT - Individual connection parameters\n\nExample URL: VALKEY_URL=redis://localhost:6379\nExample individual: VALKEY_HOST=localhost, VALKEY_PORT=6379"
)));
}
let database = DatabaseConfig {
db_type: env::var("DB_TYPE").unwrap(),
host: env::var("DB_HOST").unwrap(),
port: {
let port_str = env::var("DB_PORT").unwrap();
port_str.trim().parse()
.map_err(|e| AppError::Config(format!("Invalid DB_PORT '{}': Must be a valid port number (e.g., 5432 for PostgreSQL, 3306 for MariaDB)", port_str)))?
},
username: env::var("DB_USER").unwrap(),
password: env::var("DB_PASSWORD").unwrap(),
name: env::var("DB_NAME").unwrap(),
max_connections: 32,
min_connections: 1,
};
let redis = if let Some(url) = env::var("VALKEY_URL").ok().or_else(|| env::var("REDIS_URL").ok()) {
// Parse VALKEY_URL or REDIS_URL
match url::Url::parse(&url) {
Ok(parsed_url) => {
let host = parsed_url.host_str().unwrap_or("localhost").to_string();
let port = parsed_url.port().unwrap_or(6379);
let username = if parsed_url.username().is_empty() {
None
} else {
Some(parsed_url.username().to_string())
};
let password = parsed_url.password().map(|p| p.to_string());
let database = if parsed_url.path().len() > 1 {
parsed_url.path().trim_start_matches('/').parse().ok()
} else {
None
};
RedisConfig {
host,
port,
max_connections: 32,
password,
username,
database,
}
}
Err(e) => {
return Err(AppError::Config(format!("Invalid URL format: {}", e)));
}
}
} else {
// Use individual variables - support both VALKEY_* and REDIS_* (VALKEY_* takes precedence)
let host = env::var("VALKEY_HOST").or_else(|_| env::var("REDIS_HOST")).unwrap();
let port_str = env::var("VALKEY_PORT").or_else(|_| env::var("REDIS_PORT")).unwrap();
let port = port_str.trim().parse()
.map_err(|e| AppError::Config(format!("Invalid port '{}': Must be a valid port number (e.g., 6379)", port_str)))?;
let password = env::var("VALKEY_PASSWORD").ok().or_else(|| env::var("REDIS_PASSWORD").ok());
let username = env::var("VALKEY_USERNAME").ok().or_else(|| env::var("REDIS_USERNAME").ok());
let database = env::var("VALKEY_DATABASE").ok()
.or_else(|| env::var("REDIS_DATABASE").ok())
.and_then(|d| d.parse().ok());
RedisConfig {
host,
port,
max_connections: 32,
password,
username,
database,
}
};
let server = ServerConfig {
port: 8032, // Fixed port for internal API
host: "0.0.0.0".to_string(),
};
let security = SecurityConfig {
api_key_header: "pinepods_api".to_string(),
jwt_secret: "pinepods-default-secret".to_string(),
password_salt_rounds: 12,
};
let email = EmailConfig {
smtp_server: env::var("SMTP_SERVER").ok(),
smtp_port: env::var("SMTP_PORT").ok().and_then(|p| p.parse().ok()),
smtp_username: env::var("SMTP_USERNAME").ok(),
smtp_password: env::var("SMTP_PASSWORD").ok(),
from_email: env::var("FROM_EMAIL").ok(),
};
// Check if essential OIDC fields are present and non-empty before setting any defaults
let oidc_essentials_present = env::var("OIDC_PROVIDER_NAME").map_or(false, |s| !s.trim().is_empty()) &&
env::var("OIDC_CLIENT_ID").map_or(false, |s| !s.trim().is_empty()) &&
env::var("OIDC_CLIENT_SECRET").map_or(false, |s| !s.trim().is_empty()) &&
env::var("OIDC_AUTHORIZATION_URL").map_or(false, |s| !s.trim().is_empty()) &&
env::var("OIDC_TOKEN_URL").map_or(false, |s| !s.trim().is_empty()) &&
env::var("OIDC_USER_INFO_URL").map_or(false, |s| !s.trim().is_empty());
let oidc = OIDCConfig {
disable_standard_login: env::var("OIDC_DISABLE_STANDARD_LOGIN")
.unwrap_or_else(|_| "false".to_string())
.parse()
.unwrap_or(false),
provider_name: env::var("OIDC_PROVIDER_NAME").ok(),
client_id: env::var("OIDC_CLIENT_ID").ok(),
client_secret: env::var("OIDC_CLIENT_SECRET").ok(),
authorization_url: env::var("OIDC_AUTHORIZATION_URL").ok(),
token_url: env::var("OIDC_TOKEN_URL").ok(),
user_info_url: env::var("OIDC_USER_INFO_URL").ok(),
button_text: if oidc_essentials_present {
env::var("OIDC_BUTTON_TEXT").ok().or_else(|| Some("Login with OIDC".to_string()))
} else {
env::var("OIDC_BUTTON_TEXT").ok()
},
scope: if oidc_essentials_present {
env::var("OIDC_SCOPE").ok().or_else(|| Some("openid email profile".to_string()))
} else {
env::var("OIDC_SCOPE").ok()
},
button_color: if oidc_essentials_present {
env::var("OIDC_BUTTON_COLOR").ok().or_else(|| Some("#000000".to_string()))
} else {
env::var("OIDC_BUTTON_COLOR").ok()
},
button_text_color: if oidc_essentials_present {
env::var("OIDC_BUTTON_TEXT_COLOR").ok().or_else(|| Some("#FFFFFF".to_string()))
} else {
env::var("OIDC_BUTTON_TEXT_COLOR").ok()
},
icon_svg: env::var("OIDC_ICON_SVG").ok(),
name_claim: env::var("OIDC_NAME_CLAIM").ok(),
email_claim: env::var("OIDC_EMAIL_CLAIM").ok(),
username_claim: env::var("OIDC_USERNAME_CLAIM").ok(),
roles_claim: env::var("OIDC_ROLES_CLAIM").ok(),
user_role: env::var("OIDC_USER_ROLE").ok(),
admin_role: env::var("OIDC_ADMIN_ROLE").ok(),
};
let api = ApiConfig {
search_api_url: env::var("SEARCH_API_URL").unwrap(),
people_api_url: env::var("PEOPLE_API_URL").unwrap(),
};
// Validate OIDC configuration
if let Err(validation_error) = oidc.validate() {
return Err(AppError::Config(validation_error));
}
Ok(Config {
database,
redis,
server,
security,
email,
oidc,
api,
})
}
pub fn database_url(&self) -> String {
// URL encode username and password to handle special characters
let encoded_username = urlencoding::encode(&self.database.username);
let encoded_password = urlencoding::encode(&self.database.password);
let url = match self.database.db_type.as_str() {
"postgresql" => format!(
"postgresql://{}:{}@{}:{}/{}",
encoded_username,
encoded_password,
self.database.host,
self.database.port,
self.database.name
),
_ => format!(
"mysql://{}:{}@{}:{}/{}",
encoded_username,
encoded_password,
self.database.host,
self.database.port,
self.database.name
),
};
url
}
pub fn redis_url(&self) -> String {
let mut url = String::from("redis://");
// Add authentication if provided
if let (Some(username), Some(password)) = (&self.redis.username, &self.redis.password) {
url.push_str(&format!("{}:{}@",
urlencoding::encode(username),
urlencoding::encode(password)
));
} else if let Some(password) = &self.redis.password {
url.push_str(&format!(":{}@", urlencoding::encode(password)));
}
// Add host and port
url.push_str(&format!("{}:{}", self.redis.host, self.redis.port));
// Add database if specified
if let Some(database) = self.redis.database {
url.push_str(&format!("/{}", database));
}
url
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,140 @@
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde_json::json;
use thiserror::Error;
pub type AppResult<T> = Result<T, AppError>;
#[derive(Error, Debug)]
pub enum AppError {
#[error("Database error: {0}")]
Database(#[from] sqlx::Error),
#[error("Redis error: {0}")]
Redis(#[from] redis::RedisError),
#[error("HTTP client error: {0}")]
Http(#[from] reqwest::Error),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("Configuration error: {0}")]
Config(String),
#[error("Authentication error: {0}")]
Auth(String),
#[error("Authorization error: {0}")]
Authorization(String),
#[error("Validation error: {0}")]
Validation(String),
#[error("Not found: {0}")]
NotFound(String),
#[error("Conflict: {0}")]
Conflict(String),
#[error("Bad request: {0}")]
BadRequest(String),
#[error("Internal server error: {0}")]
Internal(String),
#[error("Scheduler error: {0}")]
Scheduler(#[from] tokio_cron_scheduler::JobSchedulerError),
#[error("Service unavailable: {0}")]
ServiceUnavailable(String),
#[error("Feed parsing error: {0}")]
FeedParsing(String),
#[error("Email sending error: {0}")]
Email(String),
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let (status, error_message) = match &self {
AppError::Database(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Database error"),
AppError::Redis(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Cache error"),
AppError::Http(_) => (StatusCode::BAD_GATEWAY, "External service error"),
AppError::Io(_) => (StatusCode::INTERNAL_SERVER_ERROR, "IO error"),
AppError::Serialization(_) => (StatusCode::BAD_REQUEST, "Serialization error"),
AppError::Config(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Configuration error"),
AppError::Auth(_) => (StatusCode::UNAUTHORIZED, "Authentication failed"),
AppError::Authorization(_) => (StatusCode::FORBIDDEN, "Authorization failed"),
AppError::Validation(_) => (StatusCode::BAD_REQUEST, "Validation error"),
AppError::NotFound(_) => (StatusCode::NOT_FOUND, "Resource not found"),
AppError::Conflict(_) => (StatusCode::CONFLICT, "Resource conflict"),
AppError::BadRequest(_) => (StatusCode::BAD_REQUEST, "Bad request"),
AppError::Internal(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"),
AppError::ServiceUnavailable(_) => (StatusCode::SERVICE_UNAVAILABLE, "Service unavailable"),
AppError::FeedParsing(_) => (StatusCode::BAD_REQUEST, "Feed parsing error"),
AppError::Email(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Email error"),
AppError::Scheduler(_) => (StatusCode::INTERNAL_SERVER_ERROR, "Scheduler error"),
};
let body = Json(json!({
"error": error_message,
"message": self.to_string(),
"status_code": status.as_u16(),
}));
// Log the error for debugging (in production, you might want to use structured logging)
tracing::error!("API Error: {} - {}", status.as_u16(), self);
(status, body).into_response()
}
}
// Helper function to create internal server errors
impl From<Box<dyn std::error::Error + Send + Sync>> for AppError {
fn from(err: Box<dyn std::error::Error + Send + Sync>) -> Self {
AppError::Internal(err.to_string())
}
}
// Helper function for creating auth errors
impl AppError {
pub fn unauthorized(msg: impl Into<String>) -> Self {
AppError::Auth(msg.into())
}
pub fn forbidden(msg: impl Into<String>) -> Self {
AppError::Authorization(msg.into())
}
pub fn not_found(msg: impl Into<String>) -> Self {
AppError::NotFound(msg.into())
}
pub fn bad_request(msg: impl Into<String>) -> Self {
AppError::BadRequest(msg.into())
}
pub fn internal(msg: impl Into<String>) -> Self {
AppError::Internal(msg.into())
}
pub fn validation(msg: impl Into<String>) -> Self {
AppError::Validation(msg.into())
}
pub fn external_error(msg: impl Into<String>) -> Self {
AppError::Internal(msg.into())
}
pub fn database_error(msg: impl Into<String>) -> Self {
AppError::Internal(msg.into())
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,470 @@
use axum::{extract::State, http::HeaderMap, response::Json};
use axum::response::{Response, IntoResponse};
use axum::http::{StatusCode, header};
use sqlx::Row;
use crate::{
error::{AppError, AppResult},
handlers::{extract_api_key, validate_api_key},
models::{BulkEpisodeActionRequest, BulkEpisodeActionResponse},
AppState,
};
// Bulk episode action handlers for efficient mass operations
// Bulk mark episodes as completed
pub async fn bulk_mark_episodes_completed(
State(state): State<AppState>,
headers: HeaderMap,
Json(request): Json<BulkEpisodeActionRequest>,
) -> AppResult<Json<BulkEpisodeActionResponse>> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
let calling_user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
if calling_user_id != request.user_id {
return Err(AppError::forbidden("You can only mark episodes as completed for yourself!"));
}
let is_youtube = request.is_youtube.unwrap_or(false);
let (processed_count, failed_count) = state.db_pool
.bulk_mark_episodes_completed(request.episode_ids, request.user_id, is_youtube)
.await?;
let message = if failed_count > 0 {
format!("Marked {} episodes as completed, {} failed", processed_count, failed_count)
} else {
format!("Successfully marked {} episodes as completed", processed_count)
};
Ok(Json(BulkEpisodeActionResponse {
message,
processed_count,
failed_count: if failed_count > 0 { Some(failed_count) } else { None },
}))
}
// Bulk save episodes
pub async fn bulk_save_episodes(
State(state): State<AppState>,
headers: HeaderMap,
Json(request): Json<BulkEpisodeActionRequest>,
) -> AppResult<Json<BulkEpisodeActionResponse>> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
let calling_user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
if calling_user_id != request.user_id {
return Err(AppError::forbidden("You can only save episodes for yourself!"));
}
let is_youtube = request.is_youtube.unwrap_or(false);
let (processed_count, failed_count) = state.db_pool
.bulk_save_episodes(request.episode_ids, request.user_id, is_youtube)
.await?;
let message = if failed_count > 0 {
format!("Saved {} episodes, {} failed or already saved", processed_count, failed_count)
} else {
format!("Successfully saved {} episodes", processed_count)
};
Ok(Json(BulkEpisodeActionResponse {
message,
processed_count,
failed_count: if failed_count > 0 { Some(failed_count) } else { None },
}))
}
// Bulk queue episodes
pub async fn bulk_queue_episodes(
State(state): State<AppState>,
headers: HeaderMap,
Json(request): Json<BulkEpisodeActionRequest>,
) -> AppResult<Json<BulkEpisodeActionResponse>> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
let calling_user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
if calling_user_id != request.user_id {
return Err(AppError::forbidden("You can only queue episodes for yourself!"));
}
let is_youtube = request.is_youtube.unwrap_or(false);
let (processed_count, failed_count) = state.db_pool
.bulk_queue_episodes(request.episode_ids, request.user_id, is_youtube)
.await?;
let message = if failed_count > 0 {
format!("Queued {} episodes, {} failed or already queued", processed_count, failed_count)
} else {
format!("Successfully queued {} episodes", processed_count)
};
Ok(Json(BulkEpisodeActionResponse {
message,
processed_count,
failed_count: if failed_count > 0 { Some(failed_count) } else { None },
}))
}
// Bulk download episodes - triggers download tasks
pub async fn bulk_download_episodes(
State(state): State<AppState>,
headers: HeaderMap,
Json(request): Json<BulkEpisodeActionRequest>,
) -> AppResult<Json<BulkEpisodeActionResponse>> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
let calling_user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
if calling_user_id != request.user_id {
return Err(AppError::forbidden("You can only download episodes for yourself!"));
}
let is_youtube = request.is_youtube.unwrap_or(false);
let mut processed_count = 0;
let mut failed_count = 0;
// Check if episodes are already downloaded and queue download tasks
for episode_id in request.episode_ids {
let is_downloaded = state.db_pool
.check_downloaded(request.user_id, episode_id, is_youtube)
.await?;
if !is_downloaded {
let result = if is_youtube {
state.task_spawner.spawn_download_youtube_video(episode_id, request.user_id).await
} else {
state.task_spawner.spawn_download_podcast_episode(episode_id, request.user_id).await
};
match result {
Ok(_) => processed_count += 1,
Err(_) => failed_count += 1,
}
}
}
let message = if failed_count > 0 {
format!("Queued {} episodes for download, {} failed or already downloaded", processed_count, failed_count)
} else {
format!("Successfully queued {} episodes for download", processed_count)
};
Ok(Json(BulkEpisodeActionResponse {
message,
processed_count,
failed_count: if failed_count > 0 { Some(failed_count) } else { None },
}))
}
// Bulk delete downloaded episodes - removes multiple downloaded episodes at once
pub async fn bulk_delete_downloaded_episodes(
State(state): State<AppState>,
headers: HeaderMap,
Json(request): Json<BulkEpisodeActionRequest>,
) -> AppResult<Json<BulkEpisodeActionResponse>> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
let calling_user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
if calling_user_id != request.user_id {
return Err(AppError::forbidden("You can only delete your own downloaded episodes!"));
}
let is_youtube = request.is_youtube.unwrap_or(false);
let (processed_count, failed_count) = state.db_pool
.bulk_delete_downloaded_episodes(request.episode_ids, request.user_id, is_youtube)
.await?;
let message = if failed_count > 0 {
format!("Deleted {} downloaded episodes, {} failed or were not found", processed_count, failed_count)
} else {
format!("Successfully deleted {} downloaded episodes", processed_count)
};
Ok(Json(BulkEpisodeActionResponse {
message,
processed_count,
failed_count: if failed_count > 0 { Some(failed_count) } else { None },
}))
}
// Share episode - creates a shareable URL that expires in 60 days
pub async fn share_episode(
State(state): State<AppState>,
axum::extract::Path(episode_id): axum::extract::Path<i32>,
headers: HeaderMap,
) -> AppResult<Json<serde_json::Value>> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
// Get the user ID from the API key
let user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
// Generate unique share code and expiration date
let share_code = uuid::Uuid::new_v4().to_string();
let expiration_date = chrono::Utc::now() + chrono::Duration::days(60);
// Insert the shared episode entry
let result = state.db_pool
.add_shared_episode(episode_id, user_id, &share_code, expiration_date)
.await?;
if result {
Ok(Json(serde_json::json!({ "url_key": share_code })))
} else {
Err(AppError::internal("Failed to share episode"))
}
}
// Get episode by URL key - for accessing shared episodes
pub async fn get_episode_by_url_key(
State(state): State<AppState>,
axum::extract::Path(url_key): axum::extract::Path<String>,
) -> AppResult<Json<serde_json::Value>> {
// Find the episode ID associated with the URL key
let episode_id = match state.db_pool.get_episode_id_by_share_code(&url_key).await? {
Some(id) => id,
None => return Err(AppError::not_found("Invalid or expired URL key")),
};
// Now retrieve the episode metadata using the special shared episode method
// This bypasses user restrictions for public shared access
let episode_data = state.db_pool
.get_shared_episode_metadata(episode_id)
.await?;
Ok(Json(serde_json::json!({ "episode": episode_data })))
}
// Download episode file with metadata
pub async fn download_episode_file(
State(state): State<AppState>,
axum::extract::Path(episode_id): axum::extract::Path<i32>,
headers: HeaderMap,
axum::extract::Query(params): axum::extract::Query<std::collections::HashMap<String, String>>,
) -> AppResult<impl IntoResponse> {
// Try to get API key from header first, then from query parameter
let api_key = if let Ok(key) = extract_api_key(&headers) {
key
} else if let Some(key) = params.get("api_key") {
key.clone()
} else {
return Err(AppError::unauthorized("API key is required"));
};
validate_api_key(&state, &api_key).await?;
let user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
// Get episode metadata
let episode_info = match &state.db_pool {
crate::database::DatabasePool::Postgres(pool) => {
let row = sqlx::query(r#"
SELECT e."episodeurl", e."episodetitle", p."podcastname",
e."episodepubdate", p."author", e."episodeartwork", p."artworkurl",
e."episodedescription"
FROM "Episodes" e
JOIN "Podcasts" p ON e."podcastid" = p."podcastid"
WHERE e."episodeid" = $1
"#)
.bind(episode_id)
.fetch_one(pool)
.await?;
(
row.try_get::<String, _>("episodeurl")?,
row.try_get::<String, _>("episodetitle")?,
row.try_get::<String, _>("podcastname")?,
row.try_get::<Option<chrono::NaiveDateTime>, _>("episodepubdate")?,
row.try_get::<Option<String>, _>("author")?,
row.try_get::<Option<String>, _>("episodeartwork")?,
row.try_get::<Option<String>, _>("artworkurl")?,
row.try_get::<Option<String>, _>("episodedescription")?
)
}
crate::database::DatabasePool::MySQL(pool) => {
let row = sqlx::query("
SELECT e.EpisodeURL, e.EpisodeTitle, p.PodcastName,
e.EpisodePubDate, p.Author, e.EpisodeArtwork, p.ArtworkURL,
e.EpisodeDescription
FROM Episodes e
JOIN Podcasts p ON e.PodcastID = p.PodcastID
WHERE e.EpisodeID = ?
")
.bind(episode_id)
.fetch_one(pool)
.await?;
(
row.try_get::<String, _>("EpisodeURL")?,
row.try_get::<String, _>("EpisodeTitle")?,
row.try_get::<String, _>("PodcastName")?,
row.try_get::<Option<chrono::NaiveDateTime>, _>("EpisodePubDate")?,
row.try_get::<Option<String>, _>("Author")?,
row.try_get::<Option<String>, _>("EpisodeArtwork")?,
row.try_get::<Option<String>, _>("ArtworkURL")?,
row.try_get::<Option<String>, _>("EpisodeDescription")?
)
}
};
let (episode_url, episode_title, podcast_name, pub_date, author, episode_artwork, artwork_url, _description) = episode_info;
// Download the episode file
let client = reqwest::Client::new();
let response = client.get(&episode_url)
.send()
.await
.map_err(|e| AppError::internal(&format!("Failed to download episode: {}", e)))?;
if !response.status().is_success() {
return Err(AppError::internal(&format!("Server returned error: {}", response.status())));
}
let audio_bytes = response.bytes()
.await
.map_err(|e| AppError::internal(&format!("Failed to download audio content: {}", e)))?;
// Create a temporary file for metadata processing
let temp_dir = std::env::temp_dir();
let temp_filename = format!("episode_{}_{}_{}.mp3", episode_id, user_id, chrono::Utc::now().timestamp());
let temp_path = temp_dir.join(&temp_filename);
// Write audio content to temp file
std::fs::write(&temp_path, &audio_bytes)
.map_err(|e| AppError::internal(&format!("Failed to write temp file: {}", e)))?;
// Add metadata using the same function as server downloads
if let Err(e) = add_podcast_metadata(
&temp_path,
&episode_title,
author.as_deref().unwrap_or("Unknown"),
&podcast_name,
pub_date.as_ref(),
episode_artwork.as_deref().or(artwork_url.as_deref())
).await {
tracing::warn!("Failed to add metadata to downloaded episode: {}", e);
}
// Read the file with metadata back
let final_bytes = std::fs::read(&temp_path)
.map_err(|e| AppError::internal(&format!("Failed to read processed file: {}", e)))?;
// Clean up temp file
let _ = std::fs::remove_file(&temp_path);
// Create safe filename for download
let safe_episode_title = episode_title.chars()
.map(|c| if c.is_alphanumeric() || c == ' ' || c == '-' || c == '_' { c } else { '_' })
.collect::<String>()
.trim()
.to_string();
let safe_podcast_name = podcast_name.chars()
.map(|c| if c.is_alphanumeric() || c == ' ' || c == '-' || c == '_' { c } else { '_' })
.collect::<String>()
.trim()
.to_string();
let pub_date_str = if let Some(date) = pub_date {
date.format("%Y-%m-%d").to_string()
} else {
chrono::Utc::now().format("%Y-%m-%d").to_string()
};
let filename = format!("{}_{}_-_{}.mp3", pub_date_str, safe_podcast_name, safe_episode_title);
// Return the file with appropriate headers
let response = Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "audio/mpeg")
.header(header::CONTENT_DISPOSITION, format!("attachment; filename=\"{}\"", filename))
.header(header::CONTENT_LENGTH, final_bytes.len())
.body(axum::body::Body::from(final_bytes))
.map_err(|e| AppError::internal(&format!("Failed to create response: {}", e)))?;
Ok(response)
}
// Function to add metadata to downloaded MP3 files (copied from tasks.rs)
async fn add_podcast_metadata(
file_path: &std::path::Path,
title: &str,
artist: &str,
album: &str,
date: Option<&chrono::NaiveDateTime>,
artwork_url: Option<&str>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
use id3::TagLike; // Import the trait to use methods
use chrono::Datelike; // For year(), month(), day() methods
// Create ID3 tag and add basic metadata
let mut tag = id3::Tag::new();
tag.set_title(title);
tag.set_artist(artist);
tag.set_album(album);
// Set date if available
if let Some(date) = date {
tag.set_date_recorded(id3::Timestamp {
year: date.year(),
month: Some(date.month() as u8),
day: Some(date.day() as u8),
hour: None,
minute: None,
second: None,
});
}
// Add genre for podcasts
tag.set_genre("Podcast");
// Download and add artwork if available
if let Some(artwork_url) = artwork_url {
if let Ok(artwork_data) = download_artwork(artwork_url).await {
// Determine MIME type based on the data
let mime_type = if artwork_data.starts_with(&[0xFF, 0xD8, 0xFF]) {
"image/jpeg"
} else if artwork_data.starts_with(&[0x89, 0x50, 0x4E, 0x47]) {
"image/png"
} else {
"image/jpeg" // Default fallback
};
tag.add_frame(id3::frame::Picture {
mime_type: mime_type.to_string(),
picture_type: id3::frame::PictureType::CoverFront,
description: "Cover".to_string(),
data: artwork_data,
});
}
}
// Write the tag to the file
tag.write_to_path(file_path, id3::Version::Id3v24)?;
Ok(())
}
// Helper function to download artwork (copied from tasks.rs)
async fn download_artwork(url: &str) -> Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>> {
let client = reqwest::Client::new();
let response = client
.get(url)
.header("User-Agent", "PinePods/1.0")
.send()
.await?;
if response.status().is_success() {
let bytes = response.bytes().await?;
// Limit artwork size to reasonable bounds (e.g., 5MB)
if bytes.len() > 5 * 1024 * 1024 {
return Err("Artwork too large".into());
}
Ok(bytes.to_vec())
} else {
Err(format!("Failed to download artwork: HTTP {}", response.status()).into())
}
}

View File

@@ -0,0 +1,110 @@
use axum::{
extract::{Path, Query, State, Request},
response::Response,
};
use serde::Deserialize;
use crate::{
error::AppError,
AppState,
};
#[derive(Deserialize)]
pub struct FeedQuery {
pub api_key: String,
pub limit: Option<i32>,
pub podcast_id: Option<i32>,
#[serde(rename = "type")]
pub source_type: Option<String>,
}
// Get RSS feed for user - matches Python get_user_feed function exactly
pub async fn get_user_feed(
State(state): State<AppState>,
Path(user_id): Path<i32>,
Query(query): Query<FeedQuery>,
request: Request<axum::body::Body>,
) -> Result<Response<String>, AppError> {
let api_key = &query.api_key;
let limit = query.limit.unwrap_or(1000);
let podcast_id = query.podcast_id;
let source_type = query.source_type.as_deref();
// Get domain from request
let domain = extract_domain_from_request(&request);
// Convert single podcast_id to list format if provided
let podcast_id_list = if let Some(id) = podcast_id {
Some(vec![id])
} else {
None
};
// Get RSS key validation
let rss_key = state.db_pool.get_rss_key_if_valid(api_key, podcast_id_list.as_ref()).await?;
let rss_key = if let Some(key) = rss_key {
key
} else {
let key_id = state.db_pool.get_user_id_from_api_key(api_key).await?;
if key_id == 0 {
return Err(AppError::forbidden("Invalid API key"));
}
// Create a backwards compatibility RSS key structure
RssKeyInfo {
podcast_ids: vec![-1],
user_id: key_id,
key: api_key.to_string(),
}
};
let feed_content = state.db_pool.generate_podcast_rss(
rss_key,
limit,
source_type,
&domain,
podcast_id_list.as_ref(),
).await?;
Ok(Response::builder()
.header("content-type", "application/rss+xml")
.body(feed_content)
.map_err(|e| AppError::internal(&format!("Failed to create response: {}", e)))?)
}
#[derive(Debug, Clone)]
pub struct RssKeyInfo {
pub podcast_ids: Vec<i32>,
pub user_id: i32,
pub key: String,
}
fn extract_domain_from_request(request: &Request<axum::body::Body>) -> String {
// Check SERVER_URL environment variable first (includes scheme and port)
// Note: We use SERVER_URL instead of HOSTNAME because Docker automatically sets HOSTNAME to the container ID
// The startup script saves the user's HOSTNAME value to SERVER_URL before Docker overwrites it
if let Ok(server_url) = std::env::var("SERVER_URL") {
tracing::info!("Using SERVER_URL env var: {}", server_url);
return server_url;
}
// Try to get domain from Host header
if let Some(host) = request.headers().get("host") {
if let Ok(host_str) = host.to_str() {
// Determine scheme - check for X-Forwarded-Proto or assume http
let scheme = request.headers()
.get("x-forwarded-proto")
.and_then(|h| h.to_str().ok())
.unwrap_or("http");
let domain = format!("{}://{}", scheme, host_str);
tracing::info!("Using Host header: {}", domain);
return domain;
}
}
// Fallback
tracing::info!("Using fallback domain");
"http://localhost:8041".to_string()
}

View File

@@ -0,0 +1,39 @@
use axum::{extract::State, response::Json};
use chrono::Utc;
use crate::{
error::AppResult,
models::{HealthResponse, PinepodsCheckResponse},
AppState,
};
/// PinePods instance check endpoint - matches Python API exactly
/// GET /api/pinepods_check
pub async fn pinepods_check() -> Json<PinepodsCheckResponse> {
Json(PinepodsCheckResponse {
status_code: 200,
pinepods_instance: true,
})
}
/// Health check endpoint with database and Redis status
/// GET /api/health
pub async fn health_check(State(state): State<AppState>) -> AppResult<Json<HealthResponse>> {
// Check database health
let database_healthy = state.db_pool.health_check().await.unwrap_or(false);
// Check Redis health
let redis_healthy = state.redis_client.health_check().await.unwrap_or(false);
let overall_status = if database_healthy && redis_healthy {
"healthy"
} else {
"unhealthy"
};
Ok(Json(HealthResponse {
status: overall_status.to_string(),
database: database_healthy,
redis: redis_healthy,
timestamp: Utc::now(),
}))
}

View File

@@ -0,0 +1,112 @@
pub mod auth;
pub mod health;
pub mod podcasts;
pub mod episodes;
pub mod playlists;
pub mod websocket;
// pub mod async_tasks_examples; // File was deleted
pub mod refresh;
pub mod proxy;
pub mod settings;
pub mod sync;
pub mod youtube;
pub mod tasks;
pub mod feed;
// Common handler utilities
use axum::{
extract::Query,
http::{HeaderMap, StatusCode},
};
use crate::{
error::{AppError, AppResult},
models::PaginationParams,
AppState,
};
// Extract API key from headers (matches Python API behavior)
pub fn extract_api_key(headers: &HeaderMap) -> AppResult<String> {
headers
.get("Api-Key")
.or_else(|| headers.get("api-key"))
.or_else(|| headers.get("X-API-Key"))
.and_then(|header| header.to_str().ok())
.map(|s| s.to_string())
.ok_or_else(|| AppError::unauthorized("Missing API key"))
}
// Validate API key against database/cache
pub async fn validate_api_key(state: &AppState, api_key: &str) -> AppResult<bool> {
// First check Redis cache
if let Ok(Some(is_valid)) = state.redis_client.get_cached_api_key_validation(api_key).await {
return Ok(is_valid);
}
// If not in cache, check database
let is_valid = state.db_pool.verify_api_key(api_key).await?;
// Cache the result for 5 minutes
if let Err(e) = state.redis_client.cache_api_key_validation(api_key, is_valid, 300).await {
tracing::warn!("Failed to cache API key validation: {}", e);
}
Ok(is_valid)
}
// Check if user has permission (either owns the resource or has web key/admin access)
pub async fn check_user_access(state: &AppState, api_key: &str, target_user_id: i32) -> AppResult<bool> {
let requesting_user_id = state.db_pool.get_user_id_from_api_key(api_key).await?;
// Allow if user is accessing their own data or if they are user ID 1 (admin/web key)
Ok(requesting_user_id == target_user_id || requesting_user_id == 1)
}
// Check if user has elevated access (web key - user ID 1)
pub async fn check_web_key_access(state: &AppState, api_key: &str) -> AppResult<bool> {
let requesting_user_id = state.db_pool.get_user_id_from_api_key(api_key).await?;
Ok(requesting_user_id == 1)
}
// Check if user has admin privileges
pub async fn check_admin_access(state: &AppState, api_key: &str) -> AppResult<bool> {
let requesting_user_id = state.db_pool.get_user_id_from_api_key(api_key).await?;
state.db_pool.user_admin_check(requesting_user_id).await
}
// Check if user has permission (either owns the resource, has web key access, or is admin)
pub async fn check_user_or_admin_access(state: &AppState, api_key: &str, target_user_id: i32) -> AppResult<bool> {
let requesting_user_id = state.db_pool.get_user_id_from_api_key(api_key).await?;
// Allow if user is accessing their own data, has web key access, or is admin
if requesting_user_id == target_user_id || requesting_user_id == 1 {
Ok(true)
} else {
// Check if user is admin
state.db_pool.user_admin_check(requesting_user_id).await
}
}
// Extract and validate pagination parameters
pub fn extract_pagination(Query(params): Query<PaginationParams>) -> (i32, i32) {
let page = params.page.unwrap_or(1).max(1);
let per_page = params.per_page.unwrap_or(50).min(100).max(1); // Limit to 100 per page
(page, per_page)
}
// Calculate offset for SQL queries
pub fn calculate_offset(page: i32, per_page: i32) -> i32 {
(page - 1) * per_page
}
// Common response helpers
pub fn success_response() -> (StatusCode, &'static str) {
(StatusCode::OK, "success")
}
pub fn created_response() -> (StatusCode, &'static str) {
(StatusCode::CREATED, "created")
}
pub fn no_content_response() -> StatusCode {
StatusCode::NO_CONTENT
}

View File

@@ -0,0 +1,61 @@
use axum::{extract::State, http::HeaderMap, response::Json};
use crate::{
database,
error::{AppError, AppResult},
handlers::{extract_api_key, validate_api_key},
models::{CreatePlaylistRequest, CreatePlaylistResponse, DeletePlaylistRequest, DeletePlaylistResponse},
AppState,
};
pub async fn create_playlist(
State(state): State<AppState>,
headers: HeaderMap,
Json(playlist_data): Json<CreatePlaylistRequest>,
) -> AppResult<Json<CreatePlaylistResponse>> {
let api_key = extract_api_key(&headers)?;
let is_valid = validate_api_key(&state, &api_key).await?;
if !is_valid {
return Err(AppError::unauthorized("Your API key is either invalid or does not have correct permission"));
}
let user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
let is_web_key = state.db_pool.is_web_key(&api_key).await?;
if user_id != playlist_data.user_id && !is_web_key {
return Err(AppError::forbidden("You can only create playlists for yourself!"));
}
let playlist_id = database::create_playlist(&state.db_pool, &state.config, &playlist_data).await?;
Ok(Json(CreatePlaylistResponse {
detail: "Playlist created successfully".to_string(),
playlist_id,
}))
}
pub async fn delete_playlist(
State(state): State<AppState>,
headers: HeaderMap,
Json(playlist_data): Json<DeletePlaylistRequest>,
) -> AppResult<Json<DeletePlaylistResponse>> {
let api_key = extract_api_key(&headers)?;
let is_valid = validate_api_key(&state, &api_key).await?;
if !is_valid {
return Err(AppError::unauthorized("Your API key is either invalid or does not have correct permission"));
}
let user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
let is_web_key = state.db_pool.is_web_key(&api_key).await?;
if user_id != playlist_data.user_id && !is_web_key {
return Err(AppError::forbidden("You can only delete your own playlists!"));
}
database::delete_playlist(&state.db_pool, &state.config, &playlist_data).await?;
Ok(Json(DeletePlaylistResponse {
detail: "Playlist deleted successfully".to_string(),
}))
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,78 @@
use axum::{
extract::Query,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use serde::Deserialize;
#[derive(Deserialize)]
pub struct ImageProxyQuery {
pub url: String,
}
// Image proxy endpoint - matches Python proxy_image endpoint
pub async fn proxy_image(
Query(query): Query<ImageProxyQuery>,
) -> Result<Response, StatusCode> {
tracing::info!("Image proxy request received for URL: {}", query.url);
if !is_valid_image_url(&query.url) {
tracing::error!("Invalid image URL: {}", query.url);
return Err(StatusCode::BAD_REQUEST);
}
let client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::limited(10))
.timeout(std::time::Duration::from_secs(10))
.build()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
tracing::info!("Fetching image from: {}", query.url);
let response = client
.get(&query.url)
.send()
.await
.map_err(|_| StatusCode::BAD_GATEWAY)?;
tracing::info!("Image fetch response status: {}", response.status());
if !response.status().is_success() {
return Err(StatusCode::BAD_GATEWAY);
}
let content_type = response
.headers()
.get("content-type")
.and_then(|ct| ct.to_str().ok())
.unwrap_or("")
.to_string();
tracing::info!("Content type: {}", content_type);
if !content_type.starts_with("image/") && content_type != "application/octet-stream" {
tracing::error!("Invalid content type: {}", content_type);
return Err(StatusCode::BAD_REQUEST);
}
let bytes = response.bytes().await.map_err(|_| StatusCode::BAD_GATEWAY)?;
let mut headers = HeaderMap::new();
headers.insert("content-type", content_type.parse().unwrap());
headers.insert("cache-control", "public, max-age=86400".parse().unwrap());
headers.insert("access-control-allow-origin", "*".parse().unwrap());
headers.insert("x-content-type-options", "nosniff".parse().unwrap());
tracing::info!("Returning image response");
Ok((headers, bytes).into_response())
}
fn is_valid_image_url(url: &str) -> bool {
// Basic URL validation - check if it's a valid URL and uses http/https
if let Ok(parsed_url) = url::Url::parse(url) {
matches!(parsed_url.scheme(), "http" | "https")
} else {
false
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,511 @@
use axum::{
extract::{Path, Query, State},
http::HeaderMap,
response::Json,
};
use serde::{Deserialize, Serialize};
use crate::{
error::AppError,
handlers::{extract_api_key, validate_api_key},
AppState,
};
#[derive(Debug, Deserialize)]
pub struct UpdateGpodderSyncRequest {
pub enabled: bool,
}
#[derive(Debug, Deserialize)]
pub struct RemoveSyncRequest {
pub user_id: i32,
}
// Set default gPodder device - accepts device name for frontend compatibility
pub async fn gpodder_set_default(
State(state): State<AppState>,
Path(device_name): Path<String>,
headers: HeaderMap,
) -> Result<Json<serde_json::Value>, AppError> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
let user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
let success = state.db_pool.gpodder_set_default_device_by_name(user_id, &device_name).await?;
if success {
Ok(Json(serde_json::json!({
"success": true,
"message": "Default device set successfully",
"data": null
})))
} else {
Err(AppError::internal("Failed to set default device"))
}
}
// Get gPodder devices for user - matches Python get_devices function exactly
pub async fn gpodder_get_user_devices(
State(state): State<AppState>,
Path(user_id): Path<i32>,
headers: HeaderMap,
) -> Result<Json<serde_json::Value>, AppError> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
// Check authorization - web key or own user
let user_id_from_api_key = state.db_pool.get_user_id_from_api_key(&api_key).await?;
let is_web_key = state.db_pool.is_web_key(&api_key).await?;
if user_id != user_id_from_api_key && !is_web_key {
return Err(AppError::forbidden("You can only view your own devices!"));
}
let devices = state.db_pool.gpodder_get_user_devices(user_id).await?;
Ok(Json(serde_json::json!(devices)))
}
// Get all gPodder devices - matches Python get_all_devices function exactly
pub async fn gpodder_get_all_devices(
State(state): State<AppState>,
headers: HeaderMap,
) -> Result<Json<serde_json::Value>, AppError> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
let user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
let devices = state.db_pool.gpodder_get_user_devices(user_id).await?;
Ok(Json(serde_json::json!(devices)))
}
// Force sync gPodder - performs initial full sync without timestamps (like setup)
pub async fn gpodder_force_sync(
State(state): State<AppState>,
headers: HeaderMap,
) -> Result<Json<serde_json::Value>, AppError> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
let user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
// Get user's sync settings to determine which sync method to use
let sync_settings = state.db_pool.get_user_sync_settings(user_id).await?;
if sync_settings.is_none() {
return Ok(Json(serde_json::json!({
"success": false,
"message": "No sync configured for this user",
"data": null
})));
}
let settings = sync_settings.unwrap();
let device_name = state.db_pool.get_or_create_default_device(user_id).await?;
// Perform initial full sync (without timestamps) based on sync type
let sync_result = match settings.sync_type.as_str() {
"gpodder" => {
// Internal gPodder API - call initial full sync
state.db_pool.call_gpodder_initial_full_sync(user_id, "http://localhost:8042", &settings.username, &settings.token, &device_name).await
}
"nextcloud" => {
// Nextcloud initial sync
state.db_pool.call_nextcloud_initial_full_sync(user_id, &settings.url, &settings.username, &settings.token).await
}
"external" => {
// External gPodder server - decrypt token first then call initial full sync
let decrypted_token = state.db_pool.decrypt_password(&settings.token).await.unwrap_or_default();
state.db_pool.call_gpodder_initial_full_sync(user_id, &settings.url, &settings.username, &decrypted_token, &device_name).await
}
"both" => {
// Both internal and external - call initial sync for both
let internal_result = state.db_pool.call_gpodder_initial_full_sync(user_id, "http://localhost:8042", &settings.username, &settings.token, &device_name).await;
let decrypted_token = state.db_pool.decrypt_password(&settings.token).await.unwrap_or_default();
let external_result = state.db_pool.call_gpodder_initial_full_sync(user_id, &settings.url, &settings.username, &decrypted_token, &device_name).await;
match (internal_result, external_result) {
(Ok(internal_success), Ok(external_success)) => Ok(internal_success || external_success),
(Ok(internal_success), Err(external_err)) => {
tracing::warn!("External sync failed: {}, but internal sync succeeded: {}", external_err, internal_success);
Ok(internal_success)
}
(Err(internal_err), Ok(external_success)) => {
tracing::warn!("Internal sync failed: {}, but external sync succeeded: {}", internal_err, external_success);
Ok(external_success)
}
(Err(internal_err), Err(external_err)) => {
tracing::error!("Both internal and external sync failed: internal={}, external={}", internal_err, external_err);
Err(internal_err)
}
}
}
_ => Ok(false)
};
let (success, error_message) = match sync_result {
Ok(result) => (result, None),
Err(e) => {
tracing::error!("Sync failed with error: {}", e);
(false, Some(e.to_string()))
}
};
if success {
Ok(Json(serde_json::json!({
"success": true,
"message": "Initial sync completed successfully - all data refreshed",
"data": null
})))
} else {
let message = error_message.unwrap_or_else(|| "Initial sync failed - please check your sync configuration".to_string());
Ok(Json(serde_json::json!({
"success": false,
"message": format!("Initial sync failed: {}", message),
"data": null
})))
}
}
// Regular gPodder sync - performs standard incremental sync with timestamps (like tasks.rs)
pub async fn gpodder_sync(
State(state): State<AppState>,
headers: HeaderMap,
) -> Result<Json<serde_json::Value>, AppError> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
let user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
// Use the same sync process as the scheduler (tasks.rs) which uses proper API calls with timestamps
let sync_result = state.db_pool.refresh_gpodder_subscription_background(user_id).await?;
if sync_result {
Ok(Json(serde_json::json!({
"success": true,
"message": "Sync completed successfully",
"data": null
})))
} else {
Ok(Json(serde_json::json!({
"success": false,
"message": "Sync failed or no changes detected - check your sync configuration",
"data": null
})))
}
}
// Get gPodder status - matches Python get_gpodder_status function exactly
pub async fn gpodder_status(
State(state): State<AppState>,
headers: HeaderMap,
) -> Result<Json<serde_json::Value>, AppError> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
let user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
let status = state.db_pool.gpodder_get_status(user_id).await?;
Ok(Json(serde_json::json!({
"sync_type": status.sync_type,
"gpodder_enabled": status.sync_type == "gpodder" || status.sync_type == "both" || status.sync_type == "external",
"external_enabled": status.sync_type == "external" || status.sync_type == "both",
"external_url": status.gpodder_url,
"api_url": "http://localhost:8042"
})))
}
// Toggle gPodder sync - matches Python toggle_gpodder_sync function exactly
pub async fn gpodder_toggle(
State(state): State<AppState>,
headers: HeaderMap,
Json(request): Json<UpdateGpodderSyncRequest>,
) -> Result<Json<serde_json::Value>, AppError> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
let user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
// Get current user status to match Python logic
let user_status = state.db_pool.gpodder_get_status(user_id).await?;
let current_sync_type = &user_status.sync_type;
let mut device_info: Option<serde_json::Value> = None;
if request.enabled {
// Enable gpodder sync - call function that matches Python set_gpodder_internal_sync
if let Ok(result) = state.db_pool.set_gpodder_internal_sync(user_id).await {
device_info = Some(result);
} else {
return Err(AppError::internal("Failed to enable gpodder sync"));
}
// Add background task for subscription refresh (matches Python background_tasks.add_task)
let db_pool = state.db_pool.clone();
let _task_id = state.task_spawner.spawn_progress_task(
"gpodder_subscription_refresh".to_string(),
user_id,
move |reporter| async move {
reporter.update_progress(10.0, Some("Starting GPodder subscription refresh...".to_string())).await?;
let success = db_pool.refresh_gpodder_subscription_background(user_id).await
.map_err(|e| AppError::internal(&format!("GPodder sync failed: {}", e)))?;
if success {
reporter.update_progress(100.0, Some("GPodder subscription refresh completed successfully".to_string())).await?;
Ok(serde_json::json!({"status": "GPodder subscription refresh completed successfully"}))
} else {
reporter.update_progress(100.0, Some("GPodder subscription refresh completed with no changes".to_string())).await?;
Ok(serde_json::json!({"status": "No sync performed"}))
}
},
).await?;
} else {
// Disable gpodder sync - call function that matches Python disable_gpodder_internal_sync
if !state.db_pool.disable_gpodder_internal_sync(user_id).await? {
return Err(AppError::internal("Failed to disable gpodder sync"));
}
}
// Get updated status after changes
let updated_status = state.db_pool.gpodder_get_status(user_id).await?;
let new_sync_type = &updated_status.sync_type;
let mut response = serde_json::json!({
"sync_type": new_sync_type,
"gpodder_enabled": new_sync_type == "gpodder" || new_sync_type == "both",
"external_enabled": new_sync_type == "external" || new_sync_type == "both",
"external_url": if new_sync_type == "external" || new_sync_type == "both" {
updated_status.gpodder_url
} else {
None::<String>
},
"api_url": if new_sync_type == "gpodder" || new_sync_type == "both" {
Some("http://localhost:8042")
} else {
None
}
});
// Add device information if available (matches Python logic)
if let Some(device_data) = device_info {
if request.enabled {
if let Some(device_name) = device_data.get("device_name") {
response["device_name"] = device_name.clone();
}
if let Some(device_id) = device_data.get("device_id") {
response["device_id"] = device_id.clone();
}
}
}
Ok(Json(response))
}
// gPodder test connection - matches Python test connection functionality
pub async fn gpodder_test_connection(
State(state): State<AppState>,
Query(params): Query<std::collections::HashMap<String, String>>,
headers: HeaderMap,
) -> Result<Json<serde_json::Value>, AppError> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
let user_id = params.get("user_id")
.ok_or_else(|| AppError::bad_request("Missing user_id parameter"))?
.parse::<i32>()
.map_err(|_| AppError::bad_request("Invalid user_id format"))?;
let gpodder_url = params.get("gpodder_url")
.ok_or_else(|| AppError::bad_request("Missing gpodder_url parameter"))?;
let gpodder_username = params.get("gpodder_username")
.ok_or_else(|| AppError::bad_request("Missing gpodder_username parameter"))?;
let gpodder_password = params.get("gpodder_password")
.ok_or_else(|| AppError::bad_request("Missing gpodder_password parameter"))?;
// Check authorization
let user_id_from_api_key = state.db_pool.get_user_id_from_api_key(&api_key).await?;
let is_web_key = state.db_pool.is_web_key(&api_key).await?;
if user_id != user_id_from_api_key && !is_web_key {
return Err(AppError::forbidden("You can only test connections for yourself!"));
}
// Direct HTTP call to match Python implementation exactly
let client = reqwest::Client::new();
let auth_url = format!("{}/api/2/auth/{}/login.json",
gpodder_url.trim_end_matches('/'),
gpodder_username);
let verified = match client
.post(&auth_url)
.basic_auth(gpodder_username, Some(gpodder_password))
.send()
.await
{
Ok(response) => response.status().is_success(),
Err(_) => false,
};
if verified {
Ok(Json(serde_json::json!({
"success": true,
"message": "Successfully connected to GPodder server and verified access.",
"data": {
"auth_type": "session",
"has_devices": true
}
})))
} else {
Ok(Json(serde_json::json!({
"success": false,
"message": "Failed to connect to GPodder server",
"data": null
})))
}
}
// Get default gPodder device - matches Python get_default_device function exactly
pub async fn gpodder_get_default_device(
State(state): State<AppState>,
headers: HeaderMap,
) -> Result<Json<serde_json::Value>, AppError> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
let user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
let default_device = state.db_pool.gpodder_get_default_device(user_id).await?;
Ok(Json(serde_json::json!(default_device)))
}
// Create gPodder device - matches Python create_device function exactly
#[derive(serde::Deserialize)]
pub struct CreateDeviceRequest {
pub device_name: String,
pub device_type: String,
pub device_caption: Option<String>,
}
pub async fn gpodder_create_device(
State(state): State<AppState>,
headers: HeaderMap,
Json(request): Json<CreateDeviceRequest>,
) -> Result<Json<serde_json::Value>, AppError> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
let user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
// Get user's GPodder sync settings
let settings = state.db_pool.get_user_sync_settings(user_id).await?
.ok_or_else(|| AppError::BadRequest("User not found or GPodder sync not configured".to_string()))?;
// Validate that GPodder sync is enabled
if settings.sync_type != "gpodder" && settings.sync_type != "both" && settings.sync_type != "external" {
return Err(AppError::BadRequest("GPodder sync is not enabled for this user".to_string()));
}
// Create device via GPodder API (uses proper auth for internal/external)
let device_id = state.db_pool.create_device_via_gpodder_api(
&settings.url,
&settings.username,
&settings.token,
&request.device_name
).await.map_err(|e| AppError::Internal(format!("Failed to create device via GPodder API: {}", e)))?;
// Return GPodder API standard format
Ok(Json(serde_json::json!({
"id": device_id, // GPodder device ID (string)
"name": request.device_name,
"type": request.device_type,
"caption": request.device_caption.unwrap_or_else(|| request.device_name.clone()),
"last_sync": Option::<String>::None,
"is_active": true,
"is_remote": true,
"is_default": false
})))
}
// GPodder Statistics - real server-side stats from GPodder API
#[derive(Serialize)]
pub struct GpodderStatistics {
pub server_url: String,
pub sync_type: String,
pub sync_enabled: bool,
pub server_devices: Vec<ServerDevice>,
pub total_devices: i32,
pub server_subscriptions: Vec<ServerSubscription>,
pub total_subscriptions: i32,
pub recent_episode_actions: Vec<ServerEpisodeAction>,
pub total_episode_actions: i32,
pub connection_status: String,
pub last_sync_timestamp: Option<String>,
pub api_endpoints_tested: Vec<EndpointTest>,
}
#[derive(Serialize, Clone)]
pub struct ServerDevice {
pub id: String,
pub caption: String,
pub device_type: String,
pub subscriptions: i32,
}
#[derive(Serialize, Clone)]
pub struct ServerSubscription {
pub url: String,
pub title: Option<String>,
pub description: Option<String>,
}
#[derive(Serialize, Clone)]
pub struct ServerEpisodeAction {
pub podcast: String,
pub episode: String,
pub action: String,
pub timestamp: String,
pub position: Option<i32>,
pub device: Option<String>,
}
#[derive(Serialize)]
pub struct EndpointTest {
pub endpoint: String,
pub status: String, // "success", "failed", "not_tested"
pub response_time_ms: Option<i64>,
pub error: Option<String>,
}
pub async fn gpodder_get_statistics(
State(state): State<AppState>,
headers: HeaderMap,
) -> Result<Json<GpodderStatistics>, AppError> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
let user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
// Check if GPodder is enabled for this user
let gpodder_status = state.db_pool.gpodder_get_status(user_id).await?;
if gpodder_status.sync_type == "None" {
return Ok(Json(GpodderStatistics {
server_url: "No sync configured".to_string(),
sync_type: "None".to_string(),
sync_enabled: false,
server_devices: vec![],
total_devices: 0,
server_subscriptions: vec![],
total_subscriptions: 0,
recent_episode_actions: vec![],
total_episode_actions: 0,
connection_status: "Not configured".to_string(),
last_sync_timestamp: None,
api_endpoints_tested: vec![],
}));
}
// Get real statistics from GPodder server
let statistics = state.db_pool.get_gpodder_server_statistics(user_id).await?;
Ok(Json(statistics))
}

View File

@@ -0,0 +1,397 @@
use axum::{
extract::State,
http::HeaderMap,
response::Json,
};
use serde::Deserialize;
use serde_json;
use crate::{
error::{AppError, AppResult},
handlers::{extract_api_key, validate_api_key},
AppState,
};
#[derive(Deserialize)]
pub struct InitRequest {
pub api_key: String,
}
// Startup tasks endpoint - matches Python startup_tasks function exactly
pub async fn startup_tasks(
State(state): State<AppState>,
Json(request): Json<InitRequest>,
) -> Result<Json<serde_json::Value>, AppError> {
// Verify if the API key is valid
let is_valid = validate_api_key(&state, &request.api_key).await?;
if !is_valid {
return Err(AppError::forbidden("Invalid or unauthorized API key"));
}
// Check if the provided API key is from the background_tasks user (UserID 1)
let api_user_id = state.db_pool.get_user_id_from_api_key(&request.api_key).await?;
if api_user_id != 1 {
return Err(AppError::forbidden("Invalid or unauthorized API key"));
}
// Execute the startup tasks
state.db_pool.add_news_feed_if_not_added().await?;
// Create default playlists for any users that might be missing them
state.db_pool.create_missing_default_playlists().await?;
Ok(Json(serde_json::json!({"status": "Startup tasks completed successfully."})))
}
// Cleanup tasks endpoint - matches Python cleanup_tasks function exactly
pub async fn cleanup_tasks(
headers: HeaderMap,
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, AppError> {
let api_key = extract_api_key(&headers)?;
// Verify if the API key is valid and is web key (admin only)
let is_valid = validate_api_key(&state, &api_key).await?;
if !is_valid {
return Err(AppError::forbidden("Invalid API key"));
}
let api_user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
if api_user_id != 1 {
return Err(AppError::forbidden("Admin access required"));
}
// Run cleanup tasks in background
let db_pool = state.db_pool.clone();
let task_id = state.task_spawner.spawn_progress_task(
"cleanup_tasks".to_string(),
0, // System user
move |reporter| async move {
reporter.update_progress(50.0, Some("Running cleanup tasks...".to_string())).await?;
db_pool.cleanup_old_episodes().await
.map_err(|e| AppError::internal(&format!("Cleanup failed: {}", e)))?;
reporter.update_progress(100.0, Some("Cleanup completed successfully".to_string())).await?;
Ok(serde_json::json!({"status": "Cleanup tasks completed successfully"}))
},
).await?;
Ok(Json(serde_json::json!({
"detail": "Cleanup tasks initiated.",
"task_id": task_id
})))
}
// Update playlists endpoint - matches Python update_playlists function exactly
pub async fn update_playlists(
headers: HeaderMap,
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, AppError> {
let api_key = extract_api_key(&headers)?;
// Verify if the API key is valid and is web key (admin only)
let is_valid = validate_api_key(&state, &api_key).await?;
if !is_valid {
return Err(AppError::forbidden("Invalid API key"));
}
let api_user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
if api_user_id != 1 {
return Err(AppError::forbidden("Admin access required"));
}
// Run playlist update in background
let db_pool = state.db_pool.clone();
let task_id = state.task_spawner.spawn_progress_task(
"update_playlists".to_string(),
0, // System user
move |reporter| async move {
reporter.update_progress(50.0, Some("Updating all playlists...".to_string())).await?;
db_pool.update_all_playlists().await
.map_err(|e| AppError::internal(&format!("Playlist update failed: {}", e)))?;
reporter.update_progress(100.0, Some("Playlist update completed successfully".to_string())).await?;
Ok(serde_json::json!({"status": "Playlist update completed successfully"}))
},
).await?;
Ok(Json(serde_json::json!({
"detail": "Playlist update initiated.",
"task_id": task_id
})))
}
// Refresh hosts endpoint - matches Python refresh_all_hosts function exactly
pub async fn refresh_hosts(
headers: HeaderMap,
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, AppError> {
let api_key = extract_api_key(&headers)?;
// Verify it's the system API key (background_tasks user with UserID 1)
let is_valid = validate_api_key(&state, &api_key).await?;
if !is_valid {
return Err(AppError::forbidden("Invalid API key"));
}
let api_user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
if api_user_id != 1 {
return Err(AppError::forbidden("This endpoint requires system API key"));
}
// Run host refresh in background
let db_pool = state.db_pool.clone();
let task_id = state.task_spawner.spawn_progress_task(
"refresh_hosts".to_string(),
0, // System user
move |reporter| async move {
reporter.update_progress(10.0, Some("Getting all people/hosts...".to_string())).await?;
let all_people = db_pool.get_all_people_for_refresh().await
.map_err(|e| AppError::internal(&format!("Failed to get people: {}", e)))?;
tracing::info!("Found {} people/hosts to refresh", all_people.len());
let mut successful_refreshes = 0;
let mut failed_refreshes = 0;
for (index, (person_id, person_name, user_id)) in all_people.iter().enumerate() {
let progress = 10.0 + (80.0 * (index as f64) / (all_people.len() as f64));
reporter.update_progress(progress, Some(format!("Refreshing host: {} ({}/{})", person_name, index + 1, all_people.len()))).await?;
tracing::info!("Starting refresh for host: {} (ID: {}, User: {})", person_name, person_id, user_id);
match process_person_refresh(&db_pool, *person_id, person_name, *user_id).await {
Ok(_) => {
successful_refreshes += 1;
tracing::info!("Successfully refreshed host: {}", person_name);
}
Err(e) => {
failed_refreshes += 1;
tracing::error!("Failed to refresh host {}: {}", person_name, e);
}
}
}
// After processing all people, trigger the regular podcast refresh
tracing::info!("Person subscription processed, initiating server refresh...");
match trigger_podcast_refresh(&db_pool).await {
Ok(_) => {
tracing::info!("Server refresh completed successfully");
}
Err(e) => {
tracing::error!("Error during server refresh: {}", e);
}
}
tracing::info!("Host refresh completed: {}/{} successful, {} failed",
successful_refreshes, all_people.len(), failed_refreshes);
reporter.update_progress(100.0, Some(format!(
"Host refresh completed: {}/{} successful",
successful_refreshes, all_people.len()
))).await?;
Ok(serde_json::json!({
"success": true,
"hosts_refreshed": successful_refreshes,
"hosts_failed": failed_refreshes,
"total_hosts": all_people.len()
}))
},
).await?;
Ok(Json(serde_json::json!({
"detail": "Host refresh initiated.",
"task_id": task_id
})))
}
// Helper function to process individual person refresh - matches Python process_person_subscription
async fn process_person_refresh(
db_pool: &crate::database::DatabasePool,
person_id: i32,
person_name: &str,
user_id: i32,
) -> AppResult<()> {
tracing::info!("Processing person subscription for: {} (ID: {}, User: {})", person_name, person_id, user_id);
// Get person details and refresh their content
match db_pool.process_person_subscription(user_id, person_id, person_name.to_string()).await {
Ok(_) => {
tracing::info!("Successfully processed person subscription for {}", person_name);
Ok(())
}
Err(e) => {
tracing::error!("Error processing person subscription for {}: {}", person_name, e);
Err(e)
}
}
}
// Helper function to trigger podcast refresh after person processing - matches Python refresh_pods_task
async fn trigger_podcast_refresh(db_pool: &crate::database::DatabasePool) -> AppResult<()> {
// Get all users with podcasts and refresh them
let all_users = db_pool.get_all_users_with_podcasts().await?;
for user_id in all_users {
match refresh_user_podcasts(db_pool, user_id).await {
Ok((podcast_count, episode_count)) => {
tracing::info!("Successfully refreshed user {}: {} podcasts, {} new episodes",
user_id, podcast_count, episode_count);
}
Err(e) => {
tracing::error!("Failed to refresh user {}: {}", user_id, e);
}
}
}
Ok(())
}
// Helper function to refresh podcasts for a single user
async fn refresh_user_podcasts(db_pool: &crate::database::DatabasePool, user_id: i32) -> AppResult<(i32, i32)> {
let podcasts = db_pool.get_user_podcasts_for_refresh(user_id).await?;
let mut successful_podcasts = 0;
let mut total_new_episodes = 0;
for podcast in podcasts {
match refresh_single_podcast(db_pool, &podcast).await {
Ok(new_episode_count) => {
successful_podcasts += 1;
total_new_episodes += new_episode_count;
tracing::info!("Refreshed podcast '{}': {} new episodes", podcast.name, new_episode_count);
}
Err(e) => {
tracing::error!("Failed to refresh podcast '{}': {}", podcast.name, e);
}
}
}
Ok((successful_podcasts, total_new_episodes))
}
// Auto-complete episodes based on user settings - nightly task
pub async fn auto_complete_episodes(
headers: HeaderMap,
State(state): State<AppState>,
) -> Result<Json<serde_json::Value>, AppError> {
let api_key = extract_api_key(&headers)?;
// Verify if the API key is valid
let is_valid = validate_api_key(&state, &api_key).await?;
if !is_valid {
return Err(AppError::forbidden("Invalid or unauthorized API key"));
}
// Check if the provided API key is from the background_tasks user (UserID 1)
let api_user_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
if api_user_id != 1 {
return Err(AppError::forbidden("Invalid or unauthorized API key"));
}
// Get all users who have auto_complete_seconds > 0
let users_with_auto_complete = state.db_pool.get_users_with_auto_complete_enabled().await?;
let mut total_completed = 0;
for user in users_with_auto_complete {
let completed_count = state.db_pool.auto_complete_user_episodes(user.user_id, user.auto_complete_seconds).await.unwrap_or(0);
total_completed += completed_count;
}
Ok(Json(serde_json::json!({
"status": "Auto-complete task completed successfully",
"episodes_completed": total_completed
})))
}
// Helper function to refresh a single podcast
async fn refresh_single_podcast(
_db_pool: &crate::database::DatabasePool,
podcast: &crate::handlers::refresh::PodcastForRefresh,
) -> AppResult<i32> {
tracing::info!("Refreshing podcast: {} (ID: {})", podcast.name, podcast.id);
// This would normally refresh the podcast feed and return new episode count
// For now return 0 as placeholder since we need the podcast refresh system to be implemented
Ok(0)
}
// Internal functions for scheduler (no HTTP context needed)
pub async fn cleanup_tasks_internal(state: &AppState) -> AppResult<()> {
tracing::info!("Starting internal cleanup tasks (scheduler)");
state.db_pool.cleanup_old_episodes().await?;
tracing::info!("Cleanup tasks completed successfully");
Ok(())
}
pub async fn update_playlists_internal(state: &AppState) -> AppResult<()> {
tracing::info!("Starting internal playlist update (scheduler)");
state.db_pool.update_all_playlists().await?;
tracing::info!("Playlist update completed successfully");
Ok(())
}
pub async fn refresh_hosts_internal(state: &AppState) -> AppResult<()> {
tracing::info!("Starting internal host refresh (scheduler)");
let all_people = state.db_pool.get_all_people_for_refresh().await?;
tracing::info!("Found {} people/hosts to refresh", all_people.len());
let mut successful_refreshes = 0;
let mut failed_refreshes = 0;
for (person_id, person_name, user_id) in all_people.iter() {
tracing::info!("Starting refresh for host: {} (ID: {}, User: {})", person_name, person_id, user_id);
match process_person_refresh(&state.db_pool, *person_id, person_name, *user_id).await {
Ok(_) => {
successful_refreshes += 1;
tracing::info!("Successfully refreshed host: {}", person_name);
}
Err(e) => {
failed_refreshes += 1;
tracing::error!("Failed to refresh host {}: {}", person_name, e);
}
}
}
// After processing all people, trigger the regular podcast refresh
tracing::info!("Person subscription processed, initiating server refresh...");
match trigger_podcast_refresh(&state.db_pool).await {
Ok(_) => {
tracing::info!("Server refresh completed successfully");
}
Err(e) => {
tracing::error!("Error during server refresh: {}", e);
}
}
tracing::info!("Host refresh completed: {}/{} successful, {} failed",
successful_refreshes, all_people.len(), failed_refreshes);
Ok(())
}
pub async fn auto_complete_episodes_internal(state: &AppState) -> AppResult<()> {
tracing::info!("Starting internal auto-complete episodes (scheduler)");
// Get all users who have auto_complete_seconds > 0
let users_with_auto_complete = state.db_pool.get_users_with_auto_complete_enabled().await?;
let mut total_completed = 0;
for user in users_with_auto_complete {
let completed_count = state.db_pool.auto_complete_user_episodes(user.user_id, user.auto_complete_seconds).await.unwrap_or(0);
total_completed += completed_count;
}
tracing::info!("Auto-complete task completed: {} episodes completed", total_completed);
Ok(())
}

View File

@@ -0,0 +1,221 @@
use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
Path, Query, State,
},
response::Response,
};
use futures::{sink::SinkExt, stream::StreamExt};
use std::{collections::HashMap, sync::Arc};
use tokio::sync::{broadcast, RwLock};
use crate::{
services::task_manager::{TaskUpdate, WebSocketMessage},
AppState,
};
type UserConnections = Arc<RwLock<HashMap<i32, Vec<broadcast::Sender<TaskUpdate>>>>>;
pub struct WebSocketManager {
connections: UserConnections,
}
impl WebSocketManager {
pub fn new() -> Self {
Self {
connections: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn add_connection(&self, user_id: i32, sender: broadcast::Sender<TaskUpdate>) {
let mut connections = self.connections.write().await;
connections.entry(user_id).or_insert_with(Vec::new).push(sender);
}
pub async fn remove_connection(&self, user_id: i32, sender: &broadcast::Sender<TaskUpdate>) {
let mut connections = self.connections.write().await;
if let Some(user_connections) = connections.get_mut(&user_id) {
user_connections.retain(|s| !s.same_channel(sender));
if user_connections.is_empty() {
connections.remove(&user_id);
}
}
}
pub async fn broadcast_to_user(&self, user_id: i32, update: TaskUpdate) {
let connections = self.connections.read().await;
if let Some(user_connections) = connections.get(&user_id) {
for sender in user_connections {
let _ = sender.send(update.clone());
}
}
}
}
use serde::Deserialize;
#[derive(Deserialize)]
pub struct WebSocketQuery {
api_key: String,
}
pub async fn task_progress_websocket(
ws: WebSocketUpgrade,
Path(user_id): Path<i32>,
Query(query): Query<WebSocketQuery>,
State(state): State<AppState>,
) -> Response {
// Validate API key before upgrading websocket
match state.db_pool.verify_api_key(&query.api_key).await {
Ok(true) => {
// Verify the API key belongs to this user (or system user for background tasks)
match state.db_pool.get_user_id_from_api_key(&query.api_key).await {
Ok(key_user_id) => {
// Allow access if API key matches the user or if it's the system user (ID 1)
if key_user_id == user_id || key_user_id == 1 {
ws.on_upgrade(move |socket| handle_task_progress_socket(socket, user_id, state))
} else {
tracing::warn!("WebSocket auth failed: API key user {} tried to access user {} tasks", key_user_id, user_id);
axum::response::Response::builder()
.status(403)
.body("Unauthorized - API key does not belong to requested user".into())
.unwrap()
}
}
Err(e) => {
tracing::error!("WebSocket auth error getting user ID from API key: {}", e);
axum::response::Response::builder()
.status(403)
.body("Invalid API key".into())
.unwrap()
}
}
}
Ok(false) | Err(_) => {
tracing::warn!("WebSocket auth failed: Invalid API key");
axum::response::Response::builder()
.status(403)
.body("Invalid API key".into())
.unwrap()
}
}
}
async fn handle_task_progress_socket(socket: WebSocket, user_id: i32, state: AppState) {
let (mut sender, mut receiver) = socket.split();
let (tx, mut rx) = broadcast::channel::<TaskUpdate>(100);
// Add connection to manager
state.websocket_manager.add_connection(user_id, tx.clone()).await;
// Subscribe to task manager updates
let mut task_receiver = state.task_manager.subscribe_to_progress();
// Spawn task to forward task manager updates to user
let tx_clone = tx.clone();
let forward_task = tokio::spawn(async move {
while let Ok(update) = task_receiver.recv().await {
if update.user_id == user_id {
let _ = tx_clone.send(update);
}
}
});
// Send initial task list to newly connected client
let initial_tasks = state.task_manager.get_user_tasks(user_id).await.unwrap_or_default();
let initial_message = WebSocketMessage {
event: "initial".to_string(),
task: None,
tasks: Some(initial_tasks),
};
let initial_json = match serde_json::to_string(&initial_message) {
Ok(json) => json,
Err(_) => "{}".to_string(),
};
let _ = sender.send(Message::Text(initial_json.into())).await;
// Spawn task to send WebSocket messages
let websocket_task = tokio::spawn(async move {
while let Ok(update) = rx.recv().await {
// Wrap the update in the WebSocket event format
let ws_message = WebSocketMessage {
event: "update".to_string(),
task: Some(update),
tasks: None,
};
let message = match serde_json::to_string(&ws_message) {
Ok(json) => Message::Text(json.into()),
Err(_) => continue,
};
if sender.send(message).await.is_err() {
break;
}
}
});
// Handle incoming WebSocket messages (if any)
let ping_task = tokio::spawn(async move {
while let Some(msg) = receiver.next().await {
match msg {
Ok(Message::Text(text)) => {
// Handle ping/pong or other control messages
if text == "ping" {
// Connection is alive, no action needed
}
}
Ok(Message::Close(_)) => break,
Err(_) => break,
_ => {}
}
}
});
// Wait for any task to complete
tokio::select! {
_ = forward_task => {},
_ = websocket_task => {},
_ = ping_task => {},
}
// Clean up connection
state.websocket_manager.remove_connection(user_id, &tx).await;
}
pub async fn get_user_tasks(
Path(user_id): Path<i32>,
State(state): State<AppState>,
) -> Result<axum::Json<Vec<crate::services::task_manager::TaskInfo>>, crate::error::AppError> {
let tasks = state.task_manager.get_user_tasks(user_id).await?;
Ok(axum::Json(tasks))
}
pub async fn get_task_status(
Path(task_id): Path<String>,
State(state): State<AppState>,
) -> Result<axum::Json<crate::services::task_manager::TaskInfo>, crate::error::AppError> {
let task = state.task_manager.get_task(&task_id).await?;
Ok(axum::Json(task))
}
pub async fn get_active_tasks(
Query(params): Query<std::collections::HashMap<String, String>>,
State(state): State<AppState>,
) -> Result<axum::Json<Vec<crate::services::task_manager::TaskInfo>>, crate::error::AppError> {
// Get user_id from query parameter
let user_id: Option<i32> = params.get("user_id")
.and_then(|id| id.parse().ok());
if let Some(user_id) = user_id {
// Get active tasks for specific user
let tasks = state.task_manager.get_user_tasks(user_id).await?;
// Filter only active tasks (status = Running or Pending)
let active_tasks: Vec<_> = tasks.into_iter()
.filter(|task| matches!(task.status, crate::services::task_manager::TaskStatus::Pending | crate::services::task_manager::TaskStatus::Running))
.collect();
Ok(axum::Json(active_tasks))
} else {
// Return empty if no user_id provided
Ok(axum::Json(vec![]))
}
}

View File

@@ -0,0 +1,619 @@
use axum::{
extract::{Query, State},
http::HeaderMap,
response::Json,
};
use serde::{Deserialize, Serialize};
use tokio::process::Command;
use std::collections::{HashMap, HashSet};
use crate::{
error::AppError,
handlers::{extract_api_key, validate_api_key},
AppState,
};
// Query struct for YouTube channel search
#[derive(Deserialize)]
pub struct YouTubeSearchQuery {
pub query: String,
pub max_results: Option<i32>,
pub user_id: i32,
}
// YouTube channel struct for search results - matches Python response exactly
#[derive(Serialize, Debug)]
pub struct YouTubeChannel {
pub channel_id: String,
pub name: String,
pub description: String,
pub subscriber_count: Option<i64>,
pub url: String,
pub video_count: Option<i64>,
pub thumbnail_url: String,
pub recent_videos: Vec<YouTubeVideo>,
}
// YouTube video struct for recent videos in channel - matches Python response exactly
#[derive(Serialize, Debug, Clone)]
pub struct YouTubeVideo {
pub id: String,
pub title: String,
pub duration: Option<f64>, // Note: Python uses float, not i64
pub url: String,
}
// Request struct for YouTube channel subscription
#[derive(Deserialize)]
pub struct YouTubeSubscribeRequest {
pub channel_id: String,
pub user_id: i32,
pub feed_cutoff: Option<i32>,
}
// Query struct for YouTube subscription endpoint
#[derive(Deserialize)]
pub struct YouTubeSubscribeQuery {
pub channel_id: String,
pub user_id: i32,
pub feed_cutoff: Option<i32>,
}
// Query struct for check YouTube channel endpoint
#[derive(Deserialize)]
pub struct CheckYouTubeChannelQuery {
pub user_id: i32,
pub channel_name: String,
pub channel_url: String,
}
// Search YouTube channels - matches Python search_youtube_channels function exactly
pub async fn search_youtube_channels(
State(state): State<AppState>,
Query(query): Query<YouTubeSearchQuery>,
headers: HeaderMap,
) -> Result<Json<serde_json::Value>, AppError> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
// Check authorization - web key or user can only search for themselves
let key_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
let is_web_key = state.db_pool.is_web_key(&api_key).await?;
if key_id != query.user_id && !is_web_key {
return Err(AppError::forbidden("You can only search with your own account."));
}
let max_results = query.max_results.unwrap_or(5);
// First get channel ID using a search - matches Python exactly
let search_url = format!("ytsearch{}:{}", max_results * 4, query.query);
println!("Searching YouTube with query: {}", query.query);
// Use yt-dlp binary to search
let output = Command::new("yt-dlp")
.args(&[
"--quiet",
"--no-warnings",
"--flat-playlist",
"--skip-download",
"--dump-json",
&search_url
])
.output()
.await
.map_err(|e| AppError::external_error(&format!("Failed to execute yt-dlp: {}", e)))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(AppError::external_error(&format!("yt-dlp search failed: {}", stderr)));
}
let stdout = String::from_utf8_lossy(&output.stdout);
// Parse each line as a separate JSON object (yt-dlp outputs one JSON per line for search results)
let mut entries = Vec::new();
for line in stdout.lines() {
if let Ok(entry) = serde_json::from_str::<serde_json::Value>(line) {
entries.push(entry);
}
}
if entries.is_empty() {
return Ok(Json(serde_json::json!({"results": []})));
}
let mut processed_results = Vec::new();
let mut seen_channels = HashSet::new();
let mut channel_videos: HashMap<String, Vec<YouTubeVideo>> = HashMap::new();
// Process entries to collect videos by channel - matches Python logic exactly
for entry in &entries {
if let Some(channel_id) = entry.get("channel_id").and_then(|v| v.as_str())
.or_else(|| entry.get("uploader_id").and_then(|v| v.as_str())) {
// First collect the video regardless of whether we've seen the channel
if !channel_videos.contains_key(channel_id) {
channel_videos.insert(channel_id.to_string(), Vec::new());
}
if let Some(videos) = channel_videos.get_mut(channel_id) {
if videos.len() < 3 { // Limit to 3 videos like Python
if let Some(video_id) = entry.get("id").and_then(|v| v.as_str()) {
let video = YouTubeVideo {
id: video_id.to_string(),
title: entry.get("title").and_then(|v| v.as_str()).unwrap_or("").to_string(),
duration: entry.get("duration").and_then(|v| v.as_f64()),
url: format!("https://www.youtube.com/watch?v={}", video_id),
};
videos.push(video);
println!("Added video to channel {}, now has {} videos", channel_id, videos.len());
}
}
}
}
}
// Now process channels - matches Python logic exactly
for entry in &entries {
if let Some(channel_id) = entry.get("channel_id").and_then(|v| v.as_str())
.or_else(|| entry.get("uploader_id").and_then(|v| v.as_str())) {
// Check if we've already processed this channel
if seen_channels.contains(channel_id) {
continue;
}
seen_channels.insert(channel_id.to_string());
// Get minimal channel info
let channel_url = format!("https://www.youtube.com/channel/{}", channel_id);
// Get thumbnail from search result - much faster than individual channel lookups
let thumbnail_url = entry.get("channel_thumbnail").and_then(|v| v.as_str())
.or_else(|| entry.get("thumbnail").and_then(|v| v.as_str()))
.unwrap_or("").to_string();
let channel_name = entry.get("channel").and_then(|v| v.as_str())
.or_else(|| entry.get("uploader").and_then(|v| v.as_str()))
.unwrap_or("").to_string();
println!("Creating channel {} with {} videos", channel_id,
channel_videos.get(channel_id).map(|v| v.len()).unwrap_or(0));
let channel = YouTubeChannel {
channel_id: channel_id.to_string(),
name: channel_name,
description: entry.get("description").and_then(|v| v.as_str())
.unwrap_or("").chars().take(500).collect::<String>(),
subscriber_count: None, // Always null like Python
url: channel_url,
video_count: None, // Always null like Python
thumbnail_url,
recent_videos: channel_videos.get(channel_id).cloned().unwrap_or_default(),
};
if processed_results.len() < max_results as usize {
processed_results.push(channel);
} else {
break;
}
}
}
println!("Found {} channels", processed_results.len());
Ok(Json(serde_json::json!({"results": processed_results})))
}
// Subscribe to YouTube channel - matches Python subscribe_to_youtube_channel function exactly
pub async fn subscribe_to_youtube_channel(
State(state): State<AppState>,
Query(query): Query<YouTubeSubscribeQuery>,
headers: HeaderMap,
) -> Result<Json<serde_json::Value>, AppError> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
// Check authorization - web key or user can only subscribe for themselves
let key_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
let is_web_key = state.db_pool.is_web_key(&api_key).await?;
if key_id != query.user_id && !is_web_key {
return Err(AppError::forbidden("You can only subscribe for yourself!"));
}
let feed_cutoff = query.feed_cutoff.unwrap_or(30);
println!("Starting subscription for channel {}", query.channel_id);
// Check if channel already exists
let existing_id = state.db_pool.check_existing_channel_subscription(
&query.channel_id,
query.user_id,
).await?;
if let Some(podcast_id) = existing_id {
println!("Channel {} already subscribed", query.channel_id);
return Ok(Json(serde_json::json!({
"success": true,
"podcast_id": podcast_id,
"message": "Already subscribed to this channel"
})));
}
println!("Getting channel info");
let channel_info = get_youtube_channel_info(&query.channel_id).await?;
println!("Adding channel to database");
let podcast_id = state.db_pool.add_youtube_channel(
&channel_info,
query.user_id,
feed_cutoff,
).await?;
// Spawn background task to process YouTube videos
let state_clone = state.clone();
let channel_id_clone = query.channel_id.clone();
tokio::spawn(async move {
if let Err(e) = process_youtube_channel(podcast_id, &channel_id_clone, feed_cutoff, &state_clone).await {
println!("Error processing YouTube channel {}: {}", channel_id_clone, e);
}
});
Ok(Json(serde_json::json!({
"success": true,
"podcast_id": podcast_id,
"message": "Successfully subscribed to YouTube channel"
})))
}
// Helper function to get YouTube channel info using Backend service
pub async fn get_youtube_channel_info(channel_id: &str) -> Result<HashMap<String, String>, AppError> {
println!("Getting channel info for {} from Backend service", channel_id);
// Get Backend URL from environment variable
let search_api_url = std::env::var("SEARCH_API_URL")
.map_err(|_| AppError::external_error("SEARCH_API_URL environment variable not set"))?;
// Replace /api/search with /api/youtube/channel for the channel details endpoint
let backend_url = search_api_url.replace("/api/search", &format!("/api/youtube/channel?id={}", channel_id));
let client = reqwest::Client::new();
let response = client.get(&backend_url)
.send()
.await
.map_err(|e| AppError::external_error(&format!("Failed to call Backend service: {}", e)))?;
if !response.status().is_success() {
return Err(AppError::external_error(&format!("Backend service error: {}", response.status())));
}
let channel_data: serde_json::Value = response.json()
.await
.map_err(|e| AppError::external_error(&format!("Failed to parse Backend response: {}", e)))?;
// Extract channel info from Backend service response
let mut channel_info = HashMap::new();
channel_info.insert("channel_id".to_string(), channel_id.to_string());
channel_info.insert("name".to_string(),
channel_data.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string());
let description = channel_data.get("description")
.and_then(|v| v.as_str())
.unwrap_or("")
.chars()
.take(500)
.collect::<String>();
channel_info.insert("description".to_string(), description);
channel_info.insert("thumbnail_url".to_string(),
channel_data.get("thumbnailUrl").and_then(|v| v.as_str()).unwrap_or("").to_string());
println!("Successfully extracted channel info for: {}", channel_info.get("name").unwrap_or(&"Unknown".to_string()));
Ok(channel_info)
}
// Helper function to get MP3 duration from file
pub fn get_mp3_duration(file_path: &str) -> Option<i32> {
match mp3_metadata::read_from_file(file_path) {
Ok(metadata) => Some(metadata.duration.as_secs() as i32),
Err(e) => {
println!("Failed to read MP3 metadata from {}: {}", file_path, e);
None
}
}
}
// Helper function to parse YouTube duration format (PT4M13S) to seconds
pub fn parse_youtube_duration(duration_str: &str) -> Option<i64> {
if !duration_str.starts_with("PT") {
return None;
}
let duration_part = &duration_str[2..]; // Remove "PT"
let mut total_seconds = 0i64;
let mut current_number = String::new();
for ch in duration_part.chars() {
if ch.is_ascii_digit() {
current_number.push(ch);
} else {
if let Ok(num) = current_number.parse::<i64>() {
match ch {
'H' => total_seconds += num * 3600,
'M' => total_seconds += num * 60,
'S' => total_seconds += num,
_ => {}
}
}
current_number.clear();
}
}
Some(total_seconds)
}
// Process YouTube channel videos using Backend service
pub async fn process_youtube_channel(
podcast_id: i32,
channel_id: &str,
feed_cutoff: i32,
state: &AppState,
) -> Result<(), AppError> {
println!("{}", "=".repeat(50));
println!("Starting YouTube channel processing with Backend service");
println!("Podcast ID: {}", podcast_id);
println!("Channel ID: {}", channel_id);
println!("{}", "=".repeat(50));
let cutoff_date = chrono::Utc::now() - chrono::Duration::days(feed_cutoff as i64);
println!("Cutoff date set to: {}", cutoff_date);
// Clean up old videos
println!("Cleaning up videos older than cutoff date...");
state.db_pool.remove_old_youtube_videos(podcast_id, cutoff_date).await?;
// Get Backend URL from environment variable
let search_api_url = std::env::var("SEARCH_API_URL")
.map_err(|_| AppError::external_error("SEARCH_API_URL environment variable not set"))?;
// Replace /api/search with /api/youtube/channel for the channel details endpoint
let backend_url = search_api_url.replace("/api/search", &format!("/api/youtube/channel?id={}", channel_id));
println!("Fetching channel data from Backend service: {}", backend_url);
// Get video list using Backend service
let client = reqwest::Client::new();
let response = client.get(&backend_url)
.send()
.await
.map_err(|e| AppError::external_error(&format!("Failed to call Backend service: {}", e)))?;
if !response.status().is_success() {
return Err(AppError::external_error(&format!("Backend service error: {}", response.status())));
}
let channel_data: serde_json::Value = response.json()
.await
.map_err(|e| AppError::external_error(&format!("Failed to parse Backend response: {}", e)))?;
let empty_vec = vec![];
let recent_videos_data = channel_data.get("recentVideos")
.and_then(|v| v.as_array())
.unwrap_or(&empty_vec);
println!("Found {} total videos from Backend service", recent_videos_data.len());
let mut recent_videos = Vec::new();
// Process each video from Backend service response
for video_entry in recent_videos_data {
let video_id = video_entry.get("id").and_then(|v| v.as_str()).unwrap_or("");
if video_id.is_empty() {
println!("Skipping video with missing ID");
continue;
}
println!("Processing video ID: {}", video_id);
// Parse the publishedAt date from Backend service
let published_str = video_entry.get("publishedAt").and_then(|v| v.as_str()).unwrap_or("");
let published = chrono::DateTime::parse_from_rfc3339(published_str)
.map(|dt| dt.with_timezone(&chrono::Utc))
.unwrap_or_else(|_| {
println!("Failed to parse date {}, using current time", published_str);
chrono::Utc::now()
});
println!("Video publish date: {}", published);
if published <= cutoff_date {
println!("Video {} from {} is too old, stopping processing", video_id, published);
break;
}
// Debug: print what we got from Backend for this video
println!("Backend video data for {}: {:?}", video_id, video_entry);
let duration_str = video_entry.get("duration").and_then(|v| v.as_str()).unwrap_or("");
println!("Duration string from Backend: '{}'", duration_str);
let parsed_duration = if !duration_str.is_empty() {
parse_youtube_duration(duration_str).unwrap_or(0)
} else {
0
};
println!("Parsed duration: {}", parsed_duration);
let video_data = serde_json::json!({
"id": video_id,
"title": video_entry.get("title").and_then(|v| v.as_str()).unwrap_or(""),
"description": video_entry.get("description").and_then(|v| v.as_str()).unwrap_or(""),
"url": format!("https://www.youtube.com/watch?v={}", video_id),
"thumbnail": video_entry.get("thumbnail").and_then(|v| v.as_str()).unwrap_or(""),
"publish_date": published.to_rfc3339(),
"duration": duration_str // Store as string for proper parsing in database
});
println!("Successfully added video {} to processing queue", video_id);
recent_videos.push(video_data);
}
println!("Processing complete - Found {} recent videos", recent_videos.len());
if !recent_videos.is_empty() {
println!("Starting database updates");
// Get existing videos
let existing_videos = state.db_pool.get_existing_youtube_videos(podcast_id).await?;
// Filter out videos that already exist
let mut new_videos = Vec::new();
for video in &recent_videos {
let video_url = format!("https://www.youtube.com/watch?v={}",
video.get("id").and_then(|v| v.as_str()).unwrap_or(""));
if !existing_videos.contains(&video_url) {
new_videos.push(video.clone());
} else {
println!("Video already exists, skipping: {}",
video.get("title").and_then(|v| v.as_str()).unwrap_or(""));
}
}
if !new_videos.is_empty() {
state.db_pool.add_youtube_videos(podcast_id, &new_videos).await?;
println!("Successfully added {} new videos", new_videos.len());
} else {
println!("No new videos to add");
}
// Download audio for recent videos
println!("Starting audio downloads");
let mut successful_downloads = 0;
let mut failed_downloads = 0;
for video in &recent_videos {
let video_id = video.get("id").and_then(|v| v.as_str()).unwrap_or("");
let title = video.get("title").and_then(|v| v.as_str()).unwrap_or("");
let output_path = format!("/opt/pinepods/downloads/youtube/{}.mp3", video_id);
let output_path_double = format!("{}.mp3", output_path);
println!("Processing download for video: {}", video_id);
println!("Title: {}", title);
println!("Target path: {}", output_path);
// Check if file already exists
if tokio::fs::metadata(&output_path).await.is_ok() ||
tokio::fs::metadata(&output_path_double).await.is_ok() {
println!("Audio file already exists, skipping download");
continue;
}
println!("Starting download...");
match download_youtube_audio(video_id, &output_path).await {
Ok(_) => {
println!("Download completed successfully");
successful_downloads += 1;
// Get duration from the downloaded MP3 file and update database
if let Some(duration) = get_mp3_duration(&output_path) {
if let Err(e) = state.db_pool.update_youtube_video_duration(video_id, duration).await {
println!("Failed to update duration for video {}: {}", video_id, e);
} else {
println!("Updated duration for video {} to {} seconds", video_id, duration);
}
} else {
println!("Could not read duration from MP3 file: {}", output_path);
}
}
Err(e) => {
failed_downloads += 1;
let error_msg = e.to_string();
if error_msg.to_lowercase().contains("members-only") {
println!("Skipping video {} - Members-only content: {}", video_id, title);
} else if error_msg.to_lowercase().contains("private") {
println!("Skipping video {} - Private video: {}", video_id, title);
} else if error_msg.to_lowercase().contains("unavailable") {
println!("Skipping video {} - Unavailable video: {}", video_id, title);
} else {
println!("Failed to download video {}: {}", video_id, title);
println!("Error: {}", error_msg);
}
}
}
}
println!("Download summary: {} successful, {} failed", successful_downloads, failed_downloads);
} else {
println!("No new videos to process");
}
// Update episode count
state.db_pool.update_episode_count(podcast_id).await?;
println!("{}", "=".repeat(50));
println!("Channel processing complete");
println!("{}", "=".repeat(50));
Ok(())
}
// Download YouTube audio using yt-dlp binary
pub async fn download_youtube_audio(video_id: &str, output_path: &str) -> Result<(), AppError> {
// Remove .mp3 extension if present to prevent double extension
let base_path = if output_path.ends_with(".mp3") {
&output_path[..output_path.len() - 4]
} else {
output_path
};
let video_url = format!("https://www.youtube.com/watch?v={}", video_id);
let output = Command::new("yt-dlp")
.args(&[
"--format", "bestaudio/best",
"--extract-audio",
"--audio-format", "mp3",
"--output", base_path,
"--ignore-errors",
"--socket-timeout", "30",
&video_url
])
.output()
.await
.map_err(|e| AppError::external_error(&format!("Failed to execute yt-dlp: {}", e)))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(AppError::external_error(&format!("yt-dlp download failed: {}", stderr)));
}
Ok(())
}
// Check if YouTube channel exists - matches Python api_check_youtube_channel function exactly
pub async fn check_youtube_channel(
State(state): State<AppState>,
Query(query): Query<CheckYouTubeChannelQuery>,
headers: HeaderMap,
) -> Result<Json<serde_json::Value>, AppError> {
let api_key = extract_api_key(&headers)?;
validate_api_key(&state, &api_key).await?;
// Check authorization - web key or user can only check for themselves
let key_id = state.db_pool.get_user_id_from_api_key(&api_key).await?;
let is_web_key = state.db_pool.is_web_key(&api_key).await?;
if key_id != query.user_id && !is_web_key {
return Err(AppError::forbidden("You can only check channels for yourself!"));
}
let exists = state.db_pool.check_youtube_channel(
query.user_id,
&query.channel_name,
&query.channel_url,
).await?;
Ok(Json(serde_json::json!({ "exists": exists })))
}

View File

@@ -0,0 +1,460 @@
use axum::{
routing::{delete, get, post, put},
Router,
};
use std::net::SocketAddr;
use tokio::signal;
use tower::ServiceBuilder;
use tower_http::{
trace::TraceLayer,
compression::CompressionLayer,
};
use tracing::{info, warn, error};
mod config;
mod database;
mod error;
mod handlers;
mod models;
mod redis_client;
mod redis_manager;
mod services;
use config::Config;
use database::DatabasePool;
use error::AppResult;
use redis_client::RedisClient;
use services::{scheduler::BackgroundScheduler, task_manager::TaskManager, tasks::TaskSpawner};
use handlers::websocket::WebSocketManager;
use redis_manager::{ImportProgressManager, NotificationManager};
use std::sync::Arc;
#[derive(Clone)]
pub struct AppState {
pub db_pool: DatabasePool,
pub redis_client: RedisClient,
pub config: Config,
pub task_manager: Arc<TaskManager>,
pub task_spawner: Arc<TaskSpawner>,
pub websocket_manager: Arc<WebSocketManager>,
pub import_progress_manager: Arc<ImportProgressManager>,
pub notification_manager: Arc<NotificationManager>,
}
#[tokio::main]
async fn main() -> AppResult<()> {
// Initialize tracing with explicit level if RUST_LOG is not set
let env_filter = if std::env::var("RUST_LOG").is_ok() {
tracing_subscriber::EnvFilter::from_default_env()
} else {
tracing_subscriber::EnvFilter::new("info")
};
tracing_subscriber::fmt()
.with_env_filter(env_filter)
.init();
println!("🚀 Starting PinePods Rust API...");
info!("Starting PinePods Rust API");
// Load configuration
let config = Config::new()?;
info!("Configuration loaded");
info!("Database config: host={}, port={}, user={}, db={}, type={}",
config.database.host, config.database.port, config.database.username,
config.database.name, config.database.db_type);
// Initialize database pool
let db_pool = DatabasePool::new(&config).await?;
info!("Database pool initialized");
// Initialize Redis client
let redis_client = RedisClient::new(&config).await?;
info!("Redis/Valkey client initialized");
// Initialize task management
let task_manager = Arc::new(TaskManager::new(redis_client.clone()));
let task_spawner = Arc::new(TaskSpawner::new(task_manager.clone(), db_pool.clone()));
let websocket_manager = Arc::new(WebSocketManager::new());
let import_progress_manager = Arc::new(ImportProgressManager::new(redis_client.clone()));
let notification_manager = Arc::new(NotificationManager::new(redis_client.clone()));
info!("Task management system initialized");
// Create shared application state
let app_state = AppState {
db_pool,
redis_client,
config: config.clone(),
task_manager,
task_spawner,
websocket_manager,
import_progress_manager,
notification_manager,
};
// Build the application with routes
let app = create_app(app_state.clone());
// Initialize and start background scheduler
info!("🕒 Initializing background task scheduler...");
let scheduler = BackgroundScheduler::new().await?;
let scheduler_state = Arc::new(app_state.clone());
// Start the scheduler with background tasks
scheduler.start(scheduler_state.clone()).await?;
// Run initial startup tasks immediately
tokio::spawn({
let startup_state = scheduler_state.clone();
async move {
if let Err(e) = BackgroundScheduler::run_startup_tasks(startup_state).await {
error!("❌ Startup tasks failed: {}", e);
}
}
});
// Determine the address to bind to
let addr = SocketAddr::from(([0, 0, 0, 0], config.server.port));
println!("🌐 PinePods Rust API listening on http://{}", addr);
println!("📊 Health check available at: http://{}/api/health", addr);
println!("🔍 API check available at: http://{}/api/pinepods_check", addr);
info!("Server listening on {}", addr);
// Start the server
let listener = tokio::net::TcpListener::bind(addr).await?;
println!("✅ PinePods Rust API server started successfully!");
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await?;
Ok(())
}
fn create_app(state: AppState) -> Router {
Router::new()
// Health check endpoints
.route("/api/pinepods_check", get(handlers::health::pinepods_check))
.route("/api/health", get(handlers::health::health_check))
// API routes (to be implemented)
.nest("/api/data", create_data_routes())
.nest("/api/init", create_init_routes())
.nest("/api/podcasts", create_podcast_routes())
.nest("/api/episodes", create_episode_routes())
.nest("/api/playlists", create_playlist_routes())
.nest("/api/tasks", create_task_routes())
.nest("/api/async", create_async_routes())
.nest("/api/proxy", create_proxy_routes())
.nest("/api/gpodder", create_gpodder_routes())
.nest("/api/feed", create_feed_routes())
.nest("/api/auth", create_auth_routes())
.nest("/ws", create_websocket_routes())
// Middleware stack
.layer(
ServiceBuilder::new()
.layer(
TraceLayer::new_for_http()
.make_span_with(tower_http::trace::DefaultMakeSpan::new()
.level(tracing::Level::INFO))
.on_response(tower_http::trace::DefaultOnResponse::new()
.level(tracing::Level::INFO))
)
.layer(CompressionLayer::new())
.layer(axum::extract::DefaultBodyLimit::max(2 * 1024 * 1024 * 1024)) // 2GB limit for massive backup files
)
.with_state(state)
}
fn create_data_routes() -> Router<AppState> {
Router::new()
.route("/get_key", get(handlers::auth::get_key))
.route("/verify_mfa_and_get_key", post(handlers::auth::verify_mfa_and_get_key))
.route("/verify_key", get(handlers::auth::verify_api_key_endpoint))
.route("/get_user", get(handlers::auth::get_user))
.route("/user_details_id/{user_id}", get(handlers::auth::get_user_details_by_id))
.route("/self_service_status", get(handlers::auth::get_self_service_status))
.route("/public_oidc_providers", get(handlers::auth::get_public_oidc_providers))
.route("/create_first", post(handlers::auth::create_first_admin))
.route("/config", get(handlers::auth::get_config))
.route("/first_login_done/{user_id}", get(handlers::auth::first_login_done))
.route("/get_theme/{user_id}", get(handlers::auth::get_theme))
.route("/setup_time_info", post(handlers::auth::setup_time_info))
.route("/update_timezone", put(handlers::auth::update_timezone))
.route("/update_date_format", put(handlers::auth::update_date_format))
.route("/update_time_format", put(handlers::auth::update_time_format))
.route("/get_auto_complete_seconds/{user_id}", get(handlers::auth::get_auto_complete_seconds))
.route("/update_auto_complete_seconds", put(handlers::auth::update_auto_complete_seconds))
.route("/user_admin_check/{user_id}", get(handlers::auth::user_admin_check))
.route("/import_opml", post(handlers::auth::import_opml))
.route("/import_progress/{user_id}", get(handlers::auth::import_progress))
.route("/return_episodes/{user_id}", get(handlers::podcasts::return_episodes))
.route("/user_history/{user_id}", get(handlers::podcasts::user_history))
.route("/increment_listen_time/{user_id}", put(handlers::podcasts::increment_listen_time))
.route("/get_playback_speed", post(handlers::podcasts::get_playback_speed))
.route("/add_podcast", post(handlers::podcasts::add_podcast))
.route("/update_podcast_info", put(handlers::podcasts::update_podcast_info))
.route("/{podcast_id}/merge", post(handlers::podcasts::merge_podcasts))
.route("/{podcast_id}/unmerge/{target_podcast_id}", post(handlers::podcasts::unmerge_podcast))
.route("/{podcast_id}/merged", get(handlers::podcasts::get_merged_podcasts))
.route("/remove_podcast", post(handlers::podcasts::remove_podcast))
.route("/remove_podcast_id", post(handlers::podcasts::remove_podcast_id))
.route("/remove_podcast_name", post(handlers::podcasts::remove_podcast_by_name))
.route("/return_pods/{user_id}", get(handlers::podcasts::return_pods))
.route("/return_pods_extra/{user_id}", get(handlers::podcasts::return_pods_extra))
.route("/get_time_info", get(handlers::podcasts::get_time_info))
.route("/check_podcast", get(handlers::podcasts::check_podcast))
.route("/check_episode_in_db/{user_id}", get(handlers::podcasts::check_episode_in_db))
.route("/queue_pod", post(handlers::podcasts::queue_episode))
.route("/remove_queued_pod", post(handlers::podcasts::remove_queued_episode))
.route("/get_queued_episodes", get(handlers::podcasts::get_queued_episodes))
.route("/reorder_queue", post(handlers::podcasts::reorder_queue))
.route("/save_episode", post(handlers::podcasts::save_episode))
.route("/remove_saved_episode", post(handlers::podcasts::remove_saved_episode))
.route("/saved_episode_list/{user_id}", get(handlers::podcasts::get_saved_episodes))
.route("/record_podcast_history", post(handlers::podcasts::add_history))
.route("/get_podcast_id", get(handlers::podcasts::get_podcast_id))
.route("/download_episode_list", get(handlers::podcasts::download_episode_list))
.route("/download_podcast", post(handlers::podcasts::download_podcast))
.route("/delete_episode", post(handlers::podcasts::delete_episode))
.route("/download_all_podcast", post(handlers::podcasts::download_all_podcast))
.route("/download_status/{user_id}", get(handlers::podcasts::download_status))
.route("/podcast_episodes", get(handlers::podcasts::podcast_episodes))
.route("/get_podcast_id_from_ep_name", get(handlers::podcasts::get_podcast_id_from_ep_name))
.route("/get_episode_id_ep_name", get(handlers::podcasts::get_episode_id_ep_name))
.route("/get_episode_metadata", post(handlers::podcasts::get_episode_metadata))
.route("/fetch_podcasting_2_data", get(handlers::podcasts::fetch_podcasting_2_data))
.route("/get_auto_download_status", post(handlers::podcasts::get_auto_download_status))
.route("/get_feed_cutoff_days", get(handlers::podcasts::get_feed_cutoff_days))
.route("/get_play_episode_details", post(handlers::podcasts::get_play_episode_details))
.route("/fetch_podcasting_2_pod_data", get(handlers::podcasts::fetch_podcasting_2_pod_data))
.route("/mark_episode_completed", post(handlers::podcasts::mark_episode_completed))
.route("/update_episode_duration", post(handlers::podcasts::update_episode_duration))
// Bulk episode operations
.route("/bulk_mark_episodes_completed", post(handlers::episodes::bulk_mark_episodes_completed))
.route("/bulk_save_episodes", post(handlers::episodes::bulk_save_episodes))
.route("/bulk_queue_episodes", post(handlers::episodes::bulk_queue_episodes))
.route("/bulk_download_episodes", post(handlers::episodes::bulk_download_episodes))
.route("/bulk_delete_downloaded_episodes", post(handlers::episodes::bulk_delete_downloaded_episodes))
.route("/share_episode/{episode_id}", post(handlers::episodes::share_episode))
.route("/episode_by_url/{url_key}", get(handlers::episodes::get_episode_by_url_key))
.route("/increment_played/{user_id}", put(handlers::podcasts::increment_played))
.route("/record_listen_duration", post(handlers::podcasts::record_listen_duration))
.route("/get_podcast_id_from_ep_id", get(handlers::podcasts::get_podcast_id_from_ep_id))
.route("/get_stats", get(handlers::podcasts::get_stats))
.route("/get_pinepods_version", get(handlers::podcasts::get_pinepods_version))
.route("/search_data", post(handlers::podcasts::search_data))
.route("/fetch_transcript", post(handlers::podcasts::fetch_transcript))
.route("/home_overview", get(handlers::podcasts::home_overview))
.route("/get_playlists", get(handlers::podcasts::get_playlists))
.route("/get_playlist_episodes", get(handlers::podcasts::get_playlist_episodes))
.route("/create_playlist", post(handlers::playlists::create_playlist))
.route("/delete_playlist", delete(handlers::playlists::delete_playlist))
.route("/get_podcast_details", get(handlers::podcasts::get_podcast_details))
.route("/get_podcast_details_dynamic", get(handlers::podcasts::get_podcast_details_dynamic))
.route("/podpeople/host_podcasts", get(handlers::podcasts::get_host_podcasts))
.route("/update_feed_cutoff_days", post(handlers::podcasts::update_feed_cutoff_days))
.route("/fetch_podcast_feed", get(handlers::podcasts::fetch_podcast_feed))
.route("/youtube_episodes", get(handlers::podcasts::youtube_episodes))
.route("/remove_youtube_channel", post(handlers::podcasts::remove_youtube_channel))
.route("/stream/{episode_id}", get(handlers::podcasts::stream_episode))
.route("/get_rss_key", get(handlers::podcasts::get_rss_key))
.route("/mark_episode_uncompleted", post(handlers::podcasts::mark_episode_uncompleted))
.route("/user/set_theme", put(handlers::settings::set_theme))
.route("/get_user_info", get(handlers::settings::get_user_info))
.route("/my_user_info/{user_id}", get(handlers::settings::get_my_user_info))
.route("/add_user", post(handlers::settings::add_user))
.route("/add_login_user", post(handlers::settings::add_login_user))
.route("/set_fullname/{user_id}", put(handlers::settings::set_fullname))
.route("/set_password/{user_id}", put(handlers::settings::set_password))
.route("/user/delete/{user_id}", delete(handlers::settings::delete_user))
.route("/user/set_email", put(handlers::settings::set_email))
.route("/user/set_username", put(handlers::settings::set_username))
.route("/user/set_isadmin", put(handlers::settings::set_isadmin))
.route("/user/final_admin/{user_id}", get(handlers::settings::final_admin))
.route("/enable_disable_guest", post(handlers::settings::enable_disable_guest))
.route("/enable_disable_downloads", post(handlers::settings::enable_disable_downloads))
.route("/enable_disable_self_service", post(handlers::settings::enable_disable_self_service))
.route("/guest_status", get(handlers::settings::guest_status))
.route("/rss_feed_status", get(handlers::settings::rss_feed_status))
.route("/toggle_rss_feeds", post(handlers::settings::toggle_rss_feeds))
.route("/download_status", get(handlers::settings::download_status))
.route("/admin_self_service_status", get(handlers::settings::self_service_status))
.route("/save_email_settings", post(handlers::settings::save_email_settings))
.route("/get_email_settings", get(handlers::settings::get_email_settings))
.route("/send_test_email", post(handlers::settings::send_test_email))
.route("/send_email", post(handlers::settings::send_email))
.route("/reset_password_create_code", post(handlers::auth::reset_password_create_code))
.route("/verify_and_reset_password", post(handlers::auth::verify_and_reset_password))
.route("/get_api_info/{user_id}", get(handlers::settings::get_api_info))
.route("/create_api_key", post(handlers::settings::create_api_key))
.route("/delete_api_key", delete(handlers::settings::delete_api_key))
.route("/backup_user", post(handlers::settings::backup_user))
.route("/backup_server", post(handlers::settings::backup_server))
.route("/restore_server", post(handlers::settings::restore_server))
.route("/generate_mfa_secret/{user_id}", get(handlers::settings::generate_mfa_secret))
.route("/verify_temp_mfa", post(handlers::settings::verify_temp_mfa))
.route("/check_mfa_enabled/{user_id}", get(handlers::settings::check_mfa_enabled))
.route("/save_mfa_secret", post(handlers::settings::save_mfa_secret))
.route("/delete_mfa", delete(handlers::settings::delete_mfa))
.route("/initiate_nextcloud_login", post(handlers::settings::initiate_nextcloud_login))
.route("/add_nextcloud_server", post(handlers::settings::add_nextcloud_server))
.route("/verify_gpodder_auth", post(handlers::settings::verify_gpodder_auth))
.route("/add_gpodder_server", post(handlers::settings::add_gpodder_server))
.route("/get_gpodder_settings/{user_id}", get(handlers::settings::get_gpodder_settings))
.route("/check_gpodder_settings/{user_id}", get(handlers::settings::check_gpodder_settings))
.route("/remove_podcast_sync", delete(handlers::settings::remove_podcast_sync))
.route("/gpodder/status", get(handlers::sync::gpodder_status))
.route("/gpodder/toggle", post(handlers::sync::gpodder_toggle))
.route("/refresh_pods", get(handlers::refresh::refresh_pods_admin))
.route("/refresh_gpodder_subscriptions", get(handlers::refresh::refresh_gpodder_subscriptions_admin))
.route("/refresh_nextcloud_subscriptions", get(handlers::refresh::refresh_nextcloud_subscriptions_admin))
.route("/refresh_hosts", get(handlers::tasks::refresh_hosts))
.route("/cleanup_tasks", get(handlers::tasks::cleanup_tasks))
.route("/auto_complete_episodes", get(handlers::tasks::auto_complete_episodes))
.route("/update_playlists", get(handlers::tasks::update_playlists))
.route("/add_custom_podcast", post(handlers::settings::add_custom_podcast))
.route("/user/notification_settings", get(handlers::settings::get_notification_settings))
.route("/user/notification_settings", put(handlers::settings::update_notification_settings))
.route("/user/set_playback_speed", post(handlers::settings::set_playback_speed_user))
.route("/user/set_global_podcast_cover_preference", post(handlers::settings::set_global_podcast_cover_preference))
.route("/user/get_podcast_cover_preference", get(handlers::settings::get_global_podcast_cover_preference))
.route("/user/test_notification", post(handlers::settings::test_notification))
.route("/add_oidc_provider", post(handlers::settings::add_oidc_provider))
.route("/update_oidc_provider/{provider_id}", put(handlers::settings::update_oidc_provider))
.route("/list_oidc_providers", get(handlers::settings::list_oidc_providers))
.route("/remove_oidc_provider", post(handlers::settings::remove_oidc_provider))
.route("/startpage", get(handlers::settings::get_startpage))
.route("/startpage", post(handlers::settings::update_startpage))
.route("/person/subscribe/{user_id}/{person_id}", post(handlers::settings::subscribe_to_person))
.route("/person/unsubscribe/{user_id}/{person_id}", delete(handlers::settings::unsubscribe_from_person))
.route("/person/subscriptions/{user_id}", get(handlers::settings::get_person_subscriptions))
.route("/person/episodes/{user_id}/{person_id}", get(handlers::settings::get_person_episodes))
.route("/search_youtube_channels", get(handlers::youtube::search_youtube_channels))
.route("/youtube/subscribe", post(handlers::youtube::subscribe_to_youtube_channel))
.route("/check_youtube_channel", get(handlers::youtube::check_youtube_channel))
.route("/enable_auto_download", post(handlers::settings::enable_auto_download))
.route("/adjust_skip_times", post(handlers::settings::adjust_skip_times))
.route("/remove_category", post(handlers::settings::remove_category))
.route("/add_category", post(handlers::settings::add_category))
.route("/podcast/set_playback_speed", post(handlers::settings::set_podcast_playback_speed))
.route("/podcast/set_cover_preference", post(handlers::settings::set_podcast_cover_preference))
.route("/podcast/clear_cover_preference", post(handlers::settings::clear_podcast_cover_preference))
.route("/podcast/toggle_notifications", put(handlers::settings::toggle_podcast_notifications))
.route("/podcast/notification_status", post(handlers::podcasts::get_notification_status))
.route("/rss_key", get(handlers::settings::get_user_rss_key))
.route("/verify_mfa", post(handlers::settings::verify_mfa))
.route("/schedule_backup", post(handlers::settings::schedule_backup))
.route("/get_scheduled_backup", post(handlers::settings::get_scheduled_backup))
.route("/list_backup_files", post(handlers::settings::list_backup_files))
.route("/restore_backup_file", post(handlers::settings::restore_from_backup_file))
.route("/manual_backup_to_directory", post(handlers::settings::manual_backup_to_directory))
.route("/get_unmatched_podcasts", post(handlers::settings::get_unmatched_podcasts))
.route("/update_podcast_index_id", post(handlers::settings::update_podcast_index_id))
.route("/ignore_podcast_index_id", post(handlers::settings::ignore_podcast_index_id))
.route("/get_ignored_podcasts", post(handlers::settings::get_ignored_podcasts))
// Language preference endpoints
.route("/get_user_language", get(handlers::settings::get_user_language))
.route("/update_user_language", put(handlers::settings::update_user_language))
.route("/get_available_languages", get(handlers::settings::get_available_languages))
.route("/get_server_default_language", get(handlers::settings::get_server_default_language))
// Add more data routes as needed
}
fn create_podcast_routes() -> Router<AppState> {
Router::new()
.route("/notification_status", post(handlers::podcasts::get_notification_status))
}
fn create_episode_routes() -> Router<AppState> {
Router::new()
.route("/{episode_id}/download", get(handlers::episodes::download_episode_file))
}
fn create_playlist_routes() -> Router<AppState> {
Router::new()
// Add playlist routes as needed
}
fn create_task_routes() -> Router<AppState> {
Router::new()
.route("/user/{user_id}", get(handlers::websocket::get_user_tasks))
.route("/active", get(handlers::websocket::get_active_tasks))
.route("/{task_id}", get(handlers::websocket::get_task_status))
}
fn create_async_routes() -> Router<AppState> {
Router::new()
// .route("/download_episode", post(handlers::tasks::download_episode))
// .route("/import_opml", post(handlers::tasks::import_opml))
// .route("/refresh_feeds", post(handlers::tasks::refresh_all_feeds))
// .route("/episode/{episode_id}/metadata", get(handlers::tasks::quick_metadata_fetch))
}
fn create_proxy_routes() -> Router<AppState> {
Router::new()
.route("/image", get(handlers::proxy::proxy_image))
}
fn create_gpodder_routes() -> Router<AppState> {
Router::new()
.route("/test-connection", get(handlers::sync::gpodder_test_connection))
.route("/set_default/{device_id}", post(handlers::sync::gpodder_set_default))
.route("/devices/{user_id}", get(handlers::sync::gpodder_get_user_devices))
.route("/devices", get(handlers::sync::gpodder_get_all_devices))
.route("/default_device", get(handlers::sync::gpodder_get_default_device))
.route("/devices", post(handlers::sync::gpodder_create_device))
.route("/sync/force", post(handlers::sync::gpodder_force_sync))
.route("/sync", post(handlers::sync::gpodder_sync))
.route("/gpodder_statistics", get(handlers::sync::gpodder_get_statistics))
}
fn create_init_routes() -> Router<AppState> {
Router::new()
.route("/startup_tasks", post(handlers::tasks::startup_tasks))
}
fn create_feed_routes() -> Router<AppState> {
Router::new()
.route("/{user_id}", get(handlers::feed::get_user_feed))
}
fn create_websocket_routes() -> Router<AppState> {
Router::new()
.route("/api/tasks/{user_id}", get(handlers::websocket::task_progress_websocket))
.route("/api/data/episodes/{user_id}", get(handlers::refresh::websocket_refresh_episodes))
}
fn create_auth_routes() -> Router<AppState> {
Router::new()
.route("/store_state", post(handlers::auth::store_oidc_state))
.route("/callback", get(handlers::auth::oidc_callback))
}
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
warn!("Received Ctrl+C, shutting down gracefully");
},
_ = terminate => {
warn!("Received SIGTERM, shutting down gracefully");
},
}
}

View File

@@ -0,0 +1,528 @@
use serde::{Deserialize, Serialize};
use chrono::{DateTime, Utc};
// Response models to match Python API
#[derive(Debug, Serialize, Deserialize)]
pub struct ApiResponse<T> {
pub status_code: u16,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<T>,
}
impl<T> ApiResponse<T> {
pub fn success(data: T) -> Self {
Self {
status_code: 200,
message: None,
data: Some(data),
}
}
pub fn success_with_message(data: T, message: String) -> Self {
Self {
status_code: 200,
message: Some(message),
data: Some(data),
}
}
pub fn error(status_code: u16, message: String) -> ApiResponse<()> {
ApiResponse {
status_code,
message: Some(message),
data: None,
}
}
}
// PinePods check response
#[derive(Debug, Serialize, Deserialize)]
pub struct PinepodsCheckResponse {
pub status_code: u16,
pub pinepods_instance: bool,
}
// Health check response
#[derive(Debug, Serialize, Deserialize)]
pub struct HealthResponse {
pub status: String,
pub database: bool,
pub redis: bool,
pub timestamp: DateTime<Utc>,
}
// Authentication models
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginRequest {
pub username: String,
pub password: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct LoginResponse {
pub status: String,
pub user_id: Option<i32>,
pub api_key: Option<String>,
pub message: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ApiKeyValidationResponse {
pub status: String,
}
// User models
#[derive(Debug, Serialize, Deserialize)]
pub struct User {
pub user_id: i32,
pub username: String,
pub email: Option<String>,
pub is_admin: bool,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct UserSettings {
pub user_id: i32,
pub theme: String,
pub auto_download_episodes: bool,
pub auto_delete_episodes: bool,
pub download_location: Option<String>,
}
// Podcast models
#[derive(Debug, Serialize, Deserialize)]
pub struct Podcast {
pub podcast_id: i32,
pub podcast_name: String,
pub feed_url: String,
pub artwork_url: Option<String>,
pub author: Option<String>,
pub description: Option<String>,
pub website_url: Option<String>,
pub explicit: bool,
pub episode_count: i32,
pub categories: Option<String>,
pub user_id: i32,
pub auto_download: bool,
pub date_created: DateTime<Utc>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Episode {
pub episode_id: i32,
pub podcast_id: i32,
pub episode_title: String,
pub episode_description: Option<String>,
pub episode_url: Option<String>,
pub episode_artwork: Option<String>,
pub episode_pub_date: DateTime<Utc>,
pub episode_duration: i32,
pub completed: bool,
pub listen_duration: i32,
pub downloaded: bool,
pub saved: bool,
}
// Playlist models
#[derive(Debug, Serialize, Deserialize)]
pub struct Playlist {
pub playlist_id: i32,
pub user_id: i32,
pub name: String,
pub description: Option<String>,
pub is_system_playlist: bool,
pub episode_count: i32,
pub created_at: DateTime<Utc>,
pub last_updated: DateTime<Utc>,
}
// Request models
#[derive(Debug, Deserialize)]
pub struct CreatePodcastRequest {
pub feed_url: String,
pub auto_download: Option<bool>,
}
#[derive(Debug, Deserialize)]
pub struct UpdateEpisodeRequest {
pub listen_duration: Option<i32>,
pub completed: Option<bool>,
pub saved: Option<bool>,
}
#[derive(Debug, Deserialize)]
pub struct CreatePlaylistRequest {
pub user_id: i32,
pub name: String,
pub description: Option<String>,
pub podcast_ids: Option<Vec<i32>>,
pub include_unplayed: bool,
pub include_partially_played: bool,
pub include_played: bool,
pub play_progress_min: Option<f64>,
pub play_progress_max: Option<f64>,
pub time_filter_hours: Option<i32>,
pub min_duration: Option<i32>,
pub max_duration: Option<i32>,
pub sort_order: String,
pub group_by_podcast: bool,
pub max_episodes: Option<i32>,
pub icon_name: String,
}
#[derive(Debug, Serialize)]
pub struct CreatePlaylistResponse {
pub detail: String,
pub playlist_id: i32,
}
#[derive(Debug, Deserialize)]
pub struct DeletePlaylistRequest {
pub user_id: i32,
pub playlist_id: i32,
}
#[derive(Debug, Serialize)]
pub struct DeletePlaylistResponse {
pub detail: String,
}
// Search models
#[derive(Debug, Serialize, Deserialize)]
pub struct SearchRequest {
pub query: String,
pub search_type: Option<String>, // "podcasts", "episodes", "all"
pub limit: Option<i32>,
pub offset: Option<i32>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SearchResult {
pub podcasts: Vec<Podcast>,
pub episodes: Vec<Episode>,
pub total_count: i32,
}
// Statistics models
#[derive(Debug, Serialize, Deserialize)]
pub struct UserStats {
pub total_podcasts: i32,
pub total_episodes: i32,
pub total_listen_time: i32,
pub completed_episodes: i32,
pub saved_episodes: i32,
pub downloaded_episodes: i32,
}
// Language models
#[derive(Debug, Serialize, Deserialize)]
pub struct AvailableLanguage {
pub code: String,
pub name: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct LanguageUpdateRequest {
pub user_id: i32,
pub language: String,
}
#[derive(Debug, Serialize)]
pub struct UserLanguageResponse {
pub language: String,
}
#[derive(Debug, Serialize)]
pub struct AvailableLanguagesResponse {
pub languages: Vec<AvailableLanguage>,
}
// API-specific podcast models to match Python responses
#[derive(Debug, Serialize, Deserialize)]
pub struct PodcastResponse {
pub podcastid: i32,
pub podcastname: String,
pub artworkurl: Option<String>,
pub description: Option<String>,
pub episodecount: Option<i32>,
pub websiteurl: Option<String>,
pub feedurl: String,
pub author: Option<String>,
pub categories: Option<std::collections::HashMap<String, String>>,
pub explicit: bool,
pub podcastindexid: Option<i64>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PodcastExtraResponse {
pub podcastid: i32,
pub podcastname: String,
pub artworkurl: Option<String>,
pub description: Option<String>,
pub episodecount: Option<i32>,
pub websiteurl: Option<String>,
pub feedurl: String,
pub author: Option<String>,
pub categories: Option<std::collections::HashMap<String, String>>,
pub explicit: bool,
pub podcastindexid: Option<i64>,
pub play_count: i64,
pub episodes_played: i32,
pub oldest_episode_date: Option<String>,
pub is_youtube: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PodcastListResponse {
pub pods: Vec<PodcastResponse>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PodcastExtraListResponse {
pub pods: Vec<PodcastExtraResponse>,
}
// Remove podcast request model
#[derive(Debug, Deserialize)]
pub struct RemovePodcastByNameRequest {
pub user_id: i32,
pub podcast_name: String,
pub podcast_url: String,
}
// Time info response model
#[derive(Debug, Serialize, Deserialize)]
pub struct TimeInfoResponse {
pub timezone: String,
pub hour_pref: i32,
pub date_format: Option<String>,
}
// Check podcast response model
#[derive(Debug, Serialize, Deserialize)]
pub struct CheckPodcastResponse {
pub exists: bool,
}
// Check episode in database response model
#[derive(Debug, Serialize, Deserialize)]
pub struct EpisodeInDbResponse {
pub episode_in_db: bool,
}
// Queue-related models
#[derive(Debug, Deserialize)]
pub struct QueuePodcastRequest {
pub episode_id: i32,
pub user_id: i32,
pub is_youtube: bool,
}
// Saved episodes models
#[derive(Debug, Deserialize)]
pub struct SavePodcastRequest {
pub episode_id: i32,
pub user_id: i32,
pub is_youtube: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SavedEpisode {
pub episodetitle: String,
pub podcastname: String,
pub episodepubdate: String,
pub episodedescription: String,
pub episodeartwork: String,
pub episodeurl: String,
pub episodeduration: i32,
pub listenduration: Option<i32>,
pub episodeid: i32,
pub websiteurl: String,
pub completed: bool,
pub saved: bool,
pub queued: bool,
pub downloaded: bool,
pub is_youtube: bool,
pub podcastid: Option<i32>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct SavedEpisodesResponse {
pub saved_episodes: Vec<SavedEpisode>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PlaylistInfo {
pub name: String,
pub description: String,
pub episode_count: i32,
pub icon_name: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PlaylistEpisodesResponse {
pub episodes: Vec<SavedEpisode>,
pub playlist_info: PlaylistInfo,
}
#[derive(Debug, Serialize)]
pub struct SaveEpisodeResponse {
pub detail: String,
}
// History models
#[derive(Debug, Deserialize)]
pub struct HistoryAddRequest {
pub episode_id: i32,
pub episode_pos: f32,
pub user_id: i32,
pub is_youtube: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct HistoryEpisode {
pub episodetitle: String,
pub podcastname: String,
pub episodepubdate: String,
pub episodedescription: String,
pub episodeartwork: String,
pub episodeurl: String,
pub episodeduration: i32,
pub listenduration: Option<i32>,
pub episodeid: i32,
pub completed: bool,
pub listendate: Option<String>,
pub is_youtube: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct UserHistoryResponse {
pub data: Vec<HistoryEpisode>,
}
#[derive(Debug, Serialize)]
pub struct HistoryResponse {
pub detail: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct QueueResponse {
pub data: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct QueuedEpisode {
pub episodetitle: String,
pub podcastname: String,
pub episodepubdate: String,
pub episodedescription: String,
pub episodeartwork: String,
pub episodeurl: String,
pub queueposition: Option<i32>,
pub episodeduration: i32,
pub queuedate: String,
pub listenduration: Option<i32>,
pub episodeid: i32,
pub completed: bool,
pub saved: bool,
pub queued: bool,
pub downloaded: bool,
pub is_youtube: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct QueuedEpisodesResponse {
pub data: Vec<QueuedEpisode>,
}
#[derive(Debug, Deserialize)]
pub struct ReorderQueueRequest {
pub episode_ids: Vec<i32>,
}
#[derive(Debug, Serialize)]
pub struct ReorderQueueResponse {
pub message: String,
}
// Bulk episode action models - flexible episode ID lists
#[derive(Debug, Deserialize)]
pub struct BulkEpisodeActionRequest {
pub episode_ids: Vec<i32>,
pub user_id: i32,
pub is_youtube: Option<bool>,
}
#[derive(Debug, Serialize)]
pub struct BulkEpisodeActionResponse {
pub message: String,
pub processed_count: i32,
pub failed_count: Option<i32>,
}
// Background task models
#[derive(Debug, Serialize, Deserialize)]
pub struct TaskStatus {
pub task_id: String,
pub status: String,
pub progress: Option<f32>,
pub message: Option<String>,
pub created_at: DateTime<Utc>,
}
// Import/Export models
#[derive(Debug, Serialize, Deserialize)]
pub struct OpmlImportRequest {
pub opml_content: String,
pub auto_download: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ImportProgress {
pub total_feeds: i32,
pub processed_feeds: i32,
pub successful_imports: i32,
pub failed_imports: i32,
pub current_feed: Option<String>,
}
// Pagination models
#[derive(Debug, Deserialize)]
pub struct PaginationParams {
pub page: Option<i32>,
pub per_page: Option<i32>,
}
impl Default for PaginationParams {
fn default() -> Self {
Self {
page: Some(1),
per_page: Some(50),
}
}
}
#[derive(Debug, Serialize)]
pub struct PaginatedResponse<T> {
pub data: Vec<T>,
pub total_count: i32,
pub page: i32,
pub per_page: i32,
pub total_pages: i32,
}
impl<T> PaginatedResponse<T> {
pub fn new(data: Vec<T>, total_count: i32, page: i32, per_page: i32) -> Self {
let total_pages = (total_count + per_page - 1) / per_page; // Ceiling division
Self {
data,
total_count,
page,
per_page,
total_pages,
}
}
}

View File

@@ -0,0 +1,169 @@
use redis::{aio::MultiplexedConnection, AsyncCommands, Client};
use crate::{config::Config, error::AppResult};
#[derive(Clone)]
pub struct RedisClient {
connection: MultiplexedConnection,
}
impl RedisClient {
pub async fn new(config: &Config) -> AppResult<Self> {
let redis_url = config.redis_url();
let client = Client::open(redis_url)?;
let connection = client.get_multiplexed_async_connection().await?;
// Test the connection
let mut conn = connection.clone();
let _: String = redis::cmd("PING").query_async(&mut conn).await?;
tracing::info!("Successfully connected to Redis/Valkey");
Ok(RedisClient {
connection,
})
}
pub async fn health_check(&self) -> AppResult<bool> {
let mut conn = self.connection.clone();
let result: String = redis::cmd("PING").query_async(&mut conn).await?;
Ok(result == "PONG")
}
pub async fn get<T>(&self, key: &str) -> AppResult<Option<T>>
where
T: redis::FromRedisValue,
{
let mut conn = self.connection.clone();
let result: Option<T> = conn.get(key).await?;
Ok(result)
}
pub async fn set<T>(&self, key: &str, value: T) -> AppResult<()>
where
T: redis::ToRedisArgs + Send + Sync,
{
let mut conn = self.connection.clone();
let _: () = conn.set(key, value).await?;
Ok(())
}
pub async fn set_ex<T>(&self, key: &str, value: T, seconds: u64) -> AppResult<()>
where
T: redis::ToRedisArgs + Send + Sync,
{
let mut conn = self.connection.clone();
let _: () = conn.set_ex(key, value, seconds).await?;
Ok(())
}
pub async fn delete(&self, key: &str) -> AppResult<bool> {
let mut conn = self.connection.clone();
let result: bool = conn.del(key).await?;
Ok(result)
}
pub async fn exists(&self, key: &str) -> AppResult<bool> {
let mut conn = self.connection.clone();
let result: bool = conn.exists(key).await?;
Ok(result)
}
pub async fn expire(&self, key: &str, seconds: u64) -> AppResult<bool> {
let mut conn = self.connection.clone();
let result: bool = conn.expire(key, seconds as i64).await?;
Ok(result)
}
pub async fn incr(&self, key: &str) -> AppResult<i64> {
let mut conn = self.connection.clone();
let result: i64 = conn.incr(key, 1).await?;
Ok(result)
}
pub async fn decr(&self, key: &str) -> AppResult<i64> {
let mut conn = self.connection.clone();
let result: i64 = conn.decr(key, 1).await?;
Ok(result)
}
// Session management
pub async fn store_session(&self, session_id: &str, user_id: i32, ttl_seconds: u64) -> AppResult<()> {
let session_key = format!("session:{}", session_id);
self.set_ex(&session_key, user_id, ttl_seconds).await
}
pub async fn get_session(&self, session_id: &str) -> AppResult<Option<i32>> {
let session_key = format!("session:{}", session_id);
self.get(&session_key).await
}
pub async fn delete_session(&self, session_id: &str) -> AppResult<bool> {
let session_key = format!("session:{}", session_id);
self.delete(&session_key).await
}
// API key caching
pub async fn cache_api_key_validation(&self, api_key: &str, is_valid: bool, ttl_seconds: u64) -> AppResult<()> {
let cache_key = format!("api_key:{}", api_key);
self.set_ex(&cache_key, is_valid, ttl_seconds).await
}
pub async fn get_cached_api_key_validation(&self, api_key: &str) -> AppResult<Option<bool>> {
let cache_key = format!("api_key:{}", api_key);
self.get(&cache_key).await
}
// Rate limiting
pub async fn check_rate_limit(&self, identifier: &str, limit: u32, window_seconds: u64) -> AppResult<bool> {
let rate_key = format!("rate_limit:{}", identifier);
let mut conn = self.connection.clone();
let current_count: i64 = conn.incr(&rate_key, 1).await?;
if current_count == 1 {
let _: () = conn.expire(&rate_key, window_seconds as i64).await?;
}
Ok(current_count <= limit as i64)
}
// Background task tracking
pub async fn store_task_status(&self, task_id: &str, status: &str, ttl_seconds: u64) -> AppResult<()> {
let task_key = format!("task:{}", task_id);
self.set_ex(&task_key, status, ttl_seconds).await
}
pub async fn get_task_status(&self, task_id: &str) -> AppResult<Option<String>> {
let task_key = format!("task:{}", task_id);
self.get(&task_key).await
}
// Podcast refresh tracking
pub async fn set_podcast_refreshing(&self, podcast_id: i32) -> AppResult<()> {
let refresh_key = format!("refreshing:{}", podcast_id);
self.set_ex(&refresh_key, true, 300).await // 5 minute timeout
}
pub async fn is_podcast_refreshing(&self, podcast_id: i32) -> AppResult<bool> {
let refresh_key = format!("refreshing:{}", podcast_id);
Ok(self.exists(&refresh_key).await.unwrap_or(false))
}
pub async fn clear_podcast_refreshing(&self, podcast_id: i32) -> AppResult<bool> {
let refresh_key = format!("refreshing:{}", podcast_id);
self.delete(&refresh_key).await
}
// Atomic get and delete operation - critical for OIDC state management
pub async fn get_del(&self, key: &str) -> AppResult<Option<String>> {
let mut conn = self.connection.clone();
let result: Option<String> = redis::cmd("GETDEL").arg(key).query_async(&mut conn).await?;
Ok(result)
}
// Get a connection for direct Redis operations
pub async fn get_connection(&self) -> AppResult<MultiplexedConnection> {
Ok(self.connection.clone())
}
}

View File

@@ -0,0 +1,249 @@
use serde_json::Value;
use crate::{error::AppResult, redis_client::RedisClient};
pub struct ImportProgressManager {
redis_client: RedisClient,
}
impl ImportProgressManager {
pub fn new(redis_client: RedisClient) -> Self {
Self { redis_client }
}
// Start import progress tracking - matches Python ImportProgressManager.start_import
pub async fn start_import(&self, user_id: i32, total_podcasts: i32) -> AppResult<()> {
let progress_data = serde_json::json!({
"current": 0,
"total": total_podcasts,
"current_podcast": ""
});
let key = format!("import_progress:{}", user_id);
self.redis_client.set_ex(&key, &progress_data.to_string(), 3600).await?;
Ok(())
}
// Update import progress - matches Python ImportProgressManager.update_progress
pub async fn update_progress(&self, user_id: i32, current: i32, current_podcast: &str) -> AppResult<()> {
let key = format!("import_progress:{}", user_id);
// Get current progress
if let Some(progress_json) = self.redis_client.get::<String>(&key).await? {
if let Ok(mut progress) = serde_json::from_str::<Value>(&progress_json) {
progress["current"] = serde_json::Value::Number(serde_json::Number::from(current));
progress["current_podcast"] = serde_json::Value::String(current_podcast.to_string());
self.redis_client.set_ex(&key, &progress.to_string(), 3600).await?;
}
}
Ok(())
}
// Get import progress - matches Python ImportProgressManager.get_progress
pub async fn get_progress(&self, user_id: i32) -> AppResult<(i32, i32, String)> {
let key = format!("import_progress:{}", user_id);
if let Some(progress_json) = self.redis_client.get::<String>(&key).await? {
if let Ok(progress) = serde_json::from_str::<Value>(&progress_json) {
let current = progress.get("current").and_then(|v| v.as_i64()).unwrap_or(0) as i32;
let total = progress.get("total").and_then(|v| v.as_i64()).unwrap_or(0) as i32;
let current_podcast = progress.get("current_podcast").and_then(|v| v.as_str()).unwrap_or("").to_string();
return Ok((current, total, current_podcast));
}
}
Ok((0, 0, "".to_string()))
}
// Clear import progress - matches Python ImportProgressManager.clear_progress
pub async fn clear_progress(&self, user_id: i32) -> AppResult<()> {
let key = format!("import_progress:{}", user_id);
self.redis_client.delete(&key).await?;
Ok(())
}
}
// Notification manager for sending test notifications
pub struct NotificationManager {
redis_client: RedisClient,
}
impl NotificationManager {
pub fn new(redis_client: RedisClient) -> Self {
Self { redis_client }
}
// Send test notification - matches Python notification functionality
pub async fn send_test_notification(&self, user_id: i32, platform: &str, settings: &serde_json::Value) -> AppResult<bool> {
println!("Sending test notification for user {} on platform {}", user_id, platform);
match platform {
"ntfy" => self.send_ntfy_notification(settings).await,
"gotify" => self.send_gotify_notification(settings).await,
"http" => self.send_http_notification(settings).await,
_ => {
println!("Unsupported notification platform: {}", platform);
Ok(false)
}
}
}
async fn send_ntfy_notification(&self, settings: &serde_json::Value) -> AppResult<bool> {
let topic = settings.get("ntfy_topic").and_then(|v| v.as_str()).unwrap_or("");
let server_url = settings.get("ntfy_server_url").and_then(|v| v.as_str()).unwrap_or("https://ntfy.sh");
let username = settings.get("ntfy_username").and_then(|v| v.as_str());
let password = settings.get("ntfy_password").and_then(|v| v.as_str());
let access_token = settings.get("ntfy_access_token").and_then(|v| v.as_str());
if topic.is_empty() {
return Ok(false);
}
let client = reqwest::Client::new();
let url = format!("{}/{}", server_url, topic);
let mut request = client
.post(&url)
.header("Content-Type", "text/plain")
.body("Test notification from PinePods");
// Add authentication if provided
if let Some(token) = access_token.filter(|t| !t.is_empty()) {
// Use access token (preferred method)
request = request.header("Authorization", format!("Bearer {}", token));
} else if let (Some(user), Some(pass)) = (username.filter(|u| !u.is_empty()), password.filter(|p| !p.is_empty())) {
// Use username/password basic auth
request = request.basic_auth(user, Some(pass));
}
let response = request.send().await?;
let status = response.status();
let is_success = status.is_success();
if !is_success {
let response_text = response.text().await.unwrap_or_default();
println!("Ntfy notification failed with status: {} - Response: {}",
status, response_text);
}
Ok(is_success)
}
async fn send_gotify_notification(&self, settings: &serde_json::Value) -> AppResult<bool> {
let gotify_url = settings.get("gotify_url").and_then(|v| v.as_str()).unwrap_or("");
let gotify_token = settings.get("gotify_token").and_then(|v| v.as_str()).unwrap_or("");
if gotify_url.is_empty() || gotify_token.is_empty() {
return Ok(false);
}
let client = reqwest::Client::new();
let url = format!("{}/message?token={}", gotify_url, gotify_token);
let payload = serde_json::json!({
"title": "PinePods Test",
"message": "Test notification from PinePods",
"priority": 5
});
let response = client
.post(&url)
.header("Content-Type", "application/json")
.json(&payload)
.send()
.await?;
Ok(response.status().is_success())
}
async fn send_http_notification(&self, settings: &serde_json::Value) -> AppResult<bool> {
let http_url = settings.get("http_url").and_then(|v| v.as_str()).unwrap_or("");
let http_token = settings.get("http_token").and_then(|v| v.as_str()).unwrap_or("");
let http_method = settings.get("http_method").and_then(|v| v.as_str()).unwrap_or("POST");
if http_url.is_empty() {
println!("HTTP URL is empty, cannot send notification");
return Ok(false);
}
let client = reqwest::Client::new();
// Build the request based on method
let request_builder = match http_method.to_uppercase().as_str() {
"GET" => {
// For GET requests, add message as query parameter
let url_with_params = if http_url.contains('?') {
format!("{}&message={}", http_url, urlencoding::encode("Test notification from PinePods"))
} else {
format!("{}?message={}", http_url, urlencoding::encode("Test notification from PinePods"))
};
client.get(&url_with_params)
},
"POST" | _ => {
// For POST requests, send JSON payload
let payload = if http_url.contains("api.telegram.org") {
// Special handling for Telegram Bot API
let chat_id = if let Some(chat_id_str) = http_token.split(':').nth(1) {
// Extract chat_id from token if it contains chat_id (format: bot_token:chat_id)
chat_id_str
} else {
// Default chat_id - user needs to configure this properly
"YOUR_CHAT_ID"
};
serde_json::json!({
"chat_id": chat_id,
"text": "Test notification from PinePods"
})
} else {
// Generic JSON payload
serde_json::json!({
"title": "PinePods Test",
"message": "Test notification from PinePods",
"text": "Test notification from PinePods"
})
};
client.post(http_url)
.header("Content-Type", "application/json")
.json(&payload)
}
};
// Add authorization header if token is provided
let request_builder = if !http_token.is_empty() {
if http_url.contains("api.telegram.org") {
// For Telegram, token goes in URL path, not header
request_builder
} else {
// For other services, add as Bearer token
request_builder.header("Authorization", format!("Bearer {}", http_token))
}
} else {
request_builder
};
match request_builder.send().await {
Ok(response) => {
let status = response.status();
let is_success = status.is_success();
if !is_success {
let response_text = response.text().await.unwrap_or_default();
println!("HTTP notification failed with status: {} - Response: {}",
status, response_text);
}
Ok(is_success)
},
Err(e) => {
println!("HTTP notification request failed: {}", e);
Ok(false)
}
}
}
}

View File

@@ -0,0 +1,15 @@
use argon2::{Argon2, PasswordHash, PasswordVerifier};
use crate::error::{AppError, AppResult};
/// Verify password using Argon2 - matches Python's passlib CryptContext with argon2
pub fn verify_password(password: &str, stored_hash: &str) -> AppResult<bool> {
let argon2 = Argon2::default();
let parsed_hash = PasswordHash::new(stored_hash)
.map_err(|e| AppError::Auth(format!("Invalid password hash format: {}", e)))?;
match argon2.verify_password(password.as_bytes(), &parsed_hash) {
Ok(()) => Ok(true),
Err(_) => Ok(false),
}
}

View File

@@ -0,0 +1,7 @@
pub mod auth;
pub mod podcast;
pub mod scheduler;
pub mod task_manager;
pub mod tasks;
// Common service utilities and shared functionality

View File

@@ -0,0 +1,395 @@
use crate::{error::AppResult, AppState, database::DatabasePool};
use crate::handlers::refresh::PodcastForRefresh;
use tracing::{info, warn, error};
use serde_json::Value;
use sqlx::Row;
/// Podcast refresh service - matches Python's refresh_pods_for_user function exactly
pub async fn refresh_podcast(state: &AppState, podcast_id: i32) -> AppResult<Vec<Value>> {
// Check if already refreshing
if state.redis_client.is_podcast_refreshing(podcast_id).await? {
return Ok(vec![]);
}
// Mark as refreshing
state.redis_client.set_podcast_refreshing(podcast_id).await?;
let result = refresh_podcast_internal(&state.db_pool, podcast_id).await;
// Clear refreshing flag
state.redis_client.clear_podcast_refreshing(podcast_id).await?;
result
}
/// Internal refresh logic - matches Python refresh_pods_for_user function
async fn refresh_podcast_internal(db_pool: &DatabasePool, podcast_id: i32) -> AppResult<Vec<Value>> {
info!("Refresh begin for podcast {}", podcast_id);
// Get podcast details from database
let podcast_info = get_podcast_for_refresh(db_pool, podcast_id).await?;
if let Some(podcast) = podcast_info {
info!("Processing podcast: {}", podcast_id);
if podcast.is_youtube {
// Handle YouTube channel refresh
refresh_youtube_channel(db_pool, podcast_id, &podcast.feed_url, podcast.feed_cutoff_days.unwrap_or(0)).await?;
Ok(vec![])
} else {
// Handle regular RSS podcast refresh
let episodes = db_pool.add_episodes(
podcast_id,
&podcast.feed_url,
podcast.artwork_url.as_deref().unwrap_or(""),
podcast.auto_download,
podcast.username.as_deref(),
podcast.password.as_deref(),
).await?;
// Convert episodes to JSON format for WebSocket response
let episode_json = episodes.map(|_| vec![]).unwrap_or_default();
Ok(episode_json)
}
} else {
warn!("Podcast {} not found", podcast_id);
Ok(vec![])
}
}
/// Refresh all podcasts - matches Python refresh_pods function exactly
pub async fn refresh_all_podcasts(state: &AppState) -> AppResult<()> {
println!("🚀 Starting refresh process for all podcasts");
// Get all podcasts from database
let podcasts = get_all_podcasts_for_refresh(&state.db_pool).await?;
println!("📊 Found {} podcasts to refresh", podcasts.len());
let mut successful_refreshes = 0;
let mut failed_refreshes = 0;
for podcast in podcasts {
match refresh_single_podcast(&state.db_pool, &podcast).await {
Ok(_) => {
successful_refreshes += 1;
}
Err(e) => {
failed_refreshes += 1;
println!("❌ Error refreshing podcast '{}' (ID: {}): {}", podcast.name, podcast.id, e);
}
}
}
println!("🎯 Refresh process completed: {} successful, {} failed", successful_refreshes, failed_refreshes);
Ok(())
}
/// Refresh a single podcast - matches Python refresh logic
async fn refresh_single_podcast(db_pool: &DatabasePool, podcast: &PodcastForRefresh) -> AppResult<()> {
println!("🔄 Starting refresh for podcast '{}' (ID: {})", podcast.name, podcast.id);
// Count episodes before refresh
let episodes_before = match db_pool {
crate::database::DatabasePool::Postgres(pool) => {
sqlx::query_scalar(r#"SELECT COUNT(*) FROM "Episodes" WHERE podcastid = $1"#)
.bind(podcast.id)
.fetch_one(pool)
.await.unwrap_or(0)
}
crate::database::DatabasePool::MySQL(pool) => {
sqlx::query_scalar("SELECT COUNT(*) FROM Episodes WHERE PodcastID = ?")
.bind(podcast.id)
.fetch_one(pool)
.await.unwrap_or(0)
}
};
if podcast.is_youtube {
// Handle YouTube channel
refresh_youtube_channel(db_pool, podcast.id, &podcast.feed_url, podcast.feed_cutoff_days.unwrap_or(0)).await?;
} else {
// Handle regular RSS podcast
db_pool.add_episodes(
podcast.id,
&podcast.feed_url,
podcast.artwork_url.as_deref().unwrap_or(""),
podcast.auto_download,
podcast.username.as_deref(),
podcast.password.as_deref(),
).await?;
}
// Count episodes after refresh
let episodes_after: i64 = match db_pool {
crate::database::DatabasePool::Postgres(pool) => {
sqlx::query_scalar(r#"SELECT COUNT(*) FROM "Episodes" WHERE podcastid = $1"#)
.bind(podcast.id)
.fetch_one(pool)
.await.unwrap_or(0)
}
crate::database::DatabasePool::MySQL(pool) => {
sqlx::query_scalar("SELECT COUNT(*) FROM Episodes WHERE PodcastID = ?")
.bind(podcast.id)
.fetch_one(pool)
.await.unwrap_or(0)
}
};
let new_episodes = episodes_after - episodes_before;
if new_episodes > 0 {
println!("✅ Completed refresh for podcast '{}' - added {} new episodes", podcast.name, new_episodes);
} else {
println!("✅ Completed refresh for podcast '{}' - no new episodes found", podcast.name);
}
Ok(())
}
/// Handle YouTube channel refresh - matches Python YouTube processing
async fn refresh_youtube_channel(db_pool: &DatabasePool, podcast_id: i32, feed_url: &str, feed_cutoff_days: i32) -> AppResult<()> {
// Extract channel ID from feed URL
let channel_id = if feed_url.contains("channel/") {
feed_url.split("channel/").nth(1).unwrap_or(feed_url)
} else {
feed_url
};
// Clean up any trailing slashes or query parameters
let channel_id = channel_id.split('/').next().unwrap_or(channel_id);
let channel_id = channel_id.split('?').next().unwrap_or(channel_id);
info!("Processing YouTube channel: {} for podcast {}", channel_id, podcast_id);
// TODO: Implement YouTube video processing
// This would match the Python youtube.process_youtube_videos function
// For now, we'll just log that it's not implemented
warn!("YouTube channel refresh not yet implemented for channel: {}", channel_id);
Ok(())
}
/// Get podcast details for refresh - matches Python select_podcast query
async fn get_podcast_for_refresh(db_pool: &DatabasePool, podcast_id: i32) -> AppResult<Option<PodcastForRefresh>> {
match db_pool {
DatabasePool::Postgres(pool) => {
let row = sqlx::query(
r#"SELECT
PodcastID, FeedURL, ArtworkURL, AutoDownload, Username, Password,
IsYouTubeChannel, UserID, COALESCE(FeedURL, '') as channel_id, FeedCutoffDays
FROM "Podcasts"
WHERE PodcastID = $1"#
)
.bind(podcast_id)
.fetch_optional(pool)
.await?;
if let Some(row) = row {
Ok(Some(PodcastForRefresh {
id: row.try_get("PodcastID")?,
name: "".to_string(), // Not needed for refresh
feed_url: row.try_get("FeedURL")?,
artwork_url: row.try_get::<Option<String>, _>("ArtworkURL").unwrap_or_default(),
auto_download: row.try_get("AutoDownload")?,
username: row.try_get("Username").ok(),
password: row.try_get("Password").ok(),
is_youtube: row.try_get("IsYouTubeChannel")?,
user_id: row.try_get("UserID")?,
feed_cutoff_days: row.try_get("FeedCutoffDays").ok(),
}))
} else {
Ok(None)
}
}
DatabasePool::MySQL(pool) => {
let row = sqlx::query(
"SELECT
PodcastID, FeedURL, ArtworkURL, AutoDownload, Username, Password,
IsYouTubeChannel, UserID, COALESCE(FeedURL, '') as channel_id, FeedCutoffDays
FROM Podcasts
WHERE PodcastID = ?"
)
.bind(podcast_id)
.fetch_optional(pool)
.await?;
if let Some(row) = row {
Ok(Some(PodcastForRefresh {
id: row.try_get("PodcastID")?,
name: "".to_string(), // Not needed for refresh
feed_url: row.try_get("FeedURL")?,
artwork_url: row.try_get::<Option<String>, _>("ArtworkURL").unwrap_or_default(),
auto_download: row.try_get("AutoDownload")?,
username: row.try_get("Username").ok(),
password: row.try_get("Password").ok(),
is_youtube: row.try_get("IsYouTubeChannel")?,
user_id: row.try_get("UserID")?,
feed_cutoff_days: row.try_get("FeedCutoffDays").ok(),
}))
} else {
Ok(None)
}
}
}
}
/// Get all podcasts for refresh - matches Python select_podcasts query
async fn get_all_podcasts_for_refresh(db_pool: &DatabasePool) -> AppResult<Vec<PodcastForRefresh>> {
match db_pool {
DatabasePool::Postgres(pool) => {
let rows = sqlx::query(
r#"SELECT
PodcastID, FeedURL, ArtworkURL, AutoDownload, Username, Password,
IsYouTubeChannel, UserID, COALESCE(FeedURL, '') as channel_id, FeedCutoffDays
FROM "Podcasts""#
)
.fetch_all(pool)
.await?;
let mut podcasts = Vec::new();
for row in rows {
podcasts.push(PodcastForRefresh {
id: row.try_get("PodcastID")?,
name: "".to_string(), // Not needed for refresh
feed_url: row.try_get("FeedURL")?,
artwork_url: row.try_get::<Option<String>, _>("ArtworkURL").unwrap_or_default(),
auto_download: row.try_get("AutoDownload")?,
username: row.try_get("Username").ok(),
password: row.try_get("Password").ok(),
is_youtube: row.try_get("IsYouTubeChannel")?,
user_id: row.try_get("UserID")?,
feed_cutoff_days: row.try_get("FeedCutoffDays").ok(),
});
}
Ok(podcasts)
}
DatabasePool::MySQL(pool) => {
let rows = sqlx::query(
"SELECT
PodcastID, FeedURL, ArtworkURL, AutoDownload, Username, Password,
IsYouTubeChannel, UserID, COALESCE(FeedURL, '') as channel_id, FeedCutoffDays
FROM Podcasts"
)
.fetch_all(pool)
.await?;
let mut podcasts = Vec::new();
for row in rows {
podcasts.push(PodcastForRefresh {
id: row.try_get("PodcastID")?,
name: "".to_string(), // Not needed for refresh
feed_url: row.try_get("FeedURL")?,
artwork_url: row.try_get::<Option<String>, _>("ArtworkURL").unwrap_or_default(),
auto_download: row.try_get("AutoDownload")?,
username: row.try_get("Username").ok(),
password: row.try_get("Password").ok(),
is_youtube: row.try_get("IsYouTubeChannel")?,
user_id: row.try_get("UserID")?,
feed_cutoff_days: row.try_get("FeedCutoffDays").ok(),
});
}
Ok(podcasts)
}
}
}
/// Remove unavailable episodes - matches Python remove_unavailable_episodes function
pub async fn remove_unavailable_episodes(db_pool: &DatabasePool) -> AppResult<()> {
info!("Starting removal of unavailable episodes");
// Get all episodes from database
let episodes = get_all_episodes_for_check(db_pool).await?;
let client = reqwest::Client::new();
for episode in episodes {
// Check if episode URL is still valid
match client.head(&episode.url).send().await {
Ok(response) => {
if response.status().as_u16() == 404 {
// Remove episode from database
info!("Removing unavailable episode: {}", episode.id);
remove_episode_from_database(db_pool, episode.id).await?;
}
}
Err(e) => {
error!("Error checking episode {}: {}", episode.id, e);
}
}
}
Ok(())
}
/// Get all episodes for availability check
async fn get_all_episodes_for_check(db_pool: &DatabasePool) -> AppResult<Vec<EpisodeForCheck>> {
match db_pool {
DatabasePool::Postgres(pool) => {
let rows = sqlx::query(
r#"SELECT EpisodeID, PodcastID, EpisodeTitle, EpisodeURL, EpisodePubDate FROM "Episodes""#
)
.fetch_all(pool)
.await?;
let mut episodes = Vec::new();
for row in rows {
episodes.push(EpisodeForCheck {
id: row.try_get("EpisodeID")?,
podcast_id: row.try_get("PodcastID")?,
title: row.try_get("EpisodeTitle")?,
url: row.try_get("EpisodeURL")?,
pub_date: row.try_get("EpisodePubDate")?,
});
}
Ok(episodes)
}
DatabasePool::MySQL(pool) => {
let rows = sqlx::query(
"SELECT EpisodeID, PodcastID, EpisodeTitle, EpisodeURL, EpisodePubDate FROM Episodes"
)
.fetch_all(pool)
.await?;
let mut episodes = Vec::new();
for row in rows {
episodes.push(EpisodeForCheck {
id: row.try_get("EpisodeID")?,
podcast_id: row.try_get("PodcastID")?,
title: row.try_get("EpisodeTitle")?,
url: row.try_get("EpisodeURL")?,
pub_date: row.try_get("EpisodePubDate")?,
});
}
Ok(episodes)
}
}
}
/// Remove episode from database
async fn remove_episode_from_database(db_pool: &DatabasePool, episode_id: i32) -> AppResult<()> {
match db_pool {
DatabasePool::Postgres(pool) => {
sqlx::query(r#"DELETE FROM "Episodes" WHERE EpisodeID = $1"#)
.bind(episode_id)
.execute(pool)
.await?;
}
DatabasePool::MySQL(pool) => {
sqlx::query("DELETE FROM Episodes WHERE EpisodeID = ?")
.bind(episode_id)
.execute(pool)
.await?;
}
}
Ok(())
}
/// Episode data structure for availability check
#[derive(Debug, Clone)]
pub struct EpisodeForCheck {
pub id: i32,
pub podcast_id: i32,
pub title: String,
pub url: String,
pub pub_date: sqlx::types::chrono::DateTime<sqlx::types::chrono::Utc>,
}

View File

@@ -0,0 +1,157 @@
use crate::{
error::AppResult,
handlers::{refresh, tasks},
AppState,
};
use std::sync::Arc;
use tokio_cron_scheduler::{Job, JobScheduler};
use tracing::{info, error, warn};
pub struct BackgroundScheduler {
scheduler: JobScheduler,
}
impl BackgroundScheduler {
pub async fn new() -> AppResult<Self> {
let scheduler = JobScheduler::new().await?;
Ok(Self { scheduler })
}
pub async fn start(&self, app_state: Arc<AppState>) -> AppResult<()> {
info!("🕒 Starting background task scheduler...");
// Schedule podcast refresh every 30 minutes
let refresh_state = app_state.clone();
let refresh_job = Job::new_async("0 */30 * * * *", move |_uuid, _l| {
let state = refresh_state.clone();
Box::pin(async move {
info!("🔄 Running scheduled podcast refresh");
if let Err(e) = Self::run_refresh_pods(state.clone()).await {
error!("❌ Scheduled podcast refresh failed: {}", e);
} else {
info!("✅ Scheduled podcast refresh completed");
}
})
})?;
// Schedule nightly tasks at midnight
let nightly_state = app_state.clone();
let nightly_job = Job::new_async("0 0 0 * * *", move |_uuid, _l| {
let state = nightly_state.clone();
Box::pin(async move {
info!("🌙 Running scheduled nightly tasks");
if let Err(e) = Self::run_nightly_tasks(state.clone()).await {
error!("❌ Scheduled nightly tasks failed: {}", e);
} else {
info!("✅ Scheduled nightly tasks completed");
}
})
})?;
// Schedule cleanup tasks every 6 hours
let cleanup_state = app_state.clone();
let cleanup_job = Job::new_async("0 0 */6 * * *", move |_uuid, _l| {
let state = cleanup_state.clone();
Box::pin(async move {
info!("🧹 Running scheduled cleanup tasks");
if let Err(e) = Self::run_cleanup_tasks(state.clone()).await {
error!("❌ Scheduled cleanup tasks failed: {}", e);
} else {
info!("✅ Scheduled cleanup tasks completed");
}
})
})?;
// Add jobs to scheduler
self.scheduler.add(refresh_job).await?;
self.scheduler.add(nightly_job).await?;
self.scheduler.add(cleanup_job).await?;
// Start the scheduler
self.scheduler.start().await?;
info!("✅ Background task scheduler started successfully");
Ok(())
}
// Direct function calls instead of HTTP requests
async fn run_refresh_pods(state: Arc<AppState>) -> AppResult<()> {
// Call refresh_pods function directly
match refresh::refresh_pods_admin_internal(&state).await {
Ok(_) => {
info!("✅ Podcast refresh completed");
// Also run gpodder sync
if let Err(e) = refresh::refresh_gpodder_subscriptions_admin_internal(&state).await {
warn!("⚠️ GPodder sync failed during scheduled refresh: {}", e);
}
// Also run nextcloud sync
if let Err(e) = refresh::refresh_nextcloud_subscriptions_admin_internal(&state).await {
warn!("⚠️ Nextcloud sync failed during scheduled refresh: {}", e);
}
// Update playlist episode counts (replaces complex playlist content updates)
if let Err(e) = state.db_pool.update_playlist_episode_counts().await {
warn!("⚠️ Playlist episode count update failed during scheduled refresh: {}", e);
}
}
Err(e) => {
error!("❌ Podcast refresh failed: {}", e);
return Err(e);
}
}
Ok(())
}
async fn run_nightly_tasks(state: Arc<AppState>) -> AppResult<()> {
// Call nightly tasks directly
if let Err(e) = tasks::refresh_hosts_internal(&state).await {
warn!("⚠️ Refresh hosts failed during nightly tasks: {}", e);
}
if let Err(e) = tasks::auto_complete_episodes_internal(&state).await {
warn!("⚠️ Auto complete episodes failed during nightly tasks: {}", e);
}
info!("✅ Nightly tasks completed");
Ok(())
}
async fn run_cleanup_tasks(state: Arc<AppState>) -> AppResult<()> {
// Call cleanup tasks directly
match tasks::cleanup_tasks_internal(&state).await {
Ok(_) => {
info!("✅ Cleanup tasks completed");
}
Err(e) => {
error!("❌ Cleanup tasks failed: {}", e);
return Err(e);
}
}
Ok(())
}
// Run initial startup tasks immediately
pub async fn run_startup_tasks(state: Arc<AppState>) -> AppResult<()> {
info!("🚀 Running initial startup tasks...");
// Initialize OIDC provider from environment variables if configured
if let Err(e) = state.db_pool.init_oidc_from_env(&state.config.oidc).await {
warn!("⚠️ OIDC initialization failed: {}", e);
}
// Create missing default playlists for existing users
if let Err(e) = state.db_pool.create_missing_default_playlists().await {
warn!("⚠️ Creating missing default playlists failed: {}", e);
}
// Run an immediate refresh to ensure data is current on startup
if let Err(e) = Self::run_refresh_pods(state.clone()).await {
warn!("⚠️ Initial startup refresh failed: {}", e);
}
info!("✅ Startup tasks completed");
Ok(())
}
}

View File

@@ -0,0 +1,312 @@
use crate::{error::AppResult, redis_client::RedisClient};
use redis::AsyncCommands;
use serde::{Deserialize, Serialize};
use tokio::sync::broadcast;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TaskStatus {
#[serde(rename = "PENDING")]
Pending,
#[serde(rename = "DOWNLOADING")]
Running,
#[serde(rename = "SUCCESS")]
Completed,
#[serde(rename = "FAILED")]
Failed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskInfo {
pub id: String,
pub task_type: String,
pub user_id: i32,
pub status: TaskStatus,
pub progress: f64,
pub message: Option<String>,
pub created_at: chrono::DateTime<chrono::Utc>,
pub updated_at: chrono::DateTime<chrono::Utc>,
pub result: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize)]
pub struct TaskUpdate {
pub task_id: String,
pub user_id: i32,
#[serde(rename = "type")]
pub task_type: String,
pub item_id: Option<i32>,
pub progress: f64,
pub status: TaskStatus,
pub details: serde_json::Value,
pub started_at: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub completed_at: Option<String>,
}
// WebSocket message format to match Python implementation
#[derive(Debug, Clone, Serialize)]
pub struct WebSocketMessage {
pub event: String,
pub task: Option<TaskUpdate>,
pub tasks: Option<Vec<TaskInfo>>,
}
pub type TaskProgressSender = broadcast::Sender<TaskUpdate>;
pub type TaskProgressReceiver = broadcast::Receiver<TaskUpdate>;
#[derive(Clone)]
pub struct TaskManager {
redis: RedisClient,
progress_sender: TaskProgressSender,
}
impl TaskManager {
pub fn new(redis: RedisClient) -> Self {
let (progress_sender, _) = broadcast::channel(1000);
Self {
redis,
progress_sender,
}
}
pub fn subscribe_to_progress(&self) -> TaskProgressReceiver {
self.progress_sender.subscribe()
}
pub async fn create_task(
&self,
task_type: String,
user_id: i32,
) -> AppResult<String> {
self.create_task_with_item_id(task_type, user_id, None).await
}
pub async fn create_task_with_item_id(
&self,
task_type: String,
user_id: i32,
item_id: Option<i32>,
) -> AppResult<String> {
let task_id = Uuid::new_v4().to_string();
let task = TaskInfo {
id: task_id.clone(),
task_type: task_type.clone(),
user_id,
status: TaskStatus::Pending,
progress: 0.0,
message: None,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
result: None,
};
self.save_task(&task).await?;
// Send initial task update with item_id for frontend compatibility
let update = TaskUpdate {
task_id: task_id.clone(),
user_id,
task_type,
item_id,
progress: 0.0,
status: TaskStatus::Pending,
details: serde_json::json!({}),
started_at: chrono::Utc::now().to_rfc3339(),
completed_at: None,
};
let _ = self.progress_sender.send(update);
Ok(task_id)
}
pub async fn update_task_progress(
&self,
task_id: &str,
progress: f64,
message: Option<String>,
) -> AppResult<()> {
self.update_task_progress_with_item_id(task_id, progress, message, None, None).await
}
pub async fn update_task_progress_with_item_id(
&self,
task_id: &str,
progress: f64,
message: Option<String>,
item_id: Option<i32>,
task_type: Option<String>,
) -> AppResult<()> {
self.update_task_progress_with_details(task_id, progress, message, item_id, task_type, None).await
}
pub async fn update_task_progress_with_details(
&self,
task_id: &str,
progress: f64,
message: Option<String>,
item_id: Option<i32>,
task_type: Option<String>,
episode_title: Option<String>,
) -> AppResult<()> {
let mut task = self.get_task(task_id).await?;
task.progress = progress.clamp(0.0, 100.0);
task.message = message.clone();
task.updated_at = chrono::Utc::now();
if progress > 0.0 && matches!(task.status, TaskStatus::Pending) {
task.status = TaskStatus::Running;
}
self.save_task(&task).await?;
let mut details = serde_json::json!({
"status_text": message.as_deref().unwrap_or("Processing...")
});
// Add episode details if provided
if let Some(episode_id) = item_id {
details["episode_id"] = serde_json::json!(episode_id);
}
if let Some(title) = episode_title {
details["episode_title"] = serde_json::json!(title);
}
let update = TaskUpdate {
task_id: task_id.to_string(),
user_id: task.user_id,
task_type: task_type.unwrap_or_else(|| task.task_type.clone()),
item_id,
progress,
status: task.status.clone(),
details,
started_at: task.created_at.to_rfc3339(),
completed_at: None,
};
let _ = self.progress_sender.send(update);
Ok(())
}
pub async fn complete_task(
&self,
task_id: &str,
result: Option<serde_json::Value>,
message: Option<String>,
) -> AppResult<()> {
let mut task = self.get_task(task_id).await?;
task.status = TaskStatus::Completed;
task.progress = 100.0;
task.message = message.clone();
task.result = result.clone();
task.updated_at = chrono::Utc::now();
self.save_task(&task).await?;
let update = TaskUpdate {
task_id: task_id.to_string(),
user_id: task.user_id,
task_type: task.task_type.clone(),
item_id: None, // Completion updates don't need item_id
progress: 100.0,
status: TaskStatus::Completed,
details: serde_json::json!({
"status_text": message.as_deref().unwrap_or("Completed"),
"result": result
}),
started_at: task.created_at.to_rfc3339(),
completed_at: Some(chrono::Utc::now().to_rfc3339()),
};
let _ = self.progress_sender.send(update);
Ok(())
}
pub async fn fail_task(
&self,
task_id: &str,
error_message: String,
) -> AppResult<()> {
let mut task = self.get_task(task_id).await?;
task.status = TaskStatus::Failed;
task.message = Some(error_message.clone());
task.updated_at = chrono::Utc::now();
self.save_task(&task).await?;
let update = TaskUpdate {
task_id: task_id.to_string(),
user_id: task.user_id,
task_type: task.task_type.clone(),
item_id: None, // Failure updates don't need item_id
progress: task.progress,
status: TaskStatus::Failed,
details: serde_json::json!({
"status_text": error_message,
"error": error_message
}),
started_at: task.created_at.to_rfc3339(),
completed_at: Some(chrono::Utc::now().to_rfc3339()),
};
let _ = self.progress_sender.send(update);
Ok(())
}
pub async fn get_task(&self, task_id: &str) -> AppResult<TaskInfo> {
let key = format!("task:{}", task_id);
let mut conn = self.redis.get_connection().await?;
let task_json: String = conn.get(&key).await?;
let task: TaskInfo = serde_json::from_str(&task_json)?;
Ok(task)
}
pub async fn get_user_tasks(&self, user_id: i32) -> AppResult<Vec<TaskInfo>> {
let pattern = format!("task:*");
let mut conn = self.redis.get_connection().await?;
let keys: Vec<String> = conn.keys(&pattern).await?;
let mut user_tasks = Vec::new();
for key in keys {
if let Ok(task_json) = conn.get::<_, String>(&key).await {
if let Ok(task) = serde_json::from_str::<TaskInfo>(&task_json) {
if task.user_id == user_id {
user_tasks.push(task);
}
}
}
}
user_tasks.sort_by(|a, b| b.created_at.cmp(&a.created_at));
Ok(user_tasks)
}
async fn save_task(&self, task: &TaskInfo) -> AppResult<()> {
let key = format!("task:{}", task.id);
let task_json = serde_json::to_string(task)?;
let mut conn = self.redis.get_connection().await?;
conn.set_ex::<_, _, ()>(&key, &task_json, 86400 * 7).await?; // 7 days TTL
Ok(())
}
pub async fn cleanup_old_tasks(&self) -> AppResult<()> {
let cutoff = chrono::Utc::now() - chrono::Duration::days(7);
let pattern = "task:*";
let mut conn = self.redis.get_connection().await?;
let keys: Vec<String> = conn.keys(&pattern).await?;
for key in keys {
if let Ok(task_json) = conn.get::<_, String>(&key).await {
if let Ok(task) = serde_json::from_str::<TaskInfo>(&task_json) {
if task.created_at < cutoff {
let _: () = conn.del(&key).await?;
}
}
}
}
Ok(())
}
}

File diff suppressed because it is too large Load Diff