feat(turn): add embedded TURN server with credential service

Integrate a TURN/STUN server into the relay for NAT traversal of
WebRTC connections. Clients request time-limited HMAC-SHA1 credentials
over a new libp2p request-response protocol and then talk to the TURN
server directly via UDP/TCP.

Key changes:
- Add `turn` module with server, credentials, and configuration
- Register `/dusk/turn-credentials/1.0.0` request-response protocol
  so clients can obtain time-limited TURN credentials (24h TTL)
- Expose TURN signaling (3478/udp+tcp) and relay allocation ports
  (49152-65535/udp) in Dockerfile and docker-compose
- Add TURN-related environment variables for public IP, shared secret,
  realm, port ranges, and allocation limits
- Validate directory display_name (1-64 chars) and return typed errors
- Restrict keypair file permissions to 0600 on Unix
This commit is contained in:
cloudwithax 2026-02-24 20:57:05 -05:00
parent b29039557a
commit ea21aa55b6
14 changed files with 6172 additions and 24 deletions

View File

@ -48,9 +48,16 @@ USER dusk
# set working directory
WORKDIR /data
# expose the default relay port
# expose the default relay port (libp2p)
EXPOSE 4001
# expose TURN server ports (UDP + TCP signaling)
EXPOSE 3478/udp
EXPOSE 3478/tcp
# expose TURN relay allocation port range (UDP)
EXPOSE 49152-65535/udp
# persist keypair and data to the volume-mounted /data directory
# XDG_DATA_HOME tells the directories crate to resolve paths under /data
# so the keypair ends up at /data/dusk-relay/keypair instead of ~/.local/share
@ -61,6 +68,19 @@ VOLUME /data
ENV RUST_LOG=info
ENV DUSK_RELAY_PORT=4001
# TURN server environment variables
ENV DUSK_TURN_ENABLED=true
ENV DUSK_TURN_PUBLIC_IP=""
ENV DUSK_TURN_SECRET=""
ENV DUSK_TURN_UDP_PORT=3478
ENV DUSK_TURN_TCP_PORT=3478
ENV DUSK_TURN_REALM=duskchat.app
ENV DUSK_TURN_PORT_RANGE_START=49152
ENV DUSK_TURN_PORT_RANGE_END=65535
ENV DUSK_TURN_MAX_ALLOCATIONS=1000
ENV DUSK_TURN_MAX_PER_USER=10
ENV DUSK_TURN_PUBLIC_HOST=""
# health check to verify the relay is listening
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD timeout 5 bash -c 'cat < /dev/null > /dev/tcp/0.0.0.0/${DUSK_RELAY_PORT:-4001}' || exit 1

View File

@ -8,10 +8,32 @@ services:
container_name: dusk-relay
restart: unless-stopped
ports:
# libp2p relay port
- "4001:4001"
# TURN server signaling (UDP + TCP)
- "3478:3478/udp"
- "3478:3478/tcp"
# TURN relay allocation ports (UDP)
# NOTE: using a smaller range (49152-50000) for Docker; adjust as needed
- "49152-50000:49152-50000/udp"
environment:
- RUST_LOG=info
# libp2p relay
- DUSK_RELAY_PORT=4001
- DUSK_PEER_RELAYS=
- KLIPY_API_KEY=
# TURN server
- DUSK_TURN_ENABLED=true
- DUSK_TURN_PUBLIC_IP=${DUSK_TURN_PUBLIC_IP:?Set DUSK_TURN_PUBLIC_IP to your server public IP}
- DUSK_TURN_SECRET=${DUSK_TURN_SECRET:?Set DUSK_TURN_SECRET for HMAC credentials}
- DUSK_TURN_UDP_PORT=3478
- DUSK_TURN_TCP_PORT=3478
- DUSK_TURN_REALM=duskchat.app
- DUSK_TURN_PORT_RANGE_START=49152
- DUSK_TURN_PORT_RANGE_END=50000
- DUSK_TURN_MAX_ALLOCATIONS=1000
- DUSK_TURN_MAX_PER_USER=10
- DUSK_TURN_PUBLIC_HOST=${DUSK_TURN_PUBLIC_HOST:-}
volumes:
# persist the relay's keypair so peer id stays stable across restarts
- dusk-relay-data:/data

View File

