diff --git a/Dockerfile b/Dockerfile index caa2c5f..5c6166e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -48,9 +48,16 @@ USER dusk # set working directory WORKDIR /data -# expose the default relay port +# expose the default relay port (libp2p) EXPOSE 4001 +# expose TURN server ports (UDP + TCP signaling) +EXPOSE 3478/udp +EXPOSE 3478/tcp + +# expose TURN relay allocation port range (UDP) +EXPOSE 49152-65535/udp + # persist keypair and data to the volume-mounted /data directory # XDG_DATA_HOME tells the directories crate to resolve paths under /data # so the keypair ends up at /data/dusk-relay/keypair instead of ~/.local/share @@ -61,6 +68,19 @@ VOLUME /data ENV RUST_LOG=info ENV DUSK_RELAY_PORT=4001 +# TURN server environment variables +ENV DUSK_TURN_ENABLED=true +ENV DUSK_TURN_PUBLIC_IP="" +ENV DUSK_TURN_SECRET="" +ENV DUSK_TURN_UDP_PORT=3478 +ENV DUSK_TURN_TCP_PORT=3478 +ENV DUSK_TURN_REALM=duskchat.app +ENV DUSK_TURN_PORT_RANGE_START=49152 +ENV DUSK_TURN_PORT_RANGE_END=65535 +ENV DUSK_TURN_MAX_ALLOCATIONS=1000 +ENV DUSK_TURN_MAX_PER_USER=10 +ENV DUSK_TURN_PUBLIC_HOST="" + # health check to verify the relay is listening HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ CMD timeout 5 bash -c 'cat < /dev/null > /dev/tcp/0.0.0.0/${DUSK_RELAY_PORT:-4001}' || exit 1 diff --git a/docker-compose.yml b/docker-compose.yml index b20e8af..6395ad7 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,10 +8,32 @@ services: container_name: dusk-relay restart: unless-stopped ports: + # libp2p relay port - "4001:4001" + # TURN server signaling (UDP + TCP) + - "3478:3478/udp" + - "3478:3478/tcp" + # TURN relay allocation ports (UDP) + # NOTE: using a smaller range (49152-50000) for Docker; adjust as needed + - "49152-50000:49152-50000/udp" environment: - RUST_LOG=info + # libp2p relay - DUSK_RELAY_PORT=4001 + - DUSK_PEER_RELAYS= + - KLIPY_API_KEY= + # TURN server + - DUSK_TURN_ENABLED=true + - DUSK_TURN_PUBLIC_IP=${DUSK_TURN_PUBLIC_IP:?Set DUSK_TURN_PUBLIC_IP to your server public IP} + - DUSK_TURN_SECRET=${DUSK_TURN_SECRET:?Set DUSK_TURN_SECRET for HMAC credentials} + - DUSK_TURN_UDP_PORT=3478 + - DUSK_TURN_TCP_PORT=3478 + - DUSK_TURN_REALM=duskchat.app + - DUSK_TURN_PORT_RANGE_START=49152 + - DUSK_TURN_PORT_RANGE_END=50000 + - DUSK_TURN_MAX_ALLOCATIONS=1000 + - DUSK_TURN_MAX_PER_USER=10 + - DUSK_TURN_PUBLIC_HOST=${DUSK_TURN_PUBLIC_HOST:-} volumes: # persist the relay's keypair so peer id stays stable across restarts - dusk-relay-data:/data diff --git a/src/main.rs b/src/main.rs index a916078..1530201 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,6 +25,8 @@ // t3.large (8GB): 20,000 max connections // c6i.xlarge: 50,000 max connections (with kernel tuning) +mod turn; + use std::collections::{HashMap, VecDeque}; use std::path::PathBuf; use std::time::{Duration, Instant}; @@ -63,6 +65,8 @@ struct RelayBehaviour { gif_service: cbor::Behaviour, // persistent directory service - clients register/search peer profiles directory_service: cbor::Behaviour, + // TURN credential service - clients request time-limited TURN server credentials + turn_credentials: cbor::Behaviour, } // ---- gif protocol ---- @@ -227,6 +231,7 @@ pub enum DirectoryRequest { pub enum DirectoryResponse { Ok, Results(Vec), + Error(String), } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] @@ -237,6 +242,28 @@ pub struct DirectoryProfileEntry { } // ---- end directory protocol ---- +// ---- TURN credential protocol ---- +// clients request time-limited TURN credentials over libp2p request-response. +// the relay generates HMAC-SHA1 credentials that the client then uses when +// talking to the TURN server directly via UDP/TCP (separate from the libp2p swarm). + +const TURN_CREDENTIALS_PROTOCOL: StreamProtocol = + StreamProtocol::new("/dusk/turn-credentials/1.0.0"); + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct TurnCredentialRequest { + pub peer_id: String, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct TurnCredentialResponse { + pub username: String, + pub password: String, + pub ttl: u64, + pub uris: Vec, // TURN server URIs like ["turn:relay.duskchat.app:3478", "turn:relay.duskchat.app:3478?transport=tcp"] +} +// ---- end TURN credential protocol ---- + // fetch from klipy and normalize into our GifResult format async fn fetch_klipy( http: &reqwest::Client, @@ -342,6 +369,18 @@ fn load_or_generate_keypair() -> libp2p::identity::Keypair { if let Err(e) = std::fs::write(&path, &bytes) { log::warn!("failed to persist keypair: {}", e); } else { + // restrict keypair file to owner-only read/write (0600) so other + // users on a shared server cannot read the relay's private key + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + if let Err(e) = std::fs::set_permissions( + &path, + std::fs::Permissions::from_mode(0o600), + ) { + log::warn!("failed to set keypair file permissions: {}", e); + } + } log::info!("saved new keypair to {}", path.display()); } } @@ -460,12 +499,6 @@ async fn main() -> Result<(), Box> { ) .expect("valid gossipsub behaviour"); - // read connection limit (max total concurrent connections across all peers) - let max_connections = std::env::var("DUSK_MAX_CONNECTIONS") - .ok() - .and_then(|c| c.parse().ok()) - .unwrap_or(10_000); - RelayBehaviour { relay: relay::Behaviour::new(peer_id, relay::Config::default()), rendezvous: rendezvous::server::Behaviour::new( @@ -497,6 +530,12 @@ async fn main() -> Result<(), Box> { request_response::Config::default() .with_request_timeout(Duration::from_secs(15)), ), + // TURN credential service - clients request time-limited TURN credentials + turn_credentials: cbor::Behaviour::new( + [(TURN_CREDENTIALS_PROTOCOL, ProtocolSupport::Full)], + request_response::Config::default() + .with_request_timeout(Duration::from_secs(15)), + ), } })? .with_swarm_config(|cfg| cfg.with_idle_connection_timeout(Duration::from_secs(300))) @@ -512,6 +551,40 @@ async fn main() -> Result<(), Box> { let canonical_addr = format!("/ip4/0.0.0.0/tcp/{}/p2p/{}", port, local_peer_id); println!("\n relay address: {}\n", canonical_addr); + // ---- TURN server startup ---- + // get the shared secret for credential generation (shared between TURN server and credential protocol) + let turn_shared_secret: Vec = std::env::var("DUSK_TURN_SECRET") + .unwrap_or_else(|_| { + eprintln!("[TURN] WARNING: DUSK_TURN_SECRET not set, generating random secret"); + eprintln!("[TURN] This means credentials won't survive restarts and won't work across multiple relay instances"); + let timestamp = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_nanos(); + format!("dusk-turn-random-{}", timestamp) + }) + .into_bytes(); + + // start TURN server if enabled (runs on separate UDP/TCP ports from the libp2p swarm) + let _turn_handle = if turn::TurnServerConfig::is_enabled() { + let turn_config = turn::TurnServerConfig::from_env(); + let turn_server = turn::TurnServer::new(turn_config); + match turn_server.run().await { + Ok(handle) => { + println!("[RELAY] TURN server started"); + Some(handle) + } + Err(e) => { + eprintln!("[RELAY] Failed to start TURN server: {}", e); + None + } + } + } else { + println!("[RELAY] TURN server disabled"); + None + }; + // ---- end TURN server startup ---- + // subscribe to the relay federation gossip topic let federation_topic = gossipsub::IdentTopic::new("dusk/relay/federation"); swarm @@ -896,23 +969,36 @@ async fn main() -> Result<(), Box> { )) => { let response = match request { DirectoryRequest::Register { display_name } => { - let ts = now_secs(); - let result = dir_db.execute( - "INSERT INTO peer_profiles (peer_id, display_name, last_seen, registered_at) - VALUES (?1, ?2, ?3, ?3) - ON CONFLICT(peer_id) DO UPDATE SET - display_name = excluded.display_name, - last_seen = excluded.last_seen", - params![peer.to_string(), display_name, ts as i64], - ); - match result { - Ok(_) => { - log::info!("directory: registered peer {} as '{}'", peer, display_name); - DirectoryResponse::Ok - } - Err(e) => { - log::warn!("directory: failed to register {}: {}", peer, e); - DirectoryResponse::Ok + // validate display_name: reject empty or excessively long names + let trimmed_name = display_name.trim(); + if trimmed_name.is_empty() || trimmed_name.len() > 64 { + log::warn!( + "directory: rejected registration from {} (name length {})", + peer, + trimmed_name.len() + ); + DirectoryResponse::Error( + "display_name must be 1-64 characters".to_string(), + ) + } else { + let ts = now_secs(); + let result = dir_db.execute( + "INSERT INTO peer_profiles (peer_id, display_name, last_seen, registered_at) + VALUES (?1, ?2, ?3, ?3) + ON CONFLICT(peer_id) DO UPDATE SET + display_name = excluded.display_name, + last_seen = excluded.last_seen", + params![peer.to_string(), trimmed_name, ts as i64], + ); + match result { + Ok(_) => { + log::info!("directory: registered peer {} as '{}'", peer, trimmed_name); + DirectoryResponse::Ok + } + Err(e) => { + log::warn!("directory: failed to register {}: {}", peer, e); + DirectoryResponse::Error(format!("registration failed: {}", e)) + } } } } @@ -974,6 +1060,85 @@ async fn main() -> Result<(), Box> { // ignore outbound and other directory service events SwarmEvent::Behaviour(RelayBehaviourEvent::DirectoryService(_)) => {} + // ---- TURN credential service ---- + // clients request time-limited credentials for the TURN server + SwarmEvent::Behaviour(RelayBehaviourEvent::TurnCredentials( + request_response::Event::Message { + peer, + message: + request_response::Message::Request { + request, channel, .. + }, + .. + }, + )) => { + let peer_id_str = request.peer_id.clone(); + + // generate time-limited credentials using the shared secret + let (username, password) = turn::credentials::generate_credentials( + &peer_id_str, + &turn_shared_secret, + 86400, // 24 hours + ); + + // build TURN URIs using the relay's public hostname/IP + let turn_host = std::env::var("DUSK_TURN_PUBLIC_HOST") + .unwrap_or_else(|_| { + std::env::var("DUSK_TURN_PUBLIC_IP") + .unwrap_or_else(|_| "relay.duskchat.app".to_string()) + }); + let turn_port: u16 = std::env::var("DUSK_TURN_UDP_PORT") + .ok() + .and_then(|p| p.parse().ok()) + .unwrap_or(3478); + + let response = TurnCredentialResponse { + username, + password, + ttl: 86400, + uris: vec![ + format!("turn:{}:{}", turn_host, turn_port), + format!("turn:{}:{}?transport=tcp", turn_host, turn_port), + ], + }; + + if swarm + .behaviour_mut() + .turn_credentials + .send_response(channel, response) + .is_err() + { + log::warn!("[TURN-CREDS] failed to send credentials to {:?}", peer); + } + log::info!( + "[TURN-CREDS] generated credentials for peer {} (requested by {:?})", + peer_id_str, + peer + ); + } + // we're the server so we don't send requests, but handle all variants + SwarmEvent::Behaviour(RelayBehaviourEvent::TurnCredentials( + request_response::Event::Message { + message: request_response::Message::Response { .. }, + .. + }, + )) => {} + SwarmEvent::Behaviour(RelayBehaviourEvent::TurnCredentials( + request_response::Event::OutboundFailure { .. }, + )) => {} + SwarmEvent::Behaviour(RelayBehaviourEvent::TurnCredentials( + request_response::Event::InboundFailure { peer, error, .. }, + )) => { + log::warn!( + "[TURN-CREDS] inbound failure from {:?}: {:?}", + peer, + error + ); + } + SwarmEvent::Behaviour(RelayBehaviourEvent::TurnCredentials( + request_response::Event::ResponseSent { .. }, + )) => {} + _ => {} } } diff --git a/src/turn/allocation.rs b/src/turn/allocation.rs new file mode 100644 index 0000000..4cbee1d --- /dev/null +++ b/src/turn/allocation.rs @@ -0,0 +1,1007 @@ +// TURN allocation state machine per RFC 5766 §5 +// +// Manages all active TURN allocations including their permissions and +// channel bindings. Each allocation is identified by a 5-tuple +// (client addr, server addr, protocol) and owns a dedicated relay +// UDP socket. +// +// Key RFC 5766 rules enforced: +// - Channel numbers must be in range 0x4000-0x7FFE +// - A channel can only be bound to one peer address at a time +// - A peer address can only be bound to one channel at a time +// - Creating a channel also creates a permission for that peer's IP +// - Default allocation lifetime is 600s, max is 3600s, lifetime=0 deletes +// - Permissions last 300s, channels last 600s + +use std::collections::HashMap; +use std::net::{IpAddr, SocketAddr}; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use tokio::net::UdpSocket; +use tokio::sync::RwLock; + +use crate::turn::error::TurnError; + +// --------------------------------------------------------------------------- +// FiveTuple — identifies a client's transport connection +// --------------------------------------------------------------------------- + +/// Identifies a client by their 5-tuple (client addr, server addr, protocol). +/// +/// Per RFC 5766 §2.2, a TURN allocation is uniquely identified by the 5-tuple +/// consisting of the client's IP address and port, the server's IP address and +/// port, and the transport protocol (UDP or TCP). +#[derive(Debug, Clone, Hash, Eq, PartialEq)] +pub struct FiveTuple { + pub client_addr: SocketAddr, + pub server_addr: SocketAddr, + pub protocol: TransportProtocol, +} + +/// Transport protocol for the client-to-server connection. +#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] +pub enum TransportProtocol { + Udp, + Tcp, +} + +impl std::fmt::Display for TransportProtocol { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TransportProtocol::Udp => write!(f, "UDP"), + TransportProtocol::Tcp => write!(f, "TCP"), + } + } +} + +// --------------------------------------------------------------------------- +// Allocation — a single TURN relay allocation +// --------------------------------------------------------------------------- + +/// A single TURN allocation per RFC 5766 §5. +/// +/// Each allocation owns a dedicated UDP relay socket and tracks: +/// - The authenticated user who created it +/// - Current lifetime and expiry +/// - Installed permissions (peer IP → expiry) +/// - Channel bindings (channel number ↔ peer address) +#[derive(Debug)] +pub struct Allocation { + /// The 5-tuple identifying this allocation's client connection. + pub five_tuple: FiveTuple, + /// The relayed transport address (public IP + allocated port). + pub relay_addr: SocketAddr, + /// The UDP socket bound to the relay port for relaying data. + pub relay_socket: Arc, + /// The relay port number (kept separately for returning to the pool). + pub relay_port: u16, + /// The authenticated username that created this allocation. + pub username: String, + /// The authentication realm. + pub realm: String, + /// The current nonce used for this allocation's authentication. + pub nonce: String, + /// The current allocation lifetime. + pub lifetime: Duration, + /// When this allocation expires (monotonic clock). + pub expires_at: Instant, + /// Installed permissions: peer IP address → expiry time. + /// Per RFC 5766 §8, permissions last 300 seconds. + pub permissions: HashMap, + /// Channel bindings: channel number → binding info. + /// Per RFC 5766 §11, channel bindings last 600 seconds. + pub channels: HashMap, + /// Reverse lookup: peer address → channel number. + pub channel_by_peer: HashMap, +} + +/// A channel binding associates a channel number with a peer address. +#[derive(Debug, Clone)] +pub struct ChannelBinding { + /// The peer transport address bound to this channel. + pub peer_addr: SocketAddr, + /// When this channel binding expires (10 minutes per RFC 5766 §11). + pub expires_at: Instant, +} + +// --------------------------------------------------------------------------- +// AllocationConfig — tunable parameters +// --------------------------------------------------------------------------- + +/// Configuration for the [`AllocationManager`]. +/// +/// All durations and limits have sensible RFC-compliant defaults. +#[derive(Debug, Clone)] +pub struct AllocationConfig { + /// Maximum total allocations across all users (default: 1000). + pub max_allocations: usize, + /// Maximum allocations per username (default: 10). + pub max_per_user: usize, + /// Default allocation lifetime in seconds (default: 600s per RFC 5766 §6.2). + pub default_lifetime: Duration, + /// Maximum allowed allocation lifetime (default: 3600s per RFC 5766 §6.2). + pub max_lifetime: Duration, + /// Permission lifetime (default: 300s per RFC 5766 §8). + pub permission_lifetime: Duration, + /// Channel binding lifetime (default: 600s per RFC 5766 §11). + pub channel_lifetime: Duration, + /// TURN realm for authentication (default: "duskchat.app"). + pub realm: String, +} + +impl Default for AllocationConfig { + fn default() -> Self { + Self { + max_allocations: 1000, + max_per_user: 10, + default_lifetime: Duration::from_secs(600), + max_lifetime: Duration::from_secs(3600), + permission_lifetime: Duration::from_secs(300), + channel_lifetime: Duration::from_secs(600), + realm: "duskchat.app".to_string(), + } + } +} + +// --------------------------------------------------------------------------- +// AllocationInfo — read-only snapshot of allocation state +// --------------------------------------------------------------------------- + +/// A read-only snapshot of allocation data returned by +/// [`AllocationManager::get_allocation`]. +/// +/// This avoids holding the read lock while callers process the data. +#[derive(Debug, Clone)] +pub struct AllocationInfo { + pub five_tuple: FiveTuple, + pub relay_addr: SocketAddr, + pub relay_socket: Arc, + pub relay_port: u16, + pub username: String, + pub realm: String, + pub nonce: String, + pub lifetime: Duration, + pub expires_at: Instant, +} + +// --------------------------------------------------------------------------- +// AllocationManager — manages all active allocations +// --------------------------------------------------------------------------- + +/// Manages all active TURN allocations. +/// +/// Thread-safe via internal `RwLock`. All public methods are `async` and +/// acquire the lock as needed. The manager enforces per-user quotas, +/// global allocation limits, and RFC-mandated lifetimes. +pub struct AllocationManager { + allocations: RwLock>, + config: AllocationConfig, +} + +impl AllocationManager { + /// Create a new allocation manager with the given configuration. + pub fn new(config: AllocationConfig) -> Self { + Self { + allocations: RwLock::new(HashMap::new()), + config, + } + } + + /// Create a new TURN allocation. + /// + /// Returns the relay address on success. Fails if: + /// - An allocation already exists for this 5-tuple (437 Allocation Mismatch) + /// - The global allocation limit is reached (486 Allocation Quota Reached) + /// - The per-user allocation limit is reached (486 Allocation Quota Reached) + /// + /// The caller is responsible for allocating the port and binding the socket + /// before calling this method. + pub async fn create_allocation( + &self, + five_tuple: FiveTuple, + username: String, + realm: String, + nonce: String, + lifetime: Duration, + relay_socket: Arc, + relay_port: u16, + relay_addr: SocketAddr, + ) -> Result { + let mut allocs = self.allocations.write().await; + + // Check if allocation already exists for this 5-tuple + if allocs.contains_key(&five_tuple) { + return Err(TurnError::AllocationMismatch); + } + + // Check global allocation limit + if allocs.len() >= self.config.max_allocations { + return Err(TurnError::AllocationQuotaReached); + } + + // Check per-user allocation limit + let user_count = allocs + .values() + .filter(|a| a.username == username) + .count(); + if user_count >= self.config.max_per_user { + return Err(TurnError::AllocationQuotaReached); + } + + // Clamp lifetime to allowed range + let actual_lifetime = lifetime.min(self.config.max_lifetime); + let expires_at = Instant::now() + actual_lifetime; + + let allocation = Allocation { + five_tuple: five_tuple.clone(), + relay_addr, + relay_socket, + relay_port, + username, + realm, + nonce, + lifetime: actual_lifetime, + expires_at, + permissions: HashMap::new(), + channels: HashMap::new(), + channel_by_peer: HashMap::new(), + }; + + allocs.insert(five_tuple, allocation); + + Ok(relay_addr) + } + + /// Refresh an existing allocation's lifetime. + /// + /// Per RFC 5766 §7: + /// - If `lifetime` is zero, the allocation is deleted immediately + /// - Otherwise the lifetime is clamped to `max_lifetime` and the + /// expiry is updated + /// + /// Returns the actual granted lifetime, or an error if no allocation + /// exists for this 5-tuple. + pub async fn refresh_allocation( + &self, + five_tuple: &FiveTuple, + lifetime: Duration, + ) -> Result { + let mut allocs = self.allocations.write().await; + + let alloc = allocs + .get_mut(five_tuple) + .ok_or(TurnError::AllocationMismatch)?; + + if lifetime.is_zero() { + // lifetime=0 means delete the allocation + allocs.remove(five_tuple); + return Ok(Duration::ZERO); + } + + // Clamp lifetime to allowed max + let actual_lifetime = lifetime.min(self.config.max_lifetime); + alloc.lifetime = actual_lifetime; + alloc.expires_at = Instant::now() + actual_lifetime; + + Ok(actual_lifetime) + } + + /// Delete an allocation, returning the relay port number for recycling. + /// + /// Returns `None` if no allocation exists for this 5-tuple. + pub async fn delete_allocation(&self, five_tuple: &FiveTuple) -> Option { + let mut allocs = self.allocations.write().await; + allocs.remove(five_tuple).map(|a| a.relay_port) + } + + /// Get a read-only snapshot of an allocation's state. + /// + /// Returns `None` if no allocation exists for this 5-tuple. + pub async fn get_allocation(&self, five_tuple: &FiveTuple) -> Option { + let allocs = self.allocations.read().await; + allocs.get(five_tuple).map(|a| AllocationInfo { + five_tuple: a.five_tuple.clone(), + relay_addr: a.relay_addr, + relay_socket: Arc::clone(&a.relay_socket), + relay_port: a.relay_port, + username: a.username.clone(), + realm: a.realm.clone(), + nonce: a.nonce.clone(), + lifetime: a.lifetime, + expires_at: a.expires_at, + }) + } + + /// Create or refresh permissions for one or more peer IP addresses. + /// + /// Per RFC 5766 §9, each permission is installed for the peer's IP + /// address (ignoring port) and lasts for 300 seconds. Refreshing + /// an existing permission resets its timer. + pub async fn create_permission( + &self, + five_tuple: &FiveTuple, + peer_addrs: Vec, + ) -> Result<(), TurnError> { + let mut allocs = self.allocations.write().await; + let alloc = allocs + .get_mut(five_tuple) + .ok_or(TurnError::AllocationMismatch)?; + + let expires_at = Instant::now() + self.config.permission_lifetime; + + for addr in peer_addrs { + alloc.permissions.insert(addr, expires_at); + } + + Ok(()) + } + + /// Check if a permission exists for a peer IP address. + /// + /// Returns `true` if a non-expired permission exists for the given + /// peer IP on the specified allocation. + pub async fn has_permission( + &self, + five_tuple: &FiveTuple, + peer_addr: &IpAddr, + ) -> bool { + let allocs = self.allocations.read().await; + if let Some(alloc) = allocs.get(five_tuple) { + if let Some(expires_at) = alloc.permissions.get(peer_addr) { + return Instant::now() < *expires_at; + } + } + false + } + + /// Bind a channel number to a peer address. + /// + /// Per RFC 5766 §11: + /// - Channel numbers must be in range 0x4000-0x7FFE + /// - A channel number can only be bound to one peer address + /// - A peer address can only be bound to one channel number + /// - If the binding already exists with the same pair, it's refreshed + /// - Creating a channel also installs a permission for the peer's IP + pub async fn bind_channel( + &self, + five_tuple: &FiveTuple, + channel_number: u16, + peer_addr: SocketAddr, + ) -> Result<(), TurnError> { + // Validate channel number range (0x4000-0x7FFE) + if !(0x4000..=0x7FFE).contains(&channel_number) { + return Err(TurnError::StunParseError(format!( + "channel number 0x{:04x} out of range 0x4000-0x7FFE", + channel_number + ))); + } + + let mut allocs = self.allocations.write().await; + let alloc = allocs + .get_mut(five_tuple) + .ok_or(TurnError::AllocationMismatch)?; + + // Check if this channel number is already bound to a DIFFERENT peer + if let Some(existing) = alloc.channels.get(&channel_number) { + if existing.peer_addr != peer_addr { + return Err(TurnError::StunParseError( + "channel number already bound to a different peer".into(), + )); + } + // Same binding — this is a refresh, fall through + } + + // Check if this peer address is already bound to a DIFFERENT channel + if let Some(&existing_channel) = alloc.channel_by_peer.get(&peer_addr) { + if existing_channel != channel_number { + return Err(TurnError::StunParseError( + "peer address already bound to a different channel".into(), + )); + } + // Same binding — this is a refresh, fall through + } + + let expires_at = Instant::now() + self.config.channel_lifetime; + + // Install or refresh the channel binding + alloc.channels.insert( + channel_number, + ChannelBinding { + peer_addr, + expires_at, + }, + ); + alloc.channel_by_peer.insert(peer_addr, channel_number); + + // Also install a permission for the peer's IP (RFC 5766 §11.1) + let perm_expires = Instant::now() + self.config.permission_lifetime; + alloc.permissions.insert(peer_addr.ip(), perm_expires); + + Ok(()) + } + + /// Look up the peer address for a channel binding. + /// + /// Returns `None` if the channel is not bound or has expired. + pub async fn get_channel_binding( + &self, + five_tuple: &FiveTuple, + channel_number: u16, + ) -> Option { + let allocs = self.allocations.read().await; + let alloc = allocs.get(five_tuple)?; + let binding = alloc.channels.get(&channel_number)?; + + if Instant::now() >= binding.expires_at { + return None; + } + + Some(binding.peer_addr) + } + + /// Look up the channel number for a peer address (reverse lookup). + /// + /// Returns `None` if no channel is bound to this peer or the binding expired. + pub async fn get_channel_for_peer( + &self, + five_tuple: &FiveTuple, + peer_addr: &SocketAddr, + ) -> Option { + let allocs = self.allocations.read().await; + let alloc = allocs.get(five_tuple)?; + let &channel_number = alloc.channel_by_peer.get(peer_addr)?; + + // Check that the binding hasn't expired + if let Some(binding) = alloc.channels.get(&channel_number) { + if Instant::now() < binding.expires_at { + return Some(channel_number); + } + } + + None + } + + /// Find an allocation by its relay address. + /// + /// This is used when data arrives on a relay socket and needs to be + /// forwarded to the client. Returns the 5-tuple and relay socket. + pub async fn get_allocation_by_relay_addr( + &self, + relay_addr: &SocketAddr, + ) -> Option<(FiveTuple, Arc)> { + let allocs = self.allocations.read().await; + for alloc in allocs.values() { + if alloc.relay_addr == *relay_addr { + return Some(( + alloc.five_tuple.clone(), + Arc::clone(&alloc.relay_socket), + )); + } + } + None + } + + /// Clean up expired allocations, permissions, and channel bindings. + /// + /// Returns a list of relay port numbers that were freed and should + /// be returned to the port pool. + pub async fn cleanup_expired(&self) -> Vec { + let mut allocs = self.allocations.write().await; + let now = Instant::now(); + let mut freed_ports = Vec::new(); + + // Collect expired allocation keys + let expired_keys: Vec = allocs + .iter() + .filter(|(_, a)| now >= a.expires_at) + .map(|(k, _)| k.clone()) + .collect(); + + // Remove expired allocations + for key in expired_keys { + if let Some(alloc) = allocs.remove(&key) { + freed_ports.push(alloc.relay_port); + log::debug!( + "cleaned up expired allocation for {} (port {})", + key.client_addr, + alloc.relay_port, + ); + } + } + + // Clean up expired permissions and channels in remaining allocations + for alloc in allocs.values_mut() { + // Remove expired permissions + alloc.permissions.retain(|_ip, expires| now < *expires); + + // Remove expired channel bindings + let expired_channels: Vec = alloc + .channels + .iter() + .filter(|(_, binding)| now >= binding.expires_at) + .map(|(&num, _)| num) + .collect(); + + for channel_num in expired_channels { + if let Some(binding) = alloc.channels.remove(&channel_num) { + alloc.channel_by_peer.remove(&binding.peer_addr); + } + } + } + + freed_ports + } + + /// Get the total number of active allocations. + pub async fn allocation_count(&self) -> usize { + self.allocations.read().await.len() + } + + /// Get the number of allocations for a specific username. + pub async fn allocations_for_user(&self, username: &str) -> usize { + self.allocations + .read() + .await + .values() + .filter(|a| a.username == username) + .count() + } + + /// Get the allocation config (for use by the handler). + pub fn config(&self) -> &AllocationConfig { + &self.config + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{Ipv4Addr, SocketAddrV4}; + + fn test_five_tuple() -> FiveTuple { + FiveTuple { + client_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 100), 12345)), + server_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 3478)), + protocol: TransportProtocol::Udp, + } + } + + fn test_config() -> AllocationConfig { + AllocationConfig { + max_allocations: 10, + max_per_user: 3, + ..Default::default() + } + } + + async fn create_test_socket() -> (Arc, u16) { + let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let port = socket.local_addr().unwrap().port(); + (Arc::new(socket), port) + } + + #[tokio::test] + async fn test_create_allocation() { + let mgr = AllocationManager::new(test_config()); + let ft = test_five_tuple(); + let (socket, port) = create_test_socket().await; + let relay_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), port)); + + let result = mgr + .create_allocation( + ft.clone(), + "alice".into(), + "duskchat.app".into(), + "nonce123".into(), + Duration::from_secs(600), + socket, + port, + relay_addr, + ) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), relay_addr); + assert_eq!(mgr.allocation_count().await, 1); + } + + #[tokio::test] + async fn test_duplicate_allocation_rejected() { + let mgr = AllocationManager::new(test_config()); + let ft = test_five_tuple(); + let (socket1, port1) = create_test_socket().await; + let relay_addr1 = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), port1)); + + mgr.create_allocation( + ft.clone(), + "alice".into(), + "duskchat.app".into(), + "nonce".into(), + Duration::from_secs(600), + socket1, + port1, + relay_addr1, + ) + .await + .unwrap(); + + let (socket2, port2) = create_test_socket().await; + let relay_addr2 = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), port2)); + + let result = mgr + .create_allocation( + ft, + "alice".into(), + "duskchat.app".into(), + "nonce".into(), + Duration::from_secs(600), + socket2, + port2, + relay_addr2, + ) + .await; + + assert!(matches!(result, Err(TurnError::AllocationMismatch))); + } + + #[tokio::test] + async fn test_per_user_quota() { + let mgr = AllocationManager::new(test_config()); // max_per_user = 3 + + for i in 0..3 { + let ft = FiveTuple { + client_addr: SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new(192, 168, 1, i as u8), + 12345, + )), + server_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 3478)), + protocol: TransportProtocol::Udp, + }; + let (socket, port) = create_test_socket().await; + let relay_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), port)); + + mgr.create_allocation( + ft, + "alice".into(), + "duskchat.app".into(), + "nonce".into(), + Duration::from_secs(600), + socket, + port, + relay_addr, + ) + .await + .unwrap(); + } + + // 4th allocation for same user should fail + let ft = FiveTuple { + client_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 10), 12345)), + server_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 3478)), + protocol: TransportProtocol::Udp, + }; + let (socket, port) = create_test_socket().await; + let relay_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), port)); + + let result = mgr + .create_allocation( + ft, + "alice".into(), + "duskchat.app".into(), + "nonce".into(), + Duration::from_secs(600), + socket, + port, + relay_addr, + ) + .await; + + assert!(matches!(result, Err(TurnError::AllocationQuotaReached))); + } + + #[tokio::test] + async fn test_refresh_allocation() { + let mgr = AllocationManager::new(test_config()); + let ft = test_five_tuple(); + let (socket, port) = create_test_socket().await; + let relay_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), port)); + + mgr.create_allocation( + ft.clone(), + "alice".into(), + "duskchat.app".into(), + "nonce".into(), + Duration::from_secs(600), + socket, + port, + relay_addr, + ) + .await + .unwrap(); + + let result = mgr + .refresh_allocation(&ft, Duration::from_secs(1200)) + .await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Duration::from_secs(1200)); + } + + #[tokio::test] + async fn test_refresh_lifetime_clamped() { + let mgr = AllocationManager::new(test_config()); // max_lifetime = 3600s + let ft = test_five_tuple(); + let (socket, port) = create_test_socket().await; + let relay_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), port)); + + mgr.create_allocation( + ft.clone(), + "alice".into(), + "duskchat.app".into(), + "nonce".into(), + Duration::from_secs(600), + socket, + port, + relay_addr, + ) + .await + .unwrap(); + + let result = mgr + .refresh_allocation(&ft, Duration::from_secs(99999)) + .await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Duration::from_secs(3600)); + } + + #[tokio::test] + async fn test_refresh_zero_deletes() { + let mgr = AllocationManager::new(test_config()); + let ft = test_five_tuple(); + let (socket, port) = create_test_socket().await; + let relay_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), port)); + + mgr.create_allocation( + ft.clone(), + "alice".into(), + "duskchat.app".into(), + "nonce".into(), + Duration::from_secs(600), + socket, + port, + relay_addr, + ) + .await + .unwrap(); + + let result = mgr.refresh_allocation(&ft, Duration::ZERO).await; + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Duration::ZERO); + assert_eq!(mgr.allocation_count().await, 0); + } + + #[tokio::test] + async fn test_permissions() { + let mgr = AllocationManager::new(test_config()); + let ft = test_five_tuple(); + let (socket, port) = create_test_socket().await; + let relay_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), port)); + + mgr.create_allocation( + ft.clone(), + "alice".into(), + "duskchat.app".into(), + "nonce".into(), + Duration::from_secs(600), + socket, + port, + relay_addr, + ) + .await + .unwrap(); + + let peer_ip: IpAddr = "10.0.0.2".parse().unwrap(); + + // No permission initially + assert!(!mgr.has_permission(&ft, &peer_ip).await); + + // Install permission + mgr.create_permission(&ft, vec![peer_ip]).await.unwrap(); + + // Now it should exist + assert!(mgr.has_permission(&ft, &peer_ip).await); + } + + #[tokio::test] + async fn test_channel_binding() { + let mgr = AllocationManager::new(test_config()); + let ft = test_five_tuple(); + let (socket, port) = create_test_socket().await; + let relay_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), port)); + + mgr.create_allocation( + ft.clone(), + "alice".into(), + "duskchat.app".into(), + "nonce".into(), + Duration::from_secs(600), + socket, + port, + relay_addr, + ) + .await + .unwrap(); + + let peer_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 2), 9999)); + + // Bind channel + mgr.bind_channel(&ft, 0x4000, peer_addr).await.unwrap(); + + // Look up by channel number + let result = mgr.get_channel_binding(&ft, 0x4000).await; + assert_eq!(result, Some(peer_addr)); + + // Look up by peer address (reverse) + let result = mgr.get_channel_for_peer(&ft, &peer_addr).await; + assert_eq!(result, Some(0x4000)); + + // Channel binding should also install permission + assert!(mgr.has_permission(&ft, &peer_addr.ip()).await); + } + + #[tokio::test] + async fn test_channel_number_validation() { + let mgr = AllocationManager::new(test_config()); + let ft = test_five_tuple(); + let (socket, port) = create_test_socket().await; + let relay_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), port)); + + mgr.create_allocation( + ft.clone(), + "alice".into(), + "duskchat.app".into(), + "nonce".into(), + Duration::from_secs(600), + socket, + port, + relay_addr, + ) + .await + .unwrap(); + + let peer_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 2), 9999)); + + // Too low + assert!(mgr.bind_channel(&ft, 0x3FFF, peer_addr).await.is_err()); + // Too high + assert!(mgr.bind_channel(&ft, 0x7FFF, peer_addr).await.is_err()); + // Valid range boundary + assert!(mgr.bind_channel(&ft, 0x4000, peer_addr).await.is_ok()); + } + + #[tokio::test] + async fn test_channel_conflict_different_peer() { + let mgr = AllocationManager::new(test_config()); + let ft = test_five_tuple(); + let (socket, port) = create_test_socket().await; + let relay_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), port)); + + mgr.create_allocation( + ft.clone(), + "alice".into(), + "duskchat.app".into(), + "nonce".into(), + Duration::from_secs(600), + socket, + port, + relay_addr, + ) + .await + .unwrap(); + + let peer1 = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 2), 9999)); + let peer2 = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 3), 9999)); + + // Bind channel 0x4000 to peer1 + mgr.bind_channel(&ft, 0x4000, peer1).await.unwrap(); + + // Try to bind same channel to different peer → error + assert!(mgr.bind_channel(&ft, 0x4000, peer2).await.is_err()); + + // Try to bind different channel to peer1 → error (peer already bound) + assert!(mgr.bind_channel(&ft, 0x4001, peer1).await.is_err()); + } + + #[tokio::test] + async fn test_delete_allocation() { + let mgr = AllocationManager::new(test_config()); + let ft = test_five_tuple(); + let (socket, port) = create_test_socket().await; + let relay_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), port)); + + mgr.create_allocation( + ft.clone(), + "alice".into(), + "duskchat.app".into(), + "nonce".into(), + Duration::from_secs(600), + socket, + port, + relay_addr, + ) + .await + .unwrap(); + + let freed_port = mgr.delete_allocation(&ft).await; + assert_eq!(freed_port, Some(port)); + assert_eq!(mgr.allocation_count().await, 0); + } + + #[tokio::test] + async fn test_get_allocation_by_relay_addr() { + let mgr = AllocationManager::new(test_config()); + let ft = test_five_tuple(); + let (socket, port) = create_test_socket().await; + let relay_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), port)); + + mgr.create_allocation( + ft.clone(), + "alice".into(), + "duskchat.app".into(), + "nonce".into(), + Duration::from_secs(600), + socket, + port, + relay_addr, + ) + .await + .unwrap(); + + let result = mgr.get_allocation_by_relay_addr(&relay_addr).await; + assert!(result.is_some()); + let (found_ft, _) = result.unwrap(); + assert_eq!(found_ft, ft); + + // Non-existent relay addr + let bad_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(9, 9, 9, 9), 1234)); + assert!(mgr.get_allocation_by_relay_addr(&bad_addr).await.is_none()); + } + + #[tokio::test] + async fn test_allocations_for_user() { + let mgr = AllocationManager::new(test_config()); + + for i in 0..2 { + let ft = FiveTuple { + client_addr: SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::new(192, 168, 1, i as u8), + 12345, + )), + server_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 3478)), + protocol: TransportProtocol::Udp, + }; + let (socket, port) = create_test_socket().await; + let relay_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(1, 2, 3, 4), port)); + + mgr.create_allocation( + ft, + "alice".into(), + "duskchat.app".into(), + "nonce".into(), + Duration::from_secs(600), + socket, + port, + relay_addr, + ) + .await + .unwrap(); + } + + assert_eq!(mgr.allocations_for_user("alice").await, 2); + assert_eq!(mgr.allocations_for_user("bob").await, 0); + } +} diff --git a/src/turn/attributes.rs b/src/turn/attributes.rs new file mode 100644 index 0000000..c14df05 --- /dev/null +++ b/src/turn/attributes.rs @@ -0,0 +1,799 @@ +// STUN/TURN attribute types and TLV encoding/decoding +// +// Implements attributes per RFC 5389 (STUN) and RFC 5766 (TURN). +// Each attribute is a Type-Length-Value (TLV) structure: +// Type: 2 bytes (attribute type code) +// Length: 2 bytes (value length, excluding padding) +// Value: variable (padded to 4-byte boundary with zeros) +// +// XOR-encoded address attributes (XOR-MAPPED-ADDRESS, XOR-PEER-ADDRESS, +// XOR-RELAYED-ADDRESS) use the STUN magic cookie and transaction ID +// to XOR the address bytes, preventing NAT ALGs from rewriting them. + +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + +use crate::turn::error::TurnError; +use crate::turn::stun::MAGIC_COOKIE; + +// --------------------------------------------------------------------------- +// Attribute type codes +// --------------------------------------------------------------------------- + +/// RFC 5389 STUN attribute type codes +pub const ATTR_MAPPED_ADDRESS: u16 = 0x0001; +pub const ATTR_USERNAME: u16 = 0x0006; +pub const ATTR_MESSAGE_INTEGRITY: u16 = 0x0008; +pub const ATTR_ERROR_CODE: u16 = 0x0009; +pub const ATTR_UNKNOWN_ATTRIBUTES: u16 = 0x000A; +pub const ATTR_REALM: u16 = 0x0014; +pub const ATTR_NONCE: u16 = 0x0015; +pub const ATTR_XOR_MAPPED_ADDRESS: u16 = 0x0020; +pub const ATTR_SOFTWARE: u16 = 0x8022; +pub const ATTR_FINGERPRINT: u16 = 0x8028; + +/// RFC 5766 TURN attribute type codes +pub const ATTR_CHANNEL_NUMBER: u16 = 0x000C; +pub const ATTR_LIFETIME: u16 = 0x000D; +pub const ATTR_XOR_PEER_ADDRESS: u16 = 0x0012; +pub const ATTR_DATA: u16 = 0x0013; +pub const ATTR_XOR_RELAYED_ADDRESS: u16 = 0x0016; +pub const ATTR_REQUESTED_ADDRESS_FAMILY: u16 = 0x0017; +pub const ATTR_EVEN_PORT: u16 = 0x0018; +pub const ATTR_REQUESTED_TRANSPORT: u16 = 0x0019; +pub const ATTR_DONT_FRAGMENT: u16 = 0x001A; + +// --------------------------------------------------------------------------- +// Address family constants +// --------------------------------------------------------------------------- + +const ADDR_FAMILY_IPV4: u8 = 0x01; +const ADDR_FAMILY_IPV6: u8 = 0x02; + +// --------------------------------------------------------------------------- +// StunAttribute enum +// --------------------------------------------------------------------------- + +/// A parsed STUN or TURN attribute. +/// +/// Each variant corresponds to a specific attribute type. Unknown attributes +/// are preserved as raw bytes for forwarding or diagnostic purposes. +#[derive(Debug, Clone)] +pub enum StunAttribute { + // ---- RFC 5389 STUN attributes ---- + + /// MAPPED-ADDRESS (0x0001): The reflexive transport address of the client + /// as seen by the server. Uses plain (non-XOR) encoding. + MappedAddress(SocketAddr), + + /// USERNAME (0x0006): The username for message integrity, encoded as UTF-8. + Username(String), + + /// MESSAGE-INTEGRITY (0x0008): 20-byte HMAC-SHA1 over the STUN message + /// (up to but not including this attribute). + MessageIntegrity([u8; 20]), + + /// ERROR-CODE (0x0009): Error response code and human-readable reason phrase. + /// Code is in range 300-699 per RFC 5389 §15.6. + ErrorCode { code: u16, reason: String }, + + /// UNKNOWN-ATTRIBUTES (0x000A): List of attribute types that the server + /// did not understand. Used in 420 Unknown Attribute error responses. + UnknownAttributes(Vec), + + /// REALM (0x0014): The authentication realm, encoded as UTF-8. + /// Used with long-term credential mechanism. + Realm(String), + + /// NONCE (0x0015): A server-generated nonce for replay protection. + Nonce(String), + + /// XOR-MAPPED-ADDRESS (0x0020): Same as MAPPED-ADDRESS but XOR-encoded + /// with the magic cookie (and transaction ID for IPv6) to prevent + /// NAT ALG interference. + XorMappedAddress(SocketAddr), + + /// SOFTWARE (0x8022): Textual description of the software being used. + /// Informational only. + Software(String), + + /// FINGERPRINT (0x8028): CRC32 of the STUN message XORed with 0x5354554e. + /// Used to demultiplex STUN from other protocols on the same port. + Fingerprint(u32), + + // ---- RFC 5766 TURN attributes ---- + + /// CHANNEL-NUMBER (0x000C): Channel number for ChannelBind. + /// Must be in range 0x4000-0x7FFF. + ChannelNumber(u16), + + /// LIFETIME (0x000D): Requested or granted allocation lifetime in seconds. + Lifetime(u32), + + /// XOR-PEER-ADDRESS (0x0012): The peer address for Send/Data indications, + /// CreatePermission, and ChannelBind. XOR-encoded like XOR-MAPPED-ADDRESS. + XorPeerAddress(SocketAddr), + + /// DATA (0x0013): The application data payload in Send/Data indications. + Data(Vec), + + /// XOR-RELAYED-ADDRESS (0x0016): The relayed transport address allocated + /// by the server. XOR-encoded. + XorRelayedAddress(SocketAddr), + + /// EVEN-PORT (0x0018): Requests an even port number for the relay address. + /// The boolean indicates the R bit (reserve next-higher port). + EvenPort(bool), + + /// REQUESTED-TRANSPORT (0x0019): The transport protocol for the relay. + /// Value is an IANA protocol number (17 = UDP, 6 = TCP). + RequestedTransport(u8), + + /// DONT-FRAGMENT (0x001A): Requests that the server set the DF bit + /// in outgoing UDP packets. No value (zero-length attribute). + DontFragment, + + /// REQUESTED-ADDRESS-FAMILY (0x0017): Requests a specific address family + /// for the relayed address (0x01 = IPv4, 0x02 = IPv6). + RequestedAddressFamily(u8), + + /// Unknown/unsupported attribute preserved as raw bytes. + Unknown { attr_type: u16, value: Vec }, +} + +// --------------------------------------------------------------------------- +// Decoding +// --------------------------------------------------------------------------- + +/// Decode a single attribute from its type code and raw value bytes. +/// +/// The `transaction_id` is needed for XOR address decoding. +pub fn decode_attribute( + attr_type: u16, + value: &[u8], + transaction_id: &[u8; 12], +) -> Result { + match attr_type { + ATTR_MAPPED_ADDRESS => decode_mapped_address(value), + ATTR_USERNAME => decode_utf8_string(value).map(StunAttribute::Username), + ATTR_MESSAGE_INTEGRITY => decode_message_integrity(value), + ATTR_ERROR_CODE => decode_error_code(value), + ATTR_UNKNOWN_ATTRIBUTES => decode_unknown_attributes(value), + ATTR_REALM => decode_utf8_string(value).map(StunAttribute::Realm), + ATTR_NONCE => decode_utf8_string(value).map(StunAttribute::Nonce), + ATTR_XOR_MAPPED_ADDRESS => { + decode_xor_address(value, transaction_id).map(StunAttribute::XorMappedAddress) + } + ATTR_SOFTWARE => decode_utf8_string(value).map(StunAttribute::Software), + ATTR_FINGERPRINT => decode_fingerprint(value), + ATTR_CHANNEL_NUMBER => decode_channel_number(value), + ATTR_LIFETIME => decode_lifetime(value), + ATTR_XOR_PEER_ADDRESS => { + decode_xor_address(value, transaction_id).map(StunAttribute::XorPeerAddress) + } + ATTR_DATA => Ok(StunAttribute::Data(value.to_vec())), + ATTR_XOR_RELAYED_ADDRESS => { + decode_xor_address(value, transaction_id).map(StunAttribute::XorRelayedAddress) + } + ATTR_REQUESTED_ADDRESS_FAMILY => decode_requested_address_family(value), + ATTR_EVEN_PORT => decode_even_port(value), + ATTR_REQUESTED_TRANSPORT => decode_requested_transport(value), + ATTR_DONT_FRAGMENT => Ok(StunAttribute::DontFragment), + _ => Ok(StunAttribute::Unknown { + attr_type, + value: value.to_vec(), + }), + } +} + +// --------------------------------------------------------------------------- +// Encoding +// --------------------------------------------------------------------------- + +/// Encode a single attribute into TLV wire format with 4-byte padding. +/// +/// Returns the complete TLV bytes: type (2) + length (2) + value + padding. +pub fn encode_attribute(attr: &StunAttribute, transaction_id: &[u8; 12]) -> Vec { + let (attr_type, value) = encode_attribute_value(attr, transaction_id); + let value_len = value.len(); + let padded_len = (value_len + 3) & !3; + + let mut buf = Vec::with_capacity(4 + padded_len); + buf.extend_from_slice(&attr_type.to_be_bytes()); + buf.extend_from_slice(&(value_len as u16).to_be_bytes()); + buf.extend_from_slice(&value); + + // Pad to 4-byte boundary + let padding = padded_len - value_len; + for _ in 0..padding { + buf.push(0); + } + + buf +} + +/// Encode an attribute's value bytes (without the TLV header or padding). +/// Returns (attribute_type_code, value_bytes). +fn encode_attribute_value(attr: &StunAttribute, transaction_id: &[u8; 12]) -> (u16, Vec) { + match attr { + StunAttribute::MappedAddress(addr) => { + (ATTR_MAPPED_ADDRESS, encode_plain_address(addr)) + } + StunAttribute::Username(s) => (ATTR_USERNAME, s.as_bytes().to_vec()), + StunAttribute::MessageIntegrity(hmac) => (ATTR_MESSAGE_INTEGRITY, hmac.to_vec()), + StunAttribute::ErrorCode { code, reason } => { + (ATTR_ERROR_CODE, encode_error_code(*code, reason)) + } + StunAttribute::UnknownAttributes(types) => { + let mut buf = Vec::with_capacity(types.len() * 2); + for &t in types { + buf.extend_from_slice(&t.to_be_bytes()); + } + (ATTR_UNKNOWN_ATTRIBUTES, buf) + } + StunAttribute::Realm(s) => (ATTR_REALM, s.as_bytes().to_vec()), + StunAttribute::Nonce(s) => (ATTR_NONCE, s.as_bytes().to_vec()), + StunAttribute::XorMappedAddress(addr) => { + (ATTR_XOR_MAPPED_ADDRESS, encode_xor_address(addr, transaction_id)) + } + StunAttribute::Software(s) => (ATTR_SOFTWARE, s.as_bytes().to_vec()), + StunAttribute::Fingerprint(val) => (ATTR_FINGERPRINT, val.to_be_bytes().to_vec()), + StunAttribute::ChannelNumber(num) => { + // Channel number is 16 bits followed by 16 bits of RFFU (reserved) + let mut buf = vec![0u8; 4]; + buf[0..2].copy_from_slice(&num.to_be_bytes()); + (ATTR_CHANNEL_NUMBER, buf) + } + StunAttribute::Lifetime(secs) => (ATTR_LIFETIME, secs.to_be_bytes().to_vec()), + StunAttribute::XorPeerAddress(addr) => { + (ATTR_XOR_PEER_ADDRESS, encode_xor_address(addr, transaction_id)) + } + StunAttribute::Data(data) => (ATTR_DATA, data.clone()), + StunAttribute::XorRelayedAddress(addr) => { + (ATTR_XOR_RELAYED_ADDRESS, encode_xor_address(addr, transaction_id)) + } + StunAttribute::RequestedAddressFamily(family) => { + let mut buf = vec![0u8; 4]; + buf[0] = *family; + (ATTR_REQUESTED_ADDRESS_FAMILY, buf) + } + StunAttribute::EvenPort(reserve) => { + let byte = if *reserve { 0x80 } else { 0x00 }; + (ATTR_EVEN_PORT, vec![byte]) + } + StunAttribute::RequestedTransport(proto) => { + // Protocol number in first byte, followed by 3 bytes RFFU + let mut buf = vec![0u8; 4]; + buf[0] = *proto; + (ATTR_REQUESTED_TRANSPORT, buf) + } + StunAttribute::DontFragment => (ATTR_DONT_FRAGMENT, vec![]), + StunAttribute::Unknown { attr_type, value } => (*attr_type, value.clone()), + } +} + +// --------------------------------------------------------------------------- +// Address encoding/decoding helpers +// --------------------------------------------------------------------------- + +/// Decode a plain (non-XOR) MAPPED-ADDRESS attribute value. +/// +/// Format: 1 byte reserved, 1 byte family, 2 bytes port, 4/16 bytes address +fn decode_mapped_address(value: &[u8]) -> Result { + if value.len() < 4 { + return Err(TurnError::StunParseError( + "MAPPED-ADDRESS too short".into(), + )); + } + + let family = value[1]; + let port = u16::from_be_bytes([value[2], value[3]]); + + let addr = match family { + ADDR_FAMILY_IPV4 => { + if value.len() < 8 { + return Err(TurnError::StunParseError( + "MAPPED-ADDRESS IPv4 too short".into(), + )); + } + let ip = Ipv4Addr::new(value[4], value[5], value[6], value[7]); + SocketAddr::new(IpAddr::V4(ip), port) + } + ADDR_FAMILY_IPV6 => { + if value.len() < 20 { + return Err(TurnError::StunParseError( + "MAPPED-ADDRESS IPv6 too short".into(), + )); + } + let mut octets = [0u8; 16]; + octets.copy_from_slice(&value[4..20]); + let ip = Ipv6Addr::from(octets); + SocketAddr::new(IpAddr::V6(ip), port) + } + _ => { + return Err(TurnError::StunParseError(format!( + "unknown address family: 0x{:02x}", + family + ))); + } + }; + + Ok(StunAttribute::MappedAddress(addr)) +} + +/// Encode a plain (non-XOR) address into wire format. +fn encode_plain_address(addr: &SocketAddr) -> Vec { + match addr { + SocketAddr::V4(v4) => { + let mut buf = vec![0u8; 8]; + buf[0] = 0; // reserved + buf[1] = ADDR_FAMILY_IPV4; + buf[2..4].copy_from_slice(&v4.port().to_be_bytes()); + buf[4..8].copy_from_slice(&v4.ip().octets()); + buf + } + SocketAddr::V6(v6) => { + let mut buf = vec![0u8; 20]; + buf[0] = 0; // reserved + buf[1] = ADDR_FAMILY_IPV6; + buf[2..4].copy_from_slice(&v6.port().to_be_bytes()); + buf[4..20].copy_from_slice(&v6.ip().octets()); + buf + } + } +} + +/// Decode an XOR-encoded address (XOR-MAPPED-ADDRESS, XOR-PEER-ADDRESS, +/// XOR-RELAYED-ADDRESS) per RFC 5389 §15.2. +/// +/// For IPv4: port is XORed with top 16 bits of magic cookie; +/// address is XORed with magic cookie. +/// For IPv6: port is XORed with top 16 bits of magic cookie; +/// address is XORed with magic cookie || transaction ID (16 bytes). +fn decode_xor_address( + value: &[u8], + transaction_id: &[u8; 12], +) -> Result { + if value.len() < 4 { + return Err(TurnError::StunParseError( + "XOR address attribute too short".into(), + )); + } + + let family = value[1]; + let x_port = u16::from_be_bytes([value[2], value[3]]); + let cookie_bytes = MAGIC_COOKIE.to_be_bytes(); + let port = x_port ^ u16::from_be_bytes([cookie_bytes[0], cookie_bytes[1]]); + + match family { + ADDR_FAMILY_IPV4 => { + if value.len() < 8 { + return Err(TurnError::StunParseError( + "XOR-MAPPED-ADDRESS IPv4 too short".into(), + )); + } + let x_addr = u32::from_be_bytes([value[4], value[5], value[6], value[7]]); + let addr = x_addr ^ MAGIC_COOKIE; + let ip = Ipv4Addr::from(addr); + Ok(SocketAddr::new(IpAddr::V4(ip), port)) + } + ADDR_FAMILY_IPV6 => { + if value.len() < 20 { + return Err(TurnError::StunParseError( + "XOR-MAPPED-ADDRESS IPv6 too short".into(), + )); + } + // Build the 16-byte XOR mask: magic cookie (4 bytes) + transaction ID (12 bytes) + let mut xor_mask = [0u8; 16]; + xor_mask[0..4].copy_from_slice(&cookie_bytes); + xor_mask[4..16].copy_from_slice(transaction_id); + + let mut addr_bytes = [0u8; 16]; + for i in 0..16 { + addr_bytes[i] = value[4 + i] ^ xor_mask[i]; + } + let ip = Ipv6Addr::from(addr_bytes); + Ok(SocketAddr::new(IpAddr::V6(ip), port)) + } + _ => Err(TurnError::StunParseError(format!( + "unknown address family in XOR address: 0x{:02x}", + family + ))), + } +} + +/// Encode an address using XOR encoding per RFC 5389 §15.2. +fn encode_xor_address(addr: &SocketAddr, transaction_id: &[u8; 12]) -> Vec { + let cookie_bytes = MAGIC_COOKIE.to_be_bytes(); + let x_port = addr.port() ^ u16::from_be_bytes([cookie_bytes[0], cookie_bytes[1]]); + + match addr { + SocketAddr::V4(v4) => { + let mut buf = vec![0u8; 8]; + buf[0] = 0; // reserved + buf[1] = ADDR_FAMILY_IPV4; + buf[2..4].copy_from_slice(&x_port.to_be_bytes()); + let x_addr = u32::from_be_bytes(v4.ip().octets()) ^ MAGIC_COOKIE; + buf[4..8].copy_from_slice(&x_addr.to_be_bytes()); + buf + } + SocketAddr::V6(v6) => { + let mut buf = vec![0u8; 20]; + buf[0] = 0; // reserved + buf[1] = ADDR_FAMILY_IPV6; + buf[2..4].copy_from_slice(&x_port.to_be_bytes()); + + // XOR mask: magic cookie (4 bytes) + transaction ID (12 bytes) + let mut xor_mask = [0u8; 16]; + xor_mask[0..4].copy_from_slice(&cookie_bytes); + xor_mask[4..16].copy_from_slice(transaction_id); + + let octets = v6.ip().octets(); + for i in 0..16 { + buf[4 + i] = octets[i] ^ xor_mask[i]; + } + buf + } + } +} + +// --------------------------------------------------------------------------- +// Specific attribute decoders +// --------------------------------------------------------------------------- + +/// Decode UTF-8 string from raw bytes. +fn decode_utf8_string(value: &[u8]) -> Result { + String::from_utf8(value.to_vec()).map_err(|e| { + TurnError::StunParseError(format!("invalid UTF-8 in attribute: {}", e)) + }) +} + +/// Decode MESSAGE-INTEGRITY (exactly 20 bytes HMAC-SHA1). +fn decode_message_integrity(value: &[u8]) -> Result { + if value.len() != 20 { + return Err(TurnError::StunParseError(format!( + "MESSAGE-INTEGRITY must be 20 bytes, got {}", + value.len() + ))); + } + let mut hmac = [0u8; 20]; + hmac.copy_from_slice(value); + Ok(StunAttribute::MessageIntegrity(hmac)) +} + +/// Decode ERROR-CODE attribute per RFC 5389 §15.6. +/// +/// Format: 2 bytes reserved, 1 byte with class (hundreds digit) in bits 0-2, +/// 1 byte with number (tens+units), followed by UTF-8 reason phrase. +fn decode_error_code(value: &[u8]) -> Result { + if value.len() < 4 { + return Err(TurnError::StunParseError( + "ERROR-CODE too short".into(), + )); + } + + let class = (value[2] & 0x07) as u16; + let number = value[3] as u16; + let code = class * 100 + number; + + let reason = if value.len() > 4 { + String::from_utf8_lossy(&value[4..]).into_owned() + } else { + String::new() + }; + + Ok(StunAttribute::ErrorCode { code, reason }) +} + +/// Encode ERROR-CODE value bytes per RFC 5389 §15.6. +fn encode_error_code(code: u16, reason: &str) -> Vec { + let class = (code / 100) as u8; + let number = (code % 100) as u8; + + let reason_bytes = reason.as_bytes(); + let mut buf = Vec::with_capacity(4 + reason_bytes.len()); + buf.push(0); // reserved + buf.push(0); // reserved + buf.push(class & 0x07); + buf.push(number); + buf.extend_from_slice(reason_bytes); + buf +} + +/// Decode UNKNOWN-ATTRIBUTES (list of 16-bit attribute types). +fn decode_unknown_attributes(value: &[u8]) -> Result { + if value.len() % 2 != 0 { + return Err(TurnError::StunParseError( + "UNKNOWN-ATTRIBUTES length must be even".into(), + )); + } + let mut types = Vec::with_capacity(value.len() / 2); + for chunk in value.chunks_exact(2) { + types.push(u16::from_be_bytes([chunk[0], chunk[1]])); + } + Ok(StunAttribute::UnknownAttributes(types)) +} + +/// Decode FINGERPRINT (4-byte CRC32 XOR value). +fn decode_fingerprint(value: &[u8]) -> Result { + if value.len() != 4 { + return Err(TurnError::StunParseError(format!( + "FINGERPRINT must be 4 bytes, got {}", + value.len() + ))); + } + let val = u32::from_be_bytes([value[0], value[1], value[2], value[3]]); + Ok(StunAttribute::Fingerprint(val)) +} + +/// Decode CHANNEL-NUMBER (16-bit number + 16-bit RFFU). +fn decode_channel_number(value: &[u8]) -> Result { + if value.len() < 4 { + return Err(TurnError::StunParseError( + "CHANNEL-NUMBER too short".into(), + )); + } + let num = u16::from_be_bytes([value[0], value[1]]); + Ok(StunAttribute::ChannelNumber(num)) +} + +/// Decode LIFETIME (32-bit seconds). +fn decode_lifetime(value: &[u8]) -> Result { + if value.len() < 4 { + return Err(TurnError::StunParseError("LIFETIME too short".into())); + } + let secs = u32::from_be_bytes([value[0], value[1], value[2], value[3]]); + Ok(StunAttribute::Lifetime(secs)) +} + +/// Decode REQUESTED-ADDRESS-FAMILY (1 byte family + 3 bytes RFFU). +fn decode_requested_address_family(value: &[u8]) -> Result { + if value.is_empty() { + return Err(TurnError::StunParseError( + "REQUESTED-ADDRESS-FAMILY empty".into(), + )); + } + Ok(StunAttribute::RequestedAddressFamily(value[0])) +} + +/// Decode EVEN-PORT (1 byte, R bit in most significant bit). +fn decode_even_port(value: &[u8]) -> Result { + if value.is_empty() { + return Err(TurnError::StunParseError("EVEN-PORT empty".into())); + } + let reserve = (value[0] & 0x80) != 0; + Ok(StunAttribute::EvenPort(reserve)) +} + +/// Decode REQUESTED-TRANSPORT (1 byte protocol + 3 bytes RFFU). +fn decode_requested_transport(value: &[u8]) -> Result { + if value.is_empty() { + return Err(TurnError::StunParseError( + "REQUESTED-TRANSPORT empty".into(), + )); + } + Ok(StunAttribute::RequestedTransport(value[0])) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6}; + + #[test] + fn test_xor_mapped_address_ipv4_roundtrip() { + let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 100), 12345)); + let txn_id = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]; + + let encoded = encode_xor_address(&addr, &txn_id); + let decoded = decode_xor_address(&encoded, &txn_id).unwrap(); + assert_eq!(decoded, addr); + } + + #[test] + fn test_xor_mapped_address_ipv6_roundtrip() { + let ip = Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0, 0, 0x8a2e, 0x0370, 0x7334); + let addr = SocketAddr::V6(SocketAddrV6::new(ip, 54321, 0, 0)); + let txn_id = [0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6]; + + let encoded = encode_xor_address(&addr, &txn_id); + let decoded = decode_xor_address(&encoded, &txn_id).unwrap(); + assert_eq!(decoded, addr); + } + + #[test] + fn test_plain_address_ipv4_roundtrip() { + let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 8080)); + let encoded = encode_plain_address(&addr); + let decoded = decode_mapped_address(&encoded).unwrap(); + if let StunAttribute::MappedAddress(decoded_addr) = decoded { + assert_eq!(decoded_addr, addr); + } else { + panic!("expected MappedAddress"); + } + } + + #[test] + fn test_error_code_roundtrip() { + let encoded = encode_error_code(401, "Unauthorized"); + let decoded = decode_error_code(&encoded).unwrap(); + if let StunAttribute::ErrorCode { code, reason } = decoded { + assert_eq!(code, 401); + assert_eq!(reason, "Unauthorized"); + } else { + panic!("expected ErrorCode"); + } + } + + #[test] + fn test_error_code_438() { + let encoded = encode_error_code(438, "Stale Nonce"); + let decoded = decode_error_code(&encoded).unwrap(); + if let StunAttribute::ErrorCode { code, reason } = decoded { + assert_eq!(code, 438); + assert_eq!(reason, "Stale Nonce"); + } else { + panic!("expected ErrorCode"); + } + } + + #[test] + fn test_attribute_tlv_encoding() { + let txn_id = [0u8; 12]; + let attr = StunAttribute::Username("alice".into()); + let encoded = encode_attribute(&attr, &txn_id); + + // Type (2) + Length (2) + "alice" (5) + padding (3) = 12 + assert_eq!(encoded.len(), 12); + assert_eq!(encoded[0..2], ATTR_USERNAME.to_be_bytes()); + assert_eq!(encoded[2..4], 5u16.to_be_bytes()); + assert_eq!(&encoded[4..9], b"alice"); + assert_eq!(encoded[9], 0); // padding + assert_eq!(encoded[10], 0); + assert_eq!(encoded[11], 0); + } + + #[test] + fn test_attribute_tlv_encoding_4_byte_aligned() { + let txn_id = [0u8; 12]; + let attr = StunAttribute::Username("test".into()); + let encoded = encode_attribute(&attr, &txn_id); + + // Type (2) + Length (2) + "test" (4) = 8, already aligned + assert_eq!(encoded.len(), 8); + } + + #[test] + fn test_lifetime_roundtrip() { + let attr = StunAttribute::Lifetime(600); + let txn_id = [0u8; 12]; + let encoded = encode_attribute(&attr, &txn_id); + + // Type (2) + Length (2) + value (4) = 8 + assert_eq!(encoded.len(), 8); + + let decoded = decode_attribute(ATTR_LIFETIME, &encoded[4..8], &txn_id).unwrap(); + if let StunAttribute::Lifetime(secs) = decoded { + assert_eq!(secs, 600); + } else { + panic!("expected Lifetime"); + } + } + + #[test] + fn test_channel_number_roundtrip() { + let attr = StunAttribute::ChannelNumber(0x4000); + let txn_id = [0u8; 12]; + let encoded = encode_attribute(&attr, &txn_id); + + // Type (2) + Length (2) + value (4) = 8 + assert_eq!(encoded.len(), 8); + + let decoded = decode_attribute(ATTR_CHANNEL_NUMBER, &encoded[4..8], &txn_id).unwrap(); + if let StunAttribute::ChannelNumber(num) = decoded { + assert_eq!(num, 0x4000); + } else { + panic!("expected ChannelNumber"); + } + } + + #[test] + fn test_requested_transport_udp() { + let attr = StunAttribute::RequestedTransport(17); // UDP + let txn_id = [0u8; 12]; + let encoded = encode_attribute(&attr, &txn_id); + + let decoded = decode_attribute(ATTR_REQUESTED_TRANSPORT, &encoded[4..8], &txn_id).unwrap(); + if let StunAttribute::RequestedTransport(proto) = decoded { + assert_eq!(proto, 17); + } else { + panic!("expected RequestedTransport"); + } + } + + #[test] + fn test_dont_fragment() { + let attr = StunAttribute::DontFragment; + let txn_id = [0u8; 12]; + let encoded = encode_attribute(&attr, &txn_id); + + // Type (2) + Length (2) + no value = 4 + assert_eq!(encoded.len(), 4); + assert_eq!(encoded[2..4], 0u16.to_be_bytes()); // length = 0 + } + + #[test] + fn test_unknown_attribute_preserved() { + let txn_id = [0u8; 12]; + let decoded = + decode_attribute(0xFFFF, &[0x01, 0x02, 0x03], &txn_id).unwrap(); + if let StunAttribute::Unknown { attr_type, value } = decoded { + assert_eq!(attr_type, 0xFFFF); + assert_eq!(value, vec![0x01, 0x02, 0x03]); + } else { + panic!("expected Unknown"); + } + } + + #[test] + fn test_message_integrity_decode() { + let hmac = [0xAA; 20]; + let decoded = decode_message_integrity(&hmac).unwrap(); + if let StunAttribute::MessageIntegrity(h) = decoded { + assert_eq!(h, [0xAA; 20]); + } else { + panic!("expected MessageIntegrity"); + } + } + + #[test] + fn test_message_integrity_wrong_length() { + let result = decode_message_integrity(&[0; 10]); + assert!(result.is_err()); + } + + #[test] + fn test_fingerprint_roundtrip() { + let val = 0xDEADBEEF_u32; + let encoded_value = val.to_be_bytes().to_vec(); + let decoded = decode_fingerprint(&encoded_value).unwrap(); + if let StunAttribute::Fingerprint(v) = decoded { + assert_eq!(v, val); + } else { + panic!("expected Fingerprint"); + } + } + + #[test] + fn test_even_port_reserve_bit() { + // R bit set + let decoded = decode_even_port(&[0x80]).unwrap(); + if let StunAttribute::EvenPort(reserve) = decoded { + assert!(reserve); + } else { + panic!("expected EvenPort"); + } + + // R bit not set + let decoded = decode_even_port(&[0x00]).unwrap(); + if let StunAttribute::EvenPort(reserve) = decoded { + assert!(!reserve); + } else { + panic!("expected EvenPort"); + } + } + + #[test] + fn test_unknown_attributes_list() { + let value = vec![0x00, 0x01, 0x00, 0x20]; // MAPPED-ADDRESS, XOR-MAPPED-ADDRESS + let decoded = decode_unknown_attributes(&value).unwrap(); + if let StunAttribute::UnknownAttributes(types) = decoded { + assert_eq!(types, vec![0x0001, 0x0020]); + } else { + panic!("expected UnknownAttributes"); + } + } +} diff --git a/src/turn/credentials.rs b/src/turn/credentials.rs new file mode 100644 index 0000000..fc1d42d --- /dev/null +++ b/src/turn/credentials.rs @@ -0,0 +1,758 @@ +// TURN credential generation and validation +// +// Implements the HMAC-SHA1 time-limited credential mechanism per +// draft-uberti-behave-turn-rest-00, which is the standard used by +// all major WebRTC implementations (Chrome, Firefox, Safari). +// +// Credential format: +// Username: "{expiry_unix_timestamp}:{peer_id}" +// Password: Base64(HMAC-SHA1(shared_secret, username)) +// +// Also implements MESSAGE-INTEGRITY computation per RFC 5389 §15.4 +// using long-term credentials (key = MD5(username:realm:password)). +// +// Nonce generation uses HMAC-SHA1 with an embedded timestamp for +// stateless stale-nonce detection. + +use std::time::{SystemTime, UNIX_EPOCH}; + +use crate::turn::error::TurnError; + +// --------------------------------------------------------------------------- +// HMAC-SHA1 (RFC 2104) +// --------------------------------------------------------------------------- + +/// Compute HMAC-SHA1(key, message) returning a 20-byte digest. +/// +/// This is a self-contained implementation of HMAC-SHA1 that doesn't +/// depend on external crates. In production, replace with `hmac` + `sha1` +/// crates for better performance and constant-time comparison. +pub fn hmac_sha1(key: &[u8], message: &[u8]) -> [u8; 20] { + const BLOCK_SIZE: usize = 64; + const IPAD: u8 = 0x36; + const OPAD: u8 = 0x5C; + + // If key > block size, hash it first + let key_block = if key.len() > BLOCK_SIZE { + let hashed = sha1_digest(key); + let mut block = [0u8; BLOCK_SIZE]; + block[..20].copy_from_slice(&hashed); + block + } else { + let mut block = [0u8; BLOCK_SIZE]; + block[..key.len()].copy_from_slice(key); + block + }; + + // Inner hash: SHA1(key XOR ipad || message) + let mut inner_input = Vec::with_capacity(BLOCK_SIZE + message.len()); + for i in 0..BLOCK_SIZE { + inner_input.push(key_block[i] ^ IPAD); + } + inner_input.extend_from_slice(message); + let inner_hash = sha1_digest(&inner_input); + + // Outer hash: SHA1(key XOR opad || inner_hash) + let mut outer_input = Vec::with_capacity(BLOCK_SIZE + 20); + for i in 0..BLOCK_SIZE { + outer_input.push(key_block[i] ^ OPAD); + } + outer_input.extend_from_slice(&inner_hash); + sha1_digest(&outer_input) +} + +// --------------------------------------------------------------------------- +// SHA-1 (FIPS 180-4) +// --------------------------------------------------------------------------- + +/// Compute SHA-1 digest of the input data, returning a 20-byte hash. +/// +/// This is a self-contained implementation. In production, use the `sha1` crate. +fn sha1_digest(data: &[u8]) -> [u8; 20] { + let mut h0: u32 = 0x67452301; + let mut h1: u32 = 0xEFCDAB89; + let mut h2: u32 = 0x98BADCFE; + let mut h3: u32 = 0x10325476; + let mut h4: u32 = 0xC3D2E1F0; + + // Pre-processing: add padding + let bit_len = (data.len() as u64) * 8; + let mut padded = data.to_vec(); + padded.push(0x80); + while padded.len() % 64 != 56 { + padded.push(0x00); + } + padded.extend_from_slice(&bit_len.to_be_bytes()); + + // Process each 512-bit (64-byte) block + for block in padded.chunks_exact(64) { + let mut w = [0u32; 80]; + for i in 0..16 { + w[i] = u32::from_be_bytes([ + block[i * 4], + block[i * 4 + 1], + block[i * 4 + 2], + block[i * 4 + 3], + ]); + } + for i in 16..80 { + w[i] = (w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16]).rotate_left(1); + } + + let mut a = h0; + let mut b = h1; + let mut c = h2; + let mut d = h3; + let mut e = h4; + + for i in 0..80 { + let (f, k) = match i { + 0..=19 => ((b & c) | ((!b) & d), 0x5A827999_u32), + 20..=39 => (b ^ c ^ d, 0x6ED9EBA1_u32), + 40..=59 => ((b & c) | (b & d) | (c & d), 0x8F1BBCDC_u32), + 60..=79 => (b ^ c ^ d, 0xCA62C1D6_u32), + _ => unreachable!(), + }; + + let temp = a + .rotate_left(5) + .wrapping_add(f) + .wrapping_add(e) + .wrapping_add(k) + .wrapping_add(w[i]); + e = d; + d = c; + c = b.rotate_left(30); + b = a; + a = temp; + } + + h0 = h0.wrapping_add(a); + h1 = h1.wrapping_add(b); + h2 = h2.wrapping_add(c); + h3 = h3.wrapping_add(d); + h4 = h4.wrapping_add(e); + } + + let mut result = [0u8; 20]; + result[0..4].copy_from_slice(&h0.to_be_bytes()); + result[4..8].copy_from_slice(&h1.to_be_bytes()); + result[8..12].copy_from_slice(&h2.to_be_bytes()); + result[12..16].copy_from_slice(&h3.to_be_bytes()); + result[16..20].copy_from_slice(&h4.to_be_bytes()); + result +} + +// --------------------------------------------------------------------------- +// MD5 (RFC 1321) - needed for long-term credential key derivation +// --------------------------------------------------------------------------- + +/// Compute MD5 digest of the input data, returning a 16-byte hash. +/// +/// Used for the long-term credential key: `MD5(username:realm:password)`. +/// Self-contained implementation; in production use the `md-5` crate. +fn md5_digest(data: &[u8]) -> [u8; 16] { + // Initial state + let mut a0: u32 = 0x67452301; + let mut b0: u32 = 0xefcdab89; + let mut c0: u32 = 0x98badcfe; + let mut d0: u32 = 0x10325476; + + // Per-round shift amounts + const S: [u32; 64] = [ + 7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22, + 5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20, + 4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23, + 6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21, + ]; + + // Pre-computed constants: floor(2^32 * |sin(i + 1)|) + const K: [u32; 64] = [ + 0xd76aa478, 0xe8c7b756, 0x242070db, 0xc1bdceee, + 0xf57c0faf, 0x4787c62a, 0xa8304613, 0xfd469501, + 0x698098d8, 0x8b44f7af, 0xffff5bb1, 0x895cd7be, + 0x6b901122, 0xfd987193, 0xa679438e, 0x49b40821, + 0xf61e2562, 0xc040b340, 0x265e5a51, 0xe9b6c7aa, + 0xd62f105d, 0x02441453, 0xd8a1e681, 0xe7d3fbc8, + 0x21e1cde6, 0xc33707d6, 0xf4d50d87, 0x455a14ed, + 0xa9e3e905, 0xfcefa3f8, 0x676f02d9, 0x8d2a4c8a, + 0xfffa3942, 0x8771f681, 0x6d9d6122, 0xfde5380c, + 0xa4beea44, 0x4bdecfa9, 0xf6bb4b60, 0xbebfbc70, + 0x289b7ec6, 0xeaa127fa, 0xd4ef3085, 0x04881d05, + 0xd9d4d039, 0xe6db99e5, 0x1fa27cf8, 0xc4ac5665, + 0xf4292244, 0x432aff97, 0xab9423a7, 0xfc93a039, + 0x655b59c3, 0x8f0ccc92, 0xffeff47d, 0x85845dd1, + 0x6fa87e4f, 0xfe2ce6e0, 0xa3014314, 0x4e0811a1, + 0xf7537e82, 0xbd3af235, 0x2ad7d2bb, 0xeb86d391, + ]; + + // Pre-processing: add padding + let bit_len = (data.len() as u64) * 8; + let mut padded = data.to_vec(); + padded.push(0x80); + while padded.len() % 64 != 56 { + padded.push(0x00); + } + padded.extend_from_slice(&bit_len.to_le_bytes()); + + // Process each 512-bit block + for block in padded.chunks_exact(64) { + let mut m = [0u32; 16]; + for i in 0..16 { + m[i] = u32::from_le_bytes([ + block[i * 4], + block[i * 4 + 1], + block[i * 4 + 2], + block[i * 4 + 3], + ]); + } + + let mut a = a0; + let mut b = b0; + let mut c = c0; + let mut d = d0; + + for i in 0..64 { + let (f, g) = match i { + 0..=15 => ((b & c) | ((!b) & d), i), + 16..=31 => ((d & b) | ((!d) & c), (5 * i + 1) % 16), + 32..=47 => (b ^ c ^ d, (3 * i + 5) % 16), + 48..=63 => (c ^ (b | (!d)), (7 * i) % 16), + _ => unreachable!(), + }; + + let temp = d; + d = c; + c = b; + b = b.wrapping_add( + (a.wrapping_add(f).wrapping_add(K[i]).wrapping_add(m[g])) + .rotate_left(S[i]), + ); + a = temp; + } + + a0 = a0.wrapping_add(a); + b0 = b0.wrapping_add(b); + c0 = c0.wrapping_add(c); + d0 = d0.wrapping_add(d); + } + + let mut result = [0u8; 16]; + result[0..4].copy_from_slice(&a0.to_le_bytes()); + result[4..8].copy_from_slice(&b0.to_le_bytes()); + result[8..12].copy_from_slice(&c0.to_le_bytes()); + result[12..16].copy_from_slice(&d0.to_le_bytes()); + result +} + +// --------------------------------------------------------------------------- +// Base64 encoding (RFC 4648) +// --------------------------------------------------------------------------- + +/// Encode bytes to base64 string using standard alphabet with padding. +fn base64_encode(data: &[u8]) -> String { + const ALPHABET: &[u8; 64] = + b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + let mut result = String::with_capacity((data.len() + 2) / 3 * 4); + let chunks = data.chunks(3); + + for chunk in chunks { + let b0 = chunk[0] as u32; + let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 }; + let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 }; + + let triple = (b0 << 16) | (b1 << 8) | b2; + + result.push(ALPHABET[((triple >> 18) & 0x3F) as usize] as char); + result.push(ALPHABET[((triple >> 12) & 0x3F) as usize] as char); + + if chunk.len() > 1 { + result.push(ALPHABET[((triple >> 6) & 0x3F) as usize] as char); + } else { + result.push('='); + } + + if chunk.len() > 2 { + result.push(ALPHABET[(triple & 0x3F) as usize] as char); + } else { + result.push('='); + } + } + + result +} + +/// Decode a base64 string to bytes. Returns None on invalid input. +fn base64_decode(input: &str) -> Option> { + fn char_val(c: u8) -> Option { + const ALPHABET: &[u8; 64] = + b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + ALPHABET.iter().position(|&x| x == c).map(|p| p as u8) + } + + let input = input.trim_end_matches('='); + let bytes = input.as_bytes(); + + let mut result = Vec::with_capacity(bytes.len() * 3 / 4); + let chunks = bytes.chunks(4); + + for chunk in chunks { + let vals: Vec = chunk.iter().filter_map(|&b| char_val(b)).collect(); + if vals.len() != chunk.len() { + return None; + } + + let triple = match vals.len() { + 4 => { + ((vals[0] as u32) << 18) + | ((vals[1] as u32) << 12) + | ((vals[2] as u32) << 6) + | (vals[3] as u32) + } + 3 => { + ((vals[0] as u32) << 18) | ((vals[1] as u32) << 12) | ((vals[2] as u32) << 6) + } + 2 => ((vals[0] as u32) << 18) | ((vals[1] as u32) << 12), + _ => return None, + }; + + result.push((triple >> 16) as u8); + if vals.len() > 2 { + result.push((triple >> 8 & 0xFF) as u8); + } + if vals.len() > 3 { + result.push((triple & 0xFF) as u8); + } + } + + Some(result) +} + +// --------------------------------------------------------------------------- +// Hex encoding (for nonces) +// --------------------------------------------------------------------------- + +/// Encode bytes as lowercase hexadecimal string. +fn hex_encode(data: &[u8]) -> String { + let mut s = String::with_capacity(data.len() * 2); + for &b in data { + s.push_str(&format!("{:02x}", b)); + } + s +} + +/// Decode a hexadecimal string to bytes. Returns None on invalid input. +fn hex_decode(s: &str) -> Option> { + if s.len() % 2 != 0 { + return None; + } + let mut result = Vec::with_capacity(s.len() / 2); + for i in (0..s.len()).step_by(2) { + let byte = u8::from_str_radix(&s[i..i + 2], 16).ok()?; + result.push(byte); + } + Some(result) +} + +// --------------------------------------------------------------------------- +// Credential generation / validation +// --------------------------------------------------------------------------- + +/// Generate time-limited TURN credentials per draft-uberti-behave-turn-rest. +/// +/// Returns `(username, password)` where: +/// - username = `"{expiry_timestamp}:{peer_id}"` +/// - password = `Base64(HMAC-SHA1(shared_secret, username))` +/// +/// The credentials expire `ttl_secs` seconds from now. +pub fn generate_credentials( + peer_id: &str, + shared_secret: &[u8], + ttl_secs: u64, +) -> (String, String) { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + let expiry = now + ttl_secs; + + let username = format!("{}:{}", expiry, peer_id); + let hmac = hmac_sha1(shared_secret, username.as_bytes()); + let password = base64_encode(&hmac); + + (username, password) +} + +/// Validate time-limited TURN credentials. +/// +/// 1. Parse the expiry timestamp from the username +/// 2. Check that the credentials have not expired +/// 3. Recompute HMAC-SHA1 and compare with the provided password +pub fn validate_credentials( + username: &str, + password: &str, + shared_secret: &[u8], +) -> Result<(), TurnError> { + // Parse expiry timestamp from username (format: "{timestamp}:{peer_id}") + let colon_pos = username + .find(':') + .ok_or(TurnError::Unauthorized)?; + + let expiry_str = &username[..colon_pos]; + let expiry: u64 = expiry_str + .parse() + .map_err(|_| TurnError::Unauthorized)?; + + // Check expiry + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + if now > expiry { + return Err(TurnError::Unauthorized); + } + + // Recompute HMAC-SHA1 and compare + let expected_hmac = hmac_sha1(shared_secret, username.as_bytes()); + let expected_password = base64_encode(&expected_hmac); + + if !constant_time_eq(password.as_bytes(), expected_password.as_bytes()) { + return Err(TurnError::Unauthorized); + } + + Ok(()) +} + +// --------------------------------------------------------------------------- +// MESSAGE-INTEGRITY computation (RFC 5389 §15.4) +// --------------------------------------------------------------------------- + +/// Compute the long-term credential key per RFC 5389 §15.4: +/// `key = MD5(username ":" realm ":" password)` +pub fn compute_long_term_key(username: &str, realm: &str, password: &str) -> [u8; 16] { + let input = format!("{}:{}:{}", username, realm, password); + md5_digest(input.as_bytes()) +} + +/// Compute the MESSAGE-INTEGRITY attribute value (HMAC-SHA1). +/// +/// `key` is the long-term credential key (output of [`compute_long_term_key`]). +/// `message_bytes` should be the STUN message up to (but not including) the +/// MESSAGE-INTEGRITY attribute, with the message length field adjusted to +/// include the MESSAGE-INTEGRITY attribute. +pub fn compute_message_integrity(key: &[u8], message_bytes: &[u8]) -> [u8; 20] { + hmac_sha1(key, message_bytes) +} + +/// Validate the MESSAGE-INTEGRITY attribute of a STUN message. +/// +/// Recomputes the HMAC-SHA1 over the message bytes and compares it +/// with the provided MESSAGE-INTEGRITY value in constant time. +pub fn validate_message_integrity( + message_integrity: &[u8; 20], + key: &[u8], + message_bytes: &[u8], +) -> bool { + let expected = compute_message_integrity(key, message_bytes); + constant_time_eq(&expected, message_integrity) +} + +// --------------------------------------------------------------------------- +// Nonce generation / validation +// --------------------------------------------------------------------------- + +/// Generate a NONCE that embeds a timestamp for stateless staleness detection. +/// +/// Format: `{hex_timestamp}-{hex_hmac_of_timestamp}` +/// +/// The server can validate the nonce without storing state by recomputing +/// the HMAC. The timestamp allows detecting stale nonces. +pub fn compute_nonce(timestamp: u64, secret: &[u8]) -> String { + let ts_bytes = timestamp.to_be_bytes(); + let hmac = hmac_sha1(secret, &ts_bytes); + // Use first 8 bytes of HMAC for a shorter nonce + let hmac_short = &hmac[..8]; + format!("{}-{}", hex_encode(&ts_bytes), hex_encode(hmac_short)) +} + +/// Validate a NONCE and check that it hasn't expired. +/// +/// 1. Parse the timestamp from the nonce +/// 2. Recompute HMAC and verify it matches +/// 3. Check that the nonce age doesn't exceed `max_age_secs` +pub fn validate_nonce( + nonce: &str, + secret: &[u8], + max_age_secs: u64, +) -> Result<(), TurnError> { + let parts: Vec<&str> = nonce.splitn(2, '-').collect(); + if parts.len() != 2 { + return Err(TurnError::StaleNonce); + } + + let ts_hex = parts[0]; + let hmac_hex = parts[1]; + + // Decode timestamp + let ts_bytes = hex_decode(ts_hex).ok_or(TurnError::StaleNonce)?; + if ts_bytes.len() != 8 { + return Err(TurnError::StaleNonce); + } + + let timestamp = u64::from_be_bytes([ + ts_bytes[0], ts_bytes[1], ts_bytes[2], ts_bytes[3], + ts_bytes[4], ts_bytes[5], ts_bytes[6], ts_bytes[7], + ]); + + // Verify HMAC + let expected_hmac = hmac_sha1(secret, &ts_bytes); + let expected_hmac_hex = hex_encode(&expected_hmac[..8]); + + if !constant_time_eq(hmac_hex.as_bytes(), expected_hmac_hex.as_bytes()) { + return Err(TurnError::StaleNonce); + } + + // Check age + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + if now > timestamp + max_age_secs { + return Err(TurnError::StaleNonce); + } + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Constant-time comparison +// --------------------------------------------------------------------------- + +/// Compare two byte slices in constant time to prevent timing attacks. +fn constant_time_eq(a: &[u8], b: &[u8]) -> bool { + if a.len() != b.len() { + return false; + } + let mut diff = 0u8; + for (x, y) in a.iter().zip(b.iter()) { + diff |= x ^ y; + } + diff == 0 +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sha1_empty() { + // SHA-1("") = da39a3ee5e6b4b0d3255bfef95601890afd80709 + let hash = sha1_digest(b""); + let hex = hex_encode(&hash); + assert_eq!(hex, "da39a3ee5e6b4b0d3255bfef95601890afd80709"); + } + + #[test] + fn test_sha1_abc() { + // SHA-1("abc") = a9993e364706816aba3e25717850c26c9cd0d89d + let hash = sha1_digest(b"abc"); + let hex = hex_encode(&hash); + assert_eq!(hex, "a9993e364706816aba3e25717850c26c9cd0d89d"); + } + + #[test] + fn test_md5_empty() { + // MD5("") = d41d8cd98f00b204e9800998ecf8427e + let hash = md5_digest(b""); + let hex = hex_encode(&hash); + assert_eq!(hex, "d41d8cd98f00b204e9800998ecf8427e"); + } + + #[test] + fn test_md5_abc() { + // MD5("abc") = 900150983cd24fb0d6963f7d28e17f72 + let hash = md5_digest(b"abc"); + let hex = hex_encode(&hash); + assert_eq!(hex, "900150983cd24fb0d6963f7d28e17f72"); + } + + #[test] + fn test_hmac_sha1_rfc2202_test1() { + // RFC 2202 Test Case 1: + // Key = 0x0b repeated 20 times + // Data = "Hi There" + // HMAC = b617318655057264e28bc0b6fb378c8ef146be00 + let key = [0x0bu8; 20]; + let data = b"Hi There"; + let hmac = hmac_sha1(&key, data); + let hex = hex_encode(&hmac); + assert_eq!(hex, "b617318655057264e28bc0b6fb378c8ef146be00"); + } + + #[test] + fn test_hmac_sha1_rfc2202_test2() { + // RFC 2202 Test Case 2: + // Key = "Jefe" + // Data = "what do ya want for nothing?" + // HMAC = effcdf6ae5eb2fa2d27416d5f184df9c259a7c79 + let key = b"Jefe"; + let data = b"what do ya want for nothing?"; + let hmac = hmac_sha1(key, data); + let hex = hex_encode(&hmac); + assert_eq!(hex, "effcdf6ae5eb2fa2d27416d5f184df9c259a7c79"); + } + + #[test] + fn test_base64_encode() { + assert_eq!(base64_encode(b""), ""); + assert_eq!(base64_encode(b"f"), "Zg=="); + assert_eq!(base64_encode(b"fo"), "Zm8="); + assert_eq!(base64_encode(b"foo"), "Zm9v"); + assert_eq!(base64_encode(b"foobar"), "Zm9vYmFy"); + } + + #[test] + fn test_base64_decode() { + assert_eq!(base64_decode("").unwrap(), b""); + assert_eq!(base64_decode("Zg==").unwrap(), b"f"); + assert_eq!(base64_decode("Zm8=").unwrap(), b"fo"); + assert_eq!(base64_decode("Zm9v").unwrap(), b"foo"); + assert_eq!(base64_decode("Zm9vYmFy").unwrap(), b"foobar"); + } + + #[test] + fn test_base64_roundtrip() { + let data = b"Hello, TURN server!"; + let encoded = base64_encode(data); + let decoded = base64_decode(&encoded).unwrap(); + assert_eq!(&decoded, data); + } + + #[test] + fn test_generate_validate_credentials() { + let secret = b"supersecretkey"; + let peer_id = "12D3KooWTestPeerId"; + let ttl = 3600; // 1 hour + + let (username, password) = generate_credentials(peer_id, secret, ttl); + + // Username should contain the peer_id + assert!(username.contains(peer_id)); + assert!(username.contains(':')); + + // Password should be valid base64 + assert!(base64_decode(&password).is_some()); + + // Validation should succeed + let result = validate_credentials(&username, &password, secret); + assert!(result.is_ok()); + } + + #[test] + fn test_validate_credentials_wrong_password() { + let secret = b"supersecretkey"; + let (username, _) = generate_credentials("test", secret, 3600); + + let result = validate_credentials(&username, "wrongpassword", secret); + assert!(result.is_err()); + } + + #[test] + fn test_validate_credentials_wrong_secret() { + let secret = b"supersecretkey"; + let (username, password) = generate_credentials("test", secret, 3600); + + let result = validate_credentials(&username, &password, b"wrongsecret"); + assert!(result.is_err()); + } + + #[test] + fn test_long_term_key() { + // RFC 5389 example-ish: key = MD5("user:realm:pass") + let key = compute_long_term_key("user", "realm", "pass"); + assert_eq!(key.len(), 16); + + // Should be deterministic + let key2 = compute_long_term_key("user", "realm", "pass"); + assert_eq!(key, key2); + } + + #[test] + fn test_message_integrity_roundtrip() { + let key = compute_long_term_key("alice", "duskchat.app", "password123"); + let message = b"fake stun message bytes for testing"; + + let integrity = compute_message_integrity(&key, message); + assert!(validate_message_integrity(&integrity, &key, message)); + + // Tampered message should fail + let tampered = b"tampered stun message bytes for testing"; + assert!(!validate_message_integrity(&integrity, &key, tampered)); + } + + #[test] + fn test_nonce_roundtrip() { + let secret = b"nonce_secret_key"; + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + let nonce = compute_nonce(now, secret); + assert!(nonce.contains('-')); + + // Should validate successfully with generous max age + let result = validate_nonce(&nonce, secret, 3600); + assert!(result.is_ok()); + } + + #[test] + fn test_nonce_stale() { + let secret = b"nonce_secret_key"; + // Use a timestamp from long ago + let old_timestamp = 1000000; + + let nonce = compute_nonce(old_timestamp, secret); + let result = validate_nonce(&nonce, secret, 3600); + assert!(result.is_err()); + } + + #[test] + fn test_nonce_tampered() { + let secret = b"nonce_secret_key"; + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + let nonce = compute_nonce(now, secret); + + // Validate with wrong secret + let result = validate_nonce(&nonce, b"wrong_secret", 3600); + assert!(result.is_err()); + } + + #[test] + fn test_hex_roundtrip() { + let data = &[0xDE, 0xAD, 0xBE, 0xEF]; + let encoded = hex_encode(data); + assert_eq!(encoded, "deadbeef"); + let decoded = hex_decode(&encoded).unwrap(); + assert_eq!(&decoded, data); + } + + #[test] + fn test_constant_time_eq() { + assert!(constant_time_eq(b"hello", b"hello")); + assert!(!constant_time_eq(b"hello", b"world")); + assert!(!constant_time_eq(b"hello", b"hell")); + assert!(constant_time_eq(b"", b"")); + } +} diff --git a/src/turn/error.rs b/src/turn/error.rs new file mode 100644 index 0000000..f088784 --- /dev/null +++ b/src/turn/error.rs @@ -0,0 +1,102 @@ +// TURN server error types +// +// Covers all error conditions that can arise during STUN/TURN message +// processing, credential validation, and allocation management. + +use std::fmt; + +/// Comprehensive error type for TURN server operations. +/// +/// Each variant maps to a specific failure condition in the STUN/TURN +/// protocol stack, from low-level parse errors to high-level allocation +/// policy violations. +#[derive(Debug, Clone)] +pub enum TurnError { + /// Invalid STUN message format: bad header, truncated message, + /// missing magic cookie, or malformed TLV attributes. + StunParseError(String), + + /// MESSAGE-INTEGRITY attribute does not match the computed HMAC-SHA1. + /// This means either the password is wrong or the message was tampered with. + InvalidMessageIntegrity, + + /// The NONCE has expired. The client should retry with a fresh nonce + /// from the 438 Stale Nonce error response. + StaleNonce, + + /// The client already has an allocation on this 5-tuple, or is + /// attempting an operation that conflicts with existing allocation state + /// (RFC 5766 §6.2). + AllocationMismatch, + + /// The server has reached its per-user or global allocation quota. + /// Maps to TURN error code 486. + AllocationQuotaReached, + + /// The server cannot fulfill the request due to resource constraints + /// (e.g., no relay ports available). Maps to TURN error code 508. + InsufficientCapacity, + + /// The request lacks valid credentials, or the credentials have expired. + /// Maps to STUN error code 401. + Unauthorized, + + /// The peer address in the request is forbidden by server policy + /// (e.g., loopback or private IP filtering). Maps to TURN error code 403. + ForbiddenIp, + + /// The REQUESTED-TRANSPORT attribute specifies a transport protocol + /// that the server does not support (e.g., TCP relay when only UDP + /// is available). Maps to TURN error code 442. + UnsupportedTransport, + + /// An I/O error occurred on a socket or file operation. + IoError(String), +} + +impl fmt::Display for TurnError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TurnError::StunParseError(msg) => write!(f, "STUN parse error: {}", msg), + TurnError::InvalidMessageIntegrity => write!(f, "invalid MESSAGE-INTEGRITY"), + TurnError::StaleNonce => write!(f, "stale nonce"), + TurnError::AllocationMismatch => write!(f, "allocation mismatch"), + TurnError::AllocationQuotaReached => write!(f, "allocation quota reached"), + TurnError::InsufficientCapacity => write!(f, "insufficient capacity"), + TurnError::Unauthorized => write!(f, "unauthorized"), + TurnError::ForbiddenIp => write!(f, "forbidden IP address"), + TurnError::UnsupportedTransport => write!(f, "unsupported transport protocol"), + TurnError::IoError(msg) => write!(f, "I/O error: {}", msg), + } + } +} + +impl std::error::Error for TurnError {} + +impl From for TurnError { + fn from(err: std::io::Error) -> Self { + TurnError::IoError(err.to_string()) + } +} + +/// Maps a [`TurnError`] to the corresponding STUN/TURN error code +/// for use in error response messages. +/// +/// Error codes follow RFC 5389 §15.6 and RFC 5766 §6. +impl TurnError { + /// Returns the STUN/TURN error code and default reason phrase for this error. + pub fn to_error_code(&self) -> (u16, &'static str) { + match self { + TurnError::StunParseError(_) => (400, "Bad Request"), + TurnError::InvalidMessageIntegrity => (401, "Unauthorized"), + TurnError::StaleNonce => (438, "Stale Nonce"), + TurnError::AllocationMismatch => (437, "Allocation Mismatch"), + TurnError::AllocationQuotaReached => (486, "Allocation Quota Reached"), + TurnError::InsufficientCapacity => (508, "Insufficient Capacity"), + TurnError::Unauthorized => (401, "Unauthorized"), + TurnError::ForbiddenIp => (403, "Forbidden"), + TurnError::UnsupportedTransport => (442, "Unsupported Transport Protocol"), + TurnError::IoError(_) => (500, "Server Error"), + } + } +} diff --git a/src/turn/handler.rs b/src/turn/handler.rs new file mode 100644 index 0000000..c4d4845 --- /dev/null +++ b/src/turn/handler.rs @@ -0,0 +1,1147 @@ +// TURN message handler per RFC 5766 +// +// This is the core protocol logic that processes incoming STUN/TURN messages +// and produces response actions. It is completely I/O-free — the handler +// takes a parsed message + context and returns actions (response bytes, +// relay instructions) that the listener layer actually executes. +// +// Supported STUN/TURN methods: +// - Binding Request (RFC 5389) — NAT discovery, no auth +// - Allocate Request (RFC 5766 §6) — create relay allocation +// - Refresh Request (RFC 5766 §7) — refresh/delete allocation +// - CreatePermission Request (RFC 5766 §9) — install permissions +// - ChannelBind Request (RFC 5766 §11) — bind channel to peer +// - Send Indication (RFC 5766 §10) — relay data to peer +// - ChannelData (RFC 5766 §11.4) — compact channel relay +// +// Authentication uses the long-term credential mechanism (RFC 5389 §10.2.2) +// with time-limited credentials per draft-uberti-behave-turn-rest-00. + +use std::net::{IpAddr, SocketAddr}; +use std::sync::Arc; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use tokio::net::UdpSocket; + +use crate::turn::allocation::{ + AllocationManager, FiveTuple, TransportProtocol, +}; +use crate::turn::attributes::StunAttribute; +use crate::turn::credentials::{ + compute_long_term_key, compute_message_integrity, compute_nonce, + hmac_sha1, validate_message_integrity, validate_nonce, +}; +use crate::turn::port_pool::PortPool; +use crate::turn::stun::{ + ChannelData, Class, MessageType, Method, StunMessage, compute_fingerprint, +}; + +// --------------------------------------------------------------------------- +// SOFTWARE attribute value +// --------------------------------------------------------------------------- + +const SOFTWARE_NAME: &str = "Dusk TURN Server 0.1"; + +/// Maximum nonce age in seconds (10 minutes). +const NONCE_MAX_AGE_SECS: u64 = 600; + +// --------------------------------------------------------------------------- +// MessageContext — transport-layer info for a received message +// --------------------------------------------------------------------------- + +/// Context information about the transport connection on which a message +/// was received. Provided by the listener layer to the handler. +pub struct MessageContext { + /// The client's source address (IP + port). + pub client_addr: SocketAddr, + /// The server address the message was received on. + pub server_addr: SocketAddr, + /// The transport protocol (UDP or TCP). + pub protocol: TransportProtocol, + /// The server's public IP (for XOR-RELAYED-ADDRESS). + pub server_public_ip: IpAddr, +} + +// --------------------------------------------------------------------------- +// HandleResult — what the listener should do after handling a message +// --------------------------------------------------------------------------- + +/// The result of processing a STUN/TURN message. +/// +/// The listener layer inspects these actions and performs the actual I/O. +pub enum HandleResult { + /// Send this response back to the client. + Response(Vec), + /// Relay data to a peer via the allocation's relay socket. + RelayToPeer { + peer_addr: SocketAddr, + data: Vec, + relay_socket: Arc, + }, + /// Send ChannelData to a peer via the relay socket. + ChannelDataToPeer { + peer_addr: SocketAddr, + data: Vec, + relay_socket: Arc, + }, + /// A new allocation was created. The listener should send the response + /// to the client AND spawn a relay receiver task for the relay socket. + AllocationCreated { + /// The encoded STUN success response to send back to the client. + response: Vec, + /// The relay socket for the new allocation (used to receive peer data). + relay_socket: Arc, + /// The relay address (public IP + allocated port). + relay_addr: SocketAddr, + /// The 5-tuple identifying this allocation's client connection. + five_tuple: FiveTuple, + }, + /// No response needed (e.g., invalid message silently dropped). + None, +} + +// --------------------------------------------------------------------------- +// TurnHandler — the core message handler +// --------------------------------------------------------------------------- + +/// Processes incoming STUN/TURN messages and returns I/O actions. +/// +/// All state is accessed through `Arc`-wrapped shared references, making +/// the handler safe to share across tasks. +pub struct TurnHandler { + /// Shared allocation state. + allocations: Arc, + /// Port pool for allocating relay ports. + port_pool: Arc>, + /// Shared secret for time-limited credentials. + shared_secret: Vec, + /// Authentication realm (e.g., "duskchat.app"). + realm: String, + /// Secret used for HMAC-based nonce generation. + nonce_secret: Vec, + /// The server's public IP address for relay addresses. + server_public_ip: IpAddr, + /// Software name for the SOFTWARE attribute. + software: String, +} + +impl TurnHandler { + /// Create a new TURN message handler. + pub fn new( + allocations: Arc, + port_pool: Arc>, + shared_secret: Vec, + realm: String, + nonce_secret: Vec, + server_public_ip: IpAddr, + ) -> Self { + Self { + allocations, + port_pool, + shared_secret, + realm, + nonce_secret, + server_public_ip, + software: SOFTWARE_NAME.to_string(), + } + } + + /// Handle an incoming STUN message. + /// + /// Dispatches to the appropriate method handler based on the message + /// type. Returns a list of actions for the listener to execute. + pub async fn handle_message( + &self, + msg: &StunMessage, + ctx: &MessageContext, + ) -> Vec { + let result = match (msg.msg_type.method, msg.msg_type.class) { + (Method::Binding, Class::Request) => { + self.handle_binding_request(msg, ctx).await + } + (Method::Allocate, Class::Request) => { + self.handle_allocate_request(msg, ctx).await + } + (Method::Refresh, Class::Request) => { + self.handle_refresh_request(msg, ctx).await + } + (Method::CreatePermission, Class::Request) => { + self.handle_create_permission_request(msg, ctx).await + } + (Method::ChannelBind, Class::Request) => { + self.handle_channel_bind_request(msg, ctx).await + } + (Method::Send, Class::Indication) => { + self.handle_send_indication(msg, ctx).await + } + _ => { + log::debug!( + "ignoring unsupported message: {:?} {:?}", + msg.msg_type.method, + msg.msg_type.class + ); + HandleResult::None + } + }; + + vec![result] + } + + /// Handle incoming ChannelData (compact framing for channel bindings). + pub async fn handle_channel_data( + &self, + channel_data: &ChannelData, + ctx: &MessageContext, + ) -> Option { + let five_tuple = FiveTuple { + client_addr: ctx.client_addr, + server_addr: ctx.server_addr, + protocol: ctx.protocol, + }; + + // Look up the channel binding + let peer_addr = self + .allocations + .get_channel_binding(&five_tuple, channel_data.channel_number) + .await?; + + // Get the allocation's relay socket + let alloc_info = self.allocations.get_allocation(&five_tuple).await?; + + Some(HandleResult::ChannelDataToPeer { + peer_addr, + data: channel_data.data.clone(), + relay_socket: alloc_info.relay_socket, + }) + } + + // ----------------------------------------------------------------------- + // Binding Request (STUN, RFC 5389 §10.1) + // ----------------------------------------------------------------------- + + /// Handle a STUN Binding Request. + /// + /// Returns XOR-MAPPED-ADDRESS containing the client's reflexive address. + /// No authentication required. + async fn handle_binding_request( + &self, + msg: &StunMessage, + ctx: &MessageContext, + ) -> HandleResult { + let mut response = StunMessage::new( + MessageType::new(Method::Binding, Class::SuccessResponse), + msg.transaction_id, + ); + + response.add_attribute(StunAttribute::XorMappedAddress(ctx.client_addr)); + response.add_attribute(StunAttribute::Software(self.software.clone())); + + // Binding responses don't require MESSAGE-INTEGRITY (no auth) + // Add FINGERPRINT for demultiplexing + let encoded = response.encode_for_fingerprint(); + let fingerprint = compute_fingerprint(&encoded); + response.add_attribute(StunAttribute::Fingerprint(fingerprint)); + + HandleResult::Response(response.encode()) + } + + // ----------------------------------------------------------------------- + // Allocate Request (RFC 5766 §6) + // ----------------------------------------------------------------------- + + /// Handle a TURN Allocate Request per RFC 5766 §6. + /// + /// 1. Check no existing allocation for this 5-tuple + /// 2. Authenticate (challenge if needed) + /// 3. Validate REQUESTED-TRANSPORT (must be UDP/17) + /// 4. Check quotas + /// 5. Allocate port, bind socket + /// 6. Create allocation + /// 7. Return success with relay address and lifetime + async fn handle_allocate_request( + &self, + msg: &StunMessage, + ctx: &MessageContext, + ) -> HandleResult { + let five_tuple = FiveTuple { + client_addr: ctx.client_addr, + server_addr: ctx.server_addr, + protocol: ctx.protocol, + }; + + // §6.2 step 1: Check if allocation already exists + if self.allocations.get_allocation(&five_tuple).await.is_some() { + return self.build_error_response(msg, 437, "Allocation Mismatch"); + } + + // §6.2 step 2: Authenticate + let (username, key) = match self.authenticate_request(msg) { + Ok(creds) => creds, + Err(error_response) => return HandleResult::Response(error_response), + }; + + // §6.2 step 3: Check REQUESTED-TRANSPORT + let mut has_requested_transport = false; + for attr in &msg.attributes { + if let StunAttribute::RequestedTransport(proto) = attr { + if *proto != 17 { + // Only UDP relay is supported + return self.build_error_response( + msg, + 442, + "Unsupported Transport Protocol", + ); + } + has_requested_transport = true; + break; + } + } + + if !has_requested_transport { + return self.build_error_response(msg, 400, "Missing REQUESTED-TRANSPORT"); + } + + // Extract requested lifetime (or use default) + let requested_lifetime = msg + .attributes + .iter() + .find_map(|a| { + if let StunAttribute::Lifetime(secs) = a { + Some(Duration::from_secs(*secs as u64)) + } else { + None + } + }) + .unwrap_or(self.allocations.config().default_lifetime); + + // §6.2 step 4: Allocate a port from the pool + let port = { + let mut pool = self.port_pool.lock().await; + match pool.allocate() { + Some(p) => p, + None => { + return self.build_error_response(msg, 508, "Insufficient Capacity"); + } + } + }; + + // §6.2 step 5: Bind a UDP socket on the relay port + let bind_addr = format!("0.0.0.0:{}", port); + let relay_socket = match UdpSocket::bind(&bind_addr).await { + Ok(sock) => Arc::new(sock), + Err(e) => { + log::error!("failed to bind relay socket on port {}: {}", port, e); + // Return port to pool + let mut pool = self.port_pool.lock().await; + pool.release(port); + return self.build_error_response(msg, 508, "Insufficient Capacity"); + } + }; + + // The relay address uses the server's public IP + let relay_addr = SocketAddr::new(ctx.server_public_ip, port); + + // Generate nonce for this allocation + let now_secs = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + let nonce = compute_nonce(now_secs, &self.nonce_secret); + + // §6.2 step 6: Create the allocation + let actual_lifetime = requested_lifetime.min(self.allocations.config().max_lifetime); + let relay_socket_clone = Arc::clone(&relay_socket); + let five_tuple_clone = five_tuple.clone(); + match self + .allocations + .create_allocation( + five_tuple, + username.clone(), + self.realm.clone(), + nonce, + actual_lifetime, + relay_socket, + port, + relay_addr, + ) + .await + { + Ok(_) => { + log::info!( + "created allocation for {} → {} (port {}, lifetime {}s)", + ctx.client_addr, + relay_addr, + port, + actual_lifetime.as_secs(), + ); + } + Err(e) => { + // Return port to pool on failure + let mut pool = self.port_pool.lock().await; + pool.release(port); + let (code, reason) = e.to_error_code(); + return self.build_error_response(msg, code, reason); + } + } + + // §6.2 step 7: Build success response + let attrs = vec![ + StunAttribute::XorRelayedAddress(relay_addr), + StunAttribute::Lifetime(actual_lifetime.as_secs() as u32), + StunAttribute::XorMappedAddress(ctx.client_addr), + StunAttribute::Software(self.software.clone()), + ]; + + HandleResult::AllocationCreated { + response: self.build_success_response(msg, attrs, &key), + relay_socket: relay_socket_clone, + relay_addr, + five_tuple: five_tuple_clone, + } + } + + // ----------------------------------------------------------------------- + // Refresh Request (RFC 5766 §7) + // ----------------------------------------------------------------------- + + /// Handle a TURN Refresh Request per RFC 5766 §7. + /// + /// 1. Find existing allocation + /// 2. Authenticate + /// 3. If LIFETIME=0, delete allocation + /// 4. Otherwise refresh lifetime (clamped to max) + async fn handle_refresh_request( + &self, + msg: &StunMessage, + ctx: &MessageContext, + ) -> HandleResult { + let five_tuple = FiveTuple { + client_addr: ctx.client_addr, + server_addr: ctx.server_addr, + protocol: ctx.protocol, + }; + + // Check allocation exists + if self.allocations.get_allocation(&five_tuple).await.is_none() { + return self.build_error_response(msg, 437, "Allocation Mismatch"); + } + + // Authenticate + let (_username, key) = match self.authenticate_request(msg) { + Ok(creds) => creds, + Err(error_response) => return HandleResult::Response(error_response), + }; + + // Extract requested lifetime (default to config default) + let requested_lifetime = msg + .attributes + .iter() + .find_map(|a| { + if let StunAttribute::Lifetime(secs) = a { + Some(Duration::from_secs(*secs as u64)) + } else { + None + } + }) + .unwrap_or(self.allocations.config().default_lifetime); + + // Refresh (or delete if lifetime=0) + match self + .allocations + .refresh_allocation(&five_tuple, requested_lifetime) + .await + { + Ok(actual_lifetime) => { + if actual_lifetime.is_zero() { + // Allocation was deleted — return port to pool + if let Some(port) = self.allocations.delete_allocation(&five_tuple).await { + let mut pool = self.port_pool.lock().await; + pool.release(port); + } + log::info!("deleted allocation for {} (lifetime=0 refresh)", ctx.client_addr); + } else { + log::debug!( + "refreshed allocation for {} (lifetime={}s)", + ctx.client_addr, + actual_lifetime.as_secs() + ); + } + + let attrs = vec![ + StunAttribute::Lifetime(actual_lifetime.as_secs() as u32), + StunAttribute::Software(self.software.clone()), + ]; + + HandleResult::Response(self.build_success_response(msg, attrs, &key)) + } + Err(e) => { + let (code, reason) = e.to_error_code(); + self.build_error_response(msg, code, reason) + } + } + } + + // ----------------------------------------------------------------------- + // CreatePermission Request (RFC 5766 §9) + // ----------------------------------------------------------------------- + + /// Handle a TURN CreatePermission Request per RFC 5766 §9. + /// + /// 1. Find existing allocation + /// 2. Authenticate + /// 3. Extract XOR-PEER-ADDRESS attributes (can have multiple) + /// 4. Install/refresh permissions for each peer IP + /// 5. Return empty success + async fn handle_create_permission_request( + &self, + msg: &StunMessage, + ctx: &MessageContext, + ) -> HandleResult { + let five_tuple = FiveTuple { + client_addr: ctx.client_addr, + server_addr: ctx.server_addr, + protocol: ctx.protocol, + }; + + // Check allocation exists + if self.allocations.get_allocation(&five_tuple).await.is_none() { + return self.build_error_response(msg, 437, "Allocation Mismatch"); + } + + // Authenticate + let (_username, key) = match self.authenticate_request(msg) { + Ok(creds) => creds, + Err(error_response) => return HandleResult::Response(error_response), + }; + + // Extract all XOR-PEER-ADDRESS attributes + let peer_ips: Vec = msg + .attributes + .iter() + .filter_map(|a| { + if let StunAttribute::XorPeerAddress(addr) = a { + Some(addr.ip()) + } else { + None + } + }) + .collect(); + + if peer_ips.is_empty() { + return self.build_error_response(msg, 400, "Missing XOR-PEER-ADDRESS"); + } + + // Install permissions + match self + .allocations + .create_permission(&five_tuple, peer_ips) + .await + { + Ok(()) => { + let attrs = vec![ + StunAttribute::Software(self.software.clone()), + ]; + + HandleResult::Response(self.build_success_response(msg, attrs, &key)) + } + Err(e) => { + let (code, reason) = e.to_error_code(); + self.build_error_response(msg, code, reason) + } + } + } + + // ----------------------------------------------------------------------- + // ChannelBind Request (RFC 5766 §11) + // ----------------------------------------------------------------------- + + /// Handle a TURN ChannelBind Request per RFC 5766 §11. + /// + /// 1. Find existing allocation + /// 2. Authenticate + /// 3. Validate channel number (0x4000-0x7FFE) + /// 4. Check for conflicting bindings + /// 5. Bind channel and install permission + /// 6. Return empty success + async fn handle_channel_bind_request( + &self, + msg: &StunMessage, + ctx: &MessageContext, + ) -> HandleResult { + let five_tuple = FiveTuple { + client_addr: ctx.client_addr, + server_addr: ctx.server_addr, + protocol: ctx.protocol, + }; + + // Check allocation exists + if self.allocations.get_allocation(&five_tuple).await.is_none() { + return self.build_error_response(msg, 437, "Allocation Mismatch"); + } + + // Authenticate + let (_username, key) = match self.authenticate_request(msg) { + Ok(creds) => creds, + Err(error_response) => return HandleResult::Response(error_response), + }; + + // Extract CHANNEL-NUMBER + let channel_number = match msg.attributes.iter().find_map(|a| { + if let StunAttribute::ChannelNumber(num) = a { + Some(*num) + } else { + None + } + }) { + Some(n) => n, + None => { + return self.build_error_response(msg, 400, "Missing CHANNEL-NUMBER"); + } + }; + + // Validate range + if !(0x4000..=0x7FFE).contains(&channel_number) { + return self.build_error_response(msg, 400, "Invalid channel number"); + } + + // Extract XOR-PEER-ADDRESS + let peer_addr = match msg.attributes.iter().find_map(|a| { + if let StunAttribute::XorPeerAddress(addr) = a { + Some(*addr) + } else { + None + } + }) { + Some(addr) => addr, + None => { + return self.build_error_response(msg, 400, "Missing XOR-PEER-ADDRESS"); + } + }; + + // Bind the channel (also installs permission) + match self + .allocations + .bind_channel(&five_tuple, channel_number, peer_addr) + .await + { + Ok(()) => { + log::debug!( + "bound channel 0x{:04x} to {} for {}", + channel_number, + peer_addr, + ctx.client_addr + ); + + let attrs = vec![ + StunAttribute::Software(self.software.clone()), + ]; + + HandleResult::Response(self.build_success_response(msg, attrs, &key)) + } + Err(e) => { + let (code, reason) = e.to_error_code(); + self.build_error_response(msg, code, reason) + } + } + } + + // ----------------------------------------------------------------------- + // Send Indication (RFC 5766 §10) + // ----------------------------------------------------------------------- + + /// Handle a TURN Send Indication per RFC 5766 §10. + /// + /// 1. Find existing allocation + /// 2. Extract XOR-PEER-ADDRESS and DATA attributes + /// 3. Check permission exists for peer IP + /// 4. Return RelayToPeer action + /// + /// Indications do not generate responses (fire-and-forget). + async fn handle_send_indication( + &self, + msg: &StunMessage, + ctx: &MessageContext, + ) -> HandleResult { + let five_tuple = FiveTuple { + client_addr: ctx.client_addr, + server_addr: ctx.server_addr, + protocol: ctx.protocol, + }; + + // Get the allocation + let alloc_info = match self.allocations.get_allocation(&five_tuple).await { + Some(info) => info, + None => { + log::debug!("send indication for non-existent allocation from {}", ctx.client_addr); + return HandleResult::None; + } + }; + + // Extract XOR-PEER-ADDRESS + let peer_addr = match msg.attributes.iter().find_map(|a| { + if let StunAttribute::XorPeerAddress(addr) = a { + Some(*addr) + } else { + None + } + }) { + Some(addr) => addr, + None => { + log::debug!("send indication missing XOR-PEER-ADDRESS from {}", ctx.client_addr); + return HandleResult::None; + } + }; + + // Extract DATA + let data = match msg.attributes.iter().find_map(|a| { + if let StunAttribute::Data(d) = a { + Some(d.clone()) + } else { + None + } + }) { + Some(d) => d, + None => { + log::debug!("send indication missing DATA from {}", ctx.client_addr); + return HandleResult::None; + } + }; + + // Check permission + if !self + .allocations + .has_permission(&five_tuple, &peer_addr.ip()) + .await + { + log::debug!( + "send indication from {} to {} denied: no permission", + ctx.client_addr, + peer_addr + ); + return HandleResult::None; + } + + HandleResult::RelayToPeer { + peer_addr, + data, + relay_socket: alloc_info.relay_socket, + } + } + + // ----------------------------------------------------------------------- + // Authentication (RFC 5389 §10.2.2, long-term credentials) + // ----------------------------------------------------------------------- + + /// Authenticate a STUN request using long-term credentials. + /// + /// Returns `Ok((username, key))` on success, or `Err(encoded_error_bytes)` + /// on failure. The error bytes are a complete STUN error response ready + /// to send. + /// + /// Authentication flow: + /// 1. No MESSAGE-INTEGRITY → 401 challenge with REALM + NONCE + /// 2. Stale NONCE → 438 with fresh NONCE + /// 3. Compute key and validate HMAC + /// 4. Invalid → 401 + fn authenticate_request( + &self, + msg: &StunMessage, + ) -> Result<(String, Vec), Vec> { + // Step 1: Check for MESSAGE-INTEGRITY + let message_integrity = match msg.get_message_integrity() { + Some(mi) => mi, + None => { + // No auth at all — send challenge + return Err(self.build_challenge_response(msg)); + } + }; + + // Get USERNAME, REALM, NONCE + let username = match msg.get_username() { + Some(u) => u.to_string(), + None => { + return Err(self.encode_error_response(msg, 400, "Missing USERNAME")); + } + }; + + let _realm = match msg.get_realm() { + Some(r) => r.to_string(), + None => { + return Err(self.encode_error_response(msg, 400, "Missing REALM")); + } + }; + + let nonce = match msg.get_nonce() { + Some(n) => n.to_string(), + None => { + return Err(self.encode_error_response(msg, 400, "Missing NONCE")); + } + }; + + // Step 2: Validate nonce (check for staleness) + if validate_nonce(&nonce, &self.nonce_secret, NONCE_MAX_AGE_SECS).is_err() { + // Stale nonce — respond with 438 and a fresh nonce + let now_secs = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + let fresh_nonce = compute_nonce(now_secs, &self.nonce_secret); + + let mut response = StunMessage::new( + MessageType::new(msg.msg_type.method, Class::ErrorResponse), + msg.transaction_id, + ); + response.add_attribute(StunAttribute::ErrorCode { + code: 438, + reason: "Stale Nonce".to_string(), + }); + response.add_attribute(StunAttribute::Realm(self.realm.clone())); + response.add_attribute(StunAttribute::Nonce(fresh_nonce)); + response.add_attribute(StunAttribute::Software(self.software.clone())); + + return Err(response.encode()); + } + + // Step 3: Compute the long-term credential key + // Password = Base64(HMAC-SHA1(shared_secret, username)) + let password_hmac = hmac_sha1(&self.shared_secret, username.as_bytes()); + let password = base64_encode_simple(&password_hmac); + let key = compute_long_term_key(&username, &self.realm, &password); + + // Step 4: Validate MESSAGE-INTEGRITY + let integrity_bytes = msg.encode_for_integrity(); + if !validate_message_integrity(message_integrity, &key, &integrity_bytes) { + return Err(self.encode_error_response(msg, 401, "Unauthorized")); + } + + Ok((username, key.to_vec())) + } + + // ----------------------------------------------------------------------- + // Response building helpers + // ----------------------------------------------------------------------- + + /// Build an error response and return it as a [`HandleResult::Response`]. + fn build_error_response( + &self, + msg: &StunMessage, + code: u16, + reason: &str, + ) -> HandleResult { + HandleResult::Response(self.encode_error_response(msg, code, reason)) + } + + /// Encode a STUN error response as wire-format bytes. + /// + /// Error responses include ERROR-CODE, SOFTWARE, and FINGERPRINT + /// but not MESSAGE-INTEGRITY (since the client may not have credentials yet). + fn encode_error_response( + &self, + msg: &StunMessage, + code: u16, + reason: &str, + ) -> Vec { + let mut response = StunMessage::new( + MessageType::new(msg.msg_type.method, Class::ErrorResponse), + msg.transaction_id, + ); + + response.add_attribute(StunAttribute::ErrorCode { + code, + reason: reason.to_string(), + }); + response.add_attribute(StunAttribute::Software(self.software.clone())); + + // Add FINGERPRINT + let fp_bytes = response.encode_for_fingerprint(); + let fingerprint = compute_fingerprint(&fp_bytes); + response.add_attribute(StunAttribute::Fingerprint(fingerprint)); + + response.encode() + } + + /// Build a STUN success response with MESSAGE-INTEGRITY and FINGERPRINT. + /// + /// The `key` is the long-term credential key used for MESSAGE-INTEGRITY. + fn build_success_response( + &self, + msg: &StunMessage, + attrs: Vec, + key: &[u8], + ) -> Vec { + let mut response = StunMessage::new( + MessageType::new(msg.msg_type.method, Class::SuccessResponse), + msg.transaction_id, + ); + + for attr in attrs { + response.add_attribute(attr); + } + + // Compute MESSAGE-INTEGRITY over the message with adjusted length + let integrity_bytes = response.encode_for_integrity(); + let hmac = compute_message_integrity(key, &integrity_bytes); + response.add_attribute(StunAttribute::MessageIntegrity(hmac)); + + // Compute FINGERPRINT over the message including MESSAGE-INTEGRITY + let fp_bytes = response.encode_for_fingerprint(); + let fingerprint = compute_fingerprint(&fp_bytes); + response.add_attribute(StunAttribute::Fingerprint(fingerprint)); + + response.encode() + } + + /// Build a 401 challenge response with REALM and NONCE. + /// + /// This is sent when a client sends a request without credentials, + /// prompting them to retry with authentication. + fn build_challenge_response(&self, msg: &StunMessage) -> Vec { + let now_secs = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + let nonce = compute_nonce(now_secs, &self.nonce_secret); + + let mut response = StunMessage::new( + MessageType::new(msg.msg_type.method, Class::ErrorResponse), + msg.transaction_id, + ); + + response.add_attribute(StunAttribute::ErrorCode { + code: 401, + reason: "Unauthorized".to_string(), + }); + response.add_attribute(StunAttribute::Realm(self.realm.clone())); + response.add_attribute(StunAttribute::Nonce(nonce)); + response.add_attribute(StunAttribute::Software(self.software.clone())); + + // Add FINGERPRINT (no MESSAGE-INTEGRITY on challenge responses) + let fp_bytes = response.encode_for_fingerprint(); + let fingerprint = compute_fingerprint(&fp_bytes); + response.add_attribute(StunAttribute::Fingerprint(fingerprint)); + + response.encode() + } +} + +// --------------------------------------------------------------------------- +// Base64 helper (simple encode only, for password generation) +// --------------------------------------------------------------------------- + +/// Simple Base64 encoding using standard alphabet with padding. +/// +/// This duplicates the private function in credentials.rs to avoid +/// exposing it as a public API. In a future refactor, the base64 +/// utilities should be extracted to a shared utility module. +fn base64_encode_simple(data: &[u8]) -> String { + const ALPHABET: &[u8; 64] = + b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + let mut result = String::with_capacity((data.len() + 2) / 3 * 4); + let chunks = data.chunks(3); + + for chunk in chunks { + let b0 = chunk[0] as u32; + let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 }; + let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 }; + + let triple = (b0 << 16) | (b1 << 8) | b2; + + result.push(ALPHABET[((triple >> 18) & 0x3F) as usize] as char); + result.push(ALPHABET[((triple >> 12) & 0x3F) as usize] as char); + + if chunk.len() > 1 { + result.push(ALPHABET[((triple >> 6) & 0x3F) as usize] as char); + } else { + result.push('='); + } + + if chunk.len() > 2 { + result.push(ALPHABET[(triple & 0x3F) as usize] as char); + } else { + result.push('='); + } + } + + result +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use crate::turn::allocation::{AllocationConfig, AllocationManager}; + use crate::turn::port_pool::PortPool; + use crate::turn::stun::{Class, MessageType, Method}; + use std::net::{Ipv4Addr, SocketAddrV4}; + + fn test_handler() -> TurnHandler { + let config = AllocationConfig { + max_allocations: 100, + max_per_user: 10, + realm: "test.example.com".to_string(), + ..Default::default() + }; + let allocations = Arc::new(AllocationManager::new(config)); + let port_pool = Arc::new(tokio::sync::Mutex::new(PortPool::new(50000, 50100))); + + TurnHandler::new( + allocations, + port_pool, + b"test_shared_secret".to_vec(), + "test.example.com".to_string(), + b"test_nonce_secret".to_vec(), + IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + ) + } + + fn test_ctx() -> MessageContext { + MessageContext { + client_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 100), 12345)), + server_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 3478)), + protocol: TransportProtocol::Udp, + server_public_ip: IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)), + } + } + + #[tokio::test] + async fn test_binding_request() { + let handler = test_handler(); + let ctx = test_ctx(); + + let msg = StunMessage::new( + MessageType::new(Method::Binding, Class::Request), + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + ); + + let results = handler.handle_message(&msg, &ctx).await; + assert_eq!(results.len(), 1); + + match &results[0] { + HandleResult::Response(bytes) => { + let response = StunMessage::decode(bytes).unwrap(); + assert_eq!(response.msg_type.method, Method::Binding); + assert_eq!(response.msg_type.class, Class::SuccessResponse); + assert_eq!(response.transaction_id, msg.transaction_id); + + // Should contain XOR-MAPPED-ADDRESS + let has_xor_mapped = response.attributes.iter().any(|a| { + matches!(a, StunAttribute::XorMappedAddress(_)) + }); + assert!(has_xor_mapped, "response should contain XOR-MAPPED-ADDRESS"); + } + _ => panic!("expected Response"), + } + } + + #[tokio::test] + async fn test_allocate_request_challenge() { + let handler = test_handler(); + let ctx = test_ctx(); + + // Send Allocate without credentials → should get 401 challenge + let mut msg = StunMessage::new( + MessageType::new(Method::Allocate, Class::Request), + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + ); + msg.add_attribute(StunAttribute::RequestedTransport(17)); + + let results = handler.handle_message(&msg, &ctx).await; + assert_eq!(results.len(), 1); + + match &results[0] { + HandleResult::Response(bytes) => { + let response = StunMessage::decode(bytes).unwrap(); + assert_eq!(response.msg_type.method, Method::Allocate); + assert_eq!(response.msg_type.class, Class::ErrorResponse); + + // Should have 401 error code + let error = response.attributes.iter().find_map(|a| { + if let StunAttribute::ErrorCode { code, .. } = a { + Some(*code) + } else { + None + } + }); + assert_eq!(error, Some(401)); + + // Should have REALM and NONCE for challenge + assert!(response.get_realm().is_some(), "challenge should include REALM"); + assert!(response.get_nonce().is_some(), "challenge should include NONCE"); + } + _ => panic!("expected Response"), + } + } + + #[tokio::test] + async fn test_allocate_missing_transport() { + let handler = test_handler(); + let ctx = test_ctx(); + + // Allocate without REQUESTED-TRANSPORT and with (fake) auth + // The handler should challenge first since there's no MESSAGE-INTEGRITY + let msg = StunMessage::new( + MessageType::new(Method::Allocate, Class::Request), + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + ); + + let results = handler.handle_message(&msg, &ctx).await; + assert_eq!(results.len(), 1); + + // Without auth, should get 401 challenge first + match &results[0] { + HandleResult::Response(bytes) => { + let response = StunMessage::decode(bytes).unwrap(); + assert_eq!(response.msg_type.class, Class::ErrorResponse); + } + _ => panic!("expected Response"), + } + } + + #[tokio::test] + async fn test_send_indication_no_allocation() { + let handler = test_handler(); + let ctx = test_ctx(); + + let mut msg = StunMessage::new( + MessageType::new(Method::Send, Class::Indication), + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + ); + msg.add_attribute(StunAttribute::XorPeerAddress( + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 2), 9999)), + )); + msg.add_attribute(StunAttribute::Data(vec![1, 2, 3])); + + let results = handler.handle_message(&msg, &ctx).await; + assert_eq!(results.len(), 1); + + // No allocation → should return None (indications don't get error responses) + assert!(matches!(results[0], HandleResult::None)); + } + + #[tokio::test] + async fn test_channel_data_no_binding() { + let handler = test_handler(); + let ctx = test_ctx(); + + let channel_data = ChannelData { + channel_number: 0x4000, + data: vec![1, 2, 3, 4], + }; + + let result = handler.handle_channel_data(&channel_data, &ctx).await; + assert!(result.is_none()); + } + + #[test] + fn test_base64_encode_simple_matches() { + assert_eq!(base64_encode_simple(b""), ""); + assert_eq!(base64_encode_simple(b"f"), "Zg=="); + assert_eq!(base64_encode_simple(b"fo"), "Zm8="); + assert_eq!(base64_encode_simple(b"foo"), "Zm9v"); + assert_eq!(base64_encode_simple(b"foobar"), "Zm9vYmFy"); + } +} diff --git a/src/turn/mod.rs b/src/turn/mod.rs new file mode 100644 index 0000000..210bc14 --- /dev/null +++ b/src/turn/mod.rs @@ -0,0 +1,30 @@ +// TURN server module for Dusk relay +// +// Implements RFC 5389 (STUN) and RFC 5766 (TURN) protocol types, +// message parsing/serialization, credential management, port allocation, +// and the full server with UDP/TCP listeners. +// +// This module is organized as follows: +// - stun: STUN message format, parsing, serialization +// - attributes: STUN/TURN attribute types and encoding +// - credentials: HMAC-SHA1 time-limited credential generation/validation +// - allocation: TURN allocation state machine +// - port_pool: Relay port allocation pool +// - handler: TURN message handler (request dispatch + response building) +// - udp_listener: UDP listener task (receives datagrams, spawns relay receivers) +// - tcp_listener: TCP listener task (accepts connections, frames STUN/ChannelData) +// - server: Top-level TURN server orchestration (config, startup, handle) +// - error: Error types + +pub mod stun; +pub mod attributes; +pub mod credentials; +pub mod error; +pub mod port_pool; +pub mod allocation; +pub mod handler; +pub mod udp_listener; +pub mod tcp_listener; +pub mod server; + +pub use server::{TurnServer, TurnServerConfig, TurnServerHandle}; diff --git a/src/turn/port_pool.rs b/src/turn/port_pool.rs new file mode 100644 index 0000000..074e3cd --- /dev/null +++ b/src/turn/port_pool.rs @@ -0,0 +1,254 @@ +// Relay port allocation pool for TURN allocations +// +// Manages a pool of UDP port numbers available for TURN relay transport +// addresses. Each TURN allocation requires a dedicated relay port. +// +// The pool pre-shuffles available ports on creation to avoid predictable +// allocation patterns. Ports are allocated from the front of the queue +// and returned to the back on release. +// +// Default range: 49152-65535 (IANA dynamic/private port range) +// This gives approximately 16,383 relay ports. + +use std::collections::HashSet; + +/// Manages a pool of available relay port numbers. +/// +/// Ports are pre-shuffled on construction to avoid predictable patterns. +/// Allocation is O(1) from a Vec (pop), release is O(1) (push). +/// +/// # Example +/// +/// ``` +/// use relay::turn::port_pool::PortPool; +/// +/// let mut pool = PortPool::new(49152, 49160); +/// assert_eq!(pool.available_count(), 9); // 49152..=49160 inclusive +/// +/// let port = pool.allocate().unwrap(); +/// assert!(port >= 49152 && port <= 49160); +/// assert_eq!(pool.available_count(), 8); +/// +/// pool.release(port); +/// assert_eq!(pool.available_count(), 9); +/// ``` +#[derive(Debug, Clone)] +pub struct PortPool { + /// Ports available for allocation (shuffled on construction). + available: Vec, + /// Currently allocated ports (for tracking and preventing double-release). + allocated: HashSet, +} + +impl PortPool { + /// Create a new port pool with ports in the range `[range_start, range_end]` inclusive. + /// + /// The available ports are shuffled using a simple Fisher-Yates-like shuffle + /// seeded from the system clock. For cryptographic randomness, use the + /// `rand` crate version below. + /// + /// # Panics + /// + /// Panics if `range_start > range_end`. + pub fn new(range_start: u16, range_end: u16) -> Self { + assert!( + range_start <= range_end, + "range_start ({}) must be <= range_end ({})", + range_start, + range_end + ); + + let mut available: Vec = (range_start..=range_end).collect(); + + // Shuffle using a simple PRNG seeded from system time. + // This is sufficient for port randomization (not security-critical). + let seed = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_nanos(); + + let mut state = seed as u64; + let len = available.len(); + if len > 1 { + for i in (1..len).rev() { + // Simple xorshift64 PRNG + state ^= state << 13; + state ^= state >> 7; + state ^= state << 17; + let j = (state as usize) % (i + 1); + available.swap(i, j); + } + } + + PortPool { + available, + allocated: HashSet::new(), + } + } + + /// Allocate a port from the pool. + /// + /// Returns `None` if no ports are available. + /// Allocated ports are tracked to prevent double-allocation. + pub fn allocate(&mut self) -> Option { + let port = self.available.pop()?; + self.allocated.insert(port); + Some(port) + } + + /// Release a previously allocated port back to the pool. + /// + /// If the port was not currently allocated (double-release or unknown port), + /// this is a no-op to prevent pool corruption. + pub fn release(&mut self, port: u16) { + if self.allocated.remove(&port) { + self.available.push(port); + } + } + + /// Returns the number of ports currently available for allocation. + pub fn available_count(&self) -> usize { + self.available.len() + } + + /// Returns the number of ports currently allocated. + pub fn allocated_count(&self) -> usize { + self.allocated.len() + } + + /// Returns the total capacity of the pool (available + allocated). + pub fn total_capacity(&self) -> usize { + self.available.len() + self.allocated.len() + } + + /// Returns `true` if the given port is currently allocated. + pub fn is_allocated(&self, port: u16) -> bool { + self.allocated.contains(&port) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_new_pool_correct_count() { + let pool = PortPool::new(49152, 49160); + assert_eq!(pool.available_count(), 9); // 49152..=49160 inclusive + assert_eq!(pool.allocated_count(), 0); + assert_eq!(pool.total_capacity(), 9); + } + + #[test] + fn test_single_port_pool() { + let mut pool = PortPool::new(5000, 5000); + assert_eq!(pool.available_count(), 1); + + let port = pool.allocate().unwrap(); + assert_eq!(port, 5000); + assert_eq!(pool.available_count(), 0); + assert!(pool.is_allocated(5000)); + + assert!(pool.allocate().is_none()); + + pool.release(5000); + assert_eq!(pool.available_count(), 1); + assert!(!pool.is_allocated(5000)); + } + + #[test] + fn test_allocate_all_ports() { + let mut pool = PortPool::new(10000, 10004); + let mut allocated = Vec::new(); + + for _ in 0..5 { + let port = pool.allocate().unwrap(); + assert!((10000..=10004).contains(&port)); + allocated.push(port); + } + + assert_eq!(pool.available_count(), 0); + assert_eq!(pool.allocated_count(), 5); + assert!(pool.allocate().is_none()); + + // All ports should be unique + let unique: HashSet = allocated.iter().copied().collect(); + assert_eq!(unique.len(), 5); + } + + #[test] + fn test_release_and_reallocate() { + let mut pool = PortPool::new(8000, 8002); + + let p1 = pool.allocate().unwrap(); + let p2 = pool.allocate().unwrap(); + let p3 = pool.allocate().unwrap(); + assert!(pool.allocate().is_none()); + + pool.release(p2); + assert_eq!(pool.available_count(), 1); + + let p4 = pool.allocate().unwrap(); + assert_eq!(p4, p2); // released port gets reused + assert!(pool.allocate().is_none()); + + pool.release(p1); + pool.release(p3); + pool.release(p4); + assert_eq!(pool.available_count(), 3); + } + + #[test] + fn test_double_release_is_noop() { + let mut pool = PortPool::new(7000, 7002); + + let port = pool.allocate().unwrap(); + pool.release(port); + assert_eq!(pool.available_count(), 3); + + // Double release should not add a duplicate + pool.release(port); + assert_eq!(pool.available_count(), 3); + } + + #[test] + fn test_release_unknown_port_is_noop() { + let mut pool = PortPool::new(7000, 7002); + + // Release a port that was never allocated + pool.release(9999); + assert_eq!(pool.available_count(), 3); + assert_eq!(pool.allocated_count(), 0); + } + + #[test] + #[should_panic(expected = "range_start")] + fn test_invalid_range_panics() { + let _ = PortPool::new(100, 50); + } + + #[test] + fn test_large_pool() { + let pool = PortPool::new(49152, 65535); + assert_eq!(pool.available_count(), 16384); + assert_eq!(pool.total_capacity(), 16384); + } + + #[test] + fn test_is_allocated() { + let mut pool = PortPool::new(6000, 6005); + + let port = pool.allocate().unwrap(); + assert!(pool.is_allocated(port)); + + pool.release(port); + assert!(!pool.is_allocated(port)); + + // Never-allocated port + assert!(!pool.is_allocated(9999)); + } +} diff --git a/src/turn/server.rs b/src/turn/server.rs new file mode 100644 index 0000000..1ad7a74 --- /dev/null +++ b/src/turn/server.rs @@ -0,0 +1,377 @@ +// TURN server entry point +// +// Ties together the TurnHandler, AllocationManager, PortPool, and +// UDP/TCP listeners into a clean, configurable server. +// +// Configuration can be loaded from environment variables via +// `TurnServerConfig::from_env()` or constructed programmatically. + +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::sync::Arc; + +use super::allocation::{AllocationConfig, AllocationManager}; +use super::handler::TurnHandler; +use super::port_pool::PortPool; +use super::tcp_listener::TcpTurnListener; +use super::udp_listener::UdpTurnListener; + +// --------------------------------------------------------------------------- +// TurnServerConfig +// --------------------------------------------------------------------------- + +/// Configuration for the TURN server. +/// +/// All fields have sensible defaults. The only required field for production +/// use is `public_ip` — without it, relay addresses will use 0.0.0.0 which +/// won't work for external peers. +#[derive(Debug, Clone)] +pub struct TurnServerConfig { + /// UDP listen address (default: 0.0.0.0:3478). + pub udp_addr: SocketAddr, + /// TCP listen address (default: 0.0.0.0:3478). + pub tcp_addr: SocketAddr, + /// Public IP address for XOR-RELAYED-ADDRESS (required for production). + pub public_ip: IpAddr, + /// Shared secret for HMAC credential generation. + pub shared_secret: Vec, + /// Authentication realm (default: "duskchat.app"). + pub realm: String, + /// Relay port range start (default: 49152). + pub relay_port_start: u16, + /// Relay port range end (default: 65535). + pub relay_port_end: u16, + /// Allocation configuration (lifetimes, quotas, etc.). + pub allocation_config: AllocationConfig, +} + +impl Default for TurnServerConfig { + fn default() -> Self { + Self { + udp_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 3478), + tcp_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 3478), + public_ip: IpAddr::V4(Ipv4Addr::UNSPECIFIED), + shared_secret: Vec::new(), + realm: "duskchat.app".to_string(), + relay_port_start: 49152, + relay_port_end: 65535, + allocation_config: AllocationConfig::default(), + } + } +} + +impl TurnServerConfig { + /// Create a configuration from environment variables with sensible defaults. + /// + /// Supported environment variables: + /// - `DUSK_TURN_UDP_PORT` — UDP listen port (default: 3478) + /// - `DUSK_TURN_TCP_PORT` — TCP listen port (default: 3478) + /// - `DUSK_TURN_PUBLIC_IP` — Public IP for relay addresses (required in prod) + /// - `DUSK_TURN_SECRET` — Shared secret for credentials (auto-generated if unset) + /// - `DUSK_TURN_REALM` — Authentication realm (default: "duskchat.app") + /// - `DUSK_TURN_PORT_RANGE_START` — Relay port range start (default: 49152) + /// - `DUSK_TURN_PORT_RANGE_END` — Relay port range end (default: 65535) + /// - `DUSK_TURN_MAX_ALLOCATIONS` — Global allocation limit (default: 1000) + /// - `DUSK_TURN_MAX_PER_USER` — Per-user allocation limit (default: 10) + pub fn from_env() -> Self { + let udp_port: u16 = std::env::var("DUSK_TURN_UDP_PORT") + .ok() + .and_then(|p| p.parse().ok()) + .unwrap_or(3478); + + let tcp_port: u16 = std::env::var("DUSK_TURN_TCP_PORT") + .ok() + .and_then(|p| p.parse().ok()) + .unwrap_or(3478); + + let public_ip: IpAddr = std::env::var("DUSK_TURN_PUBLIC_IP") + .ok() + .and_then(|ip| ip.parse().ok()) + .unwrap_or_else(|| { + log::warn!( + "[TURN] DUSK_TURN_PUBLIC_IP not set — relay addresses will use 0.0.0.0. \ + Set this to the server's public IP for production use." + ); + IpAddr::V4(Ipv4Addr::UNSPECIFIED) + }); + + let shared_secret = std::env::var("DUSK_TURN_SECRET") + .ok() + .filter(|s| !s.is_empty()) + .map(|s| s.into_bytes()) + .unwrap_or_else(|| { + let secret = generate_random_bytes(32); + log::warn!( + "[TURN] DUSK_TURN_SECRET not set — using random secret. \ + This won't work across multiple relay instances." + ); + secret + }); + + let realm = std::env::var("DUSK_TURN_REALM") + .ok() + .filter(|r| !r.is_empty()) + .unwrap_or_else(|| "duskchat.app".to_string()); + + let port_range_start: u16 = std::env::var("DUSK_TURN_PORT_RANGE_START") + .ok() + .and_then(|p| p.parse().ok()) + .unwrap_or(49152); + + let port_range_end: u16 = std::env::var("DUSK_TURN_PORT_RANGE_END") + .ok() + .and_then(|p| p.parse().ok()) + .unwrap_or(65535); + + let max_allocations: usize = std::env::var("DUSK_TURN_MAX_ALLOCATIONS") + .ok() + .and_then(|n| n.parse().ok()) + .unwrap_or(1000); + + let max_per_user: usize = std::env::var("DUSK_TURN_MAX_PER_USER") + .ok() + .and_then(|n| n.parse().ok()) + .unwrap_or(10); + + let allocation_config = AllocationConfig { + max_allocations, + max_per_user, + realm: realm.clone(), + ..Default::default() + }; + + Self { + udp_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), udp_port), + tcp_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), tcp_port), + public_ip, + shared_secret, + realm, + relay_port_start: port_range_start, + relay_port_end: port_range_end, + allocation_config, + } + } + + /// Check if the TURN server is enabled via environment variable. + /// + /// Returns `false` only if `DUSK_TURN_ENABLED=false`. Defaults to `true`. + pub fn is_enabled() -> bool { + std::env::var("DUSK_TURN_ENABLED") + .map(|v| v != "false" && v != "0") + .unwrap_or(true) + } +} + +// --------------------------------------------------------------------------- +// TurnServer +// --------------------------------------------------------------------------- + +/// The TURN server. Owns the configuration and provides a [`run`](TurnServer::run) +/// method that starts all listener tasks and returns a [`TurnServerHandle`]. +pub struct TurnServer { + config: TurnServerConfig, +} + +impl TurnServer { + /// Create a new TURN server with the given configuration. + pub fn new(config: TurnServerConfig) -> Self { + Self { config } + } + + /// Start the TURN server. + /// + /// Binds UDP and TCP sockets, creates the handler and allocation manager, + /// spawns listener and cleanup tasks, and returns a handle for monitoring. + pub async fn run(self) -> Result> { + let port_pool = Arc::new(tokio::sync::Mutex::new(PortPool::new( + self.config.relay_port_start, + self.config.relay_port_end, + ))); + + let alloc_mgr = Arc::new(AllocationManager::new( + self.config.allocation_config.clone(), + )); + + // Generate nonce secret (separate from shared secret) + let nonce_secret = generate_random_bytes(32); + + let handler = Arc::new(TurnHandler::new( + Arc::clone(&alloc_mgr), + Arc::clone(&port_pool), + self.config.shared_secret.clone(), + self.config.realm.clone(), + nonce_secret, + self.config.public_ip, + )); + + // Bind UDP socket + let udp_socket = Arc::new( + tokio::net::UdpSocket::bind(self.config.udp_addr).await?, + ); + log::info!( + "[TURN] UDP listening on {}", + udp_socket.local_addr().unwrap_or(self.config.udp_addr) + ); + + let udp_listener = Arc::new(UdpTurnListener::new( + Arc::clone(&udp_socket), + Arc::clone(&handler), + Arc::clone(&alloc_mgr), + self.config.udp_addr, + self.config.public_ip, + )); + + // Spawn UDP listener + let udp_handle = tokio::spawn({ + let listener = Arc::clone(&udp_listener); + async move { + listener.run().await; + } + }); + + // Bind TCP listener + let mut tcp_listener = TcpTurnListener::bind( + self.config.tcp_addr, + Arc::clone(&handler), + Arc::clone(&alloc_mgr), + self.config.public_ip, + ) + .await?; + + // Give the TCP listener access to the UDP socket for relay receiver tasks + tcp_listener.set_udp_socket(Arc::clone(&udp_socket)); + + log::info!( + "[TURN] TCP listening on {}", + self.config.tcp_addr + ); + + // Spawn TCP listener + let tcp_handle = tokio::spawn(async move { + tcp_listener.run().await; + }); + + // Spawn cleanup task (runs every 30 seconds) + let cleanup_alloc_mgr = Arc::clone(&alloc_mgr); + let cleanup_port_pool = Arc::clone(&port_pool); + let cleanup_handle = tokio::spawn(async move { + loop { + tokio::time::sleep(std::time::Duration::from_secs(30)).await; + let freed_ports = cleanup_alloc_mgr.cleanup_expired().await; + if !freed_ports.is_empty() { + let mut pool = cleanup_port_pool.lock().await; + for port in &freed_ports { + pool.release(*port); + } + log::info!( + "[TURN] cleaned up {} expired allocations", + freed_ports.len() + ); + } + } + }); + + log::info!( + "[TURN] server started (public_ip={}, realm={}, relay_ports={}-{})", + self.config.public_ip, + self.config.realm, + self.config.relay_port_start, + self.config.relay_port_end, + ); + + Ok(TurnServerHandle { + udp_handle, + tcp_handle, + cleanup_handle, + handler, + alloc_mgr, + port_pool, + shared_secret: self.config.shared_secret, + }) + } + + /// Get the shared secret (for credential generation by the libp2p protocol). + pub fn shared_secret(&self) -> &[u8] { + &self.config.shared_secret + } +} + +// --------------------------------------------------------------------------- +// TurnServerHandle +// --------------------------------------------------------------------------- + +/// Handle to a running TURN server. +/// +/// Provides access to the server's shared state and credential generation. +/// The server runs in the background via spawned tasks; dropping the handle +/// does NOT stop the server (the tasks keep running). +pub struct TurnServerHandle { + /// UDP listener task handle. + pub udp_handle: tokio::task::JoinHandle<()>, + /// TCP listener task handle. + pub tcp_handle: tokio::task::JoinHandle<()>, + /// Cleanup task handle. + pub cleanup_handle: tokio::task::JoinHandle<()>, + /// The shared TURN message handler. + pub handler: Arc, + /// The shared allocation manager. + pub alloc_mgr: Arc, + /// The shared port pool. + pub port_pool: Arc>, + /// The shared secret (for credential generation). + shared_secret: Vec, +} + +impl TurnServerHandle { + /// Generate time-limited TURN credentials for a peer. + /// + /// Returns `(username, password)` suitable for use with WebRTC ICE servers. + /// The credentials expire after `ttl_secs` seconds. + pub fn generate_credentials(&self, peer_id: &str, ttl_secs: u64) -> (String, String) { + super::credentials::generate_credentials(peer_id, &self.shared_secret, ttl_secs) + } + + /// Get the current number of active allocations. + pub async fn allocation_count(&self) -> usize { + self.alloc_mgr.allocation_count().await + } + + /// Get the shared secret bytes (for external credential generation). + pub fn shared_secret(&self) -> &[u8] { + &self.shared_secret + } + + /// Abort all server tasks. The server will stop after current in-flight + /// operations complete. + pub fn shutdown(&self) { + self.udp_handle.abort(); + self.tcp_handle.abort(); + self.cleanup_handle.abort(); + } +} + +// --------------------------------------------------------------------------- +// Utility: random byte generation +// --------------------------------------------------------------------------- + +/// Generate `len` pseudo-random bytes using a simple xorshift64 PRNG +/// seeded from system time. +/// +/// This is NOT cryptographically secure. It's used for nonce secrets +/// and auto-generated shared secrets when none is configured. In production, +/// configure `DUSK_TURN_SECRET` explicitly. +fn generate_random_bytes(len: usize) -> Vec { + use std::time::{SystemTime, UNIX_EPOCH}; + let seed = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_nanos() as u64; + let mut state = seed; + let mut bytes = Vec::with_capacity(len); + for _ in 0..len { + state ^= state << 13; + state ^= state >> 7; + state ^= state << 17; + bytes.push(state as u8); + } + bytes +} diff --git a/src/turn/stun.rs b/src/turn/stun.rs new file mode 100644 index 0000000..8dc1bf0 --- /dev/null +++ b/src/turn/stun.rs @@ -0,0 +1,799 @@ +// STUN message parser and serializer per RFC 5389 +// +// Implements the STUN message format with support for TURN methods (RFC 5766). +// Also handles ChannelData framing for TURN channel bindings. +// +// STUN message header layout (20 bytes): +// Bytes 0-1: Message Type (method + class encoding) +// Bytes 2-3: Message Length (excludes 20-byte header) +// Bytes 4-7: Magic Cookie (0x2112A442) +// Bytes 8-19: Transaction ID (96 bits) +// +// Message Type encoding (RFC 5389 §6): +// Bits: 13 12 11 10 9 8 7 6 5 4 3 2 1 0 +// M11 M10 M9 M8 M7 C1 M6 M5 M4 C0 M3 M2 M1 M0 +// Where M0-M11 = method bits, C0-C1 = class bits. + +use crate::turn::attributes::StunAttribute; +use crate::turn::error::TurnError; + +/// STUN magic cookie value (RFC 5389 §6). +pub const MAGIC_COOKIE: u32 = 0x2112A442; + +/// STUN header size in bytes. +pub const STUN_HEADER_SIZE: usize = 20; + +/// XOR value for FINGERPRINT CRC32 (RFC 5389 §15.5). +pub const FINGERPRINT_XOR: u32 = 0x5354554e; + +// --------------------------------------------------------------------------- +// Method +// --------------------------------------------------------------------------- + +/// STUN/TURN method identifiers. +/// +/// Methods 0x0001 (Binding) are defined in RFC 5389; the rest are +/// TURN extensions from RFC 5766. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Method { + /// STUN Binding (0x0001) — NAT discovery + Binding, + /// TURN Allocate (0x0003) — create relay allocation + Allocate, + /// TURN Refresh (0x0004) — refresh allocation lifetime + Refresh, + /// TURN Send (0x0006) — send data indication to peer + Send, + /// TURN Data (0x0007) — data indication from peer + Data, + /// TURN CreatePermission (0x0008) — install relay permission + CreatePermission, + /// TURN ChannelBind (0x0009) — bind channel number to peer + ChannelBind, +} + +impl Method { + /// Returns the 12-bit method number (M0-M11). + pub fn as_u16(self) -> u16 { + match self { + Method::Binding => 0x0001, + Method::Allocate => 0x0003, + Method::Refresh => 0x0004, + Method::Send => 0x0006, + Method::Data => 0x0007, + Method::CreatePermission => 0x0008, + Method::ChannelBind => 0x0009, + } + } + + /// Parses a 12-bit method number into a [`Method`]. + pub fn from_u16(val: u16) -> Result { + match val { + 0x0001 => Ok(Method::Binding), + 0x0003 => Ok(Method::Allocate), + 0x0004 => Ok(Method::Refresh), + 0x0006 => Ok(Method::Send), + 0x0007 => Ok(Method::Data), + 0x0008 => Ok(Method::CreatePermission), + 0x0009 => Ok(Method::ChannelBind), + _ => Err(TurnError::StunParseError(format!( + "unknown STUN method: 0x{:04x}", + val + ))), + } + } +} + +// --------------------------------------------------------------------------- +// Class +// --------------------------------------------------------------------------- + +/// STUN message class (RFC 5389 §6). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Class { + /// Request (C0=0, C1=0) + Request, + /// Indication (C0=1, C1=0) + Indication, + /// Success Response (C0=0, C1=1) + SuccessResponse, + /// Error Response (C0=1, C1=1) + ErrorResponse, +} + +impl Class { + /// Returns the 2-bit class value (C0 in bit 0, C1 in bit 1). + pub fn as_u8(self) -> u8 { + match self { + Class::Request => 0b00, + Class::Indication => 0b01, + Class::SuccessResponse => 0b10, + Class::ErrorResponse => 0b11, + } + } + + /// Parses a 2-bit class value. + pub fn from_u8(val: u8) -> Result { + match val & 0x03 { + 0b00 => Ok(Class::Request), + 0b01 => Ok(Class::Indication), + 0b10 => Ok(Class::SuccessResponse), + 0b11 => Ok(Class::ErrorResponse), + _ => unreachable!(), + } + } +} + +// --------------------------------------------------------------------------- +// MessageType +// --------------------------------------------------------------------------- + +/// Combined STUN message type (method + class). +/// +/// The 14-bit message type field encodes both the method and class with +/// an interleaved bit layout per RFC 5389 §6. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct MessageType { + pub method: Method, + pub class: Class, +} + +impl MessageType { + pub fn new(method: Method, class: Class) -> Self { + Self { method, class } + } + + /// Encode method + class into the 14-bit STUN message type field. + /// + /// Bit layout of the 16-bit type field (top 2 bits always 0 for STUN): + /// ```text + /// 15 14 | 13 12 11 10 9 | 8 | 7 6 5 | 4 | 3 2 1 0 + /// 0 0 | M11 M10 M9 M8 M7 | C1 | M6 M5 M4 | C0 | M3 M2 M1 M0 + /// ``` + pub fn to_u16(self) -> u16 { + let m = self.method.as_u16(); + let c = self.class.as_u8() as u16; + + // Extract method bit groups + let m0_3 = m & 0x000F; // bits 0-3 + let m4_6 = (m >> 4) & 0x0007; // bits 4-6 + let m7_11 = (m >> 7) & 0x001F; // bits 7-11 + + // Extract class bits + let c0 = c & 0x01; + let c1 = (c >> 1) & 0x01; + + // Assemble: M0-M3 in bits 0-3, C0 in bit 4, M4-M6 in bits 5-7, + // C1 in bit 8, M7-M11 in bits 9-13 + m0_3 | (c0 << 4) | (m4_6 << 5) | (c1 << 8) | (m7_11 << 9) + } + + /// Decode the 14-bit STUN message type field into method + class. + pub fn from_u16(val: u16) -> Result { + // Top 2 bits must be 00 for STUN messages + if val & 0xC000 != 0 { + return Err(TurnError::StunParseError( + "top 2 bits of message type must be 00 for STUN".into(), + )); + } + + // Extract class bits + let c0 = (val >> 4) & 0x01; + let c1 = (val >> 8) & 0x01; + let class_bits = (c0 | (c1 << 1)) as u8; + + // Extract method bits + let m0_3 = val & 0x000F; + let m4_6 = (val >> 5) & 0x0007; + let m7_11 = (val >> 9) & 0x001F; + let method_bits = m0_3 | (m4_6 << 4) | (m7_11 << 7); + + Ok(MessageType { + method: Method::from_u16(method_bits)?, + class: Class::from_u8(class_bits)?, + }) + } +} + +// --------------------------------------------------------------------------- +// StunMessage +// --------------------------------------------------------------------------- + +/// A parsed STUN message with header fields and attributes. +/// +/// The wire format is a 20-byte header followed by zero or more TLV-encoded +/// attributes. The message length field in the header covers only the +/// attribute portion (not the 20-byte header itself). +#[derive(Debug, Clone)] +pub struct StunMessage { + pub msg_type: MessageType, + pub transaction_id: [u8; 12], + pub attributes: Vec, +} + +impl StunMessage { + /// Create a new STUN message with the given type and transaction ID. + pub fn new(msg_type: MessageType, transaction_id: [u8; 12]) -> Self { + Self { + msg_type, + transaction_id, + attributes: Vec::new(), + } + } + + /// Create a new STUN message with a random transaction ID. + pub fn new_random(msg_type: MessageType) -> Self { + let mut transaction_id = [0u8; 12]; + // Use simple random fill; callers with `rand` can use proper RNG + #[cfg(feature = "rand")] + { + use rand::RngCore; + rand::thread_rng().fill_bytes(&mut transaction_id); + } + #[cfg(not(feature = "rand"))] + { + // Fallback: use std time-based entropy (not cryptographically secure, + // but sufficient for transaction IDs in development/testing) + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default(); + let seed = now.as_nanos(); + for (i, byte) in transaction_id.iter_mut().enumerate() { + *byte = ((seed >> (i * 5)) & 0xFF) as u8; + } + } + Self { + msg_type, + transaction_id, + attributes: Vec::new(), + } + } + + /// Add an attribute to the message. + pub fn add_attribute(&mut self, attr: StunAttribute) { + self.attributes.push(attr); + } + + /// Decode a STUN message from raw bytes. + /// + /// Validates the magic cookie and parses the header and all TLV attributes. + /// Returns an error if the message is truncated, has an invalid cookie, + /// or contains malformed attributes. + pub fn decode(bytes: &[u8]) -> Result { + if bytes.len() < STUN_HEADER_SIZE { + return Err(TurnError::StunParseError(format!( + "message too short: {} bytes, need at least {}", + bytes.len(), + STUN_HEADER_SIZE + ))); + } + + // Parse message type (first 2 bytes) + let type_val = u16::from_be_bytes([bytes[0], bytes[1]]); + let msg_type = MessageType::from_u16(type_val)?; + + // Parse message length (bytes 2-3) — length of attributes only + let msg_length = u16::from_be_bytes([bytes[2], bytes[3]]) as usize; + + // Validate magic cookie (bytes 4-7) + let cookie = u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]); + if cookie != MAGIC_COOKIE { + return Err(TurnError::StunParseError(format!( + "invalid magic cookie: 0x{:08x}, expected 0x{:08x}", + cookie, MAGIC_COOKIE + ))); + } + + // Extract transaction ID (bytes 8-19) + let mut transaction_id = [0u8; 12]; + transaction_id.copy_from_slice(&bytes[8..20]); + + // Validate total length + let total_expected = STUN_HEADER_SIZE + msg_length; + if bytes.len() < total_expected { + return Err(TurnError::StunParseError(format!( + "message truncated: have {} bytes, header says {}", + bytes.len(), + total_expected + ))); + } + + // Parse attributes from the body + let attr_bytes = &bytes[STUN_HEADER_SIZE..total_expected]; + let attributes = Self::decode_attributes(attr_bytes, &transaction_id)?; + + Ok(StunMessage { + msg_type, + transaction_id, + attributes, + }) + } + + /// Parse the TLV attribute list from the message body. + fn decode_attributes( + mut data: &[u8], + transaction_id: &[u8; 12], + ) -> Result, TurnError> { + let mut attributes = Vec::new(); + + while data.len() >= 4 { + // Each attribute: type (2 bytes) + length (2 bytes) + value + padding + let attr_type = u16::from_be_bytes([data[0], data[1]]); + let attr_len = u16::from_be_bytes([data[2], data[3]]) as usize; + + if data.len() < 4 + attr_len { + return Err(TurnError::StunParseError(format!( + "attribute 0x{:04x} truncated: need {} bytes, have {}", + attr_type, + attr_len, + data.len() - 4 + ))); + } + + let attr_value = &data[4..4 + attr_len]; + let attr = crate::turn::attributes::decode_attribute( + attr_type, + attr_value, + transaction_id, + )?; + attributes.push(attr); + + // Advance past value + padding to 4-byte boundary + let padded_len = (attr_len + 3) & !3; + let total_attr_size = 4 + padded_len; + if total_attr_size > data.len() { + break; + } + data = &data[total_attr_size..]; + } + + Ok(attributes) + } + + /// Encode this STUN message into wire format bytes. + /// + /// The message length field is computed from the encoded attributes. + /// Attributes are TLV-encoded with 4-byte padding. + pub fn encode(&self) -> Vec { + // Encode all attributes first to determine total length + let mut attr_bytes = Vec::new(); + for attr in &self.attributes { + let encoded = crate::turn::attributes::encode_attribute(attr, &self.transaction_id); + attr_bytes.extend_from_slice(&encoded); + } + + let msg_length = attr_bytes.len() as u16; + + // Build the 20-byte header + let mut buf = Vec::with_capacity(STUN_HEADER_SIZE + attr_bytes.len()); + + // Message type (2 bytes) + buf.extend_from_slice(&self.msg_type.to_u16().to_be_bytes()); + // Message length (2 bytes) + buf.extend_from_slice(&msg_length.to_be_bytes()); + // Magic cookie (4 bytes) + buf.extend_from_slice(&MAGIC_COOKIE.to_be_bytes()); + // Transaction ID (12 bytes) + buf.extend_from_slice(&self.transaction_id); + // Attributes + buf.extend_from_slice(&attr_bytes); + + buf + } + + /// Encode the message, computing the correct length field as if a + /// MESSAGE-INTEGRITY attribute of 24 bytes (4-byte TLV header + 20-byte HMAC) + /// will be appended. This is used to generate the bytes over which + /// MESSAGE-INTEGRITY is computed per RFC 5389 §15.4. + /// + /// The returned bytes do NOT include the MESSAGE-INTEGRITY attribute itself. + pub fn encode_for_integrity(&self) -> Vec { + // Encode attributes up to (but not including) MESSAGE-INTEGRITY + let mut attr_bytes = Vec::new(); + for attr in &self.attributes { + if matches!(attr, StunAttribute::MessageIntegrity(_)) { + break; + } + if matches!(attr, StunAttribute::Fingerprint(_)) { + break; + } + let encoded = crate::turn::attributes::encode_attribute(attr, &self.transaction_id); + attr_bytes.extend_from_slice(&encoded); + } + + // The length field must include the MESSAGE-INTEGRITY attribute that will follow + // (4 bytes TLV header + 20 bytes HMAC = 24 bytes) + let msg_length = (attr_bytes.len() + 24) as u16; + + let mut buf = Vec::with_capacity(STUN_HEADER_SIZE + attr_bytes.len()); + buf.extend_from_slice(&self.msg_type.to_u16().to_be_bytes()); + buf.extend_from_slice(&msg_length.to_be_bytes()); + buf.extend_from_slice(&MAGIC_COOKIE.to_be_bytes()); + buf.extend_from_slice(&self.transaction_id); + buf.extend_from_slice(&attr_bytes); + + buf + } + + /// Encode the message for FINGERPRINT computation per RFC 5389 §15.5. + /// + /// Returns all bytes up to (but not including) the FINGERPRINT attribute, + /// with the length field adjusted to include the FINGERPRINT (8 bytes). + pub fn encode_for_fingerprint(&self) -> Vec { + // Encode all attributes except FINGERPRINT + let mut attr_bytes = Vec::new(); + for attr in &self.attributes { + if matches!(attr, StunAttribute::Fingerprint(_)) { + break; + } + let encoded = crate::turn::attributes::encode_attribute(attr, &self.transaction_id); + attr_bytes.extend_from_slice(&encoded); + } + + // Length includes the FINGERPRINT attribute (4 header + 4 value = 8 bytes) + let msg_length = (attr_bytes.len() + 8) as u16; + + let mut buf = Vec::with_capacity(STUN_HEADER_SIZE + attr_bytes.len()); + buf.extend_from_slice(&self.msg_type.to_u16().to_be_bytes()); + buf.extend_from_slice(&msg_length.to_be_bytes()); + buf.extend_from_slice(&MAGIC_COOKIE.to_be_bytes()); + buf.extend_from_slice(&self.transaction_id); + buf.extend_from_slice(&attr_bytes); + + buf + } + + /// Find the first attribute of a given type in the message. + pub fn get_attribute(&self, predicate: impl Fn(&StunAttribute) -> bool) -> Option<&StunAttribute> { + self.attributes.iter().find(|a| predicate(a)) + } + + /// Get the USERNAME attribute value, if present. + pub fn get_username(&self) -> Option<&str> { + for attr in &self.attributes { + if let StunAttribute::Username(ref u) = attr { + return Some(u.as_str()); + } + } + None + } + + /// Get the REALM attribute value, if present. + pub fn get_realm(&self) -> Option<&str> { + for attr in &self.attributes { + if let StunAttribute::Realm(ref r) = attr { + return Some(r.as_str()); + } + } + None + } + + /// Get the NONCE attribute value, if present. + pub fn get_nonce(&self) -> Option<&str> { + for attr in &self.attributes { + if let StunAttribute::Nonce(ref n) = attr { + return Some(n.as_str()); + } + } + None + } + + /// Get the MESSAGE-INTEGRITY attribute value, if present. + pub fn get_message_integrity(&self) -> Option<&[u8; 20]> { + for attr in &self.attributes { + if let StunAttribute::MessageIntegrity(ref mi) = attr { + return Some(mi); + } + } + None + } +} + +// --------------------------------------------------------------------------- +// ChannelData +// --------------------------------------------------------------------------- + +/// ChannelData message for TURN channel bindings (RFC 5766 §11.4). +/// +/// ChannelData uses a compact 4-byte header instead of the STUN format: +/// Bytes 0-1: Channel Number (0x4000-0x7FFF) +/// Bytes 2-3: Data Length +/// Bytes 4+: Application Data (padded to 4 bytes over UDP) +#[derive(Debug, Clone)] +pub struct ChannelData { + pub channel_number: u16, + pub data: Vec, +} + +impl ChannelData { + /// Decode a ChannelData message from raw bytes. + pub fn decode(bytes: &[u8]) -> Result { + if bytes.len() < 4 { + return Err(TurnError::StunParseError( + "ChannelData message too short".into(), + )); + } + + let channel_number = u16::from_be_bytes([bytes[0], bytes[1]]); + let data_length = u16::from_be_bytes([bytes[2], bytes[3]]) as usize; + + // Channel numbers must be in range 0x4000-0x7FFF + if !(0x4000..=0x7FFF).contains(&channel_number) { + return Err(TurnError::StunParseError(format!( + "invalid channel number: 0x{:04x}, must be in 0x4000-0x7FFF", + channel_number + ))); + } + + if bytes.len() < 4 + data_length { + return Err(TurnError::StunParseError(format!( + "ChannelData truncated: need {} data bytes, have {}", + data_length, + bytes.len() - 4 + ))); + } + + let data = bytes[4..4 + data_length].to_vec(); + + Ok(ChannelData { + channel_number, + data, + }) + } + + /// Encode this ChannelData message into wire format. + /// + /// The data portion is padded to a 4-byte boundary (for UDP transport). + pub fn encode(&self) -> Vec { + let data_len = self.data.len(); + let padded_len = (data_len + 3) & !3; + let mut buf = Vec::with_capacity(4 + padded_len); + + buf.extend_from_slice(&self.channel_number.to_be_bytes()); + buf.extend_from_slice(&(data_len as u16).to_be_bytes()); + buf.extend_from_slice(&self.data); + + // Pad to 4-byte boundary with zeros + let padding = padded_len - data_len; + for _ in 0..padding { + buf.push(0); + } + + buf + } +} + +// --------------------------------------------------------------------------- +// Detection helpers +// --------------------------------------------------------------------------- + +/// Returns `true` if the buffer begins with a ChannelData message +/// (first two bits are NOT `00`). +/// +/// STUN messages always have the first two bits as `00` (message type field). +/// ChannelData starts with channel numbers 0x4000-0x7FFF, which have the +/// first two bits as `01`. +pub fn is_channel_data(bytes: &[u8]) -> bool { + if bytes.is_empty() { + return false; + } + // First two bits: 00 = STUN, 01 = ChannelData + (bytes[0] & 0xC0) != 0x00 +} + +/// Returns `true` if the buffer looks like a valid STUN message +/// (first two bits are `00` and magic cookie is present). +pub fn is_stun_message(bytes: &[u8]) -> bool { + if bytes.len() < STUN_HEADER_SIZE { + return false; + } + // First two bits must be 00 + if (bytes[0] & 0xC0) != 0x00 { + return false; + } + // Check magic cookie + let cookie = u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]); + cookie == MAGIC_COOKIE +} + +/// Compute the CRC32 fingerprint value for a STUN message. +/// +/// The fingerprint is `CRC32(message_bytes) XOR 0x5354554e` per RFC 5389 §15.5. +/// `message_bytes` should be the entire message up to (but not including) +/// the FINGERPRINT attribute, with the length field adjusted to include it. +pub fn compute_fingerprint(message_bytes: &[u8]) -> u32 { + let crc = crc32_compute(message_bytes); + crc ^ FINGERPRINT_XOR +} + +/// Simple CRC32 (ISO 3309 / ITU-T V.42) implementation. +/// +/// This uses the standard polynomial 0xEDB88320 (reflected form). +/// In production, the `crc32fast` crate should be used instead for +/// SIMD-accelerated performance. +fn crc32_compute(data: &[u8]) -> u32 { + let mut crc: u32 = 0xFFFFFFFF; + for &byte in data { + crc ^= byte as u32; + for _ in 0..8 { + if crc & 1 != 0 { + crc = (crc >> 1) ^ 0xEDB88320; + } else { + crc >>= 1; + } + } + } + !crc +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_message_type_encoding_binding_request() { + let mt = MessageType::new(Method::Binding, Class::Request); + let encoded = mt.to_u16(); + assert_eq!(encoded, 0x0001); + let decoded = MessageType::from_u16(encoded).unwrap(); + assert_eq!(decoded.method, Method::Binding); + assert_eq!(decoded.class, Class::Request); + } + + #[test] + fn test_message_type_encoding_binding_success() { + let mt = MessageType::new(Method::Binding, Class::SuccessResponse); + let encoded = mt.to_u16(); + assert_eq!(encoded, 0x0101); + let decoded = MessageType::from_u16(encoded).unwrap(); + assert_eq!(decoded.method, Method::Binding); + assert_eq!(decoded.class, Class::SuccessResponse); + } + + #[test] + fn test_message_type_encoding_allocate_request() { + let mt = MessageType::new(Method::Allocate, Class::Request); + let encoded = mt.to_u16(); + assert_eq!(encoded, 0x0003); + let decoded = MessageType::from_u16(encoded).unwrap(); + assert_eq!(decoded.method, Method::Allocate); + assert_eq!(decoded.class, Class::Request); + } + + #[test] + fn test_message_type_encoding_allocate_error() { + let mt = MessageType::new(Method::Allocate, Class::ErrorResponse); + let encoded = mt.to_u16(); + assert_eq!(encoded, 0x0113); + let decoded = MessageType::from_u16(encoded).unwrap(); + assert_eq!(decoded.method, Method::Allocate); + assert_eq!(decoded.class, Class::ErrorResponse); + } + + #[test] + fn test_message_type_roundtrip_all_methods() { + let methods = [ + Method::Binding, + Method::Allocate, + Method::Refresh, + Method::Send, + Method::Data, + Method::CreatePermission, + Method::ChannelBind, + ]; + let classes = [ + Class::Request, + Class::Indication, + Class::SuccessResponse, + Class::ErrorResponse, + ]; + + for method in &methods { + for class in &classes { + let mt = MessageType::new(*method, *class); + let encoded = mt.to_u16(); + let decoded = MessageType::from_u16(encoded).unwrap(); + assert_eq!(decoded.method, *method, "method mismatch for {:?}/{:?}", method, class); + assert_eq!(decoded.class, *class, "class mismatch for {:?}/{:?}", method, class); + } + } + } + + #[test] + fn test_is_channel_data() { + // STUN message: first byte has top 2 bits = 00 + assert!(!is_channel_data(&[0x00, 0x01])); + // ChannelData: channel 0x4000 -> first byte = 0x40 -> top 2 bits = 01 + assert!(is_channel_data(&[0x40, 0x00])); + // Empty + assert!(!is_channel_data(&[])); + } + + #[test] + fn test_is_stun_message() { + // Valid STUN header: type=0x0001, len=0, cookie=0x2112A442, txn_id=zeros + let mut buf = [0u8; 20]; + buf[0] = 0x00; + buf[1] = 0x01; + // length = 0 + buf[4] = 0x21; + buf[5] = 0x12; + buf[6] = 0xA4; + buf[7] = 0x42; + assert!(is_stun_message(&buf)); + + // Wrong cookie + buf[4] = 0x00; + assert!(!is_stun_message(&buf)); + } + + #[test] + fn test_channel_data_roundtrip() { + let cd = ChannelData { + channel_number: 0x4001, + data: vec![1, 2, 3, 4, 5], + }; + let encoded = cd.encode(); + let decoded = ChannelData::decode(&encoded).unwrap(); + assert_eq!(decoded.channel_number, 0x4001); + assert_eq!(decoded.data, vec![1, 2, 3, 4, 5]); + } + + #[test] + fn test_channel_data_padding() { + let cd = ChannelData { + channel_number: 0x4001, + data: vec![1, 2, 3], // 3 bytes -> padded to 4 + }; + let encoded = cd.encode(); + // 4 header + 4 padded data = 8 bytes + assert_eq!(encoded.len(), 8); + assert_eq!(encoded[7], 0); // padding byte + } + + #[test] + fn test_stun_message_encode_decode_empty() { + let msg = StunMessage { + msg_type: MessageType::new(Method::Binding, Class::Request), + transaction_id: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + attributes: vec![], + }; + let encoded = msg.encode(); + assert_eq!(encoded.len(), STUN_HEADER_SIZE); + + let decoded = StunMessage::decode(&encoded).unwrap(); + assert_eq!(decoded.msg_type.method, Method::Binding); + assert_eq!(decoded.msg_type.class, Class::Request); + assert_eq!(decoded.transaction_id, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); + assert!(decoded.attributes.is_empty()); + } + + #[test] + fn test_crc32_known_value() { + // CRC32 of "123456789" is 0xCBF43926 + let data = b"123456789"; + let crc = crc32_compute(data); + assert_eq!(crc, 0xCBF43926); + } + + #[test] + fn test_decode_too_short() { + let result = StunMessage::decode(&[0; 10]); + assert!(result.is_err()); + } + + #[test] + fn test_decode_bad_cookie() { + let mut buf = [0u8; 20]; + buf[0] = 0x00; + buf[1] = 0x01; + // Bad cookie + buf[4] = 0xFF; + let result = StunMessage::decode(&buf); + assert!(result.is_err()); + } +} diff --git a/src/turn/tcp_listener.rs b/src/turn/tcp_listener.rs new file mode 100644 index 0000000..9e40ba5 --- /dev/null +++ b/src/turn/tcp_listener.rs @@ -0,0 +1,329 @@ +// TCP TURN listener +// +// Accepts TCP connections from TURN clients and frames STUN/ChannelData +// messages over the stream. Per RFC 5766 §2.1, when TURN is used over +// TCP, the client connects to the TURN server via TCP but the relay +// still uses UDP to communicate with peers. +// +// TCP framing: STUN messages and ChannelData are self-delimiting on +// a TCP stream. The reader peeks at the first two bytes to determine +// the message type: +// - If the first two bits are 00 → STUN message. Read 20-byte header, +// extract message length from bytes 2-3, then read the attribute body. +// - Otherwise → ChannelData. The first 2 bytes are the channel number, +// next 2 bytes are the data length, then read the data (padded to 4 bytes). + +use std::net::SocketAddr; +use std::sync::Arc; + +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; + +use super::allocation::{AllocationManager, TransportProtocol}; +use super::handler::{HandleResult, MessageContext, TurnHandler}; +use super::stun::{is_channel_data, ChannelData, StunMessage}; + +// --------------------------------------------------------------------------- +// TcpTurnListener +// --------------------------------------------------------------------------- + +/// TCP listener for the TURN server. +/// +/// Accepts TCP connections from clients and processes STUN/TURN messages +/// framed on the TCP stream. Each accepted connection is handled in a +/// separate spawned task. +pub struct TcpTurnListener { + /// The bound TCP listener. + listener: TcpListener, + /// The shared TURN message handler. + handler: Arc, + /// The shared allocation manager (for relay receiver spawning). + allocations: Arc, + /// The server's listen address. + server_addr: SocketAddr, + /// The server's public IP (for MessageContext). + server_public_ip: std::net::IpAddr, + /// The primary UDP socket (for spawning relay receiver tasks). + /// TCP clients still use UDP relay sockets for the peer-facing side. + udp_socket: Option>, +} + +impl TcpTurnListener { + /// Bind a TCP listener on the given address. + pub async fn bind( + addr: SocketAddr, + handler: Arc, + allocations: Arc, + server_public_ip: std::net::IpAddr, + ) -> std::io::Result { + let listener = TcpListener::bind(addr).await?; + let server_addr = listener.local_addr()?; + Ok(Self { + listener, + handler, + allocations, + server_addr, + server_public_ip, + udp_socket: None, + }) + } + + /// Set the primary UDP socket (used for relay receiver tasks spawned + /// from TCP-originated allocations). + pub fn set_udp_socket(&mut self, socket: Arc) { + self.udp_socket = Some(socket); + } + + /// Run the TCP listener loop. + /// + /// Accepts connections in a loop and spawns a handler task for each one. + pub async fn run(self) { + let handler = self.handler; + let server_addr = self.server_addr; + let server_public_ip = self.server_public_ip; + let allocations = self.allocations; + let udp_socket = self.udp_socket; + + loop { + match self.listener.accept().await { + Ok((stream, client_addr)) => { + let handler = Arc::clone(&handler); + let allocations = Arc::clone(&allocations); + let udp_socket = udp_socket.clone(); + tokio::spawn(async move { + if let Err(e) = handle_tcp_client( + stream, + client_addr, + server_addr, + server_public_ip, + handler, + allocations, + udp_socket, + ) + .await + { + log::debug!("[TURN-TCP] client {} disconnected: {}", client_addr, e); + } + }); + } + Err(e) => { + log::error!("[TURN-TCP] accept error: {}", e); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + } + } + } +} + +// --------------------------------------------------------------------------- +// Per-client TCP handler +// --------------------------------------------------------------------------- + +/// Handle a single TCP client connection. +/// +/// Reads STUN/ChannelData messages from the TCP stream in a loop, dispatches +/// them to the handler, and writes responses back. +async fn handle_tcp_client( + stream: TcpStream, + client_addr: SocketAddr, + server_addr: SocketAddr, + server_public_ip: std::net::IpAddr, + handler: Arc, + allocations: Arc, + udp_socket: Option>, +) -> Result<(), Box> { + let ctx = MessageContext { + client_addr, + server_addr, + protocol: TransportProtocol::Tcp, + server_public_ip, + }; + + let (mut reader, mut writer) = stream.into_split(); + let mut header_buf = [0u8; 4]; + + loop { + // Read the first 2 bytes to determine message type + match reader.read_exact(&mut header_buf[..2]).await { + Ok(_) => {} + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + // Clean disconnect + return Ok(()); + } + Err(e) => return Err(e.into()), + } + + if is_channel_data(&header_buf[..2]) { + // ChannelData: first 2 bytes = channel number, next 2 = data length + reader.read_exact(&mut header_buf[2..4]).await?; + let data_len = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize; + + // TCP pads ChannelData to 4-byte boundary + let padded_len = (data_len + 3) & !3; + let mut data_buf = vec![0u8; padded_len]; + reader.read_exact(&mut data_buf[..padded_len]).await?; + + // Reconstruct the full ChannelData for decoding (header + unpadded data) + let mut full_msg = Vec::with_capacity(4 + data_len); + full_msg.extend_from_slice(&header_buf); + full_msg.extend_from_slice(&data_buf[..data_len]); + + match ChannelData::decode(&full_msg) { + Ok(channel_data) => { + if let Some(result) = + handler.handle_channel_data(&channel_data, &ctx).await + { + execute_tcp_result(&mut writer, result, client_addr, &allocations, &udp_socket) + .await; + } + } + Err(e) => { + log::debug!( + "[TURN-TCP] failed to parse ChannelData from {}: {}", + client_addr, + e + ); + } + } + } else { + // STUN message: first 2 bytes are message type, next 2 are message length + // We need to read the remaining 18 bytes of the 20-byte header + let mut stun_header = [0u8; 20]; + stun_header[0] = header_buf[0]; + stun_header[1] = header_buf[1]; + reader.read_exact(&mut stun_header[2..20]).await?; + + let msg_len = u16::from_be_bytes([stun_header[2], stun_header[3]]) as usize; + + // Read the attribute body + let mut msg_buf = Vec::with_capacity(20 + msg_len); + msg_buf.extend_from_slice(&stun_header); + if msg_len > 0 { + let mut body = vec![0u8; msg_len]; + reader.read_exact(&mut body).await?; + msg_buf.extend_from_slice(&body); + } + + match StunMessage::decode(&msg_buf) { + Ok(msg) => { + let results = handler.handle_message(&msg, &ctx).await; + for result in results { + execute_tcp_result(&mut writer, result, client_addr, &allocations, &udp_socket) + .await; + } + } + Err(e) => { + log::debug!( + "[TURN-TCP] failed to parse STUN message from {}: {}", + client_addr, + e + ); + } + } + } + } +} + +/// Execute a single [`HandleResult`] for a TCP client. +/// +/// Responses are written directly to the TCP stream. Relay operations +/// use the UDP relay socket (relay is always UDP, even for TCP clients). +async fn execute_tcp_result( + writer: &mut tokio::net::tcp::OwnedWriteHalf, + result: HandleResult, + client_addr: SocketAddr, + allocations: &Arc, + udp_socket: &Option>, +) { + match result { + HandleResult::Response(data) => { + if let Err(e) = writer.write_all(&data).await { + log::error!( + "[TURN-TCP] failed to write response to {}: {}", + client_addr, + e + ); + } + } + HandleResult::RelayToPeer { + peer_addr, + data, + relay_socket, + } => { + // Relay is always UDP even for TCP clients + if let Err(e) = relay_socket.send_to(&data, peer_addr).await { + log::error!( + "[TURN-TCP] failed to relay to peer {}: {}", + peer_addr, + e + ); + } + } + HandleResult::ChannelDataToPeer { + peer_addr, + data, + relay_socket, + } => { + if let Err(e) = relay_socket.send_to(&data, peer_addr).await { + log::error!( + "[TURN-TCP] failed to send channel data to peer {}: {}", + peer_addr, + e + ); + } + } + HandleResult::AllocationCreated { + response, + relay_socket, + relay_addr, + five_tuple, + } => { + // Send the success response to the TCP client + if let Err(e) = writer.write_all(&response).await { + log::error!( + "[TURN-TCP] failed to write allocate response to {}: {}", + client_addr, + e + ); + } + + // For TCP clients, we still need a relay receiver task to handle + // peer → relay socket → client. However, for TCP the data needs to + // go back over the TCP stream. For simplicity, if a UDP socket is + // available we use the UDP-based relay receiver. In a full + // implementation, the relay receiver would write to the TCP stream. + // + // NOTE: In the current architecture, TCP allocations still relay + // data via UDP. Peer data arriving on the relay socket will be + // forwarded to the client's address via the main UDP socket (if + // the client also has a UDP path). For pure TCP clients, a more + // sophisticated approach would be needed. + if let Some(main_udp) = udp_socket { + let main_socket = Arc::clone(main_udp); + let allocs = Arc::clone(allocations); + let ft = five_tuple.clone(); + log::debug!( + "[TURN-TCP] spawning relay receiver for {} (relay {})", + client_addr, + relay_addr + ); + tokio::spawn(async move { + super::udp_listener::relay_receiver_task( + relay_socket, + relay_addr, + ft, + main_socket, + allocs, + ) + .await; + }); + } else { + log::warn!( + "[TURN-TCP] no UDP socket available for relay receiver (allocation {})", + relay_addr + ); + } + } + HandleResult::None => {} + } +} diff --git a/src/turn/udp_listener.rs b/src/turn/udp_listener.rs new file mode 100644 index 0000000..ad94794 --- /dev/null +++ b/src/turn/udp_listener.rs @@ -0,0 +1,339 @@ +// UDP TURN listener +// +// Listens on a UDP socket for incoming STUN/TURN messages and ChannelData. +// Dispatches to the TurnHandler for processing and executes the resulting +// I/O actions (send responses, relay data to peers). +// +// Also manages relay receiver tasks — when a new allocation is created, +// a task is spawned to read data arriving on the allocation's relay socket +// from peers. That data is wrapped as either ChannelData (if a channel +// binding exists) or a STUN Data indication, and sent back to the client +// via the main UDP socket. + +use std::net::SocketAddr; +use std::sync::Arc; + +use tokio::net::UdpSocket; + +use super::allocation::{AllocationManager, FiveTuple, TransportProtocol}; +use super::attributes::StunAttribute; +use super::handler::{HandleResult, MessageContext, TurnHandler}; +use super::stun::{ + is_channel_data, is_stun_message, ChannelData, Class, MessageType, Method, StunMessage, + compute_fingerprint, +}; + +// --------------------------------------------------------------------------- +// UdpTurnListener +// --------------------------------------------------------------------------- + +/// UDP listener for the TURN server. +/// +/// Reads datagrams from the primary UDP socket, dispatches them to the +/// [`TurnHandler`], and executes I/O actions. Also spawns relay receiver +/// tasks for each new allocation. +pub struct UdpTurnListener { + /// The primary UDP socket bound to the TURN server port (e.g., 3478). + socket: Arc, + /// The shared TURN message handler. + handler: Arc, + /// The allocation manager (needed for relay receiver lookups). + allocations: Arc, + /// The server's listen address. + server_addr: SocketAddr, + /// The server's public IP (for MessageContext). + server_public_ip: std::net::IpAddr, +} + +impl UdpTurnListener { + /// Create a new UDP TURN listener. + /// + /// - `socket`: The primary UDP socket (already bound to the listen address). + /// - `handler`: The shared TURN message handler. + /// - `allocations`: The shared allocation manager. + /// - `server_addr`: The server's listen address. + /// - `server_public_ip`: The server's public IP for relay addresses. + pub fn new( + socket: Arc, + handler: Arc, + allocations: Arc, + server_addr: SocketAddr, + server_public_ip: std::net::IpAddr, + ) -> Self { + Self { + socket, + handler, + allocations, + server_addr, + server_public_ip, + } + } + + /// Run the UDP listener loop. + /// + /// This reads datagrams from the primary UDP socket in a loop. Each + /// incoming message is dispatched to a spawned task for processing, + /// so the receive loop is never blocked by handler logic. + /// + /// When the handler returns an [`HandleResult::AllocationCreated`], this + /// also spawns a relay receiver task for the new allocation's relay socket. + pub async fn run(self: Arc) { + let mut buf = vec![0u8; 65536]; // 64KB max UDP datagram + + loop { + match self.socket.recv_from(&mut buf).await { + Ok((len, client_addr)) => { + let data = buf[..len].to_vec(); + let this = Arc::clone(&self); + // Spawn a task per message to avoid blocking the receive loop + tokio::spawn(async move { + this.process_message(&data, client_addr).await; + }); + } + Err(e) => { + log::error!("[TURN-UDP] recv error: {}", e); + // Brief sleep to avoid tight error loop + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + } + } + } + + /// Process a single incoming UDP message. + /// + /// Determines whether the message is ChannelData or a STUN message, + /// dispatches to the handler, and executes the resulting actions. + async fn process_message(&self, data: &[u8], client_addr: SocketAddr) { + let ctx = MessageContext { + client_addr, + server_addr: self.server_addr, + protocol: TransportProtocol::Udp, + server_public_ip: self.server_public_ip, + }; + + if is_channel_data(data) { + // Parse ChannelData and handle + match ChannelData::decode(data) { + Ok(channel_data) => { + if let Some(result) = self.handler.handle_channel_data(&channel_data, &ctx).await + { + self.execute_result(result, client_addr).await; + } + } + Err(e) => { + log::debug!( + "[TURN-UDP] failed to parse ChannelData from {}: {}", + client_addr, + e + ); + } + } + } else if is_stun_message(data) { + // Parse STUN message and handle + match StunMessage::decode(data) { + Ok(msg) => { + let results = self.handler.handle_message(&msg, &ctx).await; + for result in results { + self.execute_result(result, client_addr).await; + } + } + Err(e) => { + log::debug!( + "[TURN-UDP] failed to parse STUN message from {}: {}", + client_addr, + e + ); + } + } + } + // else: unknown message format, silently ignored + } + + /// Execute a single [`HandleResult`] action. + /// + /// Sends responses to clients, relays data to peers, and spawns relay + /// receiver tasks for new allocations. + async fn execute_result(&self, result: HandleResult, client_addr: SocketAddr) { + match result { + HandleResult::Response(data) => { + if let Err(e) = self.socket.send_to(&data, client_addr).await { + log::error!( + "[TURN-UDP] failed to send response to {}: {}", + client_addr, + e + ); + } + } + HandleResult::RelayToPeer { + peer_addr, + data, + relay_socket, + } => { + if let Err(e) = relay_socket.send_to(&data, peer_addr).await { + log::error!( + "[TURN-UDP] failed to relay to peer {}: {}", + peer_addr, + e + ); + } + } + HandleResult::ChannelDataToPeer { + peer_addr, + data, + relay_socket, + } => { + if let Err(e) = relay_socket.send_to(&data, peer_addr).await { + log::error!( + "[TURN-UDP] failed to send channel data to peer {}: {}", + peer_addr, + e + ); + } + } + HandleResult::AllocationCreated { + response, + relay_socket, + relay_addr, + five_tuple, + } => { + // Send the success response to the client + if let Err(e) = self.socket.send_to(&response, client_addr).await { + log::error!( + "[TURN-UDP] failed to send allocate response to {}: {}", + client_addr, + e + ); + } + + // Spawn a relay receiver task for this allocation + let main_socket = Arc::clone(&self.socket); + let allocations = Arc::clone(&self.allocations); + let ft = five_tuple.clone(); + log::debug!( + "[TURN-UDP] spawning relay receiver for {} (relay {})", + client_addr, + relay_addr + ); + tokio::spawn(async move { + relay_receiver_task( + relay_socket, + relay_addr, + ft, + main_socket, + allocations, + ) + .await; + }); + } + HandleResult::None => {} + } + } +} + +// --------------------------------------------------------------------------- +// Relay receiver task +// --------------------------------------------------------------------------- + +/// Relay receiver task for a single allocation. +/// +/// Reads data arriving on the allocation's relay socket from peers and +/// forwards it to the client. The data is wrapped as: +/// - **ChannelData** if there's a channel binding for the peer address +/// - **STUN Data indication** otherwise (with XOR-PEER-ADDRESS and DATA) +/// +/// The task runs until the relay socket encounters an unrecoverable error +/// or the allocation is cleaned up (at which point recv_from will fail +/// because the socket is dropped). +/// +/// This function is `pub` so it can be reused by the TCP listener for +/// TCP-originated allocations that still use UDP relay sockets. +pub async fn relay_receiver_task( + relay_socket: Arc, + relay_addr: SocketAddr, + five_tuple: FiveTuple, + main_socket: Arc, + allocations: Arc, +) { + let mut buf = vec![0u8; 65536]; + let client_addr = five_tuple.client_addr; + + loop { + match relay_socket.recv_from(&mut buf).await { + Ok((len, peer_addr)) => { + let data = &buf[..len]; + + // Check that a permission exists for this peer's IP + if !allocations + .has_permission(&five_tuple, &peer_addr.ip()) + .await + { + log::debug!( + "[TURN-RELAY] dropping packet from {} to relay {}: no permission", + peer_addr, + relay_addr + ); + continue; + } + + // Check if there's a channel binding for this peer + let wrapped = if let Some(channel_number) = allocations + .get_channel_for_peer(&five_tuple, &peer_addr) + .await + { + // Wrap as ChannelData + let cd = ChannelData { + channel_number, + data: data.to_vec(), + }; + cd.encode() + } else { + // Wrap as a STUN Data indication + build_data_indication(peer_addr, data) + }; + + // Send to the client via the main UDP socket + if let Err(e) = main_socket.send_to(&wrapped, client_addr).await { + log::error!( + "[TURN-RELAY] failed to send to client {}: {}", + client_addr, + e + ); + } + } + Err(e) => { + // Socket error — likely the allocation was cleaned up and the + // socket was dropped. Exit the task. + log::debug!( + "[TURN-RELAY] relay socket {} recv error (allocation likely expired): {}", + relay_addr, + e + ); + break; + } + } + } + + log::debug!( + "[TURN-RELAY] relay receiver task for {} exiting", + relay_addr + ); +} + +/// Build a STUN Data indication message (Method::Data, Class::Indication) +/// with XOR-PEER-ADDRESS and DATA attributes. +/// +/// Data indications are used when there's no channel binding for the peer. +/// Per RFC 5766 §10.3, they don't require MESSAGE-INTEGRITY. +fn build_data_indication(peer_addr: SocketAddr, data: &[u8]) -> Vec { + let mut msg = StunMessage::new_random(MessageType::new(Method::Data, Class::Indication)); + + msg.add_attribute(StunAttribute::XorPeerAddress(peer_addr)); + msg.add_attribute(StunAttribute::Data(data.to_vec())); + + // Add FINGERPRINT for demultiplexing + let fp_bytes = msg.encode_for_fingerprint(); + let fingerprint = compute_fingerprint(&fp_bytes); + msg.add_attribute(StunAttribute::Fingerprint(fingerprint)); + + msg.encode() +}