@ -25,6 +25,8 @@
// t3.large (8GB): 20,000 max connections
// c6i.xlarge: 50,000 max connections (with kernel tuning)
mod turn;
use std::collections::{HashMap, VecDeque};
use std::path::PathBuf;
use std::time::{Duration, Instant};
@ -63,6 +65,8 @@ struct RelayBehaviour {
gif_service: cbor::Behaviour<GifRequest, GifResponse>,
// persistent directory service - clients register/search peer profiles
directory_service: cbor::Behaviour<DirectoryRequest, DirectoryResponse>,
// TURN credential service - clients request time-limited TURN server credentials
turn_credentials: cbor::Behaviour<TurnCredentialRequest, TurnCredentialResponse>,
}
// ---- gif protocol ----
@ -227,6 +231,7 @@ pub enum DirectoryRequest {
pub enum DirectoryResponse {
Ok,
Results(Vec<DirectoryProfileEntry>),
Error(String),
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
@ -237,6 +242,28 @@ pub struct DirectoryProfileEntry {
}
// ---- end directory protocol ----
// ---- TURN credential protocol ----
// clients request time-limited TURN credentials over libp2p request-response.
// the relay generates HMAC-SHA1 credentials that the client then uses when
// talking to the TURN server directly via UDP/TCP (separate from the libp2p swarm).
const TURN_CREDENTIALS_PROTOCOL: StreamProtocol =
StreamProtocol::new("/dusk/turn-credentials/1.0.0");
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TurnCredentialRequest {
pub peer_id: String,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TurnCredentialResponse {
pub username: String,
pub password: String,
pub ttl: u64,
pub uris: Vec<String>, // TURN server URIs like ["turn:relay.duskchat.app:3478", "turn:relay.duskchat.app:3478?transport=tcp"]
}
// ---- end TURN credential protocol ----
// fetch from klipy and normalize into our GifResult format
async fn fetch_klipy(
http: &reqwest::Client,
@ -342,6 +369,18 @@ fn load_or_generate_keypair() -> libp2p::identity::Keypair {
if let Err(e) = std::fs::write(&path, &bytes) {
log::warn!("failed to persist keypair: {}", e);
} else {
// restrict keypair file to owner-only read/write (0600) so other
// users on a shared server cannot read the relay's private key
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
if let Err(e) = std::fs::set_permissions(
&path,
std::fs::Permissions::from_mode(0o600),
) {
log::warn!("failed to set keypair file permissions: {}", e);
}
}
log::info!("saved new keypair to {}", path.display());
}
}
@ -460,12 +499,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
)
.expect("valid gossipsub behaviour");
// read connection limit (max total concurrent connections across all peers)
let max_connections = std::env::var("DUSK_MAX_CONNECTIONS")
.ok()
.and_then(|c| c.parse().ok())
.unwrap_or(10_000);
RelayBehaviour {
relay: relay::Behaviour::new(peer_id, relay::Config::default()),
rendezvous: rendezvous::server::Behaviour::new(
@ -497,6 +530,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
request_response::Config::default()
.with_request_timeout(Duration::from_secs(15)),
),
// TURN credential service - clients request time-limited TURN credentials
turn_credentials: cbor::Behaviour::new(
[(TURN_CREDENTIALS_PROTOCOL, ProtocolSupport::Full)],
request_response::Config::default()
.with_request_timeout(Duration::from_secs(15)),
),
}
})?
.with_swarm_config(|cfg| cfg.with_idle_connection_timeout(Duration::from_secs(300)))
@ -512,6 +551,40 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let canonical_addr = format!("/ip4/0.0.0.0/tcp/{}/p2p/{}", port, local_peer_id);
println!("\n relay address: {}\n", canonical_addr);
// ---- TURN server startup ----
// get the shared secret for credential generation (shared between TURN server and credential protocol)
let turn_shared_secret: Vec<u8> = std::env::var("DUSK_TURN_SECRET")
.unwrap_or_else(|_| {
eprintln!("[TURN] WARNING: DUSK_TURN_SECRET not set, generating random secret");
eprintln!("[TURN] This means credentials won't survive restarts and won't work across multiple relay instances");
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos();
format!("dusk-turn-random-{}", timestamp)
})
.into_bytes();
// start TURN server if enabled (runs on separate UDP/TCP ports from the libp2p swarm)
let _turn_handle = if turn::TurnServerConfig::is_enabled() {
let turn_config = turn::TurnServerConfig::from_env();
let turn_server = turn::TurnServer::new(turn_config);
match turn_server.run().await {
Ok(handle) => {
println!("[RELAY] TURN server started");
Some(handle)
}
Err(e) => {
eprintln!("[RELAY] Failed to start TURN server: {}", e);
None
}
}
} else {
println!("[RELAY] TURN server disabled");
None
};
// ---- end TURN server startup ----
// subscribe to the relay federation gossip topic
let federation_topic = gossipsub::IdentTopic::new("dusk/relay/federation");
swarm
@ -896,23 +969,36 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
)) => {
let response = match request {
DirectoryRequest::Register { display_name } => {
let ts = now_secs();
let result = dir_db.execute(
"INSERT INTO peer_profiles (peer_id, display_name, last_seen, registered_at)
VALUES (?1, ?2, ?3, ?3)
ON CONFLICT(peer_id) DO UPDATE SET
display_name = excluded.display_name,
last_seen = excluded.last_seen",
params![peer.to_string(), display_name, ts as i64],
);
match result {
Ok(_) => {
log::info!("directory: registered peer {} as '{}'", peer, display_name);
DirectoryResponse::Ok
}
Err(e) => {
log::warn!("directory: failed to register {}: {}", peer, e);
DirectoryResponse::Ok
// validate display_name: reject empty or excessively long names
let trimmed_name = display_name.trim();
if trimmed_name.is_empty() || trimmed_name.len() > 64 {
log::warn!(
"directory: rejected registration from {} (name length {})",
peer,
trimmed_name.len()
);
DirectoryResponse::Error(
"display_name must be 1-64 characters".to_string(),
)
} else {
let ts = now_secs();
let result = dir_db.execute(
"INSERT INTO peer_profiles (peer_id, display_name, last_seen, registered_at)
VALUES (?1, ?2, ?3, ?3)
ON CONFLICT(peer_id) DO UPDATE SET
display_name = excluded.display_name,
last_seen = excluded.last_seen",
params![peer.to_string(), trimmed_name, ts as i64],
);
match result {
Ok(_) => {
log::info!("directory: registered peer {} as '{}'", peer, trimmed_name);
DirectoryResponse::Ok
}
Err(e) => {
log::warn!("directory: failed to register {}: {}", peer, e);
DirectoryResponse::Error(format!("registration failed: {}", e))
}
}
}
}
@ -974,6 +1060,85 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// ignore outbound and other directory service events
SwarmEvent::Behaviour(RelayBehaviourEvent::DirectoryService(_)) => {}
// ---- TURN credential service ----
// clients request time-limited credentials for the TURN server
SwarmEvent::Behaviour(RelayBehaviourEvent::TurnCredentials(
request_response::Event::Message {
peer,
message:
request_response::Message::Request {
request, channel, ..
},
..
},
)) => {
let peer_id_str = request.peer_id.clone();
// generate time-limited credentials using the shared secret
let (username, password) = turn::credentials::generate_credentials(
&peer_id_str,
&turn_shared_secret,
86400, // 24 hours
);
// build TURN URIs using the relay's public hostname/IP
let turn_host = std::env::var("DUSK_TURN_PUBLIC_HOST")
.unwrap_or_else(|_| {
std::env::var("DUSK_TURN_PUBLIC_IP")
.unwrap_or_else(|_| "relay.duskchat.app".to_string())
});
let turn_port: u16 = std::env::var("DUSK_TURN_UDP_PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(3478);
let response = TurnCredentialResponse {
username,
password,
ttl: 86400,
uris: vec![
format!("turn:{}:{}", turn_host, turn_port),
format!("turn:{}:{}?transport=tcp", turn_host, turn_port),
],
};
if swarm
.behaviour_mut()
.turn_credentials
.send_response(channel, response)
.is_err()
{
log::warn!("[TURN-CREDS] failed to send credentials to {:?}", peer);
}
log::info!(
"[TURN-CREDS] generated credentials for peer {} (requested by {:?})",
peer_id_str,
peer
);
}
// we're the server so we don't send requests, but handle all variants
SwarmEvent::Behaviour(RelayBehaviourEvent::TurnCredentials(
request_response::Event::Message {
message: request_response::Message::Response { .. },
..
},
)) => {}
SwarmEvent::Behaviour(RelayBehaviourEvent::TurnCredentials(
request_response::Event::OutboundFailure { .. },
)) => {}
SwarmEvent::Behaviour(RelayBehaviourEvent::TurnCredentials(
request_response::Event::InboundFailure { peer, error, .. },
)) => {
log::warn!(
"[TURN-CREDS] inbound failure from {:?}: {:?}",
peer,
error
);
}
SwarmEvent::Behaviour(RelayBehaviourEvent::TurnCredentials(
request_response::Event::ResponseSent { .. },
)) => {}
_ => {}
}
}

1007
src/turn/allocation.rs Normal file

File diff suppressed because it is too large Load Diff

799
src/turn/attributes.rs Normal file
View File

@ -0,0 +1,799 @@
// STUN/TURN attribute types and TLV encoding/decoding
//
// Implements attributes per RFC 5389 (STUN) and RFC 5766 (TURN).
// Each attribute is a Type-Length-Value (TLV) structure:
// Type: 2 bytes (attribute type code)
// Length: 2 bytes (value length, excluding padding)
// Value: variable (padded to 4-byte boundary with zeros)
//
// XOR-encoded address attributes (XOR-MAPPED-ADDRESS, XOR-PEER-ADDRESS,
// XOR-RELAYED-ADDRESS) use the STUN magic cookie and transaction ID
// to XOR the address bytes, preventing NAT ALGs from rewriting them.
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use crate::turn::error::TurnError;
use crate::turn::stun::MAGIC_COOKIE;
// ---------------------------------------------------------------------------
// Attribute type codes
// ---------------------------------------------------------------------------
/// RFC 5389 STUN attribute type codes
pub const ATTR_MAPPED_ADDRESS: u16 = 0x0001;
pub const ATTR_USERNAME: u16 = 0x0006;
pub const ATTR_MESSAGE_INTEGRITY: u16 = 0x0008;
pub const ATTR_ERROR_CODE: u16 = 0x0009;
pub const ATTR_UNKNOWN_ATTRIBUTES: u16 = 0x000A;
pub const ATTR_REALM: u16 = 0x0014;
pub const ATTR_NONCE: u16 = 0x0015;
pub const ATTR_XOR_MAPPED_ADDRESS: u16 = 0x0020;
pub const ATTR_SOFTWARE: u16 = 0x8022;
pub const ATTR_FINGERPRINT: u16 = 0x8028;
/// RFC 5766 TURN attribute type codes
pub const ATTR_CHANNEL_NUMBER: u16 = 0x000C;
pub const ATTR_LIFETIME: u16 = 0x000D;
pub const ATTR_XOR_PEER_ADDRESS: u16 = 0x0012;
pub const ATTR_DATA: u16 = 0x0013;
pub const ATTR_XOR_RELAYED_ADDRESS: u16 = 0x0016;
pub const ATTR_REQUESTED_ADDRESS_FAMILY: u16 = 0x0017;
pub const ATTR_EVEN_PORT: u16 = 0x0018;
pub const ATTR_REQUESTED_TRANSPORT: u16 = 0x0019;
pub const ATTR_DONT_FRAGMENT: u16 = 0x001A;
// ---------------------------------------------------------------------------
// Address family constants
// ---------------------------------------------------------------------------
const ADDR_FAMILY_IPV4: u8 = 0x01;
const ADDR_FAMILY_IPV6: u8 = 0x02;
// ---------------------------------------------------------------------------
// StunAttribute enum
// ---------------------------------------------------------------------------
/// A parsed STUN or TURN attribute.
///
/// Each variant corresponds to a specific attribute type. Unknown attributes
/// are preserved as raw bytes for forwarding or diagnostic purposes.
#[derive(Debug, Clone)]
pub enum StunAttribute {
// ---- RFC 5389 STUN attributes ----
/// MAPPED-ADDRESS (0x0001): The reflexive transport address of the client
/// as seen by the server. Uses plain (non-XOR) encoding.
MappedAddress(SocketAddr),
/// USERNAME (0x0006): The username for message integrity, encoded as UTF-8.
Username(String),
/// MESSAGE-INTEGRITY (0x0008): 20-byte HMAC-SHA1 over the STUN message
/// (up to but not including this attribute).
MessageIntegrity([u8; 20]),
/// ERROR-CODE (0x0009): Error response code and human-readable reason phrase.
/// Code is in range 300-699 per RFC 5389 §15.6.
ErrorCode { code: u16, reason: String },
/// UNKNOWN-ATTRIBUTES (0x000A): List of attribute types that the server
/// did not understand. Used in 420 Unknown Attribute error responses.
UnknownAttributes(Vec<u16>),
/// REALM (0x0014): The authentication realm, encoded as UTF-8.
/// Used with long-term credential mechanism.
Realm(String),
/// NONCE (0x0015): A server-generated nonce for replay protection.
Nonce(String),
/// XOR-MAPPED-ADDRESS (0x0020): Same as MAPPED-ADDRESS but XOR-encoded
/// with the magic cookie (and transaction ID for IPv6) to prevent
/// NAT ALG interference.
XorMappedAddress(SocketAddr),
/// SOFTWARE (0x8022): Textual description of the software being used.
/// Informational only.
Software(String),
/// FINGERPRINT (0x8028): CRC32 of the STUN message XORed with 0x5354554e.
/// Used to demultiplex STUN from other protocols on the same port.
Fingerprint(u32),
// ---- RFC 5766 TURN attributes ----
/// CHANNEL-NUMBER (0x000C): Channel number for ChannelBind.
/// Must be in range 0x4000-0x7FFF.
ChannelNumber(u16),
/// LIFETIME (0x000D): Requested or granted allocation lifetime in seconds.
Lifetime(u32),
/// XOR-PEER-ADDRESS (0x0012): The peer address for Send/Data indications,
/// CreatePermission, and ChannelBind. XOR-encoded like XOR-MAPPED-ADDRESS.
XorPeerAddress(SocketAddr),
/// DATA (0x0013): The application data payload in Send/Data indications.
Data(Vec<u8>),
/// XOR-RELAYED-ADDRESS (0x0016): The relayed transport address allocated
/// by the server. XOR-encoded.
XorRelayedAddress(SocketAddr),
/// EVEN-PORT (0x0018): Requests an even port number for the relay address.
/// The boolean indicates the R bit (reserve next-higher port).
EvenPort(bool),
/// REQUESTED-TRANSPORT (0x0019): The transport protocol for the relay.
/// Value is an IANA protocol number (17 = UDP, 6 = TCP).
RequestedTransport(u8),
/// DONT-FRAGMENT (0x001A): Requests that the server set the DF bit
/// in outgoing UDP packets. No value (zero-length attribute).
DontFragment,
/// REQUESTED-ADDRESS-FAMILY (0x0017): Requests a specific address family
/// for the relayed address (0x01 = IPv4, 0x02 = IPv6).
RequestedAddressFamily(u8),
/// Unknown/unsupported attribute preserved as raw bytes.
Unknown { attr_type: u16, value: Vec<u8> },
}
// ---------------------------------------------------------------------------
// Decoding
// ---------------------------------------------------------------------------
/// Decode a single attribute from its type code and raw value bytes.
///
/// The `transaction_id` is needed for XOR address decoding.
pub fn decode_attribute(
attr_type: u16,
value: &[u8],
transaction_id: &[u8; 12],
) -> Result<StunAttribute, TurnError> {
match attr_type {
ATTR_MAPPED_ADDRESS => decode_mapped_address(value),
ATTR_USERNAME => decode_utf8_string(value).map(StunAttribute::Username),
ATTR_MESSAGE_INTEGRITY => decode_message_integrity(value),
ATTR_ERROR_CODE => decode_error_code(value),
ATTR_UNKNOWN_ATTRIBUTES => decode_unknown_attributes(value),
ATTR_REALM => decode_utf8_string(value).map(StunAttribute::Realm),
ATTR_NONCE => decode_utf8_string(value).map(StunAttribute::Nonce),
ATTR_XOR_MAPPED_ADDRESS => {
decode_xor_address(value, transaction_id).map(StunAttribute::XorMappedAddress)
}
ATTR_SOFTWARE => decode_utf8_string(value).map(StunAttribute::Software),
ATTR_FINGERPRINT => decode_fingerprint(value),
ATTR_CHANNEL_NUMBER => decode_channel_number(value),
ATTR_LIFETIME => decode_lifetime(value),
ATTR_XOR_PEER_ADDRESS => {
decode_xor_address(value, transaction_id).map(StunAttribute::XorPeerAddress)
}
ATTR_DATA => Ok(StunAttribute::Data(value.to_vec())),
ATTR_XOR_RELAYED_ADDRESS => {
decode_xor_address(value, transaction_id).map(StunAttribute::XorRelayedAddress)
}
ATTR_REQUESTED_ADDRESS_FAMILY => decode_requested_address_family(value),
ATTR_EVEN_PORT => decode_even_port(value),
ATTR_REQUESTED_TRANSPORT => decode_requested_transport(value),
ATTR_DONT_FRAGMENT => Ok(StunAttribute::DontFragment),
_ => Ok(StunAttribute::Unknown {
attr_type,
value: value.to_vec(),
}),
}
}
// ---------------------------------------------------------------------------
// Encoding
// ---------------------------------------------------------------------------
/// Encode a single attribute into TLV wire format with 4-byte padding.
///
/// Returns the complete TLV bytes: type (2) + length (2) + value + padding.
pub fn encode_attribute(attr: &StunAttribute, transaction_id: &[u8; 12]) -> Vec<u8> {
let (attr_type, value) = encode_attribute_value(attr, transaction_id);
let value_len = value.len();
let padded_len = (value_len + 3) & !3;
let mut buf = Vec::with_capacity(4 + padded_len);
buf.extend_from_slice(&attr_type.to_be_bytes());
buf.extend_from_slice(&(value_len as u16).to_be_bytes());
buf.extend_from_slice(&value);
// Pad to 4-byte boundary
let padding = padded_len - value_len;
for _ in 0..padding {
buf.push(0);
}
buf
}
/// Encode an attribute's value bytes (without the TLV header or padding).
/// Returns (attribute_type_code, value_bytes).
fn encode_attribute_value(attr: &StunAttribute, transaction_id: &[u8; 12]) -> (u16, Vec<u8>) {
match attr {
StunAttribute::MappedAddress(addr) => {
(ATTR_MAPPED_ADDRESS, encode_plain_address(addr))
}
StunAttribute::Username(s) => (ATTR_USERNAME, s.as_bytes().to_vec()),
StunAttribute::MessageIntegrity(hmac) => (ATTR_MESSAGE_INTEGRITY, hmac.to_vec()),
StunAttribute::ErrorCode { code, reason } => {
(ATTR_ERROR_CODE, encode_error_code(*code, reason))
}
StunAttribute::UnknownAttributes(types) => {
let mut buf = Vec::with_capacity(types.len() * 2);
for &t in types {
buf.extend_from_slice(&t.to_be_bytes());
}
(ATTR_UNKNOWN_ATTRIBUTES, buf)
}
StunAttribute::Realm(s) => (ATTR_REALM, s.as_bytes().to_vec()),
StunAttribute::Nonce(s) => (ATTR_NONCE, s.as_bytes().to_vec()),
StunAttribute::XorMappedAddress(addr) => {
(ATTR_XOR_MAPPED_ADDRESS, encode_xor_address(addr, transaction_id))
}
StunAttribute::Software(s) => (ATTR_SOFTWARE, s.as_bytes().to_vec()),
StunAttribute::Fingerprint(val) => (ATTR_FINGERPRINT, val.to_be_bytes().to_vec()),
StunAttribute::ChannelNumber(num) => {
// Channel number is 16 bits followed by 16 bits of RFFU (reserved)
let mut buf = vec![0u8; 4];
buf[0..2].copy_from_slice(&num.to_be_bytes());
(ATTR_CHANNEL_NUMBER, buf)
}
StunAttribute::Lifetime(secs) => (ATTR_LIFETIME, secs.to_be_bytes().to_vec()),
StunAttribute::XorPeerAddress(addr) => {
(ATTR_XOR_PEER_ADDRESS, encode_xor_address(addr, transaction_id))
}
StunAttribute::Data(data) => (ATTR_DATA, data.clone()),
StunAttribute::XorRelayedAddress(addr) => {
(ATTR_XOR_RELAYED_ADDRESS, encode_xor_address(addr, transaction_id))
}
StunAttribute::RequestedAddressFamily(family) => {
let mut buf = vec![0u8; 4];
buf[0] = *family;
(ATTR_REQUESTED_ADDRESS_FAMILY, buf)
}
StunAttribute::EvenPort(reserve) => {
let byte = if *reserve { 0x80 } else { 0x00 };
(ATTR_EVEN_PORT, vec![byte])
}
StunAttribute::RequestedTransport(proto) => {
// Protocol number in first byte, followed by 3 bytes RFFU
let mut buf = vec![0u8; 4];
buf[0] = *proto;
(ATTR_REQUESTED_TRANSPORT, buf)
}
StunAttribute::DontFragment => (ATTR_DONT_FRAGMENT, vec![]),
StunAttribute::Unknown { attr_type, value } => (*attr_type, value.clone()),
}
}
// ---------------------------------------------------------------------------
// Address encoding/decoding helpers
// ---------------------------------------------------------------------------
/// Decode a plain (non-XOR) MAPPED-ADDRESS attribute value.
///
/// Format: 1 byte reserved, 1 byte family, 2 bytes port, 4/16 bytes address
fn decode_mapped_address(value: &[u8]) -> Result<StunAttribute, TurnError> {
if value.len() < 4 {
return Err(TurnError::StunParseError(
"MAPPED-ADDRESS too short".into(),
));
}
let family = value[1];
let port = u16::from_be_bytes([value[2], value[3]]);
let addr = match family {
ADDR_FAMILY_IPV4 => {
if value.len() < 8 {
return Err(TurnError::StunParseError(
"MAPPED-ADDRESS IPv4 too short".into(),
));
}
let ip = Ipv4Addr::new(value[4], value[5], value[6], value[7]);
SocketAddr::new(IpAddr::V4(ip), port)
}
ADDR_FAMILY_IPV6 => {
if value.len() < 20 {
return Err(TurnError::StunParseError(
"MAPPED-ADDRESS IPv6 too short".into(),
));
}
let mut octets = [0u8; 16];
octets.copy_from_slice(&value[4..20]);
let ip = Ipv6Addr::from(octets);
SocketAddr::new(IpAddr::V6(ip), port)
}
_ => {
return Err(TurnError::StunParseError(format!(
"unknown address family: 0x{:02x}",
family
)));
}
};
Ok(StunAttribute::MappedAddress(addr))
}
/// Encode a plain (non-XOR) address into wire format.
fn encode_plain_address(addr: &SocketAddr) -> Vec<u8> {
match addr {
SocketAddr::V4(v4) => {
let mut buf = vec![0u8; 8];
buf[0] = 0; // reserved
buf[1] = ADDR_FAMILY_IPV4;
buf[2..4].copy_from_slice(&v4.port().to_be_bytes());
buf[4..8].copy_from_slice(&v4.ip().octets());
buf
}
SocketAddr::V6(v6) => {
let mut buf = vec![0u8; 20];
buf[0] = 0; // reserved
buf[1] = ADDR_FAMILY_IPV6;
buf[2..4].copy_from_slice(&v6.port().to_be_bytes());
buf[4..20].copy_from_slice(&v6.ip().octets());
buf
}
}
}
/// Decode an XOR-encoded address (XOR-MAPPED-ADDRESS, XOR-PEER-ADDRESS,
/// XOR-RELAYED-ADDRESS) per RFC 5389 §15.2.
///
/// For IPv4: port is XORed with top 16 bits of magic cookie;
/// address is XORed with magic cookie.
/// For IPv6: port is XORed with top 16 bits of magic cookie;
/// address is XORed with magic cookie || transaction ID (16 bytes).
fn decode_xor_address(
value: &[u8],
transaction_id: &[u8; 12],
) -> Result<SocketAddr, TurnError> {
if value.len() < 4 {
return Err(TurnError::StunParseError(
"XOR address attribute too short".into(),
));
}
let family = value[1];
let x_port = u16::from_be_bytes([value[2], value[3]]);
let cookie_bytes = MAGIC_COOKIE.to_be_bytes();
let port = x_port ^ u16::from_be_bytes([cookie_bytes[0], cookie_bytes[1]]);
match family {
ADDR_FAMILY_IPV4 => {
if value.len() < 8 {
return Err(TurnError::StunParseError(
"XOR-MAPPED-ADDRESS IPv4 too short".into(),
));
}
let x_addr = u32::from_be_bytes([value[4], value[5], value[6], value[7]]);
let addr = x_addr ^ MAGIC_COOKIE;
let ip = Ipv4Addr::from(addr);
Ok(SocketAddr::new(IpAddr::V4(ip), port))
}
ADDR_FAMILY_IPV6 => {
if value.len() < 20 {
return Err(TurnError::StunParseError(
"XOR-MAPPED-ADDRESS IPv6 too short".into(),
));
}
// Build the 16-byte XOR mask: magic cookie (4 bytes) + transaction ID (12 bytes)
let mut xor_mask = [0u8; 16];
xor_mask[0..4].copy_from_slice(&cookie_bytes);
xor_mask[4..16].copy_from_slice(transaction_id);
let mut addr_bytes = [0u8; 16];
for i in 0..16 {
addr_bytes[i] = value[4 + i] ^ xor_mask[i];
}
let ip = Ipv6Addr::from(addr_bytes);
Ok(SocketAddr::new(IpAddr::V6(ip), port))
}
_ => Err(TurnError::StunParseError(format!(
"unknown address family in XOR address: 0x{:02x}",
family
))),
}
}
/// Encode an address using XOR encoding per RFC 5389 §15.2.
fn encode_xor_address(addr: &SocketAddr, transaction_id: &[u8; 12]) -> Vec<u8> {
let cookie_bytes = MAGIC_COOKIE.to_be_bytes();
let x_port = addr.port() ^ u16::from_be_bytes([cookie_bytes[0], cookie_bytes[1]]);
match addr {
SocketAddr::V4(v4) => {
let mut buf = vec![0u8; 8];
buf[0] = 0; // reserved
buf[1] = ADDR_FAMILY_IPV4;
buf[2..4].copy_from_slice(&x_port.to_be_bytes());
let x_addr = u32::from_be_bytes(v4.ip().octets()) ^ MAGIC_COOKIE;
buf[4..8].copy_from_slice(&x_addr.to_be_bytes());
buf
}
SocketAddr::V6(v6) => {
let mut buf = vec![0u8; 20];
buf[0] = 0; // reserved
buf[1] = ADDR_FAMILY_IPV6;
buf[2..4].copy_from_slice(&x_port.to_be_bytes());
// XOR mask: magic cookie (4 bytes) + transaction ID (12 bytes)
let mut xor_mask = [0u8; 16];
xor_mask[0..4].copy_from_slice(&cookie_bytes);
xor_mask[4..16].copy_from_slice(transaction_id);
let octets = v6.ip().octets();
for i in 0..16 {
buf[4 + i] = octets[i] ^ xor_mask[i];
}
buf
}
}
}
// ---------------------------------------------------------------------------
// Specific attribute decoders
// ---------------------------------------------------------------------------
/// Decode UTF-8 string from raw bytes.
fn decode_utf8_string(value: &[u8]) -> Result<String, TurnError> {
String::from_utf8(value.to_vec()).map_err(|e| {
TurnError::StunParseError(format!("invalid UTF-8 in attribute: {}", e))
})
}
/// Decode MESSAGE-INTEGRITY (exactly 20 bytes HMAC-SHA1).
fn decode_message_integrity(value: &[u8]) -> Result<StunAttribute, TurnError> {
if value.len() != 20 {
return Err(TurnError::StunParseError(format!(
"MESSAGE-INTEGRITY must be 20 bytes, got {}",
value.len()
)));
}
let mut hmac = [0u8; 20];
hmac.copy_from_slice(value);
Ok(StunAttribute::MessageIntegrity(hmac))
}
/// Decode ERROR-CODE attribute per RFC 5389 §15.6.
///
/// Format: 2 bytes reserved, 1 byte with class (hundreds digit) in bits 0-2,
/// 1 byte with number (tens+units), followed by UTF-8 reason phrase.
fn decode_error_code(value: &[u8]) -> Result<StunAttribute, TurnError> {
if value.len() < 4 {
return Err(TurnError::StunParseError(
"ERROR-CODE too short".into(),
));
}
let class = (value[2] & 0x07) as u16;
let number = value[3] as u16;
let code = class * 100 + number;
let reason = if value.len() > 4 {
String::from_utf8_lossy(&value[4..]).into_owned()
} else {
String::new()
};
Ok(StunAttribute::ErrorCode { code, reason })
}
/// Encode ERROR-CODE value bytes per RFC 5389 §15.6.
fn encode_error_code(code: u16, reason: &str) -> Vec<u8> {
let class = (code / 100) as u8;
let number = (code % 100) as u8;
let reason_bytes = reason.as_bytes();
let mut buf = Vec::with_capacity(4 + reason_bytes.len());
buf.push(0); // reserved
buf.push(0); // reserved
buf.push(class & 0x07);
buf.push(number);
buf.extend_from_slice(reason_bytes);
buf
}
/// Decode UNKNOWN-ATTRIBUTES (list of 16-bit attribute types).
fn decode_unknown_attributes(value: &[u8]) -> Result<StunAttribute, TurnError> {
if value.len() % 2 != 0 {
return Err(TurnError::StunParseError(
"UNKNOWN-ATTRIBUTES length must be even".into(),
));
}
let mut types = Vec::with_capacity(value.len() / 2);
for chunk in value.chunks_exact(2) {
types.push(u16::from_be_bytes([chunk[0], chunk[1]]));
}
Ok(StunAttribute::UnknownAttributes(types))
}
/// Decode FINGERPRINT (4-byte CRC32 XOR value).
fn decode_fingerprint(value: &[u8]) -> Result<StunAttribute, TurnError> {
if value.len() != 4 {
return Err(TurnError::StunParseError(format!(
"FINGERPRINT must be 4 bytes, got {}",
value.len()
)));
}
let val = u32::from_be_bytes([value[0], value[1], value[2], value[3]]);
Ok(StunAttribute::Fingerprint(val))
}
/// Decode CHANNEL-NUMBER (16-bit number + 16-bit RFFU).
fn decode_channel_number(value: &[u8]) -> Result<StunAttribute, TurnError> {
if value.len() < 4 {
return Err(TurnError::StunParseError(
"CHANNEL-NUMBER too short".into(),
));
}
let num = u16::from_be_bytes([value[0], value[1]]);
Ok(StunAttribute::ChannelNumber(num))
}
/// Decode LIFETIME (32-bit seconds).
fn decode_lifetime(value: &[u8]) -> Result<StunAttribute, TurnError> {
if value.len() < 4 {
return Err(TurnError::StunParseError("LIFETIME too short".into()));
}
let secs = u32::from_be_bytes([value[0], value[1], value[2], value[3]]);
Ok(StunAttribute::Lifetime(secs))
}
/// Decode REQUESTED-ADDRESS-FAMILY (1 byte family + 3 bytes RFFU).
fn decode_requested_address_family(value: &[u8]) -> Result<StunAttribute, TurnError> {
if value.is_empty() {
return Err(TurnError::StunParseError(
"REQUESTED-ADDRESS-FAMILY empty".into(),
));
}
Ok(StunAttribute::RequestedAddressFamily(value[0]))
}
/// Decode EVEN-PORT (1 byte, R bit in most significant bit).
fn decode_even_port(value: &[u8]) -> Result<StunAttribute, TurnError> {
if value.is_empty() {
return Err(TurnError::StunParseError("EVEN-PORT empty".into()));
}
let reserve = (value[0] & 0x80) != 0;
Ok(StunAttribute::EvenPort(reserve))
}
/// Decode REQUESTED-TRANSPORT (1 byte protocol + 3 bytes RFFU).
fn decode_requested_transport(value: &[u8]) -> Result<StunAttribute, TurnError> {
if value.is_empty() {
return Err(TurnError::StunParseError(
"REQUESTED-TRANSPORT empty".into(),
));
}
Ok(StunAttribute::RequestedTransport(value[0]))
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
#[test]
fn test_xor_mapped_address_ipv4_roundtrip() {
let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 100), 12345));
let txn_id = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
let encoded = encode_xor_address(&addr, &txn_id);
let decoded = decode_xor_address(&encoded, &txn_id).unwrap();
assert_eq!(decoded, addr);
}
#[test]
fn test_xor_mapped_address_ipv6_roundtrip() {
let ip = Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0, 0, 0x8a2e, 0x0370, 0x7334);
let addr = SocketAddr::V6(SocketAddrV6::new(ip, 54321, 0, 0));
let txn_id = [0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6];
let encoded = encode_xor_address(&addr, &txn_id);
let decoded = decode_xor_address(&encoded, &txn_id).unwrap();
assert_eq!(decoded, addr);
}
#[test]
fn test_plain_address_ipv4_roundtrip() {
let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 8080));
let encoded = encode_plain_address(&addr);
let decoded = decode_mapped_address(&encoded).unwrap();
if let StunAttribute::MappedAddress(decoded_addr) = decoded {
assert_eq!(decoded_addr, addr);
} else {
panic!("expected MappedAddress");
}
}
#[test]
fn test_error_code_roundtrip() {
let encoded = encode_error_code(401, "Unauthorized");
let decoded = decode_error_code(&encoded).unwrap();
if let StunAttribute::ErrorCode { code, reason } = decoded {
assert_eq!(code, 401);
assert_eq!(reason, "Unauthorized");
} else {
panic!("expected ErrorCode");
}
}
#[test]
fn test_error_code_438() {
let encoded = encode_error_code(438, "Stale Nonce");
let decoded = decode_error_code(&encoded).unwrap();
if let StunAttribute::ErrorCode { code, reason } = decoded {
assert_eq!(code, 438);
assert_eq!(reason, "Stale Nonce");
} else {
panic!("expected ErrorCode");
}
}
#[test]
fn test_attribute_tlv_encoding() {
let txn_id = [0u8; 12];
let attr = StunAttribute::Username("alice".into());
let encoded = encode_attribute(&attr, &txn_id);
// Type (2) + Length (2) + "alice" (5) + padding (3) = 12
assert_eq!(encoded.len(), 12);
assert_eq!(encoded[0..2], ATTR_USERNAME.to_be_bytes());
assert_eq!(encoded[2..4], 5u16.to_be_bytes());
assert_eq!(&encoded[4..9], b"alice");
assert_eq!(encoded[9], 0); // padding
assert_eq!(encoded[10], 0);
assert_eq!(encoded[11], 0);
}
#[test]
fn test_attribute_tlv_encoding_4_byte_aligned() {
let txn_id = [0u8; 12];
let attr = StunAttribute::Username("test".into());
let encoded = encode_attribute(&attr, &txn_id);
// Type (2) + Length (2) + "test" (4) = 8, already aligned
assert_eq!(encoded.len(), 8);
}
#[test]
fn test_lifetime_roundtrip() {
let attr = StunAttribute::Lifetime(600);
let txn_id = [0u8; 12];
let encoded = encode_attribute(&attr, &txn_id);
// Type (2) + Length (2) + value (4) = 8
assert_eq!(encoded.len(), 8);
let decoded = decode_attribute(ATTR_LIFETIME, &encoded[4..8], &txn_id).unwrap();
if let StunAttribute::Lifetime(secs) = decoded {
assert_eq!(secs, 600);
} else {
panic!("expected Lifetime");
}
}
#[test]
fn test_channel_number_roundtrip() {
let attr = StunAttribute::ChannelNumber(0x4000);
let txn_id = [0u8; 12];
let encoded = encode_attribute(&attr, &txn_id);
// Type (2) + Length (2) + value (4) = 8
assert_eq!(encoded.len(), 8);
let decoded = decode_attribute(ATTR_CHANNEL_NUMBER, &encoded[4..8], &txn_id).unwrap();
if let StunAttribute::ChannelNumber(num) = decoded {
assert_eq!(num, 0x4000);
} else {
panic!("expected ChannelNumber");
}
}
#[test]
fn test_requested_transport_udp() {
let attr = StunAttribute::RequestedTransport(17); // UDP
let txn_id = [0u8; 12];
let encoded = encode_attribute(&attr, &txn_id);
let decoded = decode_attribute(ATTR_REQUESTED_TRANSPORT, &encoded[4..8], &txn_id).unwrap();
if let StunAttribute::RequestedTransport(proto) = decoded {
assert_eq!(proto, 17);
} else {
panic!("expected RequestedTransport");
}
}
#[test]
fn test_dont_fragment() {
let attr = StunAttribute::DontFragment;
let txn_id = [0u8; 12];
let encoded = encode_attribute(&attr, &txn_id);
// Type (2) + Length (2) + no value = 4
assert_eq!(encoded.len(), 4);
assert_eq!(encoded[2..4], 0u16.to_be_bytes()); // length = 0
}
#[test]
fn test_unknown_attribute_preserved() {
let txn_id = [0u8; 12];
let decoded =
decode_attribute(0xFFFF, &[0x01, 0x02, 0x03], &txn_id).unwrap();
if let StunAttribute::Unknown { attr_type, value } = decoded {
assert_eq!(attr_type, 0xFFFF);
assert_eq!(value, vec![0x01, 0x02, 0x03]);
} else {
panic!("expected Unknown");
}
}
#[test]
fn test_message_integrity_decode() {
let hmac = [0xAA; 20];
let decoded = decode_message_integrity(&hmac).unwrap();
if let StunAttribute::MessageIntegrity(h) = decoded {
assert_eq!(h, [0xAA; 20]);
} else {
panic!("expected MessageIntegrity");
}
}
#[test]
fn test_message_integrity_wrong_length() {
let result = decode_message_integrity(&[0; 10]);
assert!(result.is_err());
}
#[test]
fn test_fingerprint_roundtrip() {
let val = 0xDEADBEEF_u32;
let encoded_value = val.to_be_bytes().to_vec();
let decoded = decode_fingerprint(&encoded_value).unwrap();
if let StunAttribute::Fingerprint(v) = decoded {
assert_eq!(v, val);
} else {
panic!("expected Fingerprint");
}
}
#[test]
fn test_even_port_reserve_bit() {
// R bit set
let decoded = decode_even_port(&[0x80]).unwrap();
if let StunAttribute::EvenPort(reserve) = decoded {
assert!(reserve);
} else {
panic!("expected EvenPort");
}
// R bit not set
let decoded = decode_even_port(&[0x00]).unwrap();
if let StunAttribute::EvenPort(reserve) = decoded {
assert!(!reserve);
} else {
panic!("expected EvenPort");
}
}
#[test]
fn test_unknown_attributes_list() {
let value = vec![0x00, 0x01, 0x00, 0x20]; // MAPPED-ADDRESS, XOR-MAPPED-ADDRESS
let decoded = decode_unknown_attributes(&value).unwrap();
if let StunAttribute::UnknownAttributes(types) = decoded {
assert_eq!(types, vec![0x0001, 0x0020]);
} else {
panic!("expected UnknownAttributes");
}
}
}

758
src/turn/credentials.rs Normal file
View File

@ -0,0 +1,758 @@
// TURN credential generation and validation
//
// Implements the HMAC-SHA1 time-limited credential mechanism per
// draft-uberti-behave-turn-rest-00, which is the standard used by
// all major WebRTC implementations (Chrome, Firefox, Safari).
//
// Credential format:
// Username: "{expiry_unix_timestamp}:{peer_id}"
// Password: Base64(HMAC-SHA1(shared_secret, username))
//
// Also implements MESSAGE-INTEGRITY computation per RFC 5389 §15.4
// using long-term credentials (key = MD5(username:realm:password)).
//
// Nonce generation uses HMAC-SHA1 with an embedded timestamp for
// stateless stale-nonce detection.
use std::time::{SystemTime, UNIX_EPOCH};
use crate::turn::error::TurnError;
// ---------------------------------------------------------------------------
// HMAC-SHA1 (RFC 2104)
// ---------------------------------------------------------------------------
/// Compute HMAC-SHA1(key, message) returning a 20-byte digest.
///
/// This is a self-contained implementation of HMAC-SHA1 that doesn't
/// depend on external crates. In production, replace with `hmac` + `sha1`
/// crates for better performance and constant-time comparison.
pub fn hmac_sha1(key: &[u8], message: &[u8]) -> [u8; 20] {
const BLOCK_SIZE: usize = 64;
const IPAD: u8 = 0x36;
const OPAD: u8 = 0x5C;
// If key > block size, hash it first
let key_block = if key.len() > BLOCK_SIZE {
let hashed = sha1_digest(key);
let mut block = [0u8; BLOCK_SIZE];
block[..20].copy_from_slice(&hashed);
block
} else {
let mut block = [0u8; BLOCK_SIZE];
block[..key.len()].copy_from_slice(key);
block
};
// Inner hash: SHA1(key XOR ipad || message)
let mut inner_input = Vec::with_capacity(BLOCK_SIZE + message.len());
for i in 0..BLOCK_SIZE {
inner_input.push(key_block[i] ^ IPAD);
}
inner_input.extend_from_slice(message);
let inner_hash = sha1_digest(&inner_input);
// Outer hash: SHA1(key XOR opad || inner_hash)
let mut outer_input = Vec::with_capacity(BLOCK_SIZE + 20);
for i in 0..BLOCK_SIZE {
outer_input.push(key_block[i] ^ OPAD);
}
outer_input.extend_from_slice(&inner_hash);
sha1_digest(&outer_input)
}
// ---------------------------------------------------------------------------
// SHA-1 (FIPS 180-4)
// ---------------------------------------------------------------------------
/// Compute SHA-1 digest of the input data, returning a 20-byte hash.
///
/// This is a self-contained implementation. In production, use the `sha1` crate.
fn sha1_digest(data: &[u8]) -> [u8; 20] {
let mut h0: u32 = 0x67452301;
let mut h1: u32 = 0xEFCDAB89;
let mut h2: u32 = 0x98BADCFE;
let mut h3: u32 = 0x10325476;
let mut h4: u32 = 0xC3D2E1F0;
// Pre-processing: add padding
let bit_len = (data.len() as u64) * 8;
let mut padded = data.to_vec();
padded.push(0x80);
while padded.len() % 64 != 56 {
padded.push(0x00);
}
padded.extend_from_slice(&bit_len.to_be_bytes());
// Process each 512-bit (64-byte) block
for block in padded.chunks_exact(64) {
let mut w = [0u32; 80];
for i in 0..16 {
w[i] = u32::from_be_bytes([
block[i * 4],
block[i * 4 + 1],
block[i * 4 + 2],
block[i * 4 + 3],
]);
}
for i in 16..80 {
w[i] = (w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16]).rotate_left(1);
}
let mut a = h0;
let mut b = h1;
let mut c = h2;
let mut d = h3;
let mut e = h4;
for i in 0..80 {
let (f, k) = match i {
0..=19 => ((b & c) | ((!b) & d), 0x5A827999_u32),
20..=39 => (b ^ c ^ d, 0x6ED9EBA1_u32),
40..=59 => ((b & c) | (b & d) | (c & d), 0x8F1BBCDC_u32),
60..=79 => (b ^ c ^ d, 0xCA62C1D6_u32),
_ => unreachable!(),
};
let temp = a
.rotate_left(5)
.wrapping_add(f)
.wrapping_add(e)
.wrapping_add(k)
.wrapping_add(w[i]);
e = d;
d = c;
c = b.rotate_left(30);
b = a;
a = temp;
}
h0 = h0.wrapping_add(a);
h1 = h1.wrapping_add(b);
h2 = h2.wrapping_add(c);
h3 = h3.wrapping_add(d);
h4 = h4.wrapping_add(e);
}
let mut result = [0u8; 20];
result[0..4].copy_from_slice(&h0.to_be_bytes());
result[4..8].copy_from_slice(&h1.to_be_bytes());
result[8..12].copy_from_slice(&h2.to_be_bytes());
result[12..16].copy_from_slice(&h3.to_be_bytes());
result[16..20].copy_from_slice(&h4.to_be_bytes());
result
}
// ---------------------------------------------------------------------------
// MD5 (RFC 1321) - needed for long-term credential key derivation
// ---------------------------------------------------------------------------
/// Compute MD5 digest of the input data, returning a 16-byte hash.
///
/// Used for the long-term credential key: `MD5(username:realm:password)`.
/// Self-contained implementation; in production use the `md-5` crate.
fn md5_digest(data: &[u8]) -> [u8; 16] {
// Initial state
let mut a0: u32 = 0x67452301;
let mut b0: u32 = 0xefcdab89;
let mut c0: u32 = 0x98badcfe;
let mut d0: u32 = 0x10325476;
// Per-round shift amounts
const S: [u32; 64] = [
7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22,
5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20,
4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23,
6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21,
];
// Pre-computed constants: floor(2^32 * |sin(i + 1)|)
const K: [u32; 64] = [
0xd76aa478, 0xe8c7b756, 0x242070db, 0xc1bdceee,
0xf57c0faf, 0x4787c62a, 0xa8304613, 0xfd469501,
0x698098d8, 0x8b44f7af, 0xffff5bb1, 0x895cd7be,
0x6b901122, 0xfd987193, 0xa679438e, 0x49b40821,
0xf61e2562, 0xc040b340, 0x265e5a51, 0xe9b6c7aa,
0xd62f105d, 0x02441453, 0xd8a1e681, 0xe7d3fbc8,
0x21e1cde6, 0xc33707d6, 0xf4d50d87, 0x455a14ed,
0xa9e3e905, 0xfcefa3f8, 0x676f02d9, 0x8d2a4c8a,
0xfffa3942, 0x8771f681, 0x6d9d6122, 0xfde5380c,
0xa4beea44, 0x4bdecfa9, 0xf6bb4b60, 0xbebfbc70,
0x289b7ec6, 0xeaa127fa, 0xd4ef3085, 0x04881d05,
0xd9d4d039, 0xe6db99e5, 0x1fa27cf8, 0xc4ac5665,
0xf4292244, 0x432aff97, 0xab9423a7, 0xfc93a039,
0x655b59c3, 0x8f0ccc92, 0xffeff47d, 0x85845dd1,
0x6fa87e4f, 0xfe2ce6e0, 0xa3014314, 0x4e0811a1,
0xf7537e82, 0xbd3af235, 0x2ad7d2bb, 0xeb86d391,
];
// Pre-processing: add padding
let bit_len = (data.len() as u64) * 8;
let mut padded = data.to_vec();
padded.push(0x80);
while padded.len() % 64 != 56 {
padded.push(0x00);
}
padded.extend_from_slice(&bit_len.to_le_bytes());
// Process each 512-bit block
for block in padded.chunks_exact(64) {
let mut m = [0u32; 16];
for i in 0..16 {
m[i] = u32::from_le_bytes([
block[i * 4],
block[i * 4 + 1],
block[i * 4 + 2],
block[i * 4 + 3],
]);
}
let mut a = a0;
let mut b = b0;
let mut c = c0;
let mut d = d0;
for i in 0..64 {
let (f, g) = match i {
0..=15 => ((b & c) | ((!b) & d), i),
16..=31 => ((d & b) | ((!d) & c), (5 * i + 1) % 16),
32..=47 => (b ^ c ^ d, (3 * i + 5) % 16),
48..=63 => (c ^ (b | (!d)), (7 * i) % 16),
_ => unreachable!(),
};
let temp = d;
d = c;
c = b;
b = b.wrapping_add(
(a.wrapping_add(f).wrapping_add(K[i]).wrapping_add(m[g]))
.rotate_left(S[i]),
);
a = temp;
}
a0 = a0.wrapping_add(a);
b0 = b0.wrapping_add(b);
c0 = c0.wrapping_add(c);
d0 = d0.wrapping_add(d);
}
let mut result = [0u8; 16];
result[0..4].copy_from_slice(&a0.to_le_bytes());
result[4..8].copy_from_slice(&b0.to_le_bytes());
result[8..12].copy_from_slice(&c0.to_le_bytes());
result[12..16].copy_from_slice(&d0.to_le_bytes());
result
}
// ---------------------------------------------------------------------------
// Base64 encoding (RFC 4648)
// ---------------------------------------------------------------------------
/// Encode bytes to base64 string using standard alphabet with padding.
fn base64_encode(data: &[u8]) -> String {
const ALPHABET: &[u8; 64] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut result = String::with_capacity((data.len() + 2) / 3 * 4);
let chunks = data.chunks(3);
for chunk in chunks {
let b0 = chunk[0] as u32;
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
let triple = (b0 << 16) | (b1 << 8) | b2;
result.push(ALPHABET[((triple >> 18) & 0x3F) as usize] as char);
result.push(ALPHABET[((triple >> 12) & 0x3F) as usize] as char);
if chunk.len() > 1 {
result.push(ALPHABET[((triple >> 6) & 0x3F) as usize] as char);
} else {
result.push('=');
}
if chunk.len() > 2 {
result.push(ALPHABET[(triple & 0x3F) as usize] as char);
} else {
result.push('=');
}
}
result
}
/// Decode a base64 string to bytes. Returns None on invalid input.
fn base64_decode(input: &str) -> Option<Vec<u8>> {
fn char_val(c: u8) -> Option<u8> {
const ALPHABET: &[u8; 64] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
ALPHABET.iter().position(|&x| x == c).map(|p| p as u8)
}
let input = input.trim_end_matches('=');
let bytes = input.as_bytes();
let mut result = Vec::with_capacity(bytes.len() * 3 / 4);
let chunks = bytes.chunks(4);
for chunk in chunks {
let vals: Vec<u8> = chunk.iter().filter_map(|&b| char_val(b)).collect();
if vals.len() != chunk.len() {
return None;
}
let triple = match vals.len() {
4 => {
((vals[0] as u32) << 18)
| ((vals[1] as u32) << 12)
| ((vals[2] as u32) << 6)
| (vals[3] as u32)
}
3 => {
((vals[0] as u32) << 18) | ((vals[1] as u32) << 12) | ((vals[2] as u32) << 6)
}
2 => ((vals[0] as u32) << 18) | ((vals[1] as u32) << 12),
_ => return None,
};
result.push((triple >> 16) as u8);
if vals.len() > 2 {
result.push((triple >> 8 & 0xFF) as u8);
}
if vals.len() > 3 {
result.push((triple & 0xFF) as u8);
}
}
Some(result)
}
// ---------------------------------------------------------------------------
// Hex encoding (for nonces)
// ---------------------------------------------------------------------------
/// Encode bytes as lowercase hexadecimal string.
fn hex_encode(data: &[u8]) -> String {
let mut s = String::with_capacity(data.len() * 2);
for &b in data {
s.push_str(&format!("{:02x}", b));
}
s
}
/// Decode a hexadecimal string to bytes. Returns None on invalid input.
fn hex_decode(s: &str) -> Option<Vec<u8>> {
if s.len() % 2 != 0 {
return None;
}
let mut result = Vec::with_capacity(s.len() / 2);
for i in (0..s.len()).step_by(2) {
let byte = u8::from_str_radix(&s[i..i + 2], 16).ok()?;
result.push(byte);
}
Some(result)
}
// ---------------------------------------------------------------------------
// Credential generation / validation
// ---------------------------------------------------------------------------
/// Generate time-limited TURN credentials per draft-uberti-behave-turn-rest.
///
/// Returns `(username, password)` where:
/// - username = `"{expiry_timestamp}:{peer_id}"`
/// - password = `Base64(HMAC-SHA1(shared_secret, username))`
///
/// The credentials expire `ttl_secs` seconds from now.
pub fn generate_credentials(
peer_id: &str,
shared_secret: &[u8],
ttl_secs: u64,
) -> (String, String) {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let expiry = now + ttl_secs;
let username = format!("{}:{}", expiry, peer_id);
let hmac = hmac_sha1(shared_secret, username.as_bytes());
let password = base64_encode(&hmac);
(username, password)
}
/// Validate time-limited TURN credentials.
///
/// 1. Parse the expiry timestamp from the username
/// 2. Check that the credentials have not expired
/// 3. Recompute HMAC-SHA1 and compare with the provided password
pub fn validate_credentials(
username: &str,
password: &str,
shared_secret: &[u8],
) -> Result<(), TurnError> {
// Parse expiry timestamp from username (format: "{timestamp}:{peer_id}")
let colon_pos = username
.find(':')
.ok_or(TurnError::Unauthorized)?;
let expiry_str = &username[..colon_pos];
let expiry: u64 = expiry_str
.parse()
.map_err(|_| TurnError::Unauthorized)?;
// Check expiry
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if now > expiry {
return Err(TurnError::Unauthorized);
}
// Recompute HMAC-SHA1 and compare
let expected_hmac = hmac_sha1(shared_secret, username.as_bytes());
let expected_password = base64_encode(&expected_hmac);
if !constant_time_eq(password.as_bytes(), expected_password.as_bytes()) {
return Err(TurnError::Unauthorized);
}
Ok(())
}
// ---------------------------------------------------------------------------
// MESSAGE-INTEGRITY computation (RFC 5389 §15.4)
// ---------------------------------------------------------------------------
/// Compute the long-term credential key per RFC 5389 §15.4:
/// `key = MD5(username ":" realm ":" password)`
pub fn compute_long_term_key(username: &str, realm: &str, password: &str) -> [u8; 16] {
let input = format!("{}:{}:{}", username, realm, password);
md5_digest(input.as_bytes())
}
/// Compute the MESSAGE-INTEGRITY attribute value (HMAC-SHA1).
///
/// `key` is the long-term credential key (output of [`compute_long_term_key`]).
/// `message_bytes` should be the STUN message up to (but not including) the
/// MESSAGE-INTEGRITY attribute, with the message length field adjusted to
/// include the MESSAGE-INTEGRITY attribute.
pub fn compute_message_integrity(key: &[u8], message_bytes: &[u8]) -> [u8; 20] {
hmac_sha1(key, message_bytes)
}
/// Validate the MESSAGE-INTEGRITY attribute of a STUN message.
///
/// Recomputes the HMAC-SHA1 over the message bytes and compares it
/// with the provided MESSAGE-INTEGRITY value in constant time.
pub fn validate_message_integrity(
message_integrity: &[u8; 20],
key: &[u8],
message_bytes: &[u8],
) -> bool {
let expected = compute_message_integrity(key, message_bytes);
constant_time_eq(&expected, message_integrity)
}
// ---------------------------------------------------------------------------
// Nonce generation / validation
// ---------------------------------------------------------------------------
/// Generate a NONCE that embeds a timestamp for stateless staleness detection.
///
/// Format: `{hex_timestamp}-{hex_hmac_of_timestamp}`
///
/// The server can validate the nonce without storing state by recomputing
/// the HMAC. The timestamp allows detecting stale nonces.
pub fn compute_nonce(timestamp: u64, secret: &[u8]) -> String {
let ts_bytes = timestamp.to_be_bytes();
let hmac = hmac_sha1(secret, &ts_bytes);
// Use first 8 bytes of HMAC for a shorter nonce
let hmac_short = &hmac[..8];
format!("{}-{}", hex_encode(&ts_bytes), hex_encode(hmac_short))
}
/// Validate a NONCE and check that it hasn't expired.
///
/// 1. Parse the timestamp from the nonce
/// 2. Recompute HMAC and verify it matches
/// 3. Check that the nonce age doesn't exceed `max_age_secs`
pub fn validate_nonce(
nonce: &str,
secret: &[u8],
max_age_secs: u64,
) -> Result<(), TurnError> {
let parts: Vec<&str> = nonce.splitn(2, '-').collect();
if parts.len() != 2 {
return Err(TurnError::StaleNonce);
}
let ts_hex = parts[0];
let hmac_hex = parts[1];
// Decode timestamp
let ts_bytes = hex_decode(ts_hex).ok_or(TurnError::StaleNonce)?;
if ts_bytes.len() != 8 {
return Err(TurnError::StaleNonce);
}
let timestamp = u64::from_be_bytes([
ts_bytes[0], ts_bytes[1], ts_bytes[2], ts_bytes[3],
ts_bytes[4], ts_bytes[5], ts_bytes[6], ts_bytes[7],
]);
// Verify HMAC
let expected_hmac = hmac_sha1(secret, &ts_bytes);
let expected_hmac_hex = hex_encode(&expected_hmac[..8]);
if !constant_time_eq(hmac_hex.as_bytes(), expected_hmac_hex.as_bytes()) {
return Err(TurnError::StaleNonce);
}
// Check age
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if now > timestamp + max_age_secs {
return Err(TurnError::StaleNonce);
}
Ok(())
}
// ---------------------------------------------------------------------------
// Constant-time comparison
// ---------------------------------------------------------------------------
/// Compare two byte slices in constant time to prevent timing attacks.
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sha1_empty() {
// SHA-1("") = da39a3ee5e6b4b0d3255bfef95601890afd80709
let hash = sha1_digest(b"");
let hex = hex_encode(&hash);
assert_eq!(hex, "da39a3ee5e6b4b0d3255bfef95601890afd80709");
}
#[test]
fn test_sha1_abc() {
// SHA-1("abc") = a9993e364706816aba3e25717850c26c9cd0d89d
let hash = sha1_digest(b"abc");
let hex = hex_encode(&hash);
assert_eq!(hex, "a9993e364706816aba3e25717850c26c9cd0d89d");
}
#[test]
fn test_md5_empty() {
// MD5("") = d41d8cd98f00b204e9800998ecf8427e
let hash = md5_digest(b"");
let hex = hex_encode(&hash);
assert_eq!(hex, "d41d8cd98f00b204e9800998ecf8427e");
}
#[test]
fn test_md5_abc() {
// MD5("abc") = 900150983cd24fb0d6963f7d28e17f72
let hash = md5_digest(b"abc");
let hex = hex_encode(&hash);
assert_eq!(hex, "900150983cd24fb0d6963f7d28e17f72");
}
#[test]
fn test_hmac_sha1_rfc2202_test1() {
// RFC 2202 Test Case 1:
// Key = 0x0b repeated 20 times
// Data = "Hi There"
// HMAC = b617318655057264e28bc0b6fb378c8ef146be00
let key = [0x0bu8; 20];
let data = b"Hi There";
let hmac = hmac_sha1(&key, data);
let hex = hex_encode(&hmac);
assert_eq!(hex, "b617318655057264e28bc0b6fb378c8ef146be00");
}
#[test]
fn test_hmac_sha1_rfc2202_test2() {
// RFC 2202 Test Case 2:
// Key = "Jefe"
// Data = "what do ya want for nothing?"
// HMAC = effcdf6ae5eb2fa2d27416d5f184df9c259a7c79
let key = b"Jefe";
let data = b"what do ya want for nothing?";
let hmac = hmac_sha1(key, data);
let hex = hex_encode(&hmac);
assert_eq!(hex, "effcdf6ae5eb2fa2d27416d5f184df9c259a7c79");
}
#[test]
fn test_base64_encode() {
assert_eq!(base64_encode(b""), "");
assert_eq!(base64_encode(b"f"), "Zg==");
assert_eq!(base64_encode(b"fo"), "Zm8=");
assert_eq!(base64_encode(b"foo"), "Zm9v");
assert_eq!(base64_encode(b"foobar"), "Zm9vYmFy");
}
#[test]
fn test_base64_decode() {
assert_eq!(base64_decode("").unwrap(), b"");
assert_eq!(base64_decode("Zg==").unwrap(), b"f");
assert_eq!(base64_decode("Zm8=").unwrap(), b"fo");
assert_eq!(base64_decode("Zm9v").unwrap(), b"foo");
assert_eq!(base64_decode("Zm9vYmFy").unwrap(), b"foobar");
}
#[test]
fn test_base64_roundtrip() {
let data = b"Hello, TURN server!";
let encoded = base64_encode(data);
let decoded = base64_decode(&encoded).unwrap();
assert_eq!(&decoded, data);
}
#[test]
fn test_generate_validate_credentials() {
let secret = b"supersecretkey";
let peer_id = "12D3KooWTestPeerId";
let ttl = 3600; // 1 hour
let (username, password) = generate_credentials(peer_id, secret, ttl);
// Username should contain the peer_id
assert!(username.contains(peer_id));
assert!(username.contains(':'));
// Password should be valid base64
assert!(base64_decode(&password).is_some());
// Validation should succeed
let result = validate_credentials(&username, &password, secret);
assert!(result.is_ok());
}
#[test]
fn test_validate_credentials_wrong_password() {
let secret = b"supersecretkey";
let (username, _) = generate_credentials("test", secret, 3600);
let result = validate_credentials(&username, "wrongpassword", secret);
assert!(result.is_err());
}
#[test]
fn test_validate_credentials_wrong_secret() {
let secret = b"supersecretkey";
let (username, password) = generate_credentials("test", secret, 3600);
let result = validate_credentials(&username, &password, b"wrongsecret");
assert!(result.is_err());
}
#[test]
fn test_long_term_key() {
// RFC 5389 example-ish: key = MD5("user:realm:pass")
let key = compute_long_term_key("user", "realm", "pass");
assert_eq!(key.len(), 16);
// Should be deterministic
let key2 = compute_long_term_key("user", "realm", "pass");
assert_eq!(key, key2);
}
#[test]
fn test_message_integrity_roundtrip() {
let key = compute_long_term_key("alice", "duskchat.app", "password123");
let message = b"fake stun message bytes for testing";
let integrity = compute_message_integrity(&key, message);
assert!(validate_message_integrity(&integrity, &key, message));
// Tampered message should fail
let tampered = b"tampered stun message bytes for testing";
assert!(!validate_message_integrity(&integrity, &key, tampered));
}
#[test]
fn test_nonce_roundtrip() {
let secret = b"nonce_secret_key";
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let nonce = compute_nonce(now, secret);
assert!(nonce.contains('-'));
// Should validate successfully with generous max age
let result = validate_nonce(&nonce, secret, 3600);
assert!(result.is_ok());
}
#[test]
fn test_nonce_stale() {
let secret = b"nonce_secret_key";
// Use a timestamp from long ago
let old_timestamp = 1000000;
let nonce = compute_nonce(old_timestamp, secret);
let result = validate_nonce(&nonce, secret, 3600);
assert!(result.is_err());
}
#[test]
fn test_nonce_tampered() {
let secret = b"nonce_secret_key";
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let nonce = compute_nonce(now, secret);
// Validate with wrong secret
let result = validate_nonce(&nonce, b"wrong_secret", 3600);
assert!(result.is_err());
}
#[test]
fn test_hex_roundtrip() {
let data = &[0xDE, 0xAD, 0xBE, 0xEF];
let encoded = hex_encode(data);
assert_eq!(encoded, "deadbeef");
let decoded = hex_decode(&encoded).unwrap();
assert_eq!(&decoded, data);
}
#[test]
fn test_constant_time_eq() {
assert!(constant_time_eq(b"hello", b"hello"));
assert!(!constant_time_eq(b"hello", b"world"));
assert!(!constant_time_eq(b"hello", b"hell"));
assert!(constant_time_eq(b"", b""));
}
}

102
src/turn/error.rs Normal file
View File

@ -0,0 +1,102 @@
// TURN server error types
//
// Covers all error conditions that can arise during STUN/TURN message
// processing, credential validation, and allocation management.
use std::fmt;
/// Comprehensive error type for TURN server operations.
///
/// Each variant maps to a specific failure condition in the STUN/TURN
/// protocol stack, from low-level parse errors to high-level allocation
/// policy violations.
#[derive(Debug, Clone)]
pub enum TurnError {
/// Invalid STUN message format: bad header, truncated message,
/// missing magic cookie, or malformed TLV attributes.
StunParseError(String),
/// MESSAGE-INTEGRITY attribute does not match the computed HMAC-SHA1.
/// This means either the password is wrong or the message was tampered with.
InvalidMessageIntegrity,
/// The NONCE has expired. The client should retry with a fresh nonce
/// from the 438 Stale Nonce error response.
StaleNonce,
/// The client already has an allocation on this 5-tuple, or is
/// attempting an operation that conflicts with existing allocation state
/// (RFC 5766 §6.2).
AllocationMismatch,
/// The server has reached its per-user or global allocation quota.
/// Maps to TURN error code 486.
AllocationQuotaReached,
/// The server cannot fulfill the request due to resource constraints
/// (e.g., no relay ports available). Maps to TURN error code 508.
InsufficientCapacity,
/// The request lacks valid credentials, or the credentials have expired.
/// Maps to STUN error code 401.
Unauthorized,
/// The peer address in the request is forbidden by server policy
/// (e.g., loopback or private IP filtering). Maps to TURN error code 403.
ForbiddenIp,
/// The REQUESTED-TRANSPORT attribute specifies a transport protocol
/// that the server does not support (e.g., TCP relay when only UDP
/// is available). Maps to TURN error code 442.
UnsupportedTransport,
/// An I/O error occurred on a socket or file operation.
IoError(String),
}
impl fmt::Display for TurnError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TurnError::StunParseError(msg) => write!(f, "STUN parse error: {}", msg),
TurnError::InvalidMessageIntegrity => write!(f, "invalid MESSAGE-INTEGRITY"),
TurnError::StaleNonce => write!(f, "stale nonce"),
TurnError::AllocationMismatch => write!(f, "allocation mismatch"),
TurnError::AllocationQuotaReached => write!(f, "allocation quota reached"),
TurnError::InsufficientCapacity => write!(f, "insufficient capacity"),
TurnError::Unauthorized => write!(f, "unauthorized"),
TurnError::ForbiddenIp => write!(f, "forbidden IP address"),
TurnError::UnsupportedTransport => write!(f, "unsupported transport protocol"),
TurnError::IoError(msg) => write!(f, "I/O error: {}", msg),
}
}
}
impl std::error::Error for TurnError {}
impl From<std::io::Error> for TurnError {
fn from(err: std::io::Error) -> Self {
TurnError::IoError(err.to_string())
}
}
/// Maps a [`TurnError`] to the corresponding STUN/TURN error code
/// for use in error response messages.
///
/// Error codes follow RFC 5389 §15.6 and RFC 5766 §6.
impl TurnError {
/// Returns the STUN/TURN error code and default reason phrase for this error.
pub fn to_error_code(&self) -> (u16, &'static str) {
match self {
TurnError::StunParseError(_) => (400, "Bad Request"),
TurnError::InvalidMessageIntegrity => (401, "Unauthorized"),
TurnError::StaleNonce => (438, "Stale Nonce"),
TurnError::AllocationMismatch => (437, "Allocation Mismatch"),
TurnError::AllocationQuotaReached => (486, "Allocation Quota Reached"),
TurnError::InsufficientCapacity => (508, "Insufficient Capacity"),
TurnError::Unauthorized => (401, "Unauthorized"),
TurnError::ForbiddenIp => (403, "Forbidden"),
TurnError::UnsupportedTransport => (442, "Unsupported Transport Protocol"),
TurnError::IoError(_) => (500, "Server Error"),
}
}
}

1147
src/turn/handler.rs Normal file

File diff suppressed because it is too large Load Diff

30
src/turn/mod.rs Normal file
View File

@ -0,0 +1,30 @@
// TURN server module for Dusk relay
//
// Implements RFC 5389 (STUN) and RFC 5766 (TURN) protocol types,
// message parsing/serialization, credential management, port allocation,
// and the full server with UDP/TCP listeners.
//
// This module is organized as follows:
// - stun: STUN message format, parsing, serialization
// - attributes: STUN/TURN attribute types and encoding
// - credentials: HMAC-SHA1 time-limited credential generation/validation
// - allocation: TURN allocation state machine
// - port_pool: Relay port allocation pool
// - handler: TURN message handler (request dispatch + response building)
// - udp_listener: UDP listener task (receives datagrams, spawns relay receivers)
// - tcp_listener: TCP listener task (accepts connections, frames STUN/ChannelData)
// - server: Top-level TURN server orchestration (config, startup, handle)
// - error: Error types
pub mod stun;
pub mod attributes;
pub mod credentials;
pub mod error;
pub mod port_pool;
pub mod allocation;
pub mod handler;
pub mod udp_listener;
pub mod tcp_listener;
pub mod server;
pub use server::{TurnServer, TurnServerConfig, TurnServerHandle};

254
src/turn/port_pool.rs Normal file
View File

@ -0,0 +1,254 @@
// Relay port allocation pool for TURN allocations
//
// Manages a pool of UDP port numbers available for TURN relay transport
// addresses. Each TURN allocation requires a dedicated relay port.
//
// The pool pre-shuffles available ports on creation to avoid predictable
// allocation patterns. Ports are allocated from the front of the queue
// and returned to the back on release.
//
// Default range: 49152-65535 (IANA dynamic/private port range)
// This gives approximately 16,383 relay ports.
use std::collections::HashSet;
/// Manages a pool of available relay port numbers.
///
/// Ports are pre-shuffled on construction to avoid predictable patterns.
/// Allocation is O(1) from a Vec (pop), release is O(1) (push).
///
/// # Example
///
/// ```
/// use relay::turn::port_pool::PortPool;
///
/// let mut pool = PortPool::new(49152, 49160);
/// assert_eq!(pool.available_count(), 9); // 49152..=49160 inclusive
///
/// let port = pool.allocate().unwrap();
/// assert!(port >= 49152 && port <= 49160);
/// assert_eq!(pool.available_count(), 8);
///
/// pool.release(port);
/// assert_eq!(pool.available_count(), 9);
/// ```
#[derive(Debug, Clone)]
pub struct PortPool {
/// Ports available for allocation (shuffled on construction).
available: Vec<u16>,
/// Currently allocated ports (for tracking and preventing double-release).
allocated: HashSet<u16>,
}
impl PortPool {
/// Create a new port pool with ports in the range `[range_start, range_end]` inclusive.
///
/// The available ports are shuffled using a simple Fisher-Yates-like shuffle
/// seeded from the system clock. For cryptographic randomness, use the
/// `rand` crate version below.
///
/// # Panics
///
/// Panics if `range_start > range_end`.
pub fn new(range_start: u16, range_end: u16) -> Self {
assert!(
range_start <= range_end,
"range_start ({}) must be <= range_end ({})",
range_start,
range_end
);
let mut available: Vec<u16> = (range_start..=range_end).collect();
// Shuffle using a simple PRNG seeded from system time.
// This is sufficient for port randomization (not security-critical).
let seed = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
let mut state = seed as u64;
let len = available.len();
if len > 1 {
for i in (1..len).rev() {
// Simple xorshift64 PRNG
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
let j = (state as usize) % (i + 1);
available.swap(i, j);
}
}
PortPool {
available,
allocated: HashSet::new(),
}
}
/// Allocate a port from the pool.
///
/// Returns `None` if no ports are available.
/// Allocated ports are tracked to prevent double-allocation.
pub fn allocate(&mut self) -> Option<u16> {
let port = self.available.pop()?;
self.allocated.insert(port);
Some(port)
}
/// Release a previously allocated port back to the pool.
///
/// If the port was not currently allocated (double-release or unknown port),
/// this is a no-op to prevent pool corruption.
pub fn release(&mut self, port: u16) {
if self.allocated.remove(&port) {
self.available.push(port);
}
}
/// Returns the number of ports currently available for allocation.
pub fn available_count(&self) -> usize {
self.available.len()
}
/// Returns the number of ports currently allocated.
pub fn allocated_count(&self) -> usize {
self.allocated.len()
}
/// Returns the total capacity of the pool (available + allocated).
pub fn total_capacity(&self) -> usize {
self.available.len() + self.allocated.len()
}
/// Returns `true` if the given port is currently allocated.
pub fn is_allocated(&self, port: u16) -> bool {
self.allocated.contains(&port)
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_pool_correct_count() {
let pool = PortPool::new(49152, 49160);
assert_eq!(pool.available_count(), 9); // 49152..=49160 inclusive
assert_eq!(pool.allocated_count(), 0);
assert_eq!(pool.total_capacity(), 9);
}
#[test]
fn test_single_port_pool() {
let mut pool = PortPool::new(5000, 5000);
assert_eq!(pool.available_count(), 1);
let port = pool.allocate().unwrap();
assert_eq!(port, 5000);
assert_eq!(pool.available_count(), 0);
assert!(pool.is_allocated(5000));
assert!(pool.allocate().is_none());
pool.release(5000);
assert_eq!(pool.available_count(), 1);
assert!(!pool.is_allocated(5000));
}
#[test]
fn test_allocate_all_ports() {
let mut pool = PortPool::new(10000, 10004);
let mut allocated = Vec::new();
for _ in 0..5 {
let port = pool.allocate().unwrap();
assert!((10000..=10004).contains(&port));
allocated.push(port);
}
assert_eq!(pool.available_count(), 0);
assert_eq!(pool.allocated_count(), 5);
assert!(pool.allocate().is_none());
// All ports should be unique
let unique: HashSet<u16> = allocated.iter().copied().collect();
assert_eq!(unique.len(), 5);
}
#[test]
fn test_release_and_reallocate() {
let mut pool = PortPool::new(8000, 8002);
let p1 = pool.allocate().unwrap();
let p2 = pool.allocate().unwrap();
let p3 = pool.allocate().unwrap();
assert!(pool.allocate().is_none());
pool.release(p2);
assert_eq!(pool.available_count(), 1);
let p4 = pool.allocate().unwrap();
assert_eq!(p4, p2); // released port gets reused
assert!(pool.allocate().is_none());
pool.release(p1);
pool.release(p3);
pool.release(p4);
assert_eq!(pool.available_count(), 3);
}
#[test]
fn test_double_release_is_noop() {
let mut pool = PortPool::new(7000, 7002);
let port = pool.allocate().unwrap();
pool.release(port);
assert_eq!(pool.available_count(), 3);
// Double release should not add a duplicate
pool.release(port);
assert_eq!(pool.available_count(), 3);
}
#[test]
fn test_release_unknown_port_is_noop() {
let mut pool = PortPool::new(7000, 7002);
// Release a port that was never allocated
pool.release(9999);
assert_eq!(pool.available_count(), 3);
assert_eq!(pool.allocated_count(), 0);
}
#[test]
#[should_panic(expected = "range_start")]
fn test_invalid_range_panics() {
let _ = PortPool::new(100, 50);
}
#[test]
fn test_large_pool() {
let pool = PortPool::new(49152, 65535);
assert_eq!(pool.available_count(), 16384);
assert_eq!(pool.total_capacity(), 16384);
}
#[test]
fn test_is_allocated() {
let mut pool = PortPool::new(6000, 6005);
let port = pool.allocate().unwrap();
assert!(pool.is_allocated(port));
pool.release(port);
assert!(!pool.is_allocated(port));
// Never-allocated port
assert!(!pool.is_allocated(9999));
}
}

377
src/turn/server.rs Normal file
View File

@ -0,0 +1,377 @@
// TURN server entry point
//
// Ties together the TurnHandler, AllocationManager, PortPool, and
// UDP/TCP listeners into a clean, configurable server.
//
// Configuration can be loaded from environment variables via
// `TurnServerConfig::from_env()` or constructed programmatically.
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use super::allocation::{AllocationConfig, AllocationManager};
use super::handler::TurnHandler;
use super::port_pool::PortPool;
use super::tcp_listener::TcpTurnListener;
use super::udp_listener::UdpTurnListener;
// ---------------------------------------------------------------------------
// TurnServerConfig
// ---------------------------------------------------------------------------
/// Configuration for the TURN server.
///
/// All fields have sensible defaults. The only required field for production
/// use is `public_ip` — without it, relay addresses will use 0.0.0.0 which
/// won't work for external peers.
#[derive(Debug, Clone)]
pub struct TurnServerConfig {
/// UDP listen address (default: 0.0.0.0:3478).
pub udp_addr: SocketAddr,
/// TCP listen address (default: 0.0.0.0:3478).
pub tcp_addr: SocketAddr,
/// Public IP address for XOR-RELAYED-ADDRESS (required for production).
pub public_ip: IpAddr,
/// Shared secret for HMAC credential generation.
pub shared_secret: Vec<u8>,
/// Authentication realm (default: "duskchat.app").
pub realm: String,
/// Relay port range start (default: 49152).
pub relay_port_start: u16,
/// Relay port range end (default: 65535).
pub relay_port_end: u16,
/// Allocation configuration (lifetimes, quotas, etc.).
pub allocation_config: AllocationConfig,
}
impl Default for TurnServerConfig {
fn default() -> Self {
Self {
udp_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 3478),
tcp_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 3478),
public_ip: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
shared_secret: Vec::new(),
realm: "duskchat.app".to_string(),
relay_port_start: 49152,
relay_port_end: 65535,
allocation_config: AllocationConfig::default(),
}
}
}
impl TurnServerConfig {
/// Create a configuration from environment variables with sensible defaults.
///
/// Supported environment variables:
/// - `DUSK_TURN_UDP_PORT` — UDP listen port (default: 3478)
/// - `DUSK_TURN_TCP_PORT` — TCP listen port (default: 3478)
/// - `DUSK_TURN_PUBLIC_IP` — Public IP for relay addresses (required in prod)
/// - `DUSK_TURN_SECRET` — Shared secret for credentials (auto-generated if unset)
/// - `DUSK_TURN_REALM` — Authentication realm (default: "duskchat.app")
/// - `DUSK_TURN_PORT_RANGE_START` — Relay port range start (default: 49152)
/// - `DUSK_TURN_PORT_RANGE_END` — Relay port range end (default: 65535)
/// - `DUSK_TURN_MAX_ALLOCATIONS` — Global allocation limit (default: 1000)
/// - `DUSK_TURN_MAX_PER_USER` — Per-user allocation limit (default: 10)
pub fn from_env() -> Self {
let udp_port: u16 = std::env::var("DUSK_TURN_UDP_PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(3478);
let tcp_port: u16 = std::env::var("DUSK_TURN_TCP_PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(3478);
let public_ip: IpAddr = std::env::var("DUSK_TURN_PUBLIC_IP")
.ok()
.and_then(|ip| ip.parse().ok())
.unwrap_or_else(|| {
log::warn!(
"[TURN] DUSK_TURN_PUBLIC_IP not set — relay addresses will use 0.0.0.0. \
Set this to the server's public IP for production use."
);
IpAddr::V4(Ipv4Addr::UNSPECIFIED)
});
let shared_secret = std::env::var("DUSK_TURN_SECRET")
.ok()
.filter(|s| !s.is_empty())
.map(|s| s.into_bytes())
.unwrap_or_else(|| {
let secret = generate_random_bytes(32);
log::warn!(
"[TURN] DUSK_TURN_SECRET not set — using random secret. \
This won't work across multiple relay instances."
);
secret
});
let realm = std::env::var("DUSK_TURN_REALM")
.ok()
.filter(|r| !r.is_empty())
.unwrap_or_else(|| "duskchat.app".to_string());
let port_range_start: u16 = std::env::var("DUSK_TURN_PORT_RANGE_START")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(49152);
let port_range_end: u16 = std::env::var("DUSK_TURN_PORT_RANGE_END")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(65535);
let max_allocations: usize = std::env::var("DUSK_TURN_MAX_ALLOCATIONS")
.ok()
.and_then(|n| n.parse().ok())
.unwrap_or(1000);
let max_per_user: usize = std::env::var("DUSK_TURN_MAX_PER_USER")
.ok()
.and_then(|n| n.parse().ok())
.unwrap_or(10);
let allocation_config = AllocationConfig {
max_allocations,
max_per_user,
realm: realm.clone(),
..Default::default()
};
Self {
udp_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), udp_port),
tcp_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), tcp_port),
public_ip,
shared_secret,
realm,
relay_port_start: port_range_start,
relay_port_end: port_range_end,
allocation_config,
}
}
/// Check if the TURN server is enabled via environment variable.
///
/// Returns `false` only if `DUSK_TURN_ENABLED=false`. Defaults to `true`.
pub fn is_enabled() -> bool {
std::env::var("DUSK_TURN_ENABLED")
.map(|v| v != "false" && v != "0")
.unwrap_or(true)
}
}
// ---------------------------------------------------------------------------
// TurnServer
// ---------------------------------------------------------------------------
/// The TURN server. Owns the configuration and provides a [`run`](TurnServer::run)
/// method that starts all listener tasks and returns a [`TurnServerHandle`].
pub struct TurnServer {
config: TurnServerConfig,
}
impl TurnServer {
/// Create a new TURN server with the given configuration.
pub fn new(config: TurnServerConfig) -> Self {
Self { config }
}
/// Start the TURN server.
///
/// Binds UDP and TCP sockets, creates the handler and allocation manager,
/// spawns listener and cleanup tasks, and returns a handle for monitoring.
pub async fn run(self) -> Result<TurnServerHandle, Box<dyn std::error::Error>> {
let port_pool = Arc::new(tokio::sync::Mutex::new(PortPool::new(
self.config.relay_port_start,
self.config.relay_port_end,
)));
let alloc_mgr = Arc::new(AllocationManager::new(
self.config.allocation_config.clone(),
));
// Generate nonce secret (separate from shared secret)
let nonce_secret = generate_random_bytes(32);
let handler = Arc::new(TurnHandler::new(
Arc::clone(&alloc_mgr),
Arc::clone(&port_pool),
self.config.shared_secret.clone(),
self.config.realm.clone(),
nonce_secret,
self.config.public_ip,
));
// Bind UDP socket
let udp_socket = Arc::new(
tokio::net::UdpSocket::bind(self.config.udp_addr).await?,
);
log::info!(
"[TURN] UDP listening on {}",
udp_socket.local_addr().unwrap_or(self.config.udp_addr)
);
let udp_listener = Arc::new(UdpTurnListener::new(
Arc::clone(&udp_socket),
Arc::clone(&handler),
Arc::clone(&alloc_mgr),
self.config.udp_addr,
self.config.public_ip,
));
// Spawn UDP listener
let udp_handle = tokio::spawn({
let listener = Arc::clone(&udp_listener);
async move {
listener.run().await;
}
});
// Bind TCP listener
let mut tcp_listener = TcpTurnListener::bind(
self.config.tcp_addr,
Arc::clone(&handler),
Arc::clone(&alloc_mgr),
self.config.public_ip,
)
.await?;
// Give the TCP listener access to the UDP socket for relay receiver tasks
tcp_listener.set_udp_socket(Arc::clone(&udp_socket));
log::info!(
"[TURN] TCP listening on {}",
self.config.tcp_addr
);
// Spawn TCP listener
let tcp_handle = tokio::spawn(async move {
tcp_listener.run().await;
});
// Spawn cleanup task (runs every 30 seconds)
let cleanup_alloc_mgr = Arc::clone(&alloc_mgr);
let cleanup_port_pool = Arc::clone(&port_pool);
let cleanup_handle = tokio::spawn(async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
let freed_ports = cleanup_alloc_mgr.cleanup_expired().await;
if !freed_ports.is_empty() {
let mut pool = cleanup_port_pool.lock().await;
for port in &freed_ports {
pool.release(*port);
}
log::info!(
"[TURN] cleaned up {} expired allocations",
freed_ports.len()
);
}
}
});
log::info!(
"[TURN] server started (public_ip={}, realm={}, relay_ports={}-{})",
self.config.public_ip,
self.config.realm,
self.config.relay_port_start,
self.config.relay_port_end,
);
Ok(TurnServerHandle {
udp_handle,
tcp_handle,
cleanup_handle,
handler,
alloc_mgr,
port_pool,
shared_secret: self.config.shared_secret,
})
}
/// Get the shared secret (for credential generation by the libp2p protocol).
pub fn shared_secret(&self) -> &[u8] {
&self.config.shared_secret
}
}
// ---------------------------------------------------------------------------
// TurnServerHandle
// ---------------------------------------------------------------------------
/// Handle to a running TURN server.
///
/// Provides access to the server's shared state and credential generation.
/// The server runs in the background via spawned tasks; dropping the handle
/// does NOT stop the server (the tasks keep running).
pub struct TurnServerHandle {
/// UDP listener task handle.
pub udp_handle: tokio::task::JoinHandle<()>,
/// TCP listener task handle.
pub tcp_handle: tokio::task::JoinHandle<()>,
/// Cleanup task handle.
pub cleanup_handle: tokio::task::JoinHandle<()>,
/// The shared TURN message handler.
pub handler: Arc<TurnHandler>,
/// The shared allocation manager.
pub alloc_mgr: Arc<AllocationManager>,
/// The shared port pool.
pub port_pool: Arc<tokio::sync::Mutex<PortPool>>,
/// The shared secret (for credential generation).
shared_secret: Vec<u8>,
}
impl TurnServerHandle {
/// Generate time-limited TURN credentials for a peer.
///
/// Returns `(username, password)` suitable for use with WebRTC ICE servers.
/// The credentials expire after `ttl_secs` seconds.
pub fn generate_credentials(&self, peer_id: &str, ttl_secs: u64) -> (String, String) {
super::credentials::generate_credentials(peer_id, &self.shared_secret, ttl_secs)
}
/// Get the current number of active allocations.
pub async fn allocation_count(&self) -> usize {
self.alloc_mgr.allocation_count().await
}
/// Get the shared secret bytes (for external credential generation).
pub fn shared_secret(&self) -> &[u8] {
&self.shared_secret
}
/// Abort all server tasks. The server will stop after current in-flight
/// operations complete.
pub fn shutdown(&self) {
self.udp_handle.abort();
self.tcp_handle.abort();
self.cleanup_handle.abort();
}
}
// ---------------------------------------------------------------------------
// Utility: random byte generation
// ---------------------------------------------------------------------------
/// Generate `len` pseudo-random bytes using a simple xorshift64 PRNG
/// seeded from system time.
///
/// This is NOT cryptographically secure. It's used for nonce secrets
/// and auto-generated shared secrets when none is configured. In production,
/// configure `DUSK_TURN_SECRET` explicitly.
fn generate_random_bytes(len: usize) -> Vec<u8> {
use std::time::{SystemTime, UNIX_EPOCH};
let seed = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64;
let mut state = seed;
let mut bytes = Vec::with_capacity(len);
for _ in 0..len {
state ^= state << 13;
state ^= state >> 7;
state ^= state << 17;
bytes.push(state as u8);
}
bytes
}

799
src/turn/stun.rs Normal file
View File

@ -0,0 +1,799 @@
// STUN message parser and serializer per RFC 5389
//
// Implements the STUN message format with support for TURN methods (RFC 5766).
// Also handles ChannelData framing for TURN channel bindings.
//
// STUN message header layout (20 bytes):
// Bytes 0-1: Message Type (method + class encoding)
// Bytes 2-3: Message Length (excludes 20-byte header)
// Bytes 4-7: Magic Cookie (0x2112A442)
// Bytes 8-19: Transaction ID (96 bits)
//
// Message Type encoding (RFC 5389 §6):
// Bits: 13 12 11 10 9 8 7 6 5 4 3 2 1 0
// M11 M10 M9 M8 M7 C1 M6 M5 M4 C0 M3 M2 M1 M0
// Where M0-M11 = method bits, C0-C1 = class bits.
use crate::turn::attributes::StunAttribute;
use crate::turn::error::TurnError;
/// STUN magic cookie value (RFC 5389 §6).
pub const MAGIC_COOKIE: u32 = 0x2112A442;
/// STUN header size in bytes.
pub const STUN_HEADER_SIZE: usize = 20;
/// XOR value for FINGERPRINT CRC32 (RFC 5389 §15.5).
pub const FINGERPRINT_XOR: u32 = 0x5354554e;
// ---------------------------------------------------------------------------
// Method
// ---------------------------------------------------------------------------
/// STUN/TURN method identifiers.
///
/// Methods 0x0001 (Binding) are defined in RFC 5389; the rest are
/// TURN extensions from RFC 5766.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Method {
/// STUN Binding (0x0001) — NAT discovery
Binding,
/// TURN Allocate (0x0003) — create relay allocation
Allocate,
/// TURN Refresh (0x0004) — refresh allocation lifetime
Refresh,
/// TURN Send (0x0006) — send data indication to peer
Send,
/// TURN Data (0x0007) — data indication from peer
Data,
/// TURN CreatePermission (0x0008) — install relay permission
CreatePermission,
/// TURN ChannelBind (0x0009) — bind channel number to peer
ChannelBind,
}
impl Method {
/// Returns the 12-bit method number (M0-M11).
pub fn as_u16(self) -> u16 {
match self {
Method::Binding => 0x0001,
Method::Allocate => 0x0003,
Method::Refresh => 0x0004,
Method::Send => 0x0006,
Method::Data => 0x0007,
Method::CreatePermission => 0x0008,
Method::ChannelBind => 0x0009,
}
}
/// Parses a 12-bit method number into a [`Method`].
pub fn from_u16(val: u16) -> Result<Self, TurnError> {
match val {
0x0001 => Ok(Method::Binding),
0x0003 => Ok(Method::Allocate),
0x0004 => Ok(Method::Refresh),
0x0006 => Ok(Method::Send),
0x0007 => Ok(Method::Data),
0x0008 => Ok(Method::CreatePermission),
0x0009 => Ok(Method::ChannelBind),
_ => Err(TurnError::StunParseError(format!(
"unknown STUN method: 0x{:04x}",
val
))),
}
}
}
// ---------------------------------------------------------------------------
// Class
// ---------------------------------------------------------------------------
/// STUN message class (RFC 5389 §6).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Class {
/// Request (C0=0, C1=0)
Request,
/// Indication (C0=1, C1=0)
Indication,
/// Success Response (C0=0, C1=1)
SuccessResponse,
/// Error Response (C0=1, C1=1)
ErrorResponse,
}
impl Class {
/// Returns the 2-bit class value (C0 in bit 0, C1 in bit 1).
pub fn as_u8(self) -> u8 {
match self {
Class::Request => 0b00,
Class::Indication => 0b01,
Class::SuccessResponse => 0b10,
Class::ErrorResponse => 0b11,
}
}
/// Parses a 2-bit class value.
pub fn from_u8(val: u8) -> Result<Self, TurnError> {
match val & 0x03 {
0b00 => Ok(Class::Request),
0b01 => Ok(Class::Indication),
0b10 => Ok(Class::SuccessResponse),
0b11 => Ok(Class::ErrorResponse),
_ => unreachable!(),
}
}
}
// ---------------------------------------------------------------------------
// MessageType
// ---------------------------------------------------------------------------
/// Combined STUN message type (method + class).
///
/// The 14-bit message type field encodes both the method and class with
/// an interleaved bit layout per RFC 5389 §6.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MessageType {
pub method: Method,
pub class: Class,
}
impl MessageType {
pub fn new(method: Method, class: Class) -> Self {
Self { method, class }
}
/// Encode method + class into the 14-bit STUN message type field.
///
/// Bit layout of the 16-bit type field (top 2 bits always 0 for STUN):
/// ```text
/// 15 14 | 13 12 11 10 9 | 8 | 7 6 5 | 4 | 3 2 1 0
/// 0 0 | M11 M10 M9 M8 M7 | C1 | M6 M5 M4 | C0 | M3 M2 M1 M0
/// ```
pub fn to_u16(self) -> u16 {
let m = self.method.as_u16();
let c = self.class.as_u8() as u16;
// Extract method bit groups
let m0_3 = m & 0x000F; // bits 0-3
let m4_6 = (m >> 4) & 0x0007; // bits 4-6
let m7_11 = (m >> 7) & 0x001F; // bits 7-11
// Extract class bits
let c0 = c & 0x01;
let c1 = (c >> 1) & 0x01;
// Assemble: M0-M3 in bits 0-3, C0 in bit 4, M4-M6 in bits 5-7,
// C1 in bit 8, M7-M11 in bits 9-13
m0_3 | (c0 << 4) | (m4_6 << 5) | (c1 << 8) | (m7_11 << 9)
}
/// Decode the 14-bit STUN message type field into method + class.
pub fn from_u16(val: u16) -> Result<Self, TurnError> {
// Top 2 bits must be 00 for STUN messages
if val & 0xC000 != 0 {
return Err(TurnError::StunParseError(
"top 2 bits of message type must be 00 for STUN".into(),
));
}
// Extract class bits
let c0 = (val >> 4) & 0x01;
let c1 = (val >> 8) & 0x01;
let class_bits = (c0 | (c1 << 1)) as u8;
// Extract method bits
let m0_3 = val & 0x000F;
let m4_6 = (val >> 5) & 0x0007;
let m7_11 = (val >> 9) & 0x001F;
let method_bits = m0_3 | (m4_6 << 4) | (m7_11 << 7);
Ok(MessageType {
method: Method::from_u16(method_bits)?,
class: Class::from_u8(class_bits)?,
})
}
}
// ---------------------------------------------------------------------------
// StunMessage
// ---------------------------------------------------------------------------
/// A parsed STUN message with header fields and attributes.
///
/// The wire format is a 20-byte header followed by zero or more TLV-encoded
/// attributes. The message length field in the header covers only the
/// attribute portion (not the 20-byte header itself).
#[derive(Debug, Clone)]
pub struct StunMessage {
pub msg_type: MessageType,
pub transaction_id: [u8; 12],
pub attributes: Vec<StunAttribute>,
}
impl StunMessage {
/// Create a new STUN message with the given type and transaction ID.
pub fn new(msg_type: MessageType, transaction_id: [u8; 12]) -> Self {
Self {
msg_type,
transaction_id,
attributes: Vec::new(),
}
}
/// Create a new STUN message with a random transaction ID.
pub fn new_random(msg_type: MessageType) -> Self {
let mut transaction_id = [0u8; 12];
// Use simple random fill; callers with `rand` can use proper RNG
#[cfg(feature = "rand")]
{
use rand::RngCore;
rand::thread_rng().fill_bytes(&mut transaction_id);
}
#[cfg(not(feature = "rand"))]
{
// Fallback: use std time-based entropy (not cryptographically secure,
// but sufficient for transaction IDs in development/testing)
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default();
let seed = now.as_nanos();
for (i, byte) in transaction_id.iter_mut().enumerate() {
*byte = ((seed >> (i * 5)) & 0xFF) as u8;
}
}
Self {
msg_type,
transaction_id,
attributes: Vec::new(),
}
}
/// Add an attribute to the message.
pub fn add_attribute(&mut self, attr: StunAttribute) {
self.attributes.push(attr);
}
/// Decode a STUN message from raw bytes.
///
/// Validates the magic cookie and parses the header and all TLV attributes.
/// Returns an error if the message is truncated, has an invalid cookie,
/// or contains malformed attributes.
pub fn decode(bytes: &[u8]) -> Result<Self, TurnError> {
if bytes.len() < STUN_HEADER_SIZE {
return Err(TurnError::StunParseError(format!(
"message too short: {} bytes, need at least {}",
bytes.len(),
STUN_HEADER_SIZE
)));
}
// Parse message type (first 2 bytes)
let type_val = u16::from_be_bytes([bytes[0], bytes[1]]);
let msg_type = MessageType::from_u16(type_val)?;
// Parse message length (bytes 2-3) — length of attributes only
let msg_length = u16::from_be_bytes([bytes[2], bytes[3]]) as usize;
// Validate magic cookie (bytes 4-7)
let cookie = u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
if cookie != MAGIC_COOKIE {
return Err(TurnError::StunParseError(format!(
"invalid magic cookie: 0x{:08x}, expected 0x{:08x}",
cookie, MAGIC_COOKIE
)));
}
// Extract transaction ID (bytes 8-19)
let mut transaction_id = [0u8; 12];
transaction_id.copy_from_slice(&bytes[8..20]);
// Validate total length
let total_expected = STUN_HEADER_SIZE + msg_length;
if bytes.len() < total_expected {
return Err(TurnError::StunParseError(format!(
"message truncated: have {} bytes, header says {}",
bytes.len(),
total_expected
)));
}
// Parse attributes from the body
let attr_bytes = &bytes[STUN_HEADER_SIZE..total_expected];
let attributes = Self::decode_attributes(attr_bytes, &transaction_id)?;
Ok(StunMessage {
msg_type,
transaction_id,
attributes,
})
}
/// Parse the TLV attribute list from the message body.
fn decode_attributes(
mut data: &[u8],
transaction_id: &[u8; 12],
) -> Result<Vec<StunAttribute>, TurnError> {
let mut attributes = Vec::new();
while data.len() >= 4 {
// Each attribute: type (2 bytes) + length (2 bytes) + value + padding
let attr_type = u16::from_be_bytes([data[0], data[1]]);
let attr_len = u16::from_be_bytes([data[2], data[3]]) as usize;
if data.len() < 4 + attr_len {
return Err(TurnError::StunParseError(format!(
"attribute 0x{:04x} truncated: need {} bytes, have {}",
attr_type,
attr_len,
data.len() - 4
)));
}
let attr_value = &data[4..4 + attr_len];
let attr = crate::turn::attributes::decode_attribute(
attr_type,
attr_value,
transaction_id,
)?;
attributes.push(attr);
// Advance past value + padding to 4-byte boundary
let padded_len = (attr_len + 3) & !3;
let total_attr_size = 4 + padded_len;
if total_attr_size > data.len() {
break;
}
data = &data[total_attr_size..];
}
Ok(attributes)
}
/// Encode this STUN message into wire format bytes.
///
/// The message length field is computed from the encoded attributes.
/// Attributes are TLV-encoded with 4-byte padding.
pub fn encode(&self) -> Vec<u8> {
// Encode all attributes first to determine total length
let mut attr_bytes = Vec::new();
for attr in &self.attributes {
let encoded = crate::turn::attributes::encode_attribute(attr, &self.transaction_id);
attr_bytes.extend_from_slice(&encoded);
}
let msg_length = attr_bytes.len() as u16;
// Build the 20-byte header
let mut buf = Vec::with_capacity(STUN_HEADER_SIZE + attr_bytes.len());
// Message type (2 bytes)
buf.extend_from_slice(&self.msg_type.to_u16().to_be_bytes());
// Message length (2 bytes)
buf.extend_from_slice(&msg_length.to_be_bytes());
// Magic cookie (4 bytes)
buf.extend_from_slice(&MAGIC_COOKIE.to_be_bytes());
// Transaction ID (12 bytes)
buf.extend_from_slice(&self.transaction_id);
// Attributes
buf.extend_from_slice(&attr_bytes);
buf
}
/// Encode the message, computing the correct length field as if a
/// MESSAGE-INTEGRITY attribute of 24 bytes (4-byte TLV header + 20-byte HMAC)
/// will be appended. This is used to generate the bytes over which
/// MESSAGE-INTEGRITY is computed per RFC 5389 §15.4.
///
/// The returned bytes do NOT include the MESSAGE-INTEGRITY attribute itself.
pub fn encode_for_integrity(&self) -> Vec<u8> {
// Encode attributes up to (but not including) MESSAGE-INTEGRITY
let mut attr_bytes = Vec::new();
for attr in &self.attributes {
if matches!(attr, StunAttribute::MessageIntegrity(_)) {
break;
}
if matches!(attr, StunAttribute::Fingerprint(_)) {
break;
}
let encoded = crate::turn::attributes::encode_attribute(attr, &self.transaction_id);
attr_bytes.extend_from_slice(&encoded);
}
// The length field must include the MESSAGE-INTEGRITY attribute that will follow
// (4 bytes TLV header + 20 bytes HMAC = 24 bytes)
let msg_length = (attr_bytes.len() + 24) as u16;
let mut buf = Vec::with_capacity(STUN_HEADER_SIZE + attr_bytes.len());
buf.extend_from_slice(&self.msg_type.to_u16().to_be_bytes());
buf.extend_from_slice(&msg_length.to_be_bytes());
buf.extend_from_slice(&MAGIC_COOKIE.to_be_bytes());
buf.extend_from_slice(&self.transaction_id);
buf.extend_from_slice(&attr_bytes);
buf
}
/// Encode the message for FINGERPRINT computation per RFC 5389 §15.5.
///
/// Returns all bytes up to (but not including) the FINGERPRINT attribute,
/// with the length field adjusted to include the FINGERPRINT (8 bytes).
pub fn encode_for_fingerprint(&self) -> Vec<u8> {
// Encode all attributes except FINGERPRINT
let mut attr_bytes = Vec::new();
for attr in &self.attributes {
if matches!(attr, StunAttribute::Fingerprint(_)) {
break;
}
let encoded = crate::turn::attributes::encode_attribute(attr, &self.transaction_id);
attr_bytes.extend_from_slice(&encoded);
}
// Length includes the FINGERPRINT attribute (4 header + 4 value = 8 bytes)
let msg_length = (attr_bytes.len() + 8) as u16;
let mut buf = Vec::with_capacity(STUN_HEADER_SIZE + attr_bytes.len());
buf.extend_from_slice(&self.msg_type.to_u16().to_be_bytes());
buf.extend_from_slice(&msg_length.to_be_bytes());
buf.extend_from_slice(&MAGIC_COOKIE.to_be_bytes());
buf.extend_from_slice(&self.transaction_id);
buf.extend_from_slice(&attr_bytes);
buf
}
/// Find the first attribute of a given type in the message.
pub fn get_attribute(&self, predicate: impl Fn(&StunAttribute) -> bool) -> Option<&StunAttribute> {
self.attributes.iter().find(|a| predicate(a))
}
/// Get the USERNAME attribute value, if present.
pub fn get_username(&self) -> Option<&str> {
for attr in &self.attributes {
if let StunAttribute::Username(ref u) = attr {
return Some(u.as_str());
}
}
None
}
/// Get the REALM attribute value, if present.
pub fn get_realm(&self) -> Option<&str> {
for attr in &self.attributes {
if let StunAttribute::Realm(ref r) = attr {
return Some(r.as_str());
}
}
None
}
/// Get the NONCE attribute value, if present.
pub fn get_nonce(&self) -> Option<&str> {
for attr in &self.attributes {
if let StunAttribute::Nonce(ref n) = attr {
return Some(n.as_str());
}
}
None
}
/// Get the MESSAGE-INTEGRITY attribute value, if present.
pub fn get_message_integrity(&self) -> Option<&[u8; 20]> {
for attr in &self.attributes {
if let StunAttribute::MessageIntegrity(ref mi) = attr {
return Some(mi);
}
}
None
}
}
// ---------------------------------------------------------------------------
// ChannelData
// ---------------------------------------------------------------------------
/// ChannelData message for TURN channel bindings (RFC 5766 §11.4).
///
/// ChannelData uses a compact 4-byte header instead of the STUN format:
/// Bytes 0-1: Channel Number (0x4000-0x7FFF)
/// Bytes 2-3: Data Length
/// Bytes 4+: Application Data (padded to 4 bytes over UDP)
#[derive(Debug, Clone)]
pub struct ChannelData {
pub channel_number: u16,
pub data: Vec<u8>,
}
impl ChannelData {
/// Decode a ChannelData message from raw bytes.
pub fn decode(bytes: &[u8]) -> Result<Self, TurnError> {
if bytes.len() < 4 {
return Err(TurnError::StunParseError(
"ChannelData message too short".into(),
));
}
let channel_number = u16::from_be_bytes([bytes[0], bytes[1]]);
let data_length = u16::from_be_bytes([bytes[2], bytes[3]]) as usize;
// Channel numbers must be in range 0x4000-0x7FFF
if !(0x4000..=0x7FFF).contains(&channel_number) {
return Err(TurnError::StunParseError(format!(
"invalid channel number: 0x{:04x}, must be in 0x4000-0x7FFF",
channel_number
)));
}
if bytes.len() < 4 + data_length {
return Err(TurnError::StunParseError(format!(
"ChannelData truncated: need {} data bytes, have {}",
data_length,
bytes.len() - 4
)));
}
let data = bytes[4..4 + data_length].to_vec();
Ok(ChannelData {
channel_number,
data,
})
}
/// Encode this ChannelData message into wire format.
///
/// The data portion is padded to a 4-byte boundary (for UDP transport).
pub fn encode(&self) -> Vec<u8> {
let data_len = self.data.len();
let padded_len = (data_len + 3) & !3;
let mut buf = Vec::with_capacity(4 + padded_len);
buf.extend_from_slice(&self.channel_number.to_be_bytes());
buf.extend_from_slice(&(data_len as u16).to_be_bytes());
buf.extend_from_slice(&self.data);
// Pad to 4-byte boundary with zeros
let padding = padded_len - data_len;
for _ in 0..padding {
buf.push(0);
}
buf
}
}
// ---------------------------------------------------------------------------
// Detection helpers
// ---------------------------------------------------------------------------
/// Returns `true` if the buffer begins with a ChannelData message
/// (first two bits are NOT `00`).
///
/// STUN messages always have the first two bits as `00` (message type field).
/// ChannelData starts with channel numbers 0x4000-0x7FFF, which have the
/// first two bits as `01`.
pub fn is_channel_data(bytes: &[u8]) -> bool {
if bytes.is_empty() {
return false;
}
// First two bits: 00 = STUN, 01 = ChannelData
(bytes[0] & 0xC0) != 0x00
}
/// Returns `true` if the buffer looks like a valid STUN message
/// (first two bits are `00` and magic cookie is present).
pub fn is_stun_message(bytes: &[u8]) -> bool {
if bytes.len() < STUN_HEADER_SIZE {
return false;
}
// First two bits must be 00
if (bytes[0] & 0xC0) != 0x00 {
return false;
}
// Check magic cookie
let cookie = u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
cookie == MAGIC_COOKIE
}
/// Compute the CRC32 fingerprint value for a STUN message.
///
/// The fingerprint is `CRC32(message_bytes) XOR 0x5354554e` per RFC 5389 §15.5.
/// `message_bytes` should be the entire message up to (but not including)
/// the FINGERPRINT attribute, with the length field adjusted to include it.
pub fn compute_fingerprint(message_bytes: &[u8]) -> u32 {
let crc = crc32_compute(message_bytes);
crc ^ FINGERPRINT_XOR
}
/// Simple CRC32 (ISO 3309 / ITU-T V.42) implementation.
///
/// This uses the standard polynomial 0xEDB88320 (reflected form).
/// In production, the `crc32fast` crate should be used instead for
/// SIMD-accelerated performance.
fn crc32_compute(data: &[u8]) -> u32 {
let mut crc: u32 = 0xFFFFFFFF;
for &byte in data {
crc ^= byte as u32;
for _ in 0..8 {
if crc & 1 != 0 {
crc = (crc >> 1) ^ 0xEDB88320;
} else {
crc >>= 1;
}
}
}
!crc
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_type_encoding_binding_request() {
let mt = MessageType::new(Method::Binding, Class::Request);
let encoded = mt.to_u16();
assert_eq!(encoded, 0x0001);
let decoded = MessageType::from_u16(encoded).unwrap();
assert_eq!(decoded.method, Method::Binding);
assert_eq!(decoded.class, Class::Request);
}
#[test]
fn test_message_type_encoding_binding_success() {
let mt = MessageType::new(Method::Binding, Class::SuccessResponse);
let encoded = mt.to_u16();
assert_eq!(encoded, 0x0101);
let decoded = MessageType::from_u16(encoded).unwrap();
assert_eq!(decoded.method, Method::Binding);
assert_eq!(decoded.class, Class::SuccessResponse);
}
#[test]
fn test_message_type_encoding_allocate_request() {
let mt = MessageType::new(Method::Allocate, Class::Request);
let encoded = mt.to_u16();
assert_eq!(encoded, 0x0003);
let decoded = MessageType::from_u16(encoded).unwrap();
assert_eq!(decoded.method, Method::Allocate);
assert_eq!(decoded.class, Class::Request);
}
#[test]
fn test_message_type_encoding_allocate_error() {
let mt = MessageType::new(Method::Allocate, Class::ErrorResponse);
let encoded = mt.to_u16();
assert_eq!(encoded, 0x0113);
let decoded = MessageType::from_u16(encoded).unwrap();
assert_eq!(decoded.method, Method::Allocate);
assert_eq!(decoded.class, Class::ErrorResponse);
}
#[test]
fn test_message_type_roundtrip_all_methods() {
let methods = [
Method::Binding,
Method::Allocate,
Method::Refresh,
Method::Send,
Method::Data,
Method::CreatePermission,
Method::ChannelBind,
];
let classes = [
Class::Request,
Class::Indication,
Class::SuccessResponse,
Class::ErrorResponse,
];
for method in &methods {
for class in &classes {
let mt = MessageType::new(*method, *class);
let encoded = mt.to_u16();
let decoded = MessageType::from_u16(encoded).unwrap();
assert_eq!(decoded.method, *method, "method mismatch for {:?}/{:?}", method, class);
assert_eq!(decoded.class, *class, "class mismatch for {:?}/{:?}", method, class);
}
}
}
#[test]
fn test_is_channel_data() {
// STUN message: first byte has top 2 bits = 00
assert!(!is_channel_data(&[0x00, 0x01]));
// ChannelData: channel 0x4000 -> first byte = 0x40 -> top 2 bits = 01
assert!(is_channel_data(&[0x40, 0x00]));
// Empty
assert!(!is_channel_data(&[]));
}
#[test]
fn test_is_stun_message() {
// Valid STUN header: type=0x0001, len=0, cookie=0x2112A442, txn_id=zeros
let mut buf = [0u8; 20];
buf[0] = 0x00;
buf[1] = 0x01;
// length = 0
buf[4] = 0x21;
buf[5] = 0x12;
buf[6] = 0xA4;
buf[7] = 0x42;
assert!(is_stun_message(&buf));
// Wrong cookie
buf[4] = 0x00;
assert!(!is_stun_message(&buf));
}
#[test]
fn test_channel_data_roundtrip() {
let cd = ChannelData {
channel_number: 0x4001,
data: vec![1, 2, 3, 4, 5],
};
let encoded = cd.encode();
let decoded = ChannelData::decode(&encoded).unwrap();
assert_eq!(decoded.channel_number, 0x4001);
assert_eq!(decoded.data, vec![1, 2, 3, 4, 5]);
}
#[test]
fn test_channel_data_padding() {
let cd = ChannelData {
channel_number: 0x4001,
data: vec![1, 2, 3], // 3 bytes -> padded to 4
};
let encoded = cd.encode();
// 4 header + 4 padded data = 8 bytes
assert_eq!(encoded.len(), 8);
assert_eq!(encoded[7], 0); // padding byte
}
#[test]
fn test_stun_message_encode_decode_empty() {
let msg = StunMessage {
msg_type: MessageType::new(Method::Binding, Class::Request),
transaction_id: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
attributes: vec![],
};
let encoded = msg.encode();
assert_eq!(encoded.len(), STUN_HEADER_SIZE);
let decoded = StunMessage::decode(&encoded).unwrap();
assert_eq!(decoded.msg_type.method, Method::Binding);
assert_eq!(decoded.msg_type.class, Class::Request);
assert_eq!(decoded.transaction_id, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
assert!(decoded.attributes.is_empty());
}
#[test]
fn test_crc32_known_value() {
// CRC32 of "123456789" is 0xCBF43926
let data = b"123456789";
let crc = crc32_compute(data);
assert_eq!(crc, 0xCBF43926);
}
#[test]
fn test_decode_too_short() {
let result = StunMessage::decode(&[0; 10]);
assert!(result.is_err());
}
#[test]
fn test_decode_bad_cookie() {
let mut buf = [0u8; 20];
buf[0] = 0x00;
buf[1] = 0x01;
// Bad cookie
buf[4] = 0xFF;
let result = StunMessage::decode(&buf);
assert!(result.is_err());
}
}

329
src/turn/tcp_listener.rs Normal file
View File

@ -0,0 +1,329 @@
// TCP TURN listener
//
// Accepts TCP connections from TURN clients and frames STUN/ChannelData
// messages over the stream. Per RFC 5766 §2.1, when TURN is used over
// TCP, the client connects to the TURN server via TCP but the relay
// still uses UDP to communicate with peers.
//
// TCP framing: STUN messages and ChannelData are self-delimiting on
// a TCP stream. The reader peeks at the first two bytes to determine
// the message type:
// - If the first two bits are 00 → STUN message. Read 20-byte header,
// extract message length from bytes 2-3, then read the attribute body.
// - Otherwise → ChannelData. The first 2 bytes are the channel number,
// next 2 bytes are the data length, then read the data (padded to 4 bytes).
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use super::allocation::{AllocationManager, TransportProtocol};
use super::handler::{HandleResult, MessageContext, TurnHandler};
use super::stun::{is_channel_data, ChannelData, StunMessage};
// ---------------------------------------------------------------------------
// TcpTurnListener
// ---------------------------------------------------------------------------
/// TCP listener for the TURN server.
///
/// Accepts TCP connections from clients and processes STUN/TURN messages
/// framed on the TCP stream. Each accepted connection is handled in a
/// separate spawned task.
pub struct TcpTurnListener {
/// The bound TCP listener.
listener: TcpListener,
/// The shared TURN message handler.
handler: Arc<TurnHandler>,
/// The shared allocation manager (for relay receiver spawning).
allocations: Arc<AllocationManager>,
/// The server's listen address.
server_addr: SocketAddr,
/// The server's public IP (for MessageContext).
server_public_ip: std::net::IpAddr,
/// The primary UDP socket (for spawning relay receiver tasks).
/// TCP clients still use UDP relay sockets for the peer-facing side.
udp_socket: Option<Arc<tokio::net::UdpSocket>>,
}
impl TcpTurnListener {
/// Bind a TCP listener on the given address.
pub async fn bind(
addr: SocketAddr,
handler: Arc<TurnHandler>,
allocations: Arc<AllocationManager>,
server_public_ip: std::net::IpAddr,
) -> std::io::Result<Self> {
let listener = TcpListener::bind(addr).await?;
let server_addr = listener.local_addr()?;
Ok(Self {
listener,
handler,
allocations,
server_addr,
server_public_ip,
udp_socket: None,
})
}
/// Set the primary UDP socket (used for relay receiver tasks spawned
/// from TCP-originated allocations).
pub fn set_udp_socket(&mut self, socket: Arc<tokio::net::UdpSocket>) {
self.udp_socket = Some(socket);
}
/// Run the TCP listener loop.
///
/// Accepts connections in a loop and spawns a handler task for each one.
pub async fn run(self) {
let handler = self.handler;
let server_addr = self.server_addr;
let server_public_ip = self.server_public_ip;
let allocations = self.allocations;
let udp_socket = self.udp_socket;
loop {
match self.listener.accept().await {
Ok((stream, client_addr)) => {
let handler = Arc::clone(&handler);
let allocations = Arc::clone(&allocations);
let udp_socket = udp_socket.clone();
tokio::spawn(async move {
if let Err(e) = handle_tcp_client(
stream,
client_addr,
server_addr,
server_public_ip,
handler,
allocations,
udp_socket,
)
.await
{
log::debug!("[TURN-TCP] client {} disconnected: {}", client_addr, e);
}
});
}
Err(e) => {
log::error!("[TURN-TCP] accept error: {}", e);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
}
}
}
}
// ---------------------------------------------------------------------------
// Per-client TCP handler
// ---------------------------------------------------------------------------
/// Handle a single TCP client connection.
///
/// Reads STUN/ChannelData messages from the TCP stream in a loop, dispatches
/// them to the handler, and writes responses back.
async fn handle_tcp_client(
stream: TcpStream,
client_addr: SocketAddr,
server_addr: SocketAddr,
server_public_ip: std::net::IpAddr,
handler: Arc<TurnHandler>,
allocations: Arc<AllocationManager>,
udp_socket: Option<Arc<tokio::net::UdpSocket>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let ctx = MessageContext {
client_addr,
server_addr,
protocol: TransportProtocol::Tcp,
server_public_ip,
};
let (mut reader, mut writer) = stream.into_split();
let mut header_buf = [0u8; 4];
loop {
// Read the first 2 bytes to determine message type
match reader.read_exact(&mut header_buf[..2]).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
// Clean disconnect
return Ok(());
}
Err(e) => return Err(e.into()),
}
if is_channel_data(&header_buf[..2]) {
// ChannelData: first 2 bytes = channel number, next 2 = data length
reader.read_exact(&mut header_buf[2..4]).await?;
let data_len = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
// TCP pads ChannelData to 4-byte boundary
let padded_len = (data_len + 3) & !3;
let mut data_buf = vec![0u8; padded_len];
reader.read_exact(&mut data_buf[..padded_len]).await?;
// Reconstruct the full ChannelData for decoding (header + unpadded data)
let mut full_msg = Vec::with_capacity(4 + data_len);
full_msg.extend_from_slice(&header_buf);
full_msg.extend_from_slice(&data_buf[..data_len]);
match ChannelData::decode(&full_msg) {
Ok(channel_data) => {
if let Some(result) =
handler.handle_channel_data(&channel_data, &ctx).await
{
execute_tcp_result(&mut writer, result, client_addr, &allocations, &udp_socket)
.await;
}
}
Err(e) => {
log::debug!(
"[TURN-TCP] failed to parse ChannelData from {}: {}",
client_addr,
e
);
}
}
} else {
// STUN message: first 2 bytes are message type, next 2 are message length
// We need to read the remaining 18 bytes of the 20-byte header
let mut stun_header = [0u8; 20];
stun_header[0] = header_buf[0];
stun_header[1] = header_buf[1];
reader.read_exact(&mut stun_header[2..20]).await?;
let msg_len = u16::from_be_bytes([stun_header[2], stun_header[3]]) as usize;
// Read the attribute body
let mut msg_buf = Vec::with_capacity(20 + msg_len);
msg_buf.extend_from_slice(&stun_header);
if msg_len > 0 {
let mut body = vec![0u8; msg_len];
reader.read_exact(&mut body).await?;
msg_buf.extend_from_slice(&body);
}
match StunMessage::decode(&msg_buf) {
Ok(msg) => {
let results = handler.handle_message(&msg, &ctx).await;
for result in results {
execute_tcp_result(&mut writer, result, client_addr, &allocations, &udp_socket)
.await;
}
}
Err(e) => {
log::debug!(
"[TURN-TCP] failed to parse STUN message from {}: {}",
client_addr,
e
);
}
}
}
}
}
/// Execute a single [`HandleResult`] for a TCP client.
///
/// Responses are written directly to the TCP stream. Relay operations
/// use the UDP relay socket (relay is always UDP, even for TCP clients).
async fn execute_tcp_result(
writer: &mut tokio::net::tcp::OwnedWriteHalf,
result: HandleResult,
client_addr: SocketAddr,
allocations: &Arc<AllocationManager>,
udp_socket: &Option<Arc<tokio::net::UdpSocket>>,
) {
match result {
HandleResult::Response(data) => {
if let Err(e) = writer.write_all(&data).await {
log::error!(
"[TURN-TCP] failed to write response to {}: {}",
client_addr,
e
);
}
}
HandleResult::RelayToPeer {
peer_addr,
data,
relay_socket,
} => {
// Relay is always UDP even for TCP clients
if let Err(e) = relay_socket.send_to(&data, peer_addr).await {
log::error!(
"[TURN-TCP] failed to relay to peer {}: {}",
peer_addr,
e
);
}
}
HandleResult::ChannelDataToPeer {
peer_addr,
data,
relay_socket,
} => {
if let Err(e) = relay_socket.send_to(&data, peer_addr).await {
log::error!(
"[TURN-TCP] failed to send channel data to peer {}: {}",
peer_addr,
e
);
}
}
HandleResult::AllocationCreated {
response,
relay_socket,
relay_addr,
five_tuple,
} => {
// Send the success response to the TCP client
if let Err(e) = writer.write_all(&response).await {
log::error!(
"[TURN-TCP] failed to write allocate response to {}: {}",
client_addr,
e
);
}
// For TCP clients, we still need a relay receiver task to handle
// peer → relay socket → client. However, for TCP the data needs to
// go back over the TCP stream. For simplicity, if a UDP socket is
// available we use the UDP-based relay receiver. In a full
// implementation, the relay receiver would write to the TCP stream.
//
// NOTE: In the current architecture, TCP allocations still relay
// data via UDP. Peer data arriving on the relay socket will be
// forwarded to the client's address via the main UDP socket (if
// the client also has a UDP path). For pure TCP clients, a more
// sophisticated approach would be needed.
if let Some(main_udp) = udp_socket {
let main_socket = Arc::clone(main_udp);
let allocs = Arc::clone(allocations);
let ft = five_tuple.clone();
log::debug!(
"[TURN-TCP] spawning relay receiver for {} (relay {})",
client_addr,
relay_addr
);
tokio::spawn(async move {
super::udp_listener::relay_receiver_task(
relay_socket,
relay_addr,
ft,
main_socket,
allocs,
)
.await;
});
} else {
log::warn!(
"[TURN-TCP] no UDP socket available for relay receiver (allocation {})",
relay_addr
);
}
}
HandleResult::None => {}
}
}

339
src/turn/udp_listener.rs Normal file
View File

@ -0,0 +1,339 @@
// UDP TURN listener
//
// Listens on a UDP socket for incoming STUN/TURN messages and ChannelData.
// Dispatches to the TurnHandler for processing and executes the resulting
// I/O actions (send responses, relay data to peers).
//
// Also manages relay receiver tasks — when a new allocation is created,
// a task is spawned to read data arriving on the allocation's relay socket
// from peers. That data is wrapped as either ChannelData (if a channel
// binding exists) or a STUN Data indication, and sent back to the client
// via the main UDP socket.
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::UdpSocket;
use super::allocation::{AllocationManager, FiveTuple, TransportProtocol};
use super::attributes::StunAttribute;
use super::handler::{HandleResult, MessageContext, TurnHandler};
use super::stun::{
is_channel_data, is_stun_message, ChannelData, Class, MessageType, Method, StunMessage,
compute_fingerprint,
};
// ---------------------------------------------------------------------------
// UdpTurnListener
// ---------------------------------------------------------------------------
/// UDP listener for the TURN server.
///
/// Reads datagrams from the primary UDP socket, dispatches them to the
/// [`TurnHandler`], and executes I/O actions. Also spawns relay receiver
/// tasks for each new allocation.
pub struct UdpTurnListener {
/// The primary UDP socket bound to the TURN server port (e.g., 3478).
socket: Arc<UdpSocket>,
/// The shared TURN message handler.
handler: Arc<TurnHandler>,
/// The allocation manager (needed for relay receiver lookups).
allocations: Arc<AllocationManager>,
/// The server's listen address.
server_addr: SocketAddr,
/// The server's public IP (for MessageContext).
server_public_ip: std::net::IpAddr,
}
impl UdpTurnListener {
/// Create a new UDP TURN listener.
///
/// - `socket`: The primary UDP socket (already bound to the listen address).
/// - `handler`: The shared TURN message handler.
/// - `allocations`: The shared allocation manager.
/// - `server_addr`: The server's listen address.
/// - `server_public_ip`: The server's public IP for relay addresses.
pub fn new(
socket: Arc<UdpSocket>,
handler: Arc<TurnHandler>,
allocations: Arc<AllocationManager>,
server_addr: SocketAddr,
server_public_ip: std::net::IpAddr,
) -> Self {
Self {
socket,
handler,
allocations,
server_addr,
server_public_ip,
}
}
/// Run the UDP listener loop.
///
/// This reads datagrams from the primary UDP socket in a loop. Each
/// incoming message is dispatched to a spawned task for processing,
/// so the receive loop is never blocked by handler logic.
///
/// When the handler returns an [`HandleResult::AllocationCreated`], this
/// also spawns a relay receiver task for the new allocation's relay socket.
pub async fn run(self: Arc<Self>) {
let mut buf = vec![0u8; 65536]; // 64KB max UDP datagram
loop {
match self.socket.recv_from(&mut buf).await {
Ok((len, client_addr)) => {
let data = buf[..len].to_vec();
let this = Arc::clone(&self);
// Spawn a task per message to avoid blocking the receive loop
tokio::spawn(async move {
this.process_message(&data, client_addr).await;
});
}
Err(e) => {
log::error!("[TURN-UDP] recv error: {}", e);
// Brief sleep to avoid tight error loop
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
}
}
}
}
/// Process a single incoming UDP message.
///
/// Determines whether the message is ChannelData or a STUN message,
/// dispatches to the handler, and executes the resulting actions.
async fn process_message(&self, data: &[u8], client_addr: SocketAddr) {
let ctx = MessageContext {
client_addr,
server_addr: self.server_addr,
protocol: TransportProtocol::Udp,
server_public_ip: self.server_public_ip,
};
if is_channel_data(data) {
// Parse ChannelData and handle
match ChannelData::decode(data) {
Ok(channel_data) => {
if let Some(result) = self.handler.handle_channel_data(&channel_data, &ctx).await
{
self.execute_result(result, client_addr).await;
}
}
Err(e) => {
log::debug!(
"[TURN-UDP] failed to parse ChannelData from {}: {}",
client_addr,
e
);
}
}
} else if is_stun_message(data) {
// Parse STUN message and handle
match StunMessage::decode(data) {
Ok(msg) => {
let results = self.handler.handle_message(&msg, &ctx).await;
for result in results {
self.execute_result(result, client_addr).await;
}
}
Err(e) => {
log::debug!(
"[TURN-UDP] failed to parse STUN message from {}: {}",
client_addr,
e
);
}
}
}
// else: unknown message format, silently ignored
}
/// Execute a single [`HandleResult`] action.
///
/// Sends responses to clients, relays data to peers, and spawns relay
/// receiver tasks for new allocations.
async fn execute_result(&self, result: HandleResult, client_addr: SocketAddr) {
match result {
HandleResult::Response(data) => {
if let Err(e) = self.socket.send_to(&data, client_addr).await {
log::error!(
"[TURN-UDP] failed to send response to {}: {}",
client_addr,
e
);
}
}
HandleResult::RelayToPeer {
peer_addr,
data,
relay_socket,
} => {
if let Err(e) = relay_socket.send_to(&data, peer_addr).await {
log::error!(
"[TURN-UDP] failed to relay to peer {}: {}",
peer_addr,
e
);
}
}
HandleResult::ChannelDataToPeer {
peer_addr,
data,
relay_socket,
} => {
if let Err(e) = relay_socket.send_to(&data, peer_addr).await {
log::error!(
"[TURN-UDP] failed to send channel data to peer {}: {}",
peer_addr,
e
);
}
}
HandleResult::AllocationCreated {
response,
relay_socket,
relay_addr,
five_tuple,
} => {
// Send the success response to the client
if let Err(e) = self.socket.send_to(&response, client_addr).await {
log::error!(
"[TURN-UDP] failed to send allocate response to {}: {}",
client_addr,
e
);
}
// Spawn a relay receiver task for this allocation
let main_socket = Arc::clone(&self.socket);
let allocations = Arc::clone(&self.allocations);
let ft = five_tuple.clone();
log::debug!(
"[TURN-UDP] spawning relay receiver for {} (relay {})",
client_addr,
relay_addr
);
tokio::spawn(async move {
relay_receiver_task(
relay_socket,
relay_addr,
ft,
main_socket,
allocations,
)
.await;
});
}
HandleResult::None => {}
}
}
}
// ---------------------------------------------------------------------------
// Relay receiver task
// ---------------------------------------------------------------------------
/// Relay receiver task for a single allocation.
///
/// Reads data arriving on the allocation's relay socket from peers and
/// forwards it to the client. The data is wrapped as:
/// - **ChannelData** if there's a channel binding for the peer address
/// - **STUN Data indication** otherwise (with XOR-PEER-ADDRESS and DATA)
///
/// The task runs until the relay socket encounters an unrecoverable error
/// or the allocation is cleaned up (at which point recv_from will fail
/// because the socket is dropped).
///
/// This function is `pub` so it can be reused by the TCP listener for
/// TCP-originated allocations that still use UDP relay sockets.
pub async fn relay_receiver_task(
relay_socket: Arc<UdpSocket>,
relay_addr: SocketAddr,
five_tuple: FiveTuple,
main_socket: Arc<UdpSocket>,
allocations: Arc<AllocationManager>,
) {
let mut buf = vec![0u8; 65536];
let client_addr = five_tuple.client_addr;
loop {
match relay_socket.recv_from(&mut buf).await {
Ok((len, peer_addr)) => {
let data = &buf[..len];
// Check that a permission exists for this peer's IP
if !allocations
.has_permission(&five_tuple, &peer_addr.ip())
.await
{
log::debug!(
"[TURN-RELAY] dropping packet from {} to relay {}: no permission",
peer_addr,
relay_addr
);
continue;
}
// Check if there's a channel binding for this peer
let wrapped = if let Some(channel_number) = allocations
.get_channel_for_peer(&five_tuple, &peer_addr)
.await
{
// Wrap as ChannelData
let cd = ChannelData {
channel_number,
data: data.to_vec(),
};
cd.encode()
} else {
// Wrap as a STUN Data indication
build_data_indication(peer_addr, data)
};
// Send to the client via the main UDP socket
if let Err(e) = main_socket.send_to(&wrapped, client_addr).await {
log::error!(
"[TURN-RELAY] failed to send to client {}: {}",
client_addr,
e
);
}
}
Err(e) => {
// Socket error — likely the allocation was cleaned up and the
// socket was dropped. Exit the task.
log::debug!(
"[TURN-RELAY] relay socket {} recv error (allocation likely expired): {}",
relay_addr,
e
);
break;
}
}
}
log::debug!(
"[TURN-RELAY] relay receiver task for {} exiting",
relay_addr
);
}
/// Build a STUN Data indication message (Method::Data, Class::Indication)
/// with XOR-PEER-ADDRESS and DATA attributes.
///
/// Data indications are used when there's no channel binding for the peer.
/// Per RFC 5766 §10.3, they don't require MESSAGE-INTEGRITY.
fn build_data_indication(peer_addr: SocketAddr, data: &[u8]) -> Vec<u8> {
let mut msg = StunMessage::new_random(MessageType::new(Method::Data, Class::Indication));
msg.add_attribute(StunAttribute::XorPeerAddress(peer_addr));
msg.add_attribute(StunAttribute::Data(data.to_vec()));
// Add FINGERPRINT for demultiplexing
let fp_bytes = msg.encode_for_fingerprint();
let fingerprint = compute_fingerprint(&fp_bytes);
msg.add_attribute(StunAttribute::Fingerprint(fingerprint));
msg.encode()
}