feat(turn): add embedded TURN server with credential service
Integrate a TURN/STUN server into the relay for NAT traversal of WebRTC connections. Clients request time-limited HMAC-SHA1 credentials over a new libp2p request-response protocol and then talk to the TURN server directly via UDP/TCP. Key changes: - Add `turn` module with server, credentials, and configuration - Register `/dusk/turn-credentials/1.0.0` request-response protocol so clients can obtain time-limited TURN credentials (24h TTL) - Expose TURN signaling (3478/udp+tcp) and relay allocation ports (49152-65535/udp) in Dockerfile and docker-compose - Add TURN-related environment variables for public IP, shared secret, realm, port ranges, and allocation limits - Validate directory display_name (1-64 chars) and return typed errors - Restrict keypair file permissions to 0600 on Unix
This commit is contained in:
parent
b29039557a
commit
ea21aa55b6
22
Dockerfile
22
Dockerfile
|
|
@ -48,9 +48,16 @@ USER dusk
|
|||
# set working directory
|
||||
WORKDIR /data
|
||||
|
||||
# expose the default relay port
|
||||
# expose the default relay port (libp2p)
|
||||
EXPOSE 4001
|
||||
|
||||
# expose TURN server ports (UDP + TCP signaling)
|
||||
EXPOSE 3478/udp
|
||||
EXPOSE 3478/tcp
|
||||
|
||||
# expose TURN relay allocation port range (UDP)
|
||||
EXPOSE 49152-65535/udp
|
||||
|
||||
# persist keypair and data to the volume-mounted /data directory
|
||||
# XDG_DATA_HOME tells the directories crate to resolve paths under /data
|
||||
# so the keypair ends up at /data/dusk-relay/keypair instead of ~/.local/share
|
||||
|
|
@ -61,6 +68,19 @@ VOLUME /data
|
|||
ENV RUST_LOG=info
|
||||
ENV DUSK_RELAY_PORT=4001
|
||||
|
||||
# TURN server environment variables
|
||||
ENV DUSK_TURN_ENABLED=true
|
||||
ENV DUSK_TURN_PUBLIC_IP=""
|
||||
ENV DUSK_TURN_SECRET=""
|
||||
ENV DUSK_TURN_UDP_PORT=3478
|
||||
ENV DUSK_TURN_TCP_PORT=3478
|
||||
ENV DUSK_TURN_REALM=duskchat.app
|
||||
ENV DUSK_TURN_PORT_RANGE_START=49152
|
||||
ENV DUSK_TURN_PORT_RANGE_END=65535
|
||||
ENV DUSK_TURN_MAX_ALLOCATIONS=1000
|
||||
ENV DUSK_TURN_MAX_PER_USER=10
|
||||
ENV DUSK_TURN_PUBLIC_HOST=""
|
||||
|
||||
# health check to verify the relay is listening
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD timeout 5 bash -c 'cat < /dev/null > /dev/tcp/0.0.0.0/${DUSK_RELAY_PORT:-4001}' || exit 1
|
||||
|
|
|
|||
|
|
@ -8,10 +8,32 @@ services:
|
|||
container_name: dusk-relay
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
# libp2p relay port
|
||||
- "4001:4001"
|
||||
# TURN server signaling (UDP + TCP)
|
||||
- "3478:3478/udp"
|
||||
- "3478:3478/tcp"
|
||||
# TURN relay allocation ports (UDP)
|
||||
# NOTE: using a smaller range (49152-50000) for Docker; adjust as needed
|
||||
- "49152-50000:49152-50000/udp"
|
||||
environment:
|
||||
- RUST_LOG=info
|
||||
# libp2p relay
|
||||
- DUSK_RELAY_PORT=4001
|
||||
- DUSK_PEER_RELAYS=
|
||||
- KLIPY_API_KEY=
|
||||
# TURN server
|
||||
- DUSK_TURN_ENABLED=true
|
||||
- DUSK_TURN_PUBLIC_IP=${DUSK_TURN_PUBLIC_IP:?Set DUSK_TURN_PUBLIC_IP to your server public IP}
|
||||
- DUSK_TURN_SECRET=${DUSK_TURN_SECRET:?Set DUSK_TURN_SECRET for HMAC credentials}
|
||||
- DUSK_TURN_UDP_PORT=3478
|
||||
- DUSK_TURN_TCP_PORT=3478
|
||||
- DUSK_TURN_REALM=duskchat.app
|
||||
- DUSK_TURN_PORT_RANGE_START=49152
|
||||
- DUSK_TURN_PORT_RANGE_END=50000
|
||||
- DUSK_TURN_MAX_ALLOCATIONS=1000
|
||||
- DUSK_TURN_MAX_PER_USER=10
|
||||
- DUSK_TURN_PUBLIC_HOST=${DUSK_TURN_PUBLIC_HOST:-}
|
||||
volumes:
|
||||
# persist the relay's keypair so peer id stays stable across restarts
|
||||
- dusk-relay-data:/data
|
||||
|
|
|
|||
211
src/main.rs
211
src/main.rs
|
|
@ -25,6 +25,8 @@
|
|||
// t3.large (8GB): 20,000 max connections
|
||||
// c6i.xlarge: 50,000 max connections (with kernel tuning)
|
||||
|
||||
mod turn;
|
||||
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::path::PathBuf;
|
||||
use std::time::{Duration, Instant};
|
||||
|
|
@ -63,6 +65,8 @@ struct RelayBehaviour {
|
|||
gif_service: cbor::Behaviour<GifRequest, GifResponse>,
|
||||
// persistent directory service - clients register/search peer profiles
|
||||
directory_service: cbor::Behaviour<DirectoryRequest, DirectoryResponse>,
|
||||
// TURN credential service - clients request time-limited TURN server credentials
|
||||
turn_credentials: cbor::Behaviour<TurnCredentialRequest, TurnCredentialResponse>,
|
||||
}
|
||||
|
||||
// ---- gif protocol ----
|
||||
|
|
@ -227,6 +231,7 @@ pub enum DirectoryRequest {
|
|||
pub enum DirectoryResponse {
|
||||
Ok,
|
||||
Results(Vec<DirectoryProfileEntry>),
|
||||
Error(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
|
|
@ -237,6 +242,28 @@ pub struct DirectoryProfileEntry {
|
|||
}
|
||||
// ---- end directory protocol ----
|
||||
|
||||
// ---- TURN credential protocol ----
|
||||
// clients request time-limited TURN credentials over libp2p request-response.
|
||||
// the relay generates HMAC-SHA1 credentials that the client then uses when
|
||||
// talking to the TURN server directly via UDP/TCP (separate from the libp2p swarm).
|
||||
|
||||
const TURN_CREDENTIALS_PROTOCOL: StreamProtocol =
|
||||
StreamProtocol::new("/dusk/turn-credentials/1.0.0");
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct TurnCredentialRequest {
|
||||
pub peer_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct TurnCredentialResponse {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
pub ttl: u64,
|
||||
pub uris: Vec<String>, // TURN server URIs like ["turn:relay.duskchat.app:3478", "turn:relay.duskchat.app:3478?transport=tcp"]
|
||||
}
|
||||
// ---- end TURN credential protocol ----
|
||||
|
||||
// fetch from klipy and normalize into our GifResult format
|
||||
async fn fetch_klipy(
|
||||
http: &reqwest::Client,
|
||||
|
|
@ -342,6 +369,18 @@ fn load_or_generate_keypair() -> libp2p::identity::Keypair {
|
|||
if let Err(e) = std::fs::write(&path, &bytes) {
|
||||
log::warn!("failed to persist keypair: {}", e);
|
||||
} else {
|
||||
// restrict keypair file to owner-only read/write (0600) so other
|
||||
// users on a shared server cannot read the relay's private key
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
if let Err(e) = std::fs::set_permissions(
|
||||
&path,
|
||||
std::fs::Permissions::from_mode(0o600),
|
||||
) {
|
||||
log::warn!("failed to set keypair file permissions: {}", e);
|
||||
}
|
||||
}
|
||||
log::info!("saved new keypair to {}", path.display());
|
||||
}
|
||||
}
|
||||
|
|
@ -460,12 +499,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
)
|
||||
.expect("valid gossipsub behaviour");
|
||||
|
||||
// read connection limit (max total concurrent connections across all peers)
|
||||
let max_connections = std::env::var("DUSK_MAX_CONNECTIONS")
|
||||
.ok()
|
||||
.and_then(|c| c.parse().ok())
|
||||
.unwrap_or(10_000);
|
||||
|
||||
RelayBehaviour {
|
||||
relay: relay::Behaviour::new(peer_id, relay::Config::default()),
|
||||
rendezvous: rendezvous::server::Behaviour::new(
|
||||
|
|
@ -497,6 +530,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
request_response::Config::default()
|
||||
.with_request_timeout(Duration::from_secs(15)),
|
||||
),
|
||||
// TURN credential service - clients request time-limited TURN credentials
|
||||
turn_credentials: cbor::Behaviour::new(
|
||||
[(TURN_CREDENTIALS_PROTOCOL, ProtocolSupport::Full)],
|
||||
request_response::Config::default()
|
||||
.with_request_timeout(Duration::from_secs(15)),
|
||||
),
|
||||
}
|
||||
})?
|
||||
.with_swarm_config(|cfg| cfg.with_idle_connection_timeout(Duration::from_secs(300)))
|
||||
|
|
@ -512,6 +551,40 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
let canonical_addr = format!("/ip4/0.0.0.0/tcp/{}/p2p/{}", port, local_peer_id);
|
||||
println!("\n relay address: {}\n", canonical_addr);
|
||||
|
||||
// ---- TURN server startup ----
|
||||
// get the shared secret for credential generation (shared between TURN server and credential protocol)
|
||||
let turn_shared_secret: Vec<u8> = std::env::var("DUSK_TURN_SECRET")
|
||||
.unwrap_or_else(|_| {
|
||||
eprintln!("[TURN] WARNING: DUSK_TURN_SECRET not set, generating random secret");
|
||||
eprintln!("[TURN] This means credentials won't survive restarts and won't work across multiple relay instances");
|
||||
let timestamp = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_nanos();
|
||||
format!("dusk-turn-random-{}", timestamp)
|
||||
})
|
||||
.into_bytes();
|
||||
|
||||
// start TURN server if enabled (runs on separate UDP/TCP ports from the libp2p swarm)
|
||||
let _turn_handle = if turn::TurnServerConfig::is_enabled() {
|
||||
let turn_config = turn::TurnServerConfig::from_env();
|
||||
let turn_server = turn::TurnServer::new(turn_config);
|
||||
match turn_server.run().await {
|
||||
Ok(handle) => {
|
||||
println!("[RELAY] TURN server started");
|
||||
Some(handle)
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("[RELAY] Failed to start TURN server: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!("[RELAY] TURN server disabled");
|
||||
None
|
||||
};
|
||||
// ---- end TURN server startup ----
|
||||
|
||||
// subscribe to the relay federation gossip topic
|
||||
let federation_topic = gossipsub::IdentTopic::new("dusk/relay/federation");
|
||||
swarm
|
||||
|
|
@ -896,23 +969,36 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
)) => {
|
||||
let response = match request {
|
||||
DirectoryRequest::Register { display_name } => {
|
||||
let ts = now_secs();
|
||||
let result = dir_db.execute(
|
||||
"INSERT INTO peer_profiles (peer_id, display_name, last_seen, registered_at)
|
||||
VALUES (?1, ?2, ?3, ?3)
|
||||
ON CONFLICT(peer_id) DO UPDATE SET
|
||||
display_name = excluded.display_name,
|
||||
last_seen = excluded.last_seen",
|
||||
params![peer.to_string(), display_name, ts as i64],
|
||||
);
|
||||
match result {
|
||||
Ok(_) => {
|
||||
log::info!("directory: registered peer {} as '{}'", peer, display_name);
|
||||
DirectoryResponse::Ok
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("directory: failed to register {}: {}", peer, e);
|
||||
DirectoryResponse::Ok
|
||||
// validate display_name: reject empty or excessively long names
|
||||
let trimmed_name = display_name.trim();
|
||||
if trimmed_name.is_empty() || trimmed_name.len() > 64 {
|
||||
log::warn!(
|
||||
"directory: rejected registration from {} (name length {})",
|
||||
peer,
|
||||
trimmed_name.len()
|
||||
);
|
||||
DirectoryResponse::Error(
|
||||
"display_name must be 1-64 characters".to_string(),
|
||||
)
|
||||
} else {
|
||||
let ts = now_secs();
|
||||
let result = dir_db.execute(
|
||||
"INSERT INTO peer_profiles (peer_id, display_name, last_seen, registered_at)
|
||||
VALUES (?1, ?2, ?3, ?3)
|
||||
ON CONFLICT(peer_id) DO UPDATE SET
|
||||
display_name = excluded.display_name,
|
||||
last_seen = excluded.last_seen",
|
||||
params![peer.to_string(), trimmed_name, ts as i64],
|
||||
);
|
||||
match result {
|
||||
Ok(_) => {
|
||||
log::info!("directory: registered peer {} as '{}'", peer, trimmed_name);
|
||||
DirectoryResponse::Ok
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("directory: failed to register {}: {}", peer, e);
|
||||
DirectoryResponse::Error(format!("registration failed: {}", e))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -974,6 +1060,85 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
// ignore outbound and other directory service events
|
||||
SwarmEvent::Behaviour(RelayBehaviourEvent::DirectoryService(_)) => {}
|
||||
|
||||
// ---- TURN credential service ----
|
||||
// clients request time-limited credentials for the TURN server
|
||||
SwarmEvent::Behaviour(RelayBehaviourEvent::TurnCredentials(
|
||||
request_response::Event::Message {
|
||||
peer,
|
||||
message:
|
||||
request_response::Message::Request {
|
||||
request, channel, ..
|
||||
},
|
||||
..
|
||||
},
|
||||
)) => {
|
||||
let peer_id_str = request.peer_id.clone();
|
||||
|
||||
// generate time-limited credentials using the shared secret
|
||||
let (username, password) = turn::credentials::generate_credentials(
|
||||
&peer_id_str,
|
||||
&turn_shared_secret,
|
||||
86400, // 24 hours
|
||||
);
|
||||
|
||||
// build TURN URIs using the relay's public hostname/IP
|
||||
let turn_host = std::env::var("DUSK_TURN_PUBLIC_HOST")
|
||||
.unwrap_or_else(|_| {
|
||||
std::env::var("DUSK_TURN_PUBLIC_IP")
|
||||
.unwrap_or_else(|_| "relay.duskchat.app".to_string())
|
||||
});
|
||||
let turn_port: u16 = std::env::var("DUSK_TURN_UDP_PORT")
|
||||
.ok()
|
||||
.and_then(|p| p.parse().ok())
|
||||
.unwrap_or(3478);
|
||||
|
||||
let response = TurnCredentialResponse {
|
||||
username,
|
||||
password,
|
||||
ttl: 86400,
|
||||
uris: vec![
|
||||
format!("turn:{}:{}", turn_host, turn_port),
|
||||
format!("turn:{}:{}?transport=tcp", turn_host, turn_port),
|
||||
],
|
||||
};
|
||||
|
||||
if swarm
|
||||
.behaviour_mut()
|
||||
.turn_credentials
|
||||
.send_response(channel, response)
|
||||
.is_err()
|
||||
{
|
||||
log::warn!("[TURN-CREDS] failed to send credentials to {:?}", peer);
|
||||
}
|
||||
log::info!(
|
||||
"[TURN-CREDS] generated credentials for peer {} (requested by {:?})",
|
||||
peer_id_str,
|
||||
peer
|
||||
);
|
||||
}
|
||||
// we're the server so we don't send requests, but handle all variants
|
||||
SwarmEvent::Behaviour(RelayBehaviourEvent::TurnCredentials(
|
||||
request_response::Event::Message {
|
||||
message: request_response::Message::Response { .. },
|
||||
..
|
||||
},
|
||||
)) => {}
|
||||
SwarmEvent::Behaviour(RelayBehaviourEvent::TurnCredentials(
|
||||
request_response::Event::OutboundFailure { .. },
|
||||
)) => {}
|
||||
SwarmEvent::Behaviour(RelayBehaviourEvent::TurnCredentials(
|
||||
request_response::Event::InboundFailure { peer, error, .. },
|
||||
)) => {
|
||||
log::warn!(
|
||||
"[TURN-CREDS] inbound failure from {:?}: {:?}",
|
||||
peer,
|
||||
error
|
||||
);
|
||||
}
|
||||
SwarmEvent::Behaviour(RelayBehaviourEvent::TurnCredentials(
|
||||
request_response::Event::ResponseSent { .. },
|
||||
)) => {}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,799 @@
|
|||
// STUN/TURN attribute types and TLV encoding/decoding
|
||||
//
|
||||
// Implements attributes per RFC 5389 (STUN) and RFC 5766 (TURN).
|
||||
// Each attribute is a Type-Length-Value (TLV) structure:
|
||||
// Type: 2 bytes (attribute type code)
|
||||
// Length: 2 bytes (value length, excluding padding)
|
||||
// Value: variable (padded to 4-byte boundary with zeros)
|
||||
//
|
||||
// XOR-encoded address attributes (XOR-MAPPED-ADDRESS, XOR-PEER-ADDRESS,
|
||||
// XOR-RELAYED-ADDRESS) use the STUN magic cookie and transaction ID
|
||||
// to XOR the address bytes, preventing NAT ALGs from rewriting them.
|
||||
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
|
||||
|
||||
use crate::turn::error::TurnError;
|
||||
use crate::turn::stun::MAGIC_COOKIE;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Attribute type codes
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// RFC 5389 STUN attribute type codes
|
||||
pub const ATTR_MAPPED_ADDRESS: u16 = 0x0001;
|
||||
pub const ATTR_USERNAME: u16 = 0x0006;
|
||||
pub const ATTR_MESSAGE_INTEGRITY: u16 = 0x0008;
|
||||
pub const ATTR_ERROR_CODE: u16 = 0x0009;
|
||||
pub const ATTR_UNKNOWN_ATTRIBUTES: u16 = 0x000A;
|
||||
pub const ATTR_REALM: u16 = 0x0014;
|
||||
pub const ATTR_NONCE: u16 = 0x0015;
|
||||
pub const ATTR_XOR_MAPPED_ADDRESS: u16 = 0x0020;
|
||||
pub const ATTR_SOFTWARE: u16 = 0x8022;
|
||||
pub const ATTR_FINGERPRINT: u16 = 0x8028;
|
||||
|
||||
/// RFC 5766 TURN attribute type codes
|
||||
pub const ATTR_CHANNEL_NUMBER: u16 = 0x000C;
|
||||
pub const ATTR_LIFETIME: u16 = 0x000D;
|
||||
pub const ATTR_XOR_PEER_ADDRESS: u16 = 0x0012;
|
||||
pub const ATTR_DATA: u16 = 0x0013;
|
||||
pub const ATTR_XOR_RELAYED_ADDRESS: u16 = 0x0016;
|
||||
pub const ATTR_REQUESTED_ADDRESS_FAMILY: u16 = 0x0017;
|
||||
pub const ATTR_EVEN_PORT: u16 = 0x0018;
|
||||
pub const ATTR_REQUESTED_TRANSPORT: u16 = 0x0019;
|
||||
pub const ATTR_DONT_FRAGMENT: u16 = 0x001A;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Address family constants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const ADDR_FAMILY_IPV4: u8 = 0x01;
|
||||
const ADDR_FAMILY_IPV6: u8 = 0x02;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// StunAttribute enum
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A parsed STUN or TURN attribute.
|
||||
///
|
||||
/// Each variant corresponds to a specific attribute type. Unknown attributes
|
||||
/// are preserved as raw bytes for forwarding or diagnostic purposes.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StunAttribute {
|
||||
// ---- RFC 5389 STUN attributes ----
|
||||
|
||||
/// MAPPED-ADDRESS (0x0001): The reflexive transport address of the client
|
||||
/// as seen by the server. Uses plain (non-XOR) encoding.
|
||||
MappedAddress(SocketAddr),
|
||||
|
||||
/// USERNAME (0x0006): The username for message integrity, encoded as UTF-8.
|
||||
Username(String),
|
||||
|
||||
/// MESSAGE-INTEGRITY (0x0008): 20-byte HMAC-SHA1 over the STUN message
|
||||
/// (up to but not including this attribute).
|
||||
MessageIntegrity([u8; 20]),
|
||||
|
||||
/// ERROR-CODE (0x0009): Error response code and human-readable reason phrase.
|
||||
/// Code is in range 300-699 per RFC 5389 §15.6.
|
||||
ErrorCode { code: u16, reason: String },
|
||||
|
||||
/// UNKNOWN-ATTRIBUTES (0x000A): List of attribute types that the server
|
||||
/// did not understand. Used in 420 Unknown Attribute error responses.
|
||||
UnknownAttributes(Vec<u16>),
|
||||
|
||||
/// REALM (0x0014): The authentication realm, encoded as UTF-8.
|
||||
/// Used with long-term credential mechanism.
|
||||
Realm(String),
|
||||
|
||||
/// NONCE (0x0015): A server-generated nonce for replay protection.
|
||||
Nonce(String),
|
||||
|
||||
/// XOR-MAPPED-ADDRESS (0x0020): Same as MAPPED-ADDRESS but XOR-encoded
|
||||
/// with the magic cookie (and transaction ID for IPv6) to prevent
|
||||
/// NAT ALG interference.
|
||||
XorMappedAddress(SocketAddr),
|
||||
|
||||
/// SOFTWARE (0x8022): Textual description of the software being used.
|
||||
/// Informational only.
|
||||
Software(String),
|
||||
|
||||
/// FINGERPRINT (0x8028): CRC32 of the STUN message XORed with 0x5354554e.
|
||||
/// Used to demultiplex STUN from other protocols on the same port.
|
||||
Fingerprint(u32),
|
||||
|
||||
// ---- RFC 5766 TURN attributes ----
|
||||
|
||||
/// CHANNEL-NUMBER (0x000C): Channel number for ChannelBind.
|
||||
/// Must be in range 0x4000-0x7FFF.
|
||||
ChannelNumber(u16),
|
||||
|
||||
/// LIFETIME (0x000D): Requested or granted allocation lifetime in seconds.
|
||||
Lifetime(u32),
|
||||
|
||||
/// XOR-PEER-ADDRESS (0x0012): The peer address for Send/Data indications,
|
||||
/// CreatePermission, and ChannelBind. XOR-encoded like XOR-MAPPED-ADDRESS.
|
||||
XorPeerAddress(SocketAddr),
|
||||
|
||||
/// DATA (0x0013): The application data payload in Send/Data indications.
|
||||
Data(Vec<u8>),
|
||||
|
||||
/// XOR-RELAYED-ADDRESS (0x0016): The relayed transport address allocated
|
||||
/// by the server. XOR-encoded.
|
||||
XorRelayedAddress(SocketAddr),
|
||||
|
||||
/// EVEN-PORT (0x0018): Requests an even port number for the relay address.
|
||||
/// The boolean indicates the R bit (reserve next-higher port).
|
||||
EvenPort(bool),
|
||||
|
||||
/// REQUESTED-TRANSPORT (0x0019): The transport protocol for the relay.
|
||||
/// Value is an IANA protocol number (17 = UDP, 6 = TCP).
|
||||
RequestedTransport(u8),
|
||||
|
||||
/// DONT-FRAGMENT (0x001A): Requests that the server set the DF bit
|
||||
/// in outgoing UDP packets. No value (zero-length attribute).
|
||||
DontFragment,
|
||||
|
||||
/// REQUESTED-ADDRESS-FAMILY (0x0017): Requests a specific address family
|
||||
/// for the relayed address (0x01 = IPv4, 0x02 = IPv6).
|
||||
RequestedAddressFamily(u8),
|
||||
|
||||
/// Unknown/unsupported attribute preserved as raw bytes.
|
||||
Unknown { attr_type: u16, value: Vec<u8> },
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Decoding
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Decode a single attribute from its type code and raw value bytes.
|
||||
///
|
||||
/// The `transaction_id` is needed for XOR address decoding.
|
||||
pub fn decode_attribute(
|
||||
attr_type: u16,
|
||||
value: &[u8],
|
||||
transaction_id: &[u8; 12],
|
||||
) -> Result<StunAttribute, TurnError> {
|
||||
match attr_type {
|
||||
ATTR_MAPPED_ADDRESS => decode_mapped_address(value),
|
||||
ATTR_USERNAME => decode_utf8_string(value).map(StunAttribute::Username),
|
||||
ATTR_MESSAGE_INTEGRITY => decode_message_integrity(value),
|
||||
ATTR_ERROR_CODE => decode_error_code(value),
|
||||
ATTR_UNKNOWN_ATTRIBUTES => decode_unknown_attributes(value),
|
||||
ATTR_REALM => decode_utf8_string(value).map(StunAttribute::Realm),
|
||||
ATTR_NONCE => decode_utf8_string(value).map(StunAttribute::Nonce),
|
||||
ATTR_XOR_MAPPED_ADDRESS => {
|
||||
decode_xor_address(value, transaction_id).map(StunAttribute::XorMappedAddress)
|
||||
}
|
||||
ATTR_SOFTWARE => decode_utf8_string(value).map(StunAttribute::Software),
|
||||
ATTR_FINGERPRINT => decode_fingerprint(value),
|
||||
ATTR_CHANNEL_NUMBER => decode_channel_number(value),
|
||||
ATTR_LIFETIME => decode_lifetime(value),
|
||||
ATTR_XOR_PEER_ADDRESS => {
|
||||
decode_xor_address(value, transaction_id).map(StunAttribute::XorPeerAddress)
|
||||
}
|
||||
ATTR_DATA => Ok(StunAttribute::Data(value.to_vec())),
|
||||
ATTR_XOR_RELAYED_ADDRESS => {
|
||||
decode_xor_address(value, transaction_id).map(StunAttribute::XorRelayedAddress)
|
||||
}
|
||||
ATTR_REQUESTED_ADDRESS_FAMILY => decode_requested_address_family(value),
|
||||
ATTR_EVEN_PORT => decode_even_port(value),
|
||||
ATTR_REQUESTED_TRANSPORT => decode_requested_transport(value),
|
||||
ATTR_DONT_FRAGMENT => Ok(StunAttribute::DontFragment),
|
||||
_ => Ok(StunAttribute::Unknown {
|
||||
attr_type,
|
||||
value: value.to_vec(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Encoding
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Encode a single attribute into TLV wire format with 4-byte padding.
|
||||
///
|
||||
/// Returns the complete TLV bytes: type (2) + length (2) + value + padding.
|
||||
pub fn encode_attribute(attr: &StunAttribute, transaction_id: &[u8; 12]) -> Vec<u8> {
|
||||
let (attr_type, value) = encode_attribute_value(attr, transaction_id);
|
||||
let value_len = value.len();
|
||||
let padded_len = (value_len + 3) & !3;
|
||||
|
||||
let mut buf = Vec::with_capacity(4 + padded_len);
|
||||
buf.extend_from_slice(&attr_type.to_be_bytes());
|
||||
buf.extend_from_slice(&(value_len as u16).to_be_bytes());
|
||||
buf.extend_from_slice(&value);
|
||||
|
||||
// Pad to 4-byte boundary
|
||||
let padding = padded_len - value_len;
|
||||
for _ in 0..padding {
|
||||
buf.push(0);
|
||||
}
|
||||
|
||||
buf
|
||||
}
|
||||
|
||||
/// Encode an attribute's value bytes (without the TLV header or padding).
|
||||
/// Returns (attribute_type_code, value_bytes).
|
||||
fn encode_attribute_value(attr: &StunAttribute, transaction_id: &[u8; 12]) -> (u16, Vec<u8>) {
|
||||
match attr {
|
||||
StunAttribute::MappedAddress(addr) => {
|
||||
(ATTR_MAPPED_ADDRESS, encode_plain_address(addr))
|
||||
}
|
||||
StunAttribute::Username(s) => (ATTR_USERNAME, s.as_bytes().to_vec()),
|
||||
StunAttribute::MessageIntegrity(hmac) => (ATTR_MESSAGE_INTEGRITY, hmac.to_vec()),
|
||||
StunAttribute::ErrorCode { code, reason } => {
|
||||
(ATTR_ERROR_CODE, encode_error_code(*code, reason))
|
||||
}
|
||||
StunAttribute::UnknownAttributes(types) => {
|
||||
let mut buf = Vec::with_capacity(types.len() * 2);
|
||||
for &t in types {
|
||||
buf.extend_from_slice(&t.to_be_bytes());
|
||||
}
|
||||
(ATTR_UNKNOWN_ATTRIBUTES, buf)
|
||||
}
|
||||
StunAttribute::Realm(s) => (ATTR_REALM, s.as_bytes().to_vec()),
|
||||
StunAttribute::Nonce(s) => (ATTR_NONCE, s.as_bytes().to_vec()),
|
||||
StunAttribute::XorMappedAddress(addr) => {
|
||||
(ATTR_XOR_MAPPED_ADDRESS, encode_xor_address(addr, transaction_id))
|
||||
}
|
||||
StunAttribute::Software(s) => (ATTR_SOFTWARE, s.as_bytes().to_vec()),
|
||||
StunAttribute::Fingerprint(val) => (ATTR_FINGERPRINT, val.to_be_bytes().to_vec()),
|
||||
StunAttribute::ChannelNumber(num) => {
|
||||
// Channel number is 16 bits followed by 16 bits of RFFU (reserved)
|
||||
let mut buf = vec![0u8; 4];
|
||||
buf[0..2].copy_from_slice(&num.to_be_bytes());
|
||||
(ATTR_CHANNEL_NUMBER, buf)
|
||||
}
|
||||
StunAttribute::Lifetime(secs) => (ATTR_LIFETIME, secs.to_be_bytes().to_vec()),
|
||||
StunAttribute::XorPeerAddress(addr) => {
|
||||
(ATTR_XOR_PEER_ADDRESS, encode_xor_address(addr, transaction_id))
|
||||
}
|
||||
StunAttribute::Data(data) => (ATTR_DATA, data.clone()),
|
||||
StunAttribute::XorRelayedAddress(addr) => {
|
||||
(ATTR_XOR_RELAYED_ADDRESS, encode_xor_address(addr, transaction_id))
|
||||
}
|
||||
StunAttribute::RequestedAddressFamily(family) => {
|
||||
let mut buf = vec![0u8; 4];
|
||||
buf[0] = *family;
|
||||
(ATTR_REQUESTED_ADDRESS_FAMILY, buf)
|
||||
}
|
||||
StunAttribute::EvenPort(reserve) => {
|
||||
let byte = if *reserve { 0x80 } else { 0x00 };
|
||||
(ATTR_EVEN_PORT, vec![byte])
|
||||
}
|
||||
StunAttribute::RequestedTransport(proto) => {
|
||||
// Protocol number in first byte, followed by 3 bytes RFFU
|
||||
let mut buf = vec![0u8; 4];
|
||||
buf[0] = *proto;
|
||||
(ATTR_REQUESTED_TRANSPORT, buf)
|
||||
}
|
||||
StunAttribute::DontFragment => (ATTR_DONT_FRAGMENT, vec![]),
|
||||
StunAttribute::Unknown { attr_type, value } => (*attr_type, value.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Address encoding/decoding helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Decode a plain (non-XOR) MAPPED-ADDRESS attribute value.
|
||||
///
|
||||
/// Format: 1 byte reserved, 1 byte family, 2 bytes port, 4/16 bytes address
|
||||
fn decode_mapped_address(value: &[u8]) -> Result<StunAttribute, TurnError> {
|
||||
if value.len() < 4 {
|
||||
return Err(TurnError::StunParseError(
|
||||
"MAPPED-ADDRESS too short".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let family = value[1];
|
||||
let port = u16::from_be_bytes([value[2], value[3]]);
|
||||
|
||||
let addr = match family {
|
||||
ADDR_FAMILY_IPV4 => {
|
||||
if value.len() < 8 {
|
||||
return Err(TurnError::StunParseError(
|
||||
"MAPPED-ADDRESS IPv4 too short".into(),
|
||||
));
|
||||
}
|
||||
let ip = Ipv4Addr::new(value[4], value[5], value[6], value[7]);
|
||||
SocketAddr::new(IpAddr::V4(ip), port)
|
||||
}
|
||||
ADDR_FAMILY_IPV6 => {
|
||||
if value.len() < 20 {
|
||||
return Err(TurnError::StunParseError(
|
||||
"MAPPED-ADDRESS IPv6 too short".into(),
|
||||
));
|
||||
}
|
||||
let mut octets = [0u8; 16];
|
||||
octets.copy_from_slice(&value[4..20]);
|
||||
let ip = Ipv6Addr::from(octets);
|
||||
SocketAddr::new(IpAddr::V6(ip), port)
|
||||
}
|
||||
_ => {
|
||||
return Err(TurnError::StunParseError(format!(
|
||||
"unknown address family: 0x{:02x}",
|
||||
family
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(StunAttribute::MappedAddress(addr))
|
||||
}
|
||||
|
||||
/// Encode a plain (non-XOR) address into wire format.
|
||||
fn encode_plain_address(addr: &SocketAddr) -> Vec<u8> {
|
||||
match addr {
|
||||
SocketAddr::V4(v4) => {
|
||||
let mut buf = vec![0u8; 8];
|
||||
buf[0] = 0; // reserved
|
||||
buf[1] = ADDR_FAMILY_IPV4;
|
||||
buf[2..4].copy_from_slice(&v4.port().to_be_bytes());
|
||||
buf[4..8].copy_from_slice(&v4.ip().octets());
|
||||
buf
|
||||
}
|
||||
SocketAddr::V6(v6) => {
|
||||
let mut buf = vec![0u8; 20];
|
||||
buf[0] = 0; // reserved
|
||||
buf[1] = ADDR_FAMILY_IPV6;
|
||||
buf[2..4].copy_from_slice(&v6.port().to_be_bytes());
|
||||
buf[4..20].copy_from_slice(&v6.ip().octets());
|
||||
buf
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Decode an XOR-encoded address (XOR-MAPPED-ADDRESS, XOR-PEER-ADDRESS,
|
||||
/// XOR-RELAYED-ADDRESS) per RFC 5389 §15.2.
|
||||
///
|
||||
/// For IPv4: port is XORed with top 16 bits of magic cookie;
|
||||
/// address is XORed with magic cookie.
|
||||
/// For IPv6: port is XORed with top 16 bits of magic cookie;
|
||||
/// address is XORed with magic cookie || transaction ID (16 bytes).
|
||||
fn decode_xor_address(
|
||||
value: &[u8],
|
||||
transaction_id: &[u8; 12],
|
||||
) -> Result<SocketAddr, TurnError> {
|
||||
if value.len() < 4 {
|
||||
return Err(TurnError::StunParseError(
|
||||
"XOR address attribute too short".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let family = value[1];
|
||||
let x_port = u16::from_be_bytes([value[2], value[3]]);
|
||||
let cookie_bytes = MAGIC_COOKIE.to_be_bytes();
|
||||
let port = x_port ^ u16::from_be_bytes([cookie_bytes[0], cookie_bytes[1]]);
|
||||
|
||||
match family {
|
||||
ADDR_FAMILY_IPV4 => {
|
||||
if value.len() < 8 {
|
||||
return Err(TurnError::StunParseError(
|
||||
"XOR-MAPPED-ADDRESS IPv4 too short".into(),
|
||||
));
|
||||
}
|
||||
let x_addr = u32::from_be_bytes([value[4], value[5], value[6], value[7]]);
|
||||
let addr = x_addr ^ MAGIC_COOKIE;
|
||||
let ip = Ipv4Addr::from(addr);
|
||||
Ok(SocketAddr::new(IpAddr::V4(ip), port))
|
||||
}
|
||||
ADDR_FAMILY_IPV6 => {
|
||||
if value.len() < 20 {
|
||||
return Err(TurnError::StunParseError(
|
||||
"XOR-MAPPED-ADDRESS IPv6 too short".into(),
|
||||
));
|
||||
}
|
||||
// Build the 16-byte XOR mask: magic cookie (4 bytes) + transaction ID (12 bytes)
|
||||
let mut xor_mask = [0u8; 16];
|
||||
xor_mask[0..4].copy_from_slice(&cookie_bytes);
|
||||
xor_mask[4..16].copy_from_slice(transaction_id);
|
||||
|
||||
let mut addr_bytes = [0u8; 16];
|
||||
for i in 0..16 {
|
||||
addr_bytes[i] = value[4 + i] ^ xor_mask[i];
|
||||
}
|
||||
let ip = Ipv6Addr::from(addr_bytes);
|
||||
Ok(SocketAddr::new(IpAddr::V6(ip), port))
|
||||
}
|
||||
_ => Err(TurnError::StunParseError(format!(
|
||||
"unknown address family in XOR address: 0x{:02x}",
|
||||
family
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode an address using XOR encoding per RFC 5389 §15.2.
|
||||
fn encode_xor_address(addr: &SocketAddr, transaction_id: &[u8; 12]) -> Vec<u8> {
|
||||
let cookie_bytes = MAGIC_COOKIE.to_be_bytes();
|
||||
let x_port = addr.port() ^ u16::from_be_bytes([cookie_bytes[0], cookie_bytes[1]]);
|
||||
|
||||
match addr {
|
||||
SocketAddr::V4(v4) => {
|
||||
let mut buf = vec![0u8; 8];
|
||||
buf[0] = 0; // reserved
|
||||
buf[1] = ADDR_FAMILY_IPV4;
|
||||
buf[2..4].copy_from_slice(&x_port.to_be_bytes());
|
||||
let x_addr = u32::from_be_bytes(v4.ip().octets()) ^ MAGIC_COOKIE;
|
||||
buf[4..8].copy_from_slice(&x_addr.to_be_bytes());
|
||||
buf
|
||||
}
|
||||
SocketAddr::V6(v6) => {
|
||||
let mut buf = vec![0u8; 20];
|
||||
buf[0] = 0; // reserved
|
||||
buf[1] = ADDR_FAMILY_IPV6;
|
||||
buf[2..4].copy_from_slice(&x_port.to_be_bytes());
|
||||
|
||||
// XOR mask: magic cookie (4 bytes) + transaction ID (12 bytes)
|
||||
let mut xor_mask = [0u8; 16];
|
||||
xor_mask[0..4].copy_from_slice(&cookie_bytes);
|
||||
xor_mask[4..16].copy_from_slice(transaction_id);
|
||||
|
||||
let octets = v6.ip().octets();
|
||||
for i in 0..16 {
|
||||
buf[4 + i] = octets[i] ^ xor_mask[i];
|
||||
}
|
||||
buf
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Specific attribute decoders
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Decode UTF-8 string from raw bytes.
|
||||
fn decode_utf8_string(value: &[u8]) -> Result<String, TurnError> {
|
||||
String::from_utf8(value.to_vec()).map_err(|e| {
|
||||
TurnError::StunParseError(format!("invalid UTF-8 in attribute: {}", e))
|
||||
})
|
||||
}
|
||||
|
||||
/// Decode MESSAGE-INTEGRITY (exactly 20 bytes HMAC-SHA1).
|
||||
fn decode_message_integrity(value: &[u8]) -> Result<StunAttribute, TurnError> {
|
||||
if value.len() != 20 {
|
||||
return Err(TurnError::StunParseError(format!(
|
||||
"MESSAGE-INTEGRITY must be 20 bytes, got {}",
|
||||
value.len()
|
||||
)));
|
||||
}
|
||||
let mut hmac = [0u8; 20];
|
||||
hmac.copy_from_slice(value);
|
||||
Ok(StunAttribute::MessageIntegrity(hmac))
|
||||
}
|
||||
|
||||
/// Decode ERROR-CODE attribute per RFC 5389 §15.6.
|
||||
///
|
||||
/// Format: 2 bytes reserved, 1 byte with class (hundreds digit) in bits 0-2,
|
||||
/// 1 byte with number (tens+units), followed by UTF-8 reason phrase.
|
||||
fn decode_error_code(value: &[u8]) -> Result<StunAttribute, TurnError> {
|
||||
if value.len() < 4 {
|
||||
return Err(TurnError::StunParseError(
|
||||
"ERROR-CODE too short".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let class = (value[2] & 0x07) as u16;
|
||||
let number = value[3] as u16;
|
||||
let code = class * 100 + number;
|
||||
|
||||
let reason = if value.len() > 4 {
|
||||
String::from_utf8_lossy(&value[4..]).into_owned()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
Ok(StunAttribute::ErrorCode { code, reason })
|
||||
}
|
||||
|
||||
/// Encode ERROR-CODE value bytes per RFC 5389 §15.6.
|
||||
fn encode_error_code(code: u16, reason: &str) -> Vec<u8> {
|
||||
let class = (code / 100) as u8;
|
||||
let number = (code % 100) as u8;
|
||||
|
||||
let reason_bytes = reason.as_bytes();
|
||||
let mut buf = Vec::with_capacity(4 + reason_bytes.len());
|
||||
buf.push(0); // reserved
|
||||
buf.push(0); // reserved
|
||||
buf.push(class & 0x07);
|
||||
buf.push(number);
|
||||
buf.extend_from_slice(reason_bytes);
|
||||
buf
|
||||
}
|
||||
|
||||
/// Decode UNKNOWN-ATTRIBUTES (list of 16-bit attribute types).
|
||||
fn decode_unknown_attributes(value: &[u8]) -> Result<StunAttribute, TurnError> {
|
||||
if value.len() % 2 != 0 {
|
||||
return Err(TurnError::StunParseError(
|
||||
"UNKNOWN-ATTRIBUTES length must be even".into(),
|
||||
));
|
||||
}
|
||||
let mut types = Vec::with_capacity(value.len() / 2);
|
||||
for chunk in value.chunks_exact(2) {
|
||||
types.push(u16::from_be_bytes([chunk[0], chunk[1]]));
|
||||
}
|
||||
Ok(StunAttribute::UnknownAttributes(types))
|
||||
}
|
||||
|
||||
/// Decode FINGERPRINT (4-byte CRC32 XOR value).
|
||||
fn decode_fingerprint(value: &[u8]) -> Result<StunAttribute, TurnError> {
|
||||
if value.len() != 4 {
|
||||
return Err(TurnError::StunParseError(format!(
|
||||
"FINGERPRINT must be 4 bytes, got {}",
|
||||
value.len()
|
||||
)));
|
||||
}
|
||||
let val = u32::from_be_bytes([value[0], value[1], value[2], value[3]]);
|
||||
Ok(StunAttribute::Fingerprint(val))
|
||||
}
|
||||
|
||||
/// Decode CHANNEL-NUMBER (16-bit number + 16-bit RFFU).
|
||||
fn decode_channel_number(value: &[u8]) -> Result<StunAttribute, TurnError> {
|
||||
if value.len() < 4 {
|
||||
return Err(TurnError::StunParseError(
|
||||
"CHANNEL-NUMBER too short".into(),
|
||||
));
|
||||
}
|
||||
let num = u16::from_be_bytes([value[0], value[1]]);
|
||||
Ok(StunAttribute::ChannelNumber(num))
|
||||
}
|
||||
|
||||
/// Decode LIFETIME (32-bit seconds).
|
||||
fn decode_lifetime(value: &[u8]) -> Result<StunAttribute, TurnError> {
|
||||
if value.len() < 4 {
|
||||
return Err(TurnError::StunParseError("LIFETIME too short".into()));
|
||||
}
|
||||
let secs = u32::from_be_bytes([value[0], value[1], value[2], value[3]]);
|
||||
Ok(StunAttribute::Lifetime(secs))
|
||||
}
|
||||
|
||||
/// Decode REQUESTED-ADDRESS-FAMILY (1 byte family + 3 bytes RFFU).
|
||||
fn decode_requested_address_family(value: &[u8]) -> Result<StunAttribute, TurnError> {
|
||||
if value.is_empty() {
|
||||
return Err(TurnError::StunParseError(
|
||||
"REQUESTED-ADDRESS-FAMILY empty".into(),
|
||||
));
|
||||
}
|
||||
Ok(StunAttribute::RequestedAddressFamily(value[0]))
|
||||
}
|
||||
|
||||
/// Decode EVEN-PORT (1 byte, R bit in most significant bit).
|
||||
fn decode_even_port(value: &[u8]) -> Result<StunAttribute, TurnError> {
|
||||
if value.is_empty() {
|
||||
return Err(TurnError::StunParseError("EVEN-PORT empty".into()));
|
||||
}
|
||||
let reserve = (value[0] & 0x80) != 0;
|
||||
Ok(StunAttribute::EvenPort(reserve))
|
||||
}
|
||||
|
||||
/// Decode REQUESTED-TRANSPORT (1 byte protocol + 3 bytes RFFU).
|
||||
fn decode_requested_transport(value: &[u8]) -> Result<StunAttribute, TurnError> {
|
||||
if value.is_empty() {
|
||||
return Err(TurnError::StunParseError(
|
||||
"REQUESTED-TRANSPORT empty".into(),
|
||||
));
|
||||
}
|
||||
Ok(StunAttribute::RequestedTransport(value[0]))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
|
||||
|
||||
#[test]
|
||||
fn test_xor_mapped_address_ipv4_roundtrip() {
|
||||
let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(192, 168, 1, 100), 12345));
|
||||
let txn_id = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
|
||||
|
||||
let encoded = encode_xor_address(&addr, &txn_id);
|
||||
let decoded = decode_xor_address(&encoded, &txn_id).unwrap();
|
||||
assert_eq!(decoded, addr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xor_mapped_address_ipv6_roundtrip() {
|
||||
let ip = Ipv6Addr::new(0x2001, 0x0db8, 0x85a3, 0, 0, 0x8a2e, 0x0370, 0x7334);
|
||||
let addr = SocketAddr::V6(SocketAddrV6::new(ip, 54321, 0, 0));
|
||||
let txn_id = [0xA, 0xB, 0xC, 0xD, 0xE, 0xF, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6];
|
||||
|
||||
let encoded = encode_xor_address(&addr, &txn_id);
|
||||
let decoded = decode_xor_address(&encoded, &txn_id).unwrap();
|
||||
assert_eq!(decoded, addr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plain_address_ipv4_roundtrip() {
|
||||
let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(10, 0, 0, 1), 8080));
|
||||
let encoded = encode_plain_address(&addr);
|
||||
let decoded = decode_mapped_address(&encoded).unwrap();
|
||||
if let StunAttribute::MappedAddress(decoded_addr) = decoded {
|
||||
assert_eq!(decoded_addr, addr);
|
||||
} else {
|
||||
panic!("expected MappedAddress");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_code_roundtrip() {
|
||||
let encoded = encode_error_code(401, "Unauthorized");
|
||||
let decoded = decode_error_code(&encoded).unwrap();
|
||||
if let StunAttribute::ErrorCode { code, reason } = decoded {
|
||||
assert_eq!(code, 401);
|
||||
assert_eq!(reason, "Unauthorized");
|
||||
} else {
|
||||
panic!("expected ErrorCode");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_code_438() {
|
||||
let encoded = encode_error_code(438, "Stale Nonce");
|
||||
let decoded = decode_error_code(&encoded).unwrap();
|
||||
if let StunAttribute::ErrorCode { code, reason } = decoded {
|
||||
assert_eq!(code, 438);
|
||||
assert_eq!(reason, "Stale Nonce");
|
||||
} else {
|
||||
panic!("expected ErrorCode");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attribute_tlv_encoding() {
|
||||
let txn_id = [0u8; 12];
|
||||
let attr = StunAttribute::Username("alice".into());
|
||||
let encoded = encode_attribute(&attr, &txn_id);
|
||||
|
||||
// Type (2) + Length (2) + "alice" (5) + padding (3) = 12
|
||||
assert_eq!(encoded.len(), 12);
|
||||
assert_eq!(encoded[0..2], ATTR_USERNAME.to_be_bytes());
|
||||
assert_eq!(encoded[2..4], 5u16.to_be_bytes());
|
||||
assert_eq!(&encoded[4..9], b"alice");
|
||||
assert_eq!(encoded[9], 0); // padding
|
||||
assert_eq!(encoded[10], 0);
|
||||
assert_eq!(encoded[11], 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_attribute_tlv_encoding_4_byte_aligned() {
|
||||
let txn_id = [0u8; 12];
|
||||
let attr = StunAttribute::Username("test".into());
|
||||
let encoded = encode_attribute(&attr, &txn_id);
|
||||
|
||||
// Type (2) + Length (2) + "test" (4) = 8, already aligned
|
||||
assert_eq!(encoded.len(), 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lifetime_roundtrip() {
|
||||
let attr = StunAttribute::Lifetime(600);
|
||||
let txn_id = [0u8; 12];
|
||||
let encoded = encode_attribute(&attr, &txn_id);
|
||||
|
||||
// Type (2) + Length (2) + value (4) = 8
|
||||
assert_eq!(encoded.len(), 8);
|
||||
|
||||
let decoded = decode_attribute(ATTR_LIFETIME, &encoded[4..8], &txn_id).unwrap();
|
||||
if let StunAttribute::Lifetime(secs) = decoded {
|
||||
assert_eq!(secs, 600);
|
||||
} else {
|
||||
panic!("expected Lifetime");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_channel_number_roundtrip() {
|
||||
let attr = StunAttribute::ChannelNumber(0x4000);
|
||||
let txn_id = [0u8; 12];
|
||||
let encoded = encode_attribute(&attr, &txn_id);
|
||||
|
||||
// Type (2) + Length (2) + value (4) = 8
|
||||
assert_eq!(encoded.len(), 8);
|
||||
|
||||
let decoded = decode_attribute(ATTR_CHANNEL_NUMBER, &encoded[4..8], &txn_id).unwrap();
|
||||
if let StunAttribute::ChannelNumber(num) = decoded {
|
||||
assert_eq!(num, 0x4000);
|
||||
} else {
|
||||
panic!("expected ChannelNumber");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_requested_transport_udp() {
|
||||
let attr = StunAttribute::RequestedTransport(17); // UDP
|
||||
let txn_id = [0u8; 12];
|
||||
let encoded = encode_attribute(&attr, &txn_id);
|
||||
|
||||
let decoded = decode_attribute(ATTR_REQUESTED_TRANSPORT, &encoded[4..8], &txn_id).unwrap();
|
||||
if let StunAttribute::RequestedTransport(proto) = decoded {
|
||||
assert_eq!(proto, 17);
|
||||
} else {
|
||||
panic!("expected RequestedTransport");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dont_fragment() {
|
||||
let attr = StunAttribute::DontFragment;
|
||||
let txn_id = [0u8; 12];
|
||||
let encoded = encode_attribute(&attr, &txn_id);
|
||||
|
||||
// Type (2) + Length (2) + no value = 4
|
||||
assert_eq!(encoded.len(), 4);
|
||||
assert_eq!(encoded[2..4], 0u16.to_be_bytes()); // length = 0
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unknown_attribute_preserved() {
|
||||
let txn_id = [0u8; 12];
|
||||
let decoded =
|
||||
decode_attribute(0xFFFF, &[0x01, 0x02, 0x03], &txn_id).unwrap();
|
||||
if let StunAttribute::Unknown { attr_type, value } = decoded {
|
||||
assert_eq!(attr_type, 0xFFFF);
|
||||
assert_eq!(value, vec![0x01, 0x02, 0x03]);
|
||||
} else {
|
||||
panic!("expected Unknown");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_integrity_decode() {
|
||||
let hmac = [0xAA; 20];
|
||||
let decoded = decode_message_integrity(&hmac).unwrap();
|
||||
if let StunAttribute::MessageIntegrity(h) = decoded {
|
||||
assert_eq!(h, [0xAA; 20]);
|
||||
} else {
|
||||
panic!("expected MessageIntegrity");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_integrity_wrong_length() {
|
||||
let result = decode_message_integrity(&[0; 10]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fingerprint_roundtrip() {
|
||||
let val = 0xDEADBEEF_u32;
|
||||
let encoded_value = val.to_be_bytes().to_vec();
|
||||
let decoded = decode_fingerprint(&encoded_value).unwrap();
|
||||
if let StunAttribute::Fingerprint(v) = decoded {
|
||||
assert_eq!(v, val);
|
||||
} else {
|
||||
panic!("expected Fingerprint");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_even_port_reserve_bit() {
|
||||
// R bit set
|
||||
let decoded = decode_even_port(&[0x80]).unwrap();
|
||||
if let StunAttribute::EvenPort(reserve) = decoded {
|
||||
assert!(reserve);
|
||||
} else {
|
||||
panic!("expected EvenPort");
|
||||
}
|
||||
|
||||
// R bit not set
|
||||
let decoded = decode_even_port(&[0x00]).unwrap();
|
||||
if let StunAttribute::EvenPort(reserve) = decoded {
|
||||
assert!(!reserve);
|
||||
} else {
|
||||
panic!("expected EvenPort");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unknown_attributes_list() {
|
||||
let value = vec![0x00, 0x01, 0x00, 0x20]; // MAPPED-ADDRESS, XOR-MAPPED-ADDRESS
|
||||
let decoded = decode_unknown_attributes(&value).unwrap();
|
||||
if let StunAttribute::UnknownAttributes(types) = decoded {
|
||||
assert_eq!(types, vec![0x0001, 0x0020]);
|
||||
} else {
|
||||
panic!("expected UnknownAttributes");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,758 @@
|
|||
// TURN credential generation and validation
|
||||
//
|
||||
// Implements the HMAC-SHA1 time-limited credential mechanism per
|
||||
// draft-uberti-behave-turn-rest-00, which is the standard used by
|
||||
// all major WebRTC implementations (Chrome, Firefox, Safari).
|
||||
//
|
||||
// Credential format:
|
||||
// Username: "{expiry_unix_timestamp}:{peer_id}"
|
||||
// Password: Base64(HMAC-SHA1(shared_secret, username))
|
||||
//
|
||||
// Also implements MESSAGE-INTEGRITY computation per RFC 5389 §15.4
|
||||
// using long-term credentials (key = MD5(username:realm:password)).
|
||||
//
|
||||
// Nonce generation uses HMAC-SHA1 with an embedded timestamp for
|
||||
// stateless stale-nonce detection.
|
||||
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use crate::turn::error::TurnError;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// HMAC-SHA1 (RFC 2104)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Compute HMAC-SHA1(key, message) returning a 20-byte digest.
|
||||
///
|
||||
/// This is a self-contained implementation of HMAC-SHA1 that doesn't
|
||||
/// depend on external crates. In production, replace with `hmac` + `sha1`
|
||||
/// crates for better performance and constant-time comparison.
|
||||
pub fn hmac_sha1(key: &[u8], message: &[u8]) -> [u8; 20] {
|
||||
const BLOCK_SIZE: usize = 64;
|
||||
const IPAD: u8 = 0x36;
|
||||
const OPAD: u8 = 0x5C;
|
||||
|
||||
// If key > block size, hash it first
|
||||
let key_block = if key.len() > BLOCK_SIZE {
|
||||
let hashed = sha1_digest(key);
|
||||
let mut block = [0u8; BLOCK_SIZE];
|
||||
block[..20].copy_from_slice(&hashed);
|
||||
block
|
||||
} else {
|
||||
let mut block = [0u8; BLOCK_SIZE];
|
||||
block[..key.len()].copy_from_slice(key);
|
||||
block
|
||||
};
|
||||
|
||||
// Inner hash: SHA1(key XOR ipad || message)
|
||||
let mut inner_input = Vec::with_capacity(BLOCK_SIZE + message.len());
|
||||
for i in 0..BLOCK_SIZE {
|
||||
inner_input.push(key_block[i] ^ IPAD);
|
||||
}
|
||||
inner_input.extend_from_slice(message);
|
||||
let inner_hash = sha1_digest(&inner_input);
|
||||
|
||||
// Outer hash: SHA1(key XOR opad || inner_hash)
|
||||
let mut outer_input = Vec::with_capacity(BLOCK_SIZE + 20);
|
||||
for i in 0..BLOCK_SIZE {
|
||||
outer_input.push(key_block[i] ^ OPAD);
|
||||
}
|
||||
outer_input.extend_from_slice(&inner_hash);
|
||||
sha1_digest(&outer_input)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SHA-1 (FIPS 180-4)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Compute SHA-1 digest of the input data, returning a 20-byte hash.
|
||||
///
|
||||
/// This is a self-contained implementation. In production, use the `sha1` crate.
|
||||
fn sha1_digest(data: &[u8]) -> [u8; 20] {
|
||||
let mut h0: u32 = 0x67452301;
|
||||
let mut h1: u32 = 0xEFCDAB89;
|
||||
let mut h2: u32 = 0x98BADCFE;
|
||||
let mut h3: u32 = 0x10325476;
|
||||
let mut h4: u32 = 0xC3D2E1F0;
|
||||
|
||||
// Pre-processing: add padding
|
||||
let bit_len = (data.len() as u64) * 8;
|
||||
let mut padded = data.to_vec();
|
||||
padded.push(0x80);
|
||||
while padded.len() % 64 != 56 {
|
||||
padded.push(0x00);
|
||||
}
|
||||
padded.extend_from_slice(&bit_len.to_be_bytes());
|
||||
|
||||
// Process each 512-bit (64-byte) block
|
||||
for block in padded.chunks_exact(64) {
|
||||
let mut w = [0u32; 80];
|
||||
for i in 0..16 {
|
||||
w[i] = u32::from_be_bytes([
|
||||
block[i * 4],
|
||||
block[i * 4 + 1],
|
||||
block[i * 4 + 2],
|
||||
block[i * 4 + 3],
|
||||
]);
|
||||
}
|
||||
for i in 16..80 {
|
||||
w[i] = (w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16]).rotate_left(1);
|
||||
}
|
||||
|
||||
let mut a = h0;
|
||||
let mut b = h1;
|
||||
let mut c = h2;
|
||||
let mut d = h3;
|
||||
let mut e = h4;
|
||||
|
||||
for i in 0..80 {
|
||||
let (f, k) = match i {
|
||||
0..=19 => ((b & c) | ((!b) & d), 0x5A827999_u32),
|
||||
20..=39 => (b ^ c ^ d, 0x6ED9EBA1_u32),
|
||||
40..=59 => ((b & c) | (b & d) | (c & d), 0x8F1BBCDC_u32),
|
||||
60..=79 => (b ^ c ^ d, 0xCA62C1D6_u32),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let temp = a
|
||||
.rotate_left(5)
|
||||
.wrapping_add(f)
|
||||
.wrapping_add(e)
|
||||
.wrapping_add(k)
|
||||
.wrapping_add(w[i]);
|
||||
e = d;
|
||||
d = c;
|
||||
c = b.rotate_left(30);
|
||||
b = a;
|
||||
a = temp;
|
||||
}
|
||||
|
||||
h0 = h0.wrapping_add(a);
|
||||
h1 = h1.wrapping_add(b);
|
||||
h2 = h2.wrapping_add(c);
|
||||
h3 = h3.wrapping_add(d);
|
||||
h4 = h4.wrapping_add(e);
|
||||
}
|
||||
|
||||
let mut result = [0u8; 20];
|
||||
result[0..4].copy_from_slice(&h0.to_be_bytes());
|
||||
result[4..8].copy_from_slice(&h1.to_be_bytes());
|
||||
result[8..12].copy_from_slice(&h2.to_be_bytes());
|
||||
result[12..16].copy_from_slice(&h3.to_be_bytes());
|
||||
result[16..20].copy_from_slice(&h4.to_be_bytes());
|
||||
result
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MD5 (RFC 1321) - needed for long-term credential key derivation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Compute MD5 digest of the input data, returning a 16-byte hash.
|
||||
///
|
||||
/// Used for the long-term credential key: `MD5(username:realm:password)`.
|
||||
/// Self-contained implementation; in production use the `md-5` crate.
|
||||
fn md5_digest(data: &[u8]) -> [u8; 16] {
|
||||
// Initial state
|
||||
let mut a0: u32 = 0x67452301;
|
||||
let mut b0: u32 = 0xefcdab89;
|
||||
let mut c0: u32 = 0x98badcfe;
|
||||
let mut d0: u32 = 0x10325476;
|
||||
|
||||
// Per-round shift amounts
|
||||
const S: [u32; 64] = [
|
||||
7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22, 7, 12, 17, 22,
|
||||
5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20, 5, 9, 14, 20,
|
||||
4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23, 4, 11, 16, 23,
|
||||
6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21, 6, 10, 15, 21,
|
||||
];
|
||||
|
||||
// Pre-computed constants: floor(2^32 * |sin(i + 1)|)
|
||||
const K: [u32; 64] = [
|
||||
0xd76aa478, 0xe8c7b756, 0x242070db, 0xc1bdceee,
|
||||
0xf57c0faf, 0x4787c62a, 0xa8304613, 0xfd469501,
|
||||
0x698098d8, 0x8b44f7af, 0xffff5bb1, 0x895cd7be,
|
||||
0x6b901122, 0xfd987193, 0xa679438e, 0x49b40821,
|
||||
0xf61e2562, 0xc040b340, 0x265e5a51, 0xe9b6c7aa,
|
||||
0xd62f105d, 0x02441453, 0xd8a1e681, 0xe7d3fbc8,
|
||||
0x21e1cde6, 0xc33707d6, 0xf4d50d87, 0x455a14ed,
|
||||
0xa9e3e905, 0xfcefa3f8, 0x676f02d9, 0x8d2a4c8a,
|
||||
0xfffa3942, 0x8771f681, 0x6d9d6122, 0xfde5380c,
|
||||
0xa4beea44, 0x4bdecfa9, 0xf6bb4b60, 0xbebfbc70,
|
||||
0x289b7ec6, 0xeaa127fa, 0xd4ef3085, 0x04881d05,
|
||||
0xd9d4d039, 0xe6db99e5, 0x1fa27cf8, 0xc4ac5665,
|
||||
0xf4292244, 0x432aff97, 0xab9423a7, 0xfc93a039,
|
||||
0x655b59c3, 0x8f0ccc92, 0xffeff47d, 0x85845dd1,
|
||||
0x6fa87e4f, 0xfe2ce6e0, 0xa3014314, 0x4e0811a1,
|
||||
0xf7537e82, 0xbd3af235, 0x2ad7d2bb, 0xeb86d391,
|
||||
];
|
||||
|
||||
// Pre-processing: add padding
|
||||
let bit_len = (data.len() as u64) * 8;
|
||||
let mut padded = data.to_vec();
|
||||
padded.push(0x80);
|
||||
while padded.len() % 64 != 56 {
|
||||
padded.push(0x00);
|
||||
}
|
||||
padded.extend_from_slice(&bit_len.to_le_bytes());
|
||||
|
||||
// Process each 512-bit block
|
||||
for block in padded.chunks_exact(64) {
|
||||
let mut m = [0u32; 16];
|
||||
for i in 0..16 {
|
||||
m[i] = u32::from_le_bytes([
|
||||
block[i * 4],
|
||||
block[i * 4 + 1],
|
||||
block[i * 4 + 2],
|
||||
block[i * 4 + 3],
|
||||
]);
|
||||
}
|
||||
|
||||
let mut a = a0;
|
||||
let mut b = b0;
|
||||
let mut c = c0;
|
||||
let mut d = d0;
|
||||
|
||||
for i in 0..64 {
|
||||
let (f, g) = match i {
|
||||
0..=15 => ((b & c) | ((!b) & d), i),
|
||||
16..=31 => ((d & b) | ((!d) & c), (5 * i + 1) % 16),
|
||||
32..=47 => (b ^ c ^ d, (3 * i + 5) % 16),
|
||||
48..=63 => (c ^ (b | (!d)), (7 * i) % 16),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let temp = d;
|
||||
d = c;
|
||||
c = b;
|
||||
b = b.wrapping_add(
|
||||
(a.wrapping_add(f).wrapping_add(K[i]).wrapping_add(m[g]))
|
||||
.rotate_left(S[i]),
|
||||
);
|
||||
a = temp;
|
||||
}
|
||||
|
||||
a0 = a0.wrapping_add(a);
|
||||
b0 = b0.wrapping_add(b);
|
||||
c0 = c0.wrapping_add(c);
|
||||
d0 = d0.wrapping_add(d);
|
||||
}
|
||||
|
||||
let mut result = [0u8; 16];
|
||||
result[0..4].copy_from_slice(&a0.to_le_bytes());
|
||||
result[4..8].copy_from_slice(&b0.to_le_bytes());
|
||||
result[8..12].copy_from_slice(&c0.to_le_bytes());
|
||||
result[12..16].copy_from_slice(&d0.to_le_bytes());
|
||||
result
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Base64 encoding (RFC 4648)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Encode bytes to base64 string using standard alphabet with padding.
|
||||
fn base64_encode(data: &[u8]) -> String {
|
||||
const ALPHABET: &[u8; 64] =
|
||||
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||
|
||||
let mut result = String::with_capacity((data.len() + 2) / 3 * 4);
|
||||
let chunks = data.chunks(3);
|
||||
|
||||
for chunk in chunks {
|
||||
let b0 = chunk[0] as u32;
|
||||
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
|
||||
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
|
||||
|
||||
let triple = (b0 << 16) | (b1 << 8) | b2;
|
||||
|
||||
result.push(ALPHABET[((triple >> 18) & 0x3F) as usize] as char);
|
||||
result.push(ALPHABET[((triple >> 12) & 0x3F) as usize] as char);
|
||||
|
||||
if chunk.len() > 1 {
|
||||
result.push(ALPHABET[((triple >> 6) & 0x3F) as usize] as char);
|
||||
} else {
|
||||
result.push('=');
|
||||
}
|
||||
|
||||
if chunk.len() > 2 {
|
||||
result.push(ALPHABET[(triple & 0x3F) as usize] as char);
|
||||
} else {
|
||||
result.push('=');
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Decode a base64 string to bytes. Returns None on invalid input.
|
||||
fn base64_decode(input: &str) -> Option<Vec<u8>> {
|
||||
fn char_val(c: u8) -> Option<u8> {
|
||||
const ALPHABET: &[u8; 64] =
|
||||
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||
ALPHABET.iter().position(|&x| x == c).map(|p| p as u8)
|
||||
}
|
||||
|
||||
let input = input.trim_end_matches('=');
|
||||
let bytes = input.as_bytes();
|
||||
|
||||
let mut result = Vec::with_capacity(bytes.len() * 3 / 4);
|
||||
let chunks = bytes.chunks(4);
|
||||
|
||||
for chunk in chunks {
|
||||
let vals: Vec<u8> = chunk.iter().filter_map(|&b| char_val(b)).collect();
|
||||
if vals.len() != chunk.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let triple = match vals.len() {
|
||||
4 => {
|
||||
((vals[0] as u32) << 18)
|
||||
| ((vals[1] as u32) << 12)
|
||||
| ((vals[2] as u32) << 6)
|
||||
| (vals[3] as u32)
|
||||
}
|
||||
3 => {
|
||||
((vals[0] as u32) << 18) | ((vals[1] as u32) << 12) | ((vals[2] as u32) << 6)
|
||||
}
|
||||
2 => ((vals[0] as u32) << 18) | ((vals[1] as u32) << 12),
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
result.push((triple >> 16) as u8);
|
||||
if vals.len() > 2 {
|
||||
result.push((triple >> 8 & 0xFF) as u8);
|
||||
}
|
||||
if vals.len() > 3 {
|
||||
result.push((triple & 0xFF) as u8);
|
||||
}
|
||||
}
|
||||
|
||||
Some(result)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Hex encoding (for nonces)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Encode bytes as lowercase hexadecimal string.
|
||||
fn hex_encode(data: &[u8]) -> String {
|
||||
let mut s = String::with_capacity(data.len() * 2);
|
||||
for &b in data {
|
||||
s.push_str(&format!("{:02x}", b));
|
||||
}
|
||||
s
|
||||
}
|
||||
|
||||
/// Decode a hexadecimal string to bytes. Returns None on invalid input.
|
||||
fn hex_decode(s: &str) -> Option<Vec<u8>> {
|
||||
if s.len() % 2 != 0 {
|
||||
return None;
|
||||
}
|
||||
let mut result = Vec::with_capacity(s.len() / 2);
|
||||
for i in (0..s.len()).step_by(2) {
|
||||
let byte = u8::from_str_radix(&s[i..i + 2], 16).ok()?;
|
||||
result.push(byte);
|
||||
}
|
||||
Some(result)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Credential generation / validation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Generate time-limited TURN credentials per draft-uberti-behave-turn-rest.
|
||||
///
|
||||
/// Returns `(username, password)` where:
|
||||
/// - username = `"{expiry_timestamp}:{peer_id}"`
|
||||
/// - password = `Base64(HMAC-SHA1(shared_secret, username))`
|
||||
///
|
||||
/// The credentials expire `ttl_secs` seconds from now.
|
||||
pub fn generate_credentials(
|
||||
peer_id: &str,
|
||||
shared_secret: &[u8],
|
||||
ttl_secs: u64,
|
||||
) -> (String, String) {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
let expiry = now + ttl_secs;
|
||||
|
||||
let username = format!("{}:{}", expiry, peer_id);
|
||||
let hmac = hmac_sha1(shared_secret, username.as_bytes());
|
||||
let password = base64_encode(&hmac);
|
||||
|
||||
(username, password)
|
||||
}
|
||||
|
||||
/// Validate time-limited TURN credentials.
|
||||
///
|
||||
/// 1. Parse the expiry timestamp from the username
|
||||
/// 2. Check that the credentials have not expired
|
||||
/// 3. Recompute HMAC-SHA1 and compare with the provided password
|
||||
pub fn validate_credentials(
|
||||
username: &str,
|
||||
password: &str,
|
||||
shared_secret: &[u8],
|
||||
) -> Result<(), TurnError> {
|
||||
// Parse expiry timestamp from username (format: "{timestamp}:{peer_id}")
|
||||
let colon_pos = username
|
||||
.find(':')
|
||||
.ok_or(TurnError::Unauthorized)?;
|
||||
|
||||
let expiry_str = &username[..colon_pos];
|
||||
let expiry: u64 = expiry_str
|
||||
.parse()
|
||||
.map_err(|_| TurnError::Unauthorized)?;
|
||||
|
||||
// Check expiry
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
if now > expiry {
|
||||
return Err(TurnError::Unauthorized);
|
||||
}
|
||||
|
||||
// Recompute HMAC-SHA1 and compare
|
||||
let expected_hmac = hmac_sha1(shared_secret, username.as_bytes());
|
||||
let expected_password = base64_encode(&expected_hmac);
|
||||
|
||||
if !constant_time_eq(password.as_bytes(), expected_password.as_bytes()) {
|
||||
return Err(TurnError::Unauthorized);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MESSAGE-INTEGRITY computation (RFC 5389 §15.4)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Compute the long-term credential key per RFC 5389 §15.4:
|
||||
/// `key = MD5(username ":" realm ":" password)`
|
||||
pub fn compute_long_term_key(username: &str, realm: &str, password: &str) -> [u8; 16] {
|
||||
let input = format!("{}:{}:{}", username, realm, password);
|
||||
md5_digest(input.as_bytes())
|
||||
}
|
||||
|
||||
/// Compute the MESSAGE-INTEGRITY attribute value (HMAC-SHA1).
|
||||
///
|
||||
/// `key` is the long-term credential key (output of [`compute_long_term_key`]).
|
||||
/// `message_bytes` should be the STUN message up to (but not including) the
|
||||
/// MESSAGE-INTEGRITY attribute, with the message length field adjusted to
|
||||
/// include the MESSAGE-INTEGRITY attribute.
|
||||
pub fn compute_message_integrity(key: &[u8], message_bytes: &[u8]) -> [u8; 20] {
|
||||
hmac_sha1(key, message_bytes)
|
||||
}
|
||||
|
||||
/// Validate the MESSAGE-INTEGRITY attribute of a STUN message.
|
||||
///
|
||||
/// Recomputes the HMAC-SHA1 over the message bytes and compares it
|
||||
/// with the provided MESSAGE-INTEGRITY value in constant time.
|
||||
pub fn validate_message_integrity(
|
||||
message_integrity: &[u8; 20],
|
||||
key: &[u8],
|
||||
message_bytes: &[u8],
|
||||
) -> bool {
|
||||
let expected = compute_message_integrity(key, message_bytes);
|
||||
constant_time_eq(&expected, message_integrity)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Nonce generation / validation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Generate a NONCE that embeds a timestamp for stateless staleness detection.
|
||||
///
|
||||
/// Format: `{hex_timestamp}-{hex_hmac_of_timestamp}`
|
||||
///
|
||||
/// The server can validate the nonce without storing state by recomputing
|
||||
/// the HMAC. The timestamp allows detecting stale nonces.
|
||||
pub fn compute_nonce(timestamp: u64, secret: &[u8]) -> String {
|
||||
let ts_bytes = timestamp.to_be_bytes();
|
||||
let hmac = hmac_sha1(secret, &ts_bytes);
|
||||
// Use first 8 bytes of HMAC for a shorter nonce
|
||||
let hmac_short = &hmac[..8];
|
||||
format!("{}-{}", hex_encode(&ts_bytes), hex_encode(hmac_short))
|
||||
}
|
||||
|
||||
/// Validate a NONCE and check that it hasn't expired.
|
||||
///
|
||||
/// 1. Parse the timestamp from the nonce
|
||||
/// 2. Recompute HMAC and verify it matches
|
||||
/// 3. Check that the nonce age doesn't exceed `max_age_secs`
|
||||
pub fn validate_nonce(
|
||||
nonce: &str,
|
||||
secret: &[u8],
|
||||
max_age_secs: u64,
|
||||
) -> Result<(), TurnError> {
|
||||
let parts: Vec<&str> = nonce.splitn(2, '-').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err(TurnError::StaleNonce);
|
||||
}
|
||||
|
||||
let ts_hex = parts[0];
|
||||
let hmac_hex = parts[1];
|
||||
|
||||
// Decode timestamp
|
||||
let ts_bytes = hex_decode(ts_hex).ok_or(TurnError::StaleNonce)?;
|
||||
if ts_bytes.len() != 8 {
|
||||
return Err(TurnError::StaleNonce);
|
||||
}
|
||||
|
||||
let timestamp = u64::from_be_bytes([
|
||||
ts_bytes[0], ts_bytes[1], ts_bytes[2], ts_bytes[3],
|
||||
ts_bytes[4], ts_bytes[5], ts_bytes[6], ts_bytes[7],
|
||||
]);
|
||||
|
||||
// Verify HMAC
|
||||
let expected_hmac = hmac_sha1(secret, &ts_bytes);
|
||||
let expected_hmac_hex = hex_encode(&expected_hmac[..8]);
|
||||
|
||||
if !constant_time_eq(hmac_hex.as_bytes(), expected_hmac_hex.as_bytes()) {
|
||||
return Err(TurnError::StaleNonce);
|
||||
}
|
||||
|
||||
// Check age
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
if now > timestamp + max_age_secs {
|
||||
return Err(TurnError::StaleNonce);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Constant-time comparison
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Compare two byte slices in constant time to prevent timing attacks.
|
||||
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
|
||||
if a.len() != b.len() {
|
||||
return false;
|
||||
}
|
||||
let mut diff = 0u8;
|
||||
for (x, y) in a.iter().zip(b.iter()) {
|
||||
diff |= x ^ y;
|
||||
}
|
||||
diff == 0
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_sha1_empty() {
|
||||
// SHA-1("") = da39a3ee5e6b4b0d3255bfef95601890afd80709
|
||||
let hash = sha1_digest(b"");
|
||||
let hex = hex_encode(&hash);
|
||||
assert_eq!(hex, "da39a3ee5e6b4b0d3255bfef95601890afd80709");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sha1_abc() {
|
||||
// SHA-1("abc") = a9993e364706816aba3e25717850c26c9cd0d89d
|
||||
let hash = sha1_digest(b"abc");
|
||||
let hex = hex_encode(&hash);
|
||||
assert_eq!(hex, "a9993e364706816aba3e25717850c26c9cd0d89d");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_md5_empty() {
|
||||
// MD5("") = d41d8cd98f00b204e9800998ecf8427e
|
||||
let hash = md5_digest(b"");
|
||||
let hex = hex_encode(&hash);
|
||||
assert_eq!(hex, "d41d8cd98f00b204e9800998ecf8427e");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_md5_abc() {
|
||||
// MD5("abc") = 900150983cd24fb0d6963f7d28e17f72
|
||||
let hash = md5_digest(b"abc");
|
||||
let hex = hex_encode(&hash);
|
||||
assert_eq!(hex, "900150983cd24fb0d6963f7d28e17f72");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hmac_sha1_rfc2202_test1() {
|
||||
// RFC 2202 Test Case 1:
|
||||
// Key = 0x0b repeated 20 times
|
||||
// Data = "Hi There"
|
||||
// HMAC = b617318655057264e28bc0b6fb378c8ef146be00
|
||||
let key = [0x0bu8; 20];
|
||||
let data = b"Hi There";
|
||||
let hmac = hmac_sha1(&key, data);
|
||||
let hex = hex_encode(&hmac);
|
||||
assert_eq!(hex, "b617318655057264e28bc0b6fb378c8ef146be00");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hmac_sha1_rfc2202_test2() {
|
||||
// RFC 2202 Test Case 2:
|
||||
// Key = "Jefe"
|
||||
// Data = "what do ya want for nothing?"
|
||||
// HMAC = effcdf6ae5eb2fa2d27416d5f184df9c259a7c79
|
||||
let key = b"Jefe";
|
||||
let data = b"what do ya want for nothing?";
|
||||
let hmac = hmac_sha1(key, data);
|
||||
let hex = hex_encode(&hmac);
|
||||
assert_eq!(hex, "effcdf6ae5eb2fa2d27416d5f184df9c259a7c79");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_base64_encode() {
|
||||
assert_eq!(base64_encode(b""), "");
|
||||
assert_eq!(base64_encode(b"f"), "Zg==");
|
||||
assert_eq!(base64_encode(b"fo"), "Zm8=");
|
||||
assert_eq!(base64_encode(b"foo"), "Zm9v");
|
||||
assert_eq!(base64_encode(b"foobar"), "Zm9vYmFy");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_base64_decode() {
|
||||
assert_eq!(base64_decode("").unwrap(), b"");
|
||||
assert_eq!(base64_decode("Zg==").unwrap(), b"f");
|
||||
assert_eq!(base64_decode("Zm8=").unwrap(), b"fo");
|
||||
assert_eq!(base64_decode("Zm9v").unwrap(), b"foo");
|
||||
assert_eq!(base64_decode("Zm9vYmFy").unwrap(), b"foobar");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_base64_roundtrip() {
|
||||
let data = b"Hello, TURN server!";
|
||||
let encoded = base64_encode(data);
|
||||
let decoded = base64_decode(&encoded).unwrap();
|
||||
assert_eq!(&decoded, data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_validate_credentials() {
|
||||
let secret = b"supersecretkey";
|
||||
let peer_id = "12D3KooWTestPeerId";
|
||||
let ttl = 3600; // 1 hour
|
||||
|
||||
let (username, password) = generate_credentials(peer_id, secret, ttl);
|
||||
|
||||
// Username should contain the peer_id
|
||||
assert!(username.contains(peer_id));
|
||||
assert!(username.contains(':'));
|
||||
|
||||
// Password should be valid base64
|
||||
assert!(base64_decode(&password).is_some());
|
||||
|
||||
// Validation should succeed
|
||||
let result = validate_credentials(&username, &password, secret);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_credentials_wrong_password() {
|
||||
let secret = b"supersecretkey";
|
||||
let (username, _) = generate_credentials("test", secret, 3600);
|
||||
|
||||
let result = validate_credentials(&username, "wrongpassword", secret);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_credentials_wrong_secret() {
|
||||
let secret = b"supersecretkey";
|
||||
let (username, password) = generate_credentials("test", secret, 3600);
|
||||
|
||||
let result = validate_credentials(&username, &password, b"wrongsecret");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_long_term_key() {
|
||||
// RFC 5389 example-ish: key = MD5("user:realm:pass")
|
||||
let key = compute_long_term_key("user", "realm", "pass");
|
||||
assert_eq!(key.len(), 16);
|
||||
|
||||
// Should be deterministic
|
||||
let key2 = compute_long_term_key("user", "realm", "pass");
|
||||
assert_eq!(key, key2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_integrity_roundtrip() {
|
||||
let key = compute_long_term_key("alice", "duskchat.app", "password123");
|
||||
let message = b"fake stun message bytes for testing";
|
||||
|
||||
let integrity = compute_message_integrity(&key, message);
|
||||
assert!(validate_message_integrity(&integrity, &key, message));
|
||||
|
||||
// Tampered message should fail
|
||||
let tampered = b"tampered stun message bytes for testing";
|
||||
assert!(!validate_message_integrity(&integrity, &key, tampered));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nonce_roundtrip() {
|
||||
let secret = b"nonce_secret_key";
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
let nonce = compute_nonce(now, secret);
|
||||
assert!(nonce.contains('-'));
|
||||
|
||||
// Should validate successfully with generous max age
|
||||
let result = validate_nonce(&nonce, secret, 3600);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nonce_stale() {
|
||||
let secret = b"nonce_secret_key";
|
||||
// Use a timestamp from long ago
|
||||
let old_timestamp = 1000000;
|
||||
|
||||
let nonce = compute_nonce(old_timestamp, secret);
|
||||
let result = validate_nonce(&nonce, secret, 3600);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nonce_tampered() {
|
||||
let secret = b"nonce_secret_key";
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
let nonce = compute_nonce(now, secret);
|
||||
|
||||
// Validate with wrong secret
|
||||
let result = validate_nonce(&nonce, b"wrong_secret", 3600);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hex_roundtrip() {
|
||||
let data = &[0xDE, 0xAD, 0xBE, 0xEF];
|
||||
let encoded = hex_encode(data);
|
||||
assert_eq!(encoded, "deadbeef");
|
||||
let decoded = hex_decode(&encoded).unwrap();
|
||||
assert_eq!(&decoded, data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_constant_time_eq() {
|
||||
assert!(constant_time_eq(b"hello", b"hello"));
|
||||
assert!(!constant_time_eq(b"hello", b"world"));
|
||||
assert!(!constant_time_eq(b"hello", b"hell"));
|
||||
assert!(constant_time_eq(b"", b""));
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,102 @@
|
|||
// TURN server error types
|
||||
//
|
||||
// Covers all error conditions that can arise during STUN/TURN message
|
||||
// processing, credential validation, and allocation management.
|
||||
|
||||
use std::fmt;
|
||||
|
||||
/// Comprehensive error type for TURN server operations.
|
||||
///
|
||||
/// Each variant maps to a specific failure condition in the STUN/TURN
|
||||
/// protocol stack, from low-level parse errors to high-level allocation
|
||||
/// policy violations.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TurnError {
|
||||
/// Invalid STUN message format: bad header, truncated message,
|
||||
/// missing magic cookie, or malformed TLV attributes.
|
||||
StunParseError(String),
|
||||
|
||||
/// MESSAGE-INTEGRITY attribute does not match the computed HMAC-SHA1.
|
||||
/// This means either the password is wrong or the message was tampered with.
|
||||
InvalidMessageIntegrity,
|
||||
|
||||
/// The NONCE has expired. The client should retry with a fresh nonce
|
||||
/// from the 438 Stale Nonce error response.
|
||||
StaleNonce,
|
||||
|
||||
/// The client already has an allocation on this 5-tuple, or is
|
||||
/// attempting an operation that conflicts with existing allocation state
|
||||
/// (RFC 5766 §6.2).
|
||||
AllocationMismatch,
|
||||
|
||||
/// The server has reached its per-user or global allocation quota.
|
||||
/// Maps to TURN error code 486.
|
||||
AllocationQuotaReached,
|
||||
|
||||
/// The server cannot fulfill the request due to resource constraints
|
||||
/// (e.g., no relay ports available). Maps to TURN error code 508.
|
||||
InsufficientCapacity,
|
||||
|
||||
/// The request lacks valid credentials, or the credentials have expired.
|
||||
/// Maps to STUN error code 401.
|
||||
Unauthorized,
|
||||
|
||||
/// The peer address in the request is forbidden by server policy
|
||||
/// (e.g., loopback or private IP filtering). Maps to TURN error code 403.
|
||||
ForbiddenIp,
|
||||
|
||||
/// The REQUESTED-TRANSPORT attribute specifies a transport protocol
|
||||
/// that the server does not support (e.g., TCP relay when only UDP
|
||||
/// is available). Maps to TURN error code 442.
|
||||
UnsupportedTransport,
|
||||
|
||||
/// An I/O error occurred on a socket or file operation.
|
||||
IoError(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for TurnError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
TurnError::StunParseError(msg) => write!(f, "STUN parse error: {}", msg),
|
||||
TurnError::InvalidMessageIntegrity => write!(f, "invalid MESSAGE-INTEGRITY"),
|
||||
TurnError::StaleNonce => write!(f, "stale nonce"),
|
||||
TurnError::AllocationMismatch => write!(f, "allocation mismatch"),
|
||||
TurnError::AllocationQuotaReached => write!(f, "allocation quota reached"),
|
||||
TurnError::InsufficientCapacity => write!(f, "insufficient capacity"),
|
||||
TurnError::Unauthorized => write!(f, "unauthorized"),
|
||||
TurnError::ForbiddenIp => write!(f, "forbidden IP address"),
|
||||
TurnError::UnsupportedTransport => write!(f, "unsupported transport protocol"),
|
||||
TurnError::IoError(msg) => write!(f, "I/O error: {}", msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for TurnError {}
|
||||
|
||||
impl From<std::io::Error> for TurnError {
|
||||
fn from(err: std::io::Error) -> Self {
|
||||
TurnError::IoError(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Maps a [`TurnError`] to the corresponding STUN/TURN error code
|
||||
/// for use in error response messages.
|
||||
///
|
||||
/// Error codes follow RFC 5389 §15.6 and RFC 5766 §6.
|
||||
impl TurnError {
|
||||
/// Returns the STUN/TURN error code and default reason phrase for this error.
|
||||
pub fn to_error_code(&self) -> (u16, &'static str) {
|
||||
match self {
|
||||
TurnError::StunParseError(_) => (400, "Bad Request"),
|
||||
TurnError::InvalidMessageIntegrity => (401, "Unauthorized"),
|
||||
TurnError::StaleNonce => (438, "Stale Nonce"),
|
||||
TurnError::AllocationMismatch => (437, "Allocation Mismatch"),
|
||||
TurnError::AllocationQuotaReached => (486, "Allocation Quota Reached"),
|
||||
TurnError::InsufficientCapacity => (508, "Insufficient Capacity"),
|
||||
TurnError::Unauthorized => (401, "Unauthorized"),
|
||||
TurnError::ForbiddenIp => (403, "Forbidden"),
|
||||
TurnError::UnsupportedTransport => (442, "Unsupported Transport Protocol"),
|
||||
TurnError::IoError(_) => (500, "Server Error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,30 @@
|
|||
// TURN server module for Dusk relay
|
||||
//
|
||||
// Implements RFC 5389 (STUN) and RFC 5766 (TURN) protocol types,
|
||||
// message parsing/serialization, credential management, port allocation,
|
||||
// and the full server with UDP/TCP listeners.
|
||||
//
|
||||
// This module is organized as follows:
|
||||
// - stun: STUN message format, parsing, serialization
|
||||
// - attributes: STUN/TURN attribute types and encoding
|
||||
// - credentials: HMAC-SHA1 time-limited credential generation/validation
|
||||
// - allocation: TURN allocation state machine
|
||||
// - port_pool: Relay port allocation pool
|
||||
// - handler: TURN message handler (request dispatch + response building)
|
||||
// - udp_listener: UDP listener task (receives datagrams, spawns relay receivers)
|
||||
// - tcp_listener: TCP listener task (accepts connections, frames STUN/ChannelData)
|
||||
// - server: Top-level TURN server orchestration (config, startup, handle)
|
||||
// - error: Error types
|
||||
|
||||
pub mod stun;
|
||||
pub mod attributes;
|
||||
pub mod credentials;
|
||||
pub mod error;
|
||||
pub mod port_pool;
|
||||
pub mod allocation;
|
||||
pub mod handler;
|
||||
pub mod udp_listener;
|
||||
pub mod tcp_listener;
|
||||
pub mod server;
|
||||
|
||||
pub use server::{TurnServer, TurnServerConfig, TurnServerHandle};
|
||||
|
|
@ -0,0 +1,254 @@
|
|||
// Relay port allocation pool for TURN allocations
|
||||
//
|
||||
// Manages a pool of UDP port numbers available for TURN relay transport
|
||||
// addresses. Each TURN allocation requires a dedicated relay port.
|
||||
//
|
||||
// The pool pre-shuffles available ports on creation to avoid predictable
|
||||
// allocation patterns. Ports are allocated from the front of the queue
|
||||
// and returned to the back on release.
|
||||
//
|
||||
// Default range: 49152-65535 (IANA dynamic/private port range)
|
||||
// This gives approximately 16,383 relay ports.
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
/// Manages a pool of available relay port numbers.
|
||||
///
|
||||
/// Ports are pre-shuffled on construction to avoid predictable patterns.
|
||||
/// Allocation is O(1) from a Vec (pop), release is O(1) (push).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use relay::turn::port_pool::PortPool;
|
||||
///
|
||||
/// let mut pool = PortPool::new(49152, 49160);
|
||||
/// assert_eq!(pool.available_count(), 9); // 49152..=49160 inclusive
|
||||
///
|
||||
/// let port = pool.allocate().unwrap();
|
||||
/// assert!(port >= 49152 && port <= 49160);
|
||||
/// assert_eq!(pool.available_count(), 8);
|
||||
///
|
||||
/// pool.release(port);
|
||||
/// assert_eq!(pool.available_count(), 9);
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PortPool {
|
||||
/// Ports available for allocation (shuffled on construction).
|
||||
available: Vec<u16>,
|
||||
/// Currently allocated ports (for tracking and preventing double-release).
|
||||
allocated: HashSet<u16>,
|
||||
}
|
||||
|
||||
impl PortPool {
|
||||
/// Create a new port pool with ports in the range `[range_start, range_end]` inclusive.
|
||||
///
|
||||
/// The available ports are shuffled using a simple Fisher-Yates-like shuffle
|
||||
/// seeded from the system clock. For cryptographic randomness, use the
|
||||
/// `rand` crate version below.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `range_start > range_end`.
|
||||
pub fn new(range_start: u16, range_end: u16) -> Self {
|
||||
assert!(
|
||||
range_start <= range_end,
|
||||
"range_start ({}) must be <= range_end ({})",
|
||||
range_start,
|
||||
range_end
|
||||
);
|
||||
|
||||
let mut available: Vec<u16> = (range_start..=range_end).collect();
|
||||
|
||||
// Shuffle using a simple PRNG seeded from system time.
|
||||
// This is sufficient for port randomization (not security-critical).
|
||||
let seed = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos();
|
||||
|
||||
let mut state = seed as u64;
|
||||
let len = available.len();
|
||||
if len > 1 {
|
||||
for i in (1..len).rev() {
|
||||
// Simple xorshift64 PRNG
|
||||
state ^= state << 13;
|
||||
state ^= state >> 7;
|
||||
state ^= state << 17;
|
||||
let j = (state as usize) % (i + 1);
|
||||
available.swap(i, j);
|
||||
}
|
||||
}
|
||||
|
||||
PortPool {
|
||||
available,
|
||||
allocated: HashSet::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate a port from the pool.
|
||||
///
|
||||
/// Returns `None` if no ports are available.
|
||||
/// Allocated ports are tracked to prevent double-allocation.
|
||||
pub fn allocate(&mut self) -> Option<u16> {
|
||||
let port = self.available.pop()?;
|
||||
self.allocated.insert(port);
|
||||
Some(port)
|
||||
}
|
||||
|
||||
/// Release a previously allocated port back to the pool.
|
||||
///
|
||||
/// If the port was not currently allocated (double-release or unknown port),
|
||||
/// this is a no-op to prevent pool corruption.
|
||||
pub fn release(&mut self, port: u16) {
|
||||
if self.allocated.remove(&port) {
|
||||
self.available.push(port);
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of ports currently available for allocation.
|
||||
pub fn available_count(&self) -> usize {
|
||||
self.available.len()
|
||||
}
|
||||
|
||||
/// Returns the number of ports currently allocated.
|
||||
pub fn allocated_count(&self) -> usize {
|
||||
self.allocated.len()
|
||||
}
|
||||
|
||||
/// Returns the total capacity of the pool (available + allocated).
|
||||
pub fn total_capacity(&self) -> usize {
|
||||
self.available.len() + self.allocated.len()
|
||||
}
|
||||
|
||||
/// Returns `true` if the given port is currently allocated.
|
||||
pub fn is_allocated(&self, port: u16) -> bool {
|
||||
self.allocated.contains(&port)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_new_pool_correct_count() {
|
||||
let pool = PortPool::new(49152, 49160);
|
||||
assert_eq!(pool.available_count(), 9); // 49152..=49160 inclusive
|
||||
assert_eq!(pool.allocated_count(), 0);
|
||||
assert_eq!(pool.total_capacity(), 9);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_port_pool() {
|
||||
let mut pool = PortPool::new(5000, 5000);
|
||||
assert_eq!(pool.available_count(), 1);
|
||||
|
||||
let port = pool.allocate().unwrap();
|
||||
assert_eq!(port, 5000);
|
||||
assert_eq!(pool.available_count(), 0);
|
||||
assert!(pool.is_allocated(5000));
|
||||
|
||||
assert!(pool.allocate().is_none());
|
||||
|
||||
pool.release(5000);
|
||||
assert_eq!(pool.available_count(), 1);
|
||||
assert!(!pool.is_allocated(5000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allocate_all_ports() {
|
||||
let mut pool = PortPool::new(10000, 10004);
|
||||
let mut allocated = Vec::new();
|
||||
|
||||
for _ in 0..5 {
|
||||
let port = pool.allocate().unwrap();
|
||||
assert!((10000..=10004).contains(&port));
|
||||
allocated.push(port);
|
||||
}
|
||||
|
||||
assert_eq!(pool.available_count(), 0);
|
||||
assert_eq!(pool.allocated_count(), 5);
|
||||
assert!(pool.allocate().is_none());
|
||||
|
||||
// All ports should be unique
|
||||
let unique: HashSet<u16> = allocated.iter().copied().collect();
|
||||
assert_eq!(unique.len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_release_and_reallocate() {
|
||||
let mut pool = PortPool::new(8000, 8002);
|
||||
|
||||
let p1 = pool.allocate().unwrap();
|
||||
let p2 = pool.allocate().unwrap();
|
||||
let p3 = pool.allocate().unwrap();
|
||||
assert!(pool.allocate().is_none());
|
||||
|
||||
pool.release(p2);
|
||||
assert_eq!(pool.available_count(), 1);
|
||||
|
||||
let p4 = pool.allocate().unwrap();
|
||||
assert_eq!(p4, p2); // released port gets reused
|
||||
assert!(pool.allocate().is_none());
|
||||
|
||||
pool.release(p1);
|
||||
pool.release(p3);
|
||||
pool.release(p4);
|
||||
assert_eq!(pool.available_count(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_double_release_is_noop() {
|
||||
let mut pool = PortPool::new(7000, 7002);
|
||||
|
||||
let port = pool.allocate().unwrap();
|
||||
pool.release(port);
|
||||
assert_eq!(pool.available_count(), 3);
|
||||
|
||||
// Double release should not add a duplicate
|
||||
pool.release(port);
|
||||
assert_eq!(pool.available_count(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_release_unknown_port_is_noop() {
|
||||
let mut pool = PortPool::new(7000, 7002);
|
||||
|
||||
// Release a port that was never allocated
|
||||
pool.release(9999);
|
||||
assert_eq!(pool.available_count(), 3);
|
||||
assert_eq!(pool.allocated_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "range_start")]
|
||||
fn test_invalid_range_panics() {
|
||||
let _ = PortPool::new(100, 50);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_large_pool() {
|
||||
let pool = PortPool::new(49152, 65535);
|
||||
assert_eq!(pool.available_count(), 16384);
|
||||
assert_eq!(pool.total_capacity(), 16384);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_allocated() {
|
||||
let mut pool = PortPool::new(6000, 6005);
|
||||
|
||||
let port = pool.allocate().unwrap();
|
||||
assert!(pool.is_allocated(port));
|
||||
|
||||
pool.release(port);
|
||||
assert!(!pool.is_allocated(port));
|
||||
|
||||
// Never-allocated port
|
||||
assert!(!pool.is_allocated(9999));
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,377 @@
|
|||
// TURN server entry point
|
||||
//
|
||||
// Ties together the TurnHandler, AllocationManager, PortPool, and
|
||||
// UDP/TCP listeners into a clean, configurable server.
|
||||
//
|
||||
// Configuration can be loaded from environment variables via
|
||||
// `TurnServerConfig::from_env()` or constructed programmatically.
|
||||
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::allocation::{AllocationConfig, AllocationManager};
|
||||
use super::handler::TurnHandler;
|
||||
use super::port_pool::PortPool;
|
||||
use super::tcp_listener::TcpTurnListener;
|
||||
use super::udp_listener::UdpTurnListener;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TurnServerConfig
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Configuration for the TURN server.
|
||||
///
|
||||
/// All fields have sensible defaults. The only required field for production
|
||||
/// use is `public_ip` — without it, relay addresses will use 0.0.0.0 which
|
||||
/// won't work for external peers.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TurnServerConfig {
|
||||
/// UDP listen address (default: 0.0.0.0:3478).
|
||||
pub udp_addr: SocketAddr,
|
||||
/// TCP listen address (default: 0.0.0.0:3478).
|
||||
pub tcp_addr: SocketAddr,
|
||||
/// Public IP address for XOR-RELAYED-ADDRESS (required for production).
|
||||
pub public_ip: IpAddr,
|
||||
/// Shared secret for HMAC credential generation.
|
||||
pub shared_secret: Vec<u8>,
|
||||
/// Authentication realm (default: "duskchat.app").
|
||||
pub realm: String,
|
||||
/// Relay port range start (default: 49152).
|
||||
pub relay_port_start: u16,
|
||||
/// Relay port range end (default: 65535).
|
||||
pub relay_port_end: u16,
|
||||
/// Allocation configuration (lifetimes, quotas, etc.).
|
||||
pub allocation_config: AllocationConfig,
|
||||
}
|
||||
|
||||
impl Default for TurnServerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
udp_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 3478),
|
||||
tcp_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 3478),
|
||||
public_ip: IpAddr::V4(Ipv4Addr::UNSPECIFIED),
|
||||
shared_secret: Vec::new(),
|
||||
realm: "duskchat.app".to_string(),
|
||||
relay_port_start: 49152,
|
||||
relay_port_end: 65535,
|
||||
allocation_config: AllocationConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TurnServerConfig {
|
||||
/// Create a configuration from environment variables with sensible defaults.
|
||||
///
|
||||
/// Supported environment variables:
|
||||
/// - `DUSK_TURN_UDP_PORT` — UDP listen port (default: 3478)
|
||||
/// - `DUSK_TURN_TCP_PORT` — TCP listen port (default: 3478)
|
||||
/// - `DUSK_TURN_PUBLIC_IP` — Public IP for relay addresses (required in prod)
|
||||
/// - `DUSK_TURN_SECRET` — Shared secret for credentials (auto-generated if unset)
|
||||
/// - `DUSK_TURN_REALM` — Authentication realm (default: "duskchat.app")
|
||||
/// - `DUSK_TURN_PORT_RANGE_START` — Relay port range start (default: 49152)
|
||||
/// - `DUSK_TURN_PORT_RANGE_END` — Relay port range end (default: 65535)
|
||||
/// - `DUSK_TURN_MAX_ALLOCATIONS` — Global allocation limit (default: 1000)
|
||||
/// - `DUSK_TURN_MAX_PER_USER` — Per-user allocation limit (default: 10)
|
||||
pub fn from_env() -> Self {
|
||||
let udp_port: u16 = std::env::var("DUSK_TURN_UDP_PORT")
|
||||
.ok()
|
||||
.and_then(|p| p.parse().ok())
|
||||
.unwrap_or(3478);
|
||||
|
||||
let tcp_port: u16 = std::env::var("DUSK_TURN_TCP_PORT")
|
||||
.ok()
|
||||
.and_then(|p| p.parse().ok())
|
||||
.unwrap_or(3478);
|
||||
|
||||
let public_ip: IpAddr = std::env::var("DUSK_TURN_PUBLIC_IP")
|
||||
.ok()
|
||||
.and_then(|ip| ip.parse().ok())
|
||||
.unwrap_or_else(|| {
|
||||
log::warn!(
|
||||
"[TURN] DUSK_TURN_PUBLIC_IP not set — relay addresses will use 0.0.0.0. \
|
||||
Set this to the server's public IP for production use."
|
||||
);
|
||||
IpAddr::V4(Ipv4Addr::UNSPECIFIED)
|
||||
});
|
||||
|
||||
let shared_secret = std::env::var("DUSK_TURN_SECRET")
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(|s| s.into_bytes())
|
||||
.unwrap_or_else(|| {
|
||||
let secret = generate_random_bytes(32);
|
||||
log::warn!(
|
||||
"[TURN] DUSK_TURN_SECRET not set — using random secret. \
|
||||
This won't work across multiple relay instances."
|
||||
);
|
||||
secret
|
||||
});
|
||||
|
||||
let realm = std::env::var("DUSK_TURN_REALM")
|
||||
.ok()
|
||||
.filter(|r| !r.is_empty())
|
||||
.unwrap_or_else(|| "duskchat.app".to_string());
|
||||
|
||||
let port_range_start: u16 = std::env::var("DUSK_TURN_PORT_RANGE_START")
|
||||
.ok()
|
||||
.and_then(|p| p.parse().ok())
|
||||
.unwrap_or(49152);
|
||||
|
||||
let port_range_end: u16 = std::env::var("DUSK_TURN_PORT_RANGE_END")
|
||||
.ok()
|
||||
.and_then(|p| p.parse().ok())
|
||||
.unwrap_or(65535);
|
||||
|
||||
let max_allocations: usize = std::env::var("DUSK_TURN_MAX_ALLOCATIONS")
|
||||
.ok()
|
||||
.and_then(|n| n.parse().ok())
|
||||
.unwrap_or(1000);
|
||||
|
||||
let max_per_user: usize = std::env::var("DUSK_TURN_MAX_PER_USER")
|
||||
.ok()
|
||||
.and_then(|n| n.parse().ok())
|
||||
.unwrap_or(10);
|
||||
|
||||
let allocation_config = AllocationConfig {
|
||||
max_allocations,
|
||||
max_per_user,
|
||||
realm: realm.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
Self {
|
||||
udp_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), udp_port),
|
||||
tcp_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), tcp_port),
|
||||
public_ip,
|
||||
shared_secret,
|
||||
realm,
|
||||
relay_port_start: port_range_start,
|
||||
relay_port_end: port_range_end,
|
||||
allocation_config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the TURN server is enabled via environment variable.
|
||||
///
|
||||
/// Returns `false` only if `DUSK_TURN_ENABLED=false`. Defaults to `true`.
|
||||
pub fn is_enabled() -> bool {
|
||||
std::env::var("DUSK_TURN_ENABLED")
|
||||
.map(|v| v != "false" && v != "0")
|
||||
.unwrap_or(true)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TurnServer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// The TURN server. Owns the configuration and provides a [`run`](TurnServer::run)
|
||||
/// method that starts all listener tasks and returns a [`TurnServerHandle`].
|
||||
pub struct TurnServer {
|
||||
config: TurnServerConfig,
|
||||
}
|
||||
|
||||
impl TurnServer {
|
||||
/// Create a new TURN server with the given configuration.
|
||||
pub fn new(config: TurnServerConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// Start the TURN server.
|
||||
///
|
||||
/// Binds UDP and TCP sockets, creates the handler and allocation manager,
|
||||
/// spawns listener and cleanup tasks, and returns a handle for monitoring.
|
||||
pub async fn run(self) -> Result<TurnServerHandle, Box<dyn std::error::Error>> {
|
||||
let port_pool = Arc::new(tokio::sync::Mutex::new(PortPool::new(
|
||||
self.config.relay_port_start,
|
||||
self.config.relay_port_end,
|
||||
)));
|
||||
|
||||
let alloc_mgr = Arc::new(AllocationManager::new(
|
||||
self.config.allocation_config.clone(),
|
||||
));
|
||||
|
||||
// Generate nonce secret (separate from shared secret)
|
||||
let nonce_secret = generate_random_bytes(32);
|
||||
|
||||
let handler = Arc::new(TurnHandler::new(
|
||||
Arc::clone(&alloc_mgr),
|
||||
Arc::clone(&port_pool),
|
||||
self.config.shared_secret.clone(),
|
||||
self.config.realm.clone(),
|
||||
nonce_secret,
|
||||
self.config.public_ip,
|
||||
));
|
||||
|
||||
// Bind UDP socket
|
||||
let udp_socket = Arc::new(
|
||||
tokio::net::UdpSocket::bind(self.config.udp_addr).await?,
|
||||
);
|
||||
log::info!(
|
||||
"[TURN] UDP listening on {}",
|
||||
udp_socket.local_addr().unwrap_or(self.config.udp_addr)
|
||||
);
|
||||
|
||||
let udp_listener = Arc::new(UdpTurnListener::new(
|
||||
Arc::clone(&udp_socket),
|
||||
Arc::clone(&handler),
|
||||
Arc::clone(&alloc_mgr),
|
||||
self.config.udp_addr,
|
||||
self.config.public_ip,
|
||||
));
|
||||
|
||||
// Spawn UDP listener
|
||||
let udp_handle = tokio::spawn({
|
||||
let listener = Arc::clone(&udp_listener);
|
||||
async move {
|
||||
listener.run().await;
|
||||
}
|
||||
});
|
||||
|
||||
// Bind TCP listener
|
||||
let mut tcp_listener = TcpTurnListener::bind(
|
||||
self.config.tcp_addr,
|
||||
Arc::clone(&handler),
|
||||
Arc::clone(&alloc_mgr),
|
||||
self.config.public_ip,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Give the TCP listener access to the UDP socket for relay receiver tasks
|
||||
tcp_listener.set_udp_socket(Arc::clone(&udp_socket));
|
||||
|
||||
log::info!(
|
||||
"[TURN] TCP listening on {}",
|
||||
self.config.tcp_addr
|
||||
);
|
||||
|
||||
// Spawn TCP listener
|
||||
let tcp_handle = tokio::spawn(async move {
|
||||
tcp_listener.run().await;
|
||||
});
|
||||
|
||||
// Spawn cleanup task (runs every 30 seconds)
|
||||
let cleanup_alloc_mgr = Arc::clone(&alloc_mgr);
|
||||
let cleanup_port_pool = Arc::clone(&port_pool);
|
||||
let cleanup_handle = tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
|
||||
let freed_ports = cleanup_alloc_mgr.cleanup_expired().await;
|
||||
if !freed_ports.is_empty() {
|
||||
let mut pool = cleanup_port_pool.lock().await;
|
||||
for port in &freed_ports {
|
||||
pool.release(*port);
|
||||
}
|
||||
log::info!(
|
||||
"[TURN] cleaned up {} expired allocations",
|
||||
freed_ports.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
log::info!(
|
||||
"[TURN] server started (public_ip={}, realm={}, relay_ports={}-{})",
|
||||
self.config.public_ip,
|
||||
self.config.realm,
|
||||
self.config.relay_port_start,
|
||||
self.config.relay_port_end,
|
||||
);
|
||||
|
||||
Ok(TurnServerHandle {
|
||||
udp_handle,
|
||||
tcp_handle,
|
||||
cleanup_handle,
|
||||
handler,
|
||||
alloc_mgr,
|
||||
port_pool,
|
||||
shared_secret: self.config.shared_secret,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the shared secret (for credential generation by the libp2p protocol).
|
||||
pub fn shared_secret(&self) -> &[u8] {
|
||||
&self.config.shared_secret
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TurnServerHandle
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Handle to a running TURN server.
|
||||
///
|
||||
/// Provides access to the server's shared state and credential generation.
|
||||
/// The server runs in the background via spawned tasks; dropping the handle
|
||||
/// does NOT stop the server (the tasks keep running).
|
||||
pub struct TurnServerHandle {
|
||||
/// UDP listener task handle.
|
||||
pub udp_handle: tokio::task::JoinHandle<()>,
|
||||
/// TCP listener task handle.
|
||||
pub tcp_handle: tokio::task::JoinHandle<()>,
|
||||
/// Cleanup task handle.
|
||||
pub cleanup_handle: tokio::task::JoinHandle<()>,
|
||||
/// The shared TURN message handler.
|
||||
pub handler: Arc<TurnHandler>,
|
||||
/// The shared allocation manager.
|
||||
pub alloc_mgr: Arc<AllocationManager>,
|
||||
/// The shared port pool.
|
||||
pub port_pool: Arc<tokio::sync::Mutex<PortPool>>,
|
||||
/// The shared secret (for credential generation).
|
||||
shared_secret: Vec<u8>,
|
||||
}
|
||||
|
||||
impl TurnServerHandle {
|
||||
/// Generate time-limited TURN credentials for a peer.
|
||||
///
|
||||
/// Returns `(username, password)` suitable for use with WebRTC ICE servers.
|
||||
/// The credentials expire after `ttl_secs` seconds.
|
||||
pub fn generate_credentials(&self, peer_id: &str, ttl_secs: u64) -> (String, String) {
|
||||
super::credentials::generate_credentials(peer_id, &self.shared_secret, ttl_secs)
|
||||
}
|
||||
|
||||
/// Get the current number of active allocations.
|
||||
pub async fn allocation_count(&self) -> usize {
|
||||
self.alloc_mgr.allocation_count().await
|
||||
}
|
||||
|
||||
/// Get the shared secret bytes (for external credential generation).
|
||||
pub fn shared_secret(&self) -> &[u8] {
|
||||
&self.shared_secret
|
||||
}
|
||||
|
||||
/// Abort all server tasks. The server will stop after current in-flight
|
||||
/// operations complete.
|
||||
pub fn shutdown(&self) {
|
||||
self.udp_handle.abort();
|
||||
self.tcp_handle.abort();
|
||||
self.cleanup_handle.abort();
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Utility: random byte generation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Generate `len` pseudo-random bytes using a simple xorshift64 PRNG
|
||||
/// seeded from system time.
|
||||
///
|
||||
/// This is NOT cryptographically secure. It's used for nonce secrets
|
||||
/// and auto-generated shared secrets when none is configured. In production,
|
||||
/// configure `DUSK_TURN_SECRET` explicitly.
|
||||
fn generate_random_bytes(len: usize) -> Vec<u8> {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
let seed = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_nanos() as u64;
|
||||
let mut state = seed;
|
||||
let mut bytes = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
state ^= state << 13;
|
||||
state ^= state >> 7;
|
||||
state ^= state << 17;
|
||||
bytes.push(state as u8);
|
||||
}
|
||||
bytes
|
||||
}
|
||||
|
|
@ -0,0 +1,799 @@
|
|||
// STUN message parser and serializer per RFC 5389
|
||||
//
|
||||
// Implements the STUN message format with support for TURN methods (RFC 5766).
|
||||
// Also handles ChannelData framing for TURN channel bindings.
|
||||
//
|
||||
// STUN message header layout (20 bytes):
|
||||
// Bytes 0-1: Message Type (method + class encoding)
|
||||
// Bytes 2-3: Message Length (excludes 20-byte header)
|
||||
// Bytes 4-7: Magic Cookie (0x2112A442)
|
||||
// Bytes 8-19: Transaction ID (96 bits)
|
||||
//
|
||||
// Message Type encoding (RFC 5389 §6):
|
||||
// Bits: 13 12 11 10 9 8 7 6 5 4 3 2 1 0
|
||||
// M11 M10 M9 M8 M7 C1 M6 M5 M4 C0 M3 M2 M1 M0
|
||||
// Where M0-M11 = method bits, C0-C1 = class bits.
|
||||
|
||||
use crate::turn::attributes::StunAttribute;
|
||||
use crate::turn::error::TurnError;
|
||||
|
||||
/// STUN magic cookie value (RFC 5389 §6).
|
||||
pub const MAGIC_COOKIE: u32 = 0x2112A442;
|
||||
|
||||
/// STUN header size in bytes.
|
||||
pub const STUN_HEADER_SIZE: usize = 20;
|
||||
|
||||
/// XOR value for FINGERPRINT CRC32 (RFC 5389 §15.5).
|
||||
pub const FINGERPRINT_XOR: u32 = 0x5354554e;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Method
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// STUN/TURN method identifiers.
|
||||
///
|
||||
/// Methods 0x0001 (Binding) are defined in RFC 5389; the rest are
|
||||
/// TURN extensions from RFC 5766.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Method {
|
||||
/// STUN Binding (0x0001) — NAT discovery
|
||||
Binding,
|
||||
/// TURN Allocate (0x0003) — create relay allocation
|
||||
Allocate,
|
||||
/// TURN Refresh (0x0004) — refresh allocation lifetime
|
||||
Refresh,
|
||||
/// TURN Send (0x0006) — send data indication to peer
|
||||
Send,
|
||||
/// TURN Data (0x0007) — data indication from peer
|
||||
Data,
|
||||
/// TURN CreatePermission (0x0008) — install relay permission
|
||||
CreatePermission,
|
||||
/// TURN ChannelBind (0x0009) — bind channel number to peer
|
||||
ChannelBind,
|
||||
}
|
||||
|
||||
impl Method {
|
||||
/// Returns the 12-bit method number (M0-M11).
|
||||
pub fn as_u16(self) -> u16 {
|
||||
match self {
|
||||
Method::Binding => 0x0001,
|
||||
Method::Allocate => 0x0003,
|
||||
Method::Refresh => 0x0004,
|
||||
Method::Send => 0x0006,
|
||||
Method::Data => 0x0007,
|
||||
Method::CreatePermission => 0x0008,
|
||||
Method::ChannelBind => 0x0009,
|
||||
}
|
||||
}
|
||||
|
||||
/// Parses a 12-bit method number into a [`Method`].
|
||||
pub fn from_u16(val: u16) -> Result<Self, TurnError> {
|
||||
match val {
|
||||
0x0001 => Ok(Method::Binding),
|
||||
0x0003 => Ok(Method::Allocate),
|
||||
0x0004 => Ok(Method::Refresh),
|
||||
0x0006 => Ok(Method::Send),
|
||||
0x0007 => Ok(Method::Data),
|
||||
0x0008 => Ok(Method::CreatePermission),
|
||||
0x0009 => Ok(Method::ChannelBind),
|
||||
_ => Err(TurnError::StunParseError(format!(
|
||||
"unknown STUN method: 0x{:04x}",
|
||||
val
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Class
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// STUN message class (RFC 5389 §6).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Class {
|
||||
/// Request (C0=0, C1=0)
|
||||
Request,
|
||||
/// Indication (C0=1, C1=0)
|
||||
Indication,
|
||||
/// Success Response (C0=0, C1=1)
|
||||
SuccessResponse,
|
||||
/// Error Response (C0=1, C1=1)
|
||||
ErrorResponse,
|
||||
}
|
||||
|
||||
impl Class {
|
||||
/// Returns the 2-bit class value (C0 in bit 0, C1 in bit 1).
|
||||
pub fn as_u8(self) -> u8 {
|
||||
match self {
|
||||
Class::Request => 0b00,
|
||||
Class::Indication => 0b01,
|
||||
Class::SuccessResponse => 0b10,
|
||||
Class::ErrorResponse => 0b11,
|
||||
}
|
||||
}
|
||||
|
||||
/// Parses a 2-bit class value.
|
||||
pub fn from_u8(val: u8) -> Result<Self, TurnError> {
|
||||
match val & 0x03 {
|
||||
0b00 => Ok(Class::Request),
|
||||
0b01 => Ok(Class::Indication),
|
||||
0b10 => Ok(Class::SuccessResponse),
|
||||
0b11 => Ok(Class::ErrorResponse),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MessageType
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Combined STUN message type (method + class).
|
||||
///
|
||||
/// The 14-bit message type field encodes both the method and class with
|
||||
/// an interleaved bit layout per RFC 5389 §6.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct MessageType {
|
||||
pub method: Method,
|
||||
pub class: Class,
|
||||
}
|
||||
|
||||
impl MessageType {
|
||||
pub fn new(method: Method, class: Class) -> Self {
|
||||
Self { method, class }
|
||||
}
|
||||
|
||||
/// Encode method + class into the 14-bit STUN message type field.
|
||||
///
|
||||
/// Bit layout of the 16-bit type field (top 2 bits always 0 for STUN):
|
||||
/// ```text
|
||||
/// 15 14 | 13 12 11 10 9 | 8 | 7 6 5 | 4 | 3 2 1 0
|
||||
/// 0 0 | M11 M10 M9 M8 M7 | C1 | M6 M5 M4 | C0 | M3 M2 M1 M0
|
||||
/// ```
|
||||
pub fn to_u16(self) -> u16 {
|
||||
let m = self.method.as_u16();
|
||||
let c = self.class.as_u8() as u16;
|
||||
|
||||
// Extract method bit groups
|
||||
let m0_3 = m & 0x000F; // bits 0-3
|
||||
let m4_6 = (m >> 4) & 0x0007; // bits 4-6
|
||||
let m7_11 = (m >> 7) & 0x001F; // bits 7-11
|
||||
|
||||
// Extract class bits
|
||||
let c0 = c & 0x01;
|
||||
let c1 = (c >> 1) & 0x01;
|
||||
|
||||
// Assemble: M0-M3 in bits 0-3, C0 in bit 4, M4-M6 in bits 5-7,
|
||||
// C1 in bit 8, M7-M11 in bits 9-13
|
||||
m0_3 | (c0 << 4) | (m4_6 << 5) | (c1 << 8) | (m7_11 << 9)
|
||||
}
|
||||
|
||||
/// Decode the 14-bit STUN message type field into method + class.
|
||||
pub fn from_u16(val: u16) -> Result<Self, TurnError> {
|
||||
// Top 2 bits must be 00 for STUN messages
|
||||
if val & 0xC000 != 0 {
|
||||
return Err(TurnError::StunParseError(
|
||||
"top 2 bits of message type must be 00 for STUN".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Extract class bits
|
||||
let c0 = (val >> 4) & 0x01;
|
||||
let c1 = (val >> 8) & 0x01;
|
||||
let class_bits = (c0 | (c1 << 1)) as u8;
|
||||
|
||||
// Extract method bits
|
||||
let m0_3 = val & 0x000F;
|
||||
let m4_6 = (val >> 5) & 0x0007;
|
||||
let m7_11 = (val >> 9) & 0x001F;
|
||||
let method_bits = m0_3 | (m4_6 << 4) | (m7_11 << 7);
|
||||
|
||||
Ok(MessageType {
|
||||
method: Method::from_u16(method_bits)?,
|
||||
class: Class::from_u8(class_bits)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// StunMessage
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A parsed STUN message with header fields and attributes.
|
||||
///
|
||||
/// The wire format is a 20-byte header followed by zero or more TLV-encoded
|
||||
/// attributes. The message length field in the header covers only the
|
||||
/// attribute portion (not the 20-byte header itself).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StunMessage {
|
||||
pub msg_type: MessageType,
|
||||
pub transaction_id: [u8; 12],
|
||||
pub attributes: Vec<StunAttribute>,
|
||||
}
|
||||
|
||||
impl StunMessage {
|
||||
/// Create a new STUN message with the given type and transaction ID.
|
||||
pub fn new(msg_type: MessageType, transaction_id: [u8; 12]) -> Self {
|
||||
Self {
|
||||
msg_type,
|
||||
transaction_id,
|
||||
attributes: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new STUN message with a random transaction ID.
|
||||
pub fn new_random(msg_type: MessageType) -> Self {
|
||||
let mut transaction_id = [0u8; 12];
|
||||
// Use simple random fill; callers with `rand` can use proper RNG
|
||||
#[cfg(feature = "rand")]
|
||||
{
|
||||
use rand::RngCore;
|
||||
rand::thread_rng().fill_bytes(&mut transaction_id);
|
||||
}
|
||||
#[cfg(not(feature = "rand"))]
|
||||
{
|
||||
// Fallback: use std time-based entropy (not cryptographically secure,
|
||||
// but sufficient for transaction IDs in development/testing)
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default();
|
||||
let seed = now.as_nanos();
|
||||
for (i, byte) in transaction_id.iter_mut().enumerate() {
|
||||
*byte = ((seed >> (i * 5)) & 0xFF) as u8;
|
||||
}
|
||||
}
|
||||
Self {
|
||||
msg_type,
|
||||
transaction_id,
|
||||
attributes: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an attribute to the message.
|
||||
pub fn add_attribute(&mut self, attr: StunAttribute) {
|
||||
self.attributes.push(attr);
|
||||
}
|
||||
|
||||
/// Decode a STUN message from raw bytes.
|
||||
///
|
||||
/// Validates the magic cookie and parses the header and all TLV attributes.
|
||||
/// Returns an error if the message is truncated, has an invalid cookie,
|
||||
/// or contains malformed attributes.
|
||||
pub fn decode(bytes: &[u8]) -> Result<Self, TurnError> {
|
||||
if bytes.len() < STUN_HEADER_SIZE {
|
||||
return Err(TurnError::StunParseError(format!(
|
||||
"message too short: {} bytes, need at least {}",
|
||||
bytes.len(),
|
||||
STUN_HEADER_SIZE
|
||||
)));
|
||||
}
|
||||
|
||||
// Parse message type (first 2 bytes)
|
||||
let type_val = u16::from_be_bytes([bytes[0], bytes[1]]);
|
||||
let msg_type = MessageType::from_u16(type_val)?;
|
||||
|
||||
// Parse message length (bytes 2-3) — length of attributes only
|
||||
let msg_length = u16::from_be_bytes([bytes[2], bytes[3]]) as usize;
|
||||
|
||||
// Validate magic cookie (bytes 4-7)
|
||||
let cookie = u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
|
||||
if cookie != MAGIC_COOKIE {
|
||||
return Err(TurnError::StunParseError(format!(
|
||||
"invalid magic cookie: 0x{:08x}, expected 0x{:08x}",
|
||||
cookie, MAGIC_COOKIE
|
||||
)));
|
||||
}
|
||||
|
||||
// Extract transaction ID (bytes 8-19)
|
||||
let mut transaction_id = [0u8; 12];
|
||||
transaction_id.copy_from_slice(&bytes[8..20]);
|
||||
|
||||
// Validate total length
|
||||
let total_expected = STUN_HEADER_SIZE + msg_length;
|
||||
if bytes.len() < total_expected {
|
||||
return Err(TurnError::StunParseError(format!(
|
||||
"message truncated: have {} bytes, header says {}",
|
||||
bytes.len(),
|
||||
total_expected
|
||||
)));
|
||||
}
|
||||
|
||||
// Parse attributes from the body
|
||||
let attr_bytes = &bytes[STUN_HEADER_SIZE..total_expected];
|
||||
let attributes = Self::decode_attributes(attr_bytes, &transaction_id)?;
|
||||
|
||||
Ok(StunMessage {
|
||||
msg_type,
|
||||
transaction_id,
|
||||
attributes,
|
||||
})
|
||||
}
|
||||
|
||||
/// Parse the TLV attribute list from the message body.
|
||||
fn decode_attributes(
|
||||
mut data: &[u8],
|
||||
transaction_id: &[u8; 12],
|
||||
) -> Result<Vec<StunAttribute>, TurnError> {
|
||||
let mut attributes = Vec::new();
|
||||
|
||||
while data.len() >= 4 {
|
||||
// Each attribute: type (2 bytes) + length (2 bytes) + value + padding
|
||||
let attr_type = u16::from_be_bytes([data[0], data[1]]);
|
||||
let attr_len = u16::from_be_bytes([data[2], data[3]]) as usize;
|
||||
|
||||
if data.len() < 4 + attr_len {
|
||||
return Err(TurnError::StunParseError(format!(
|
||||
"attribute 0x{:04x} truncated: need {} bytes, have {}",
|
||||
attr_type,
|
||||
attr_len,
|
||||
data.len() - 4
|
||||
)));
|
||||
}
|
||||
|
||||
let attr_value = &data[4..4 + attr_len];
|
||||
let attr = crate::turn::attributes::decode_attribute(
|
||||
attr_type,
|
||||
attr_value,
|
||||
transaction_id,
|
||||
)?;
|
||||
attributes.push(attr);
|
||||
|
||||
// Advance past value + padding to 4-byte boundary
|
||||
let padded_len = (attr_len + 3) & !3;
|
||||
let total_attr_size = 4 + padded_len;
|
||||
if total_attr_size > data.len() {
|
||||
break;
|
||||
}
|
||||
data = &data[total_attr_size..];
|
||||
}
|
||||
|
||||
Ok(attributes)
|
||||
}
|
||||
|
||||
/// Encode this STUN message into wire format bytes.
|
||||
///
|
||||
/// The message length field is computed from the encoded attributes.
|
||||
/// Attributes are TLV-encoded with 4-byte padding.
|
||||
pub fn encode(&self) -> Vec<u8> {
|
||||
// Encode all attributes first to determine total length
|
||||
let mut attr_bytes = Vec::new();
|
||||
for attr in &self.attributes {
|
||||
let encoded = crate::turn::attributes::encode_attribute(attr, &self.transaction_id);
|
||||
attr_bytes.extend_from_slice(&encoded);
|
||||
}
|
||||
|
||||
let msg_length = attr_bytes.len() as u16;
|
||||
|
||||
// Build the 20-byte header
|
||||
let mut buf = Vec::with_capacity(STUN_HEADER_SIZE + attr_bytes.len());
|
||||
|
||||
// Message type (2 bytes)
|
||||
buf.extend_from_slice(&self.msg_type.to_u16().to_be_bytes());
|
||||
// Message length (2 bytes)
|
||||
buf.extend_from_slice(&msg_length.to_be_bytes());
|
||||
// Magic cookie (4 bytes)
|
||||
buf.extend_from_slice(&MAGIC_COOKIE.to_be_bytes());
|
||||
// Transaction ID (12 bytes)
|
||||
buf.extend_from_slice(&self.transaction_id);
|
||||
// Attributes
|
||||
buf.extend_from_slice(&attr_bytes);
|
||||
|
||||
buf
|
||||
}
|
||||
|
||||
/// Encode the message, computing the correct length field as if a
|
||||
/// MESSAGE-INTEGRITY attribute of 24 bytes (4-byte TLV header + 20-byte HMAC)
|
||||
/// will be appended. This is used to generate the bytes over which
|
||||
/// MESSAGE-INTEGRITY is computed per RFC 5389 §15.4.
|
||||
///
|
||||
/// The returned bytes do NOT include the MESSAGE-INTEGRITY attribute itself.
|
||||
pub fn encode_for_integrity(&self) -> Vec<u8> {
|
||||
// Encode attributes up to (but not including) MESSAGE-INTEGRITY
|
||||
let mut attr_bytes = Vec::new();
|
||||
for attr in &self.attributes {
|
||||
if matches!(attr, StunAttribute::MessageIntegrity(_)) {
|
||||
break;
|
||||
}
|
||||
if matches!(attr, StunAttribute::Fingerprint(_)) {
|
||||
break;
|
||||
}
|
||||
let encoded = crate::turn::attributes::encode_attribute(attr, &self.transaction_id);
|
||||
attr_bytes.extend_from_slice(&encoded);
|
||||
}
|
||||
|
||||
// The length field must include the MESSAGE-INTEGRITY attribute that will follow
|
||||
// (4 bytes TLV header + 20 bytes HMAC = 24 bytes)
|
||||
let msg_length = (attr_bytes.len() + 24) as u16;
|
||||
|
||||
let mut buf = Vec::with_capacity(STUN_HEADER_SIZE + attr_bytes.len());
|
||||
buf.extend_from_slice(&self.msg_type.to_u16().to_be_bytes());
|
||||
buf.extend_from_slice(&msg_length.to_be_bytes());
|
||||
buf.extend_from_slice(&MAGIC_COOKIE.to_be_bytes());
|
||||
buf.extend_from_slice(&self.transaction_id);
|
||||
buf.extend_from_slice(&attr_bytes);
|
||||
|
||||
buf
|
||||
}
|
||||
|
||||
/// Encode the message for FINGERPRINT computation per RFC 5389 §15.5.
|
||||
///
|
||||
/// Returns all bytes up to (but not including) the FINGERPRINT attribute,
|
||||
/// with the length field adjusted to include the FINGERPRINT (8 bytes).
|
||||
pub fn encode_for_fingerprint(&self) -> Vec<u8> {
|
||||
// Encode all attributes except FINGERPRINT
|
||||
let mut attr_bytes = Vec::new();
|
||||
for attr in &self.attributes {
|
||||
if matches!(attr, StunAttribute::Fingerprint(_)) {
|
||||
break;
|
||||
}
|
||||
let encoded = crate::turn::attributes::encode_attribute(attr, &self.transaction_id);
|
||||
attr_bytes.extend_from_slice(&encoded);
|
||||
}
|
||||
|
||||
// Length includes the FINGERPRINT attribute (4 header + 4 value = 8 bytes)
|
||||
let msg_length = (attr_bytes.len() + 8) as u16;
|
||||
|
||||
let mut buf = Vec::with_capacity(STUN_HEADER_SIZE + attr_bytes.len());
|
||||
buf.extend_from_slice(&self.msg_type.to_u16().to_be_bytes());
|
||||
buf.extend_from_slice(&msg_length.to_be_bytes());
|
||||
buf.extend_from_slice(&MAGIC_COOKIE.to_be_bytes());
|
||||
buf.extend_from_slice(&self.transaction_id);
|
||||
buf.extend_from_slice(&attr_bytes);
|
||||
|
||||
buf
|
||||
}
|
||||
|
||||
/// Find the first attribute of a given type in the message.
|
||||
pub fn get_attribute(&self, predicate: impl Fn(&StunAttribute) -> bool) -> Option<&StunAttribute> {
|
||||
self.attributes.iter().find(|a| predicate(a))
|
||||
}
|
||||
|
||||
/// Get the USERNAME attribute value, if present.
|
||||
pub fn get_username(&self) -> Option<&str> {
|
||||
for attr in &self.attributes {
|
||||
if let StunAttribute::Username(ref u) = attr {
|
||||
return Some(u.as_str());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Get the REALM attribute value, if present.
|
||||
pub fn get_realm(&self) -> Option<&str> {
|
||||
for attr in &self.attributes {
|
||||
if let StunAttribute::Realm(ref r) = attr {
|
||||
return Some(r.as_str());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Get the NONCE attribute value, if present.
|
||||
pub fn get_nonce(&self) -> Option<&str> {
|
||||
for attr in &self.attributes {
|
||||
if let StunAttribute::Nonce(ref n) = attr {
|
||||
return Some(n.as_str());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Get the MESSAGE-INTEGRITY attribute value, if present.
|
||||
pub fn get_message_integrity(&self) -> Option<&[u8; 20]> {
|
||||
for attr in &self.attributes {
|
||||
if let StunAttribute::MessageIntegrity(ref mi) = attr {
|
||||
return Some(mi);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ChannelData
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// ChannelData message for TURN channel bindings (RFC 5766 §11.4).
|
||||
///
|
||||
/// ChannelData uses a compact 4-byte header instead of the STUN format:
|
||||
/// Bytes 0-1: Channel Number (0x4000-0x7FFF)
|
||||
/// Bytes 2-3: Data Length
|
||||
/// Bytes 4+: Application Data (padded to 4 bytes over UDP)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ChannelData {
|
||||
pub channel_number: u16,
|
||||
pub data: Vec<u8>,
|
||||
}
|
||||
|
||||
impl ChannelData {
|
||||
/// Decode a ChannelData message from raw bytes.
|
||||
pub fn decode(bytes: &[u8]) -> Result<Self, TurnError> {
|
||||
if bytes.len() < 4 {
|
||||
return Err(TurnError::StunParseError(
|
||||
"ChannelData message too short".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let channel_number = u16::from_be_bytes([bytes[0], bytes[1]]);
|
||||
let data_length = u16::from_be_bytes([bytes[2], bytes[3]]) as usize;
|
||||
|
||||
// Channel numbers must be in range 0x4000-0x7FFF
|
||||
if !(0x4000..=0x7FFF).contains(&channel_number) {
|
||||
return Err(TurnError::StunParseError(format!(
|
||||
"invalid channel number: 0x{:04x}, must be in 0x4000-0x7FFF",
|
||||
channel_number
|
||||
)));
|
||||
}
|
||||
|
||||
if bytes.len() < 4 + data_length {
|
||||
return Err(TurnError::StunParseError(format!(
|
||||
"ChannelData truncated: need {} data bytes, have {}",
|
||||
data_length,
|
||||
bytes.len() - 4
|
||||
)));
|
||||
}
|
||||
|
||||
let data = bytes[4..4 + data_length].to_vec();
|
||||
|
||||
Ok(ChannelData {
|
||||
channel_number,
|
||||
data,
|
||||
})
|
||||
}
|
||||
|
||||
/// Encode this ChannelData message into wire format.
|
||||
///
|
||||
/// The data portion is padded to a 4-byte boundary (for UDP transport).
|
||||
pub fn encode(&self) -> Vec<u8> {
|
||||
let data_len = self.data.len();
|
||||
let padded_len = (data_len + 3) & !3;
|
||||
let mut buf = Vec::with_capacity(4 + padded_len);
|
||||
|
||||
buf.extend_from_slice(&self.channel_number.to_be_bytes());
|
||||
buf.extend_from_slice(&(data_len as u16).to_be_bytes());
|
||||
buf.extend_from_slice(&self.data);
|
||||
|
||||
// Pad to 4-byte boundary with zeros
|
||||
let padding = padded_len - data_len;
|
||||
for _ in 0..padding {
|
||||
buf.push(0);
|
||||
}
|
||||
|
||||
buf
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Detection helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Returns `true` if the buffer begins with a ChannelData message
|
||||
/// (first two bits are NOT `00`).
|
||||
///
|
||||
/// STUN messages always have the first two bits as `00` (message type field).
|
||||
/// ChannelData starts with channel numbers 0x4000-0x7FFF, which have the
|
||||
/// first two bits as `01`.
|
||||
pub fn is_channel_data(bytes: &[u8]) -> bool {
|
||||
if bytes.is_empty() {
|
||||
return false;
|
||||
}
|
||||
// First two bits: 00 = STUN, 01 = ChannelData
|
||||
(bytes[0] & 0xC0) != 0x00
|
||||
}
|
||||
|
||||
/// Returns `true` if the buffer looks like a valid STUN message
|
||||
/// (first two bits are `00` and magic cookie is present).
|
||||
pub fn is_stun_message(bytes: &[u8]) -> bool {
|
||||
if bytes.len() < STUN_HEADER_SIZE {
|
||||
return false;
|
||||
}
|
||||
// First two bits must be 00
|
||||
if (bytes[0] & 0xC0) != 0x00 {
|
||||
return false;
|
||||
}
|
||||
// Check magic cookie
|
||||
let cookie = u32::from_be_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
|
||||
cookie == MAGIC_COOKIE
|
||||
}
|
||||
|
||||
/// Compute the CRC32 fingerprint value for a STUN message.
|
||||
///
|
||||
/// The fingerprint is `CRC32(message_bytes) XOR 0x5354554e` per RFC 5389 §15.5.
|
||||
/// `message_bytes` should be the entire message up to (but not including)
|
||||
/// the FINGERPRINT attribute, with the length field adjusted to include it.
|
||||
pub fn compute_fingerprint(message_bytes: &[u8]) -> u32 {
|
||||
let crc = crc32_compute(message_bytes);
|
||||
crc ^ FINGERPRINT_XOR
|
||||
}
|
||||
|
||||
/// Simple CRC32 (ISO 3309 / ITU-T V.42) implementation.
|
||||
///
|
||||
/// This uses the standard polynomial 0xEDB88320 (reflected form).
|
||||
/// In production, the `crc32fast` crate should be used instead for
|
||||
/// SIMD-accelerated performance.
|
||||
fn crc32_compute(data: &[u8]) -> u32 {
|
||||
let mut crc: u32 = 0xFFFFFFFF;
|
||||
for &byte in data {
|
||||
crc ^= byte as u32;
|
||||
for _ in 0..8 {
|
||||
if crc & 1 != 0 {
|
||||
crc = (crc >> 1) ^ 0xEDB88320;
|
||||
} else {
|
||||
crc >>= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
!crc
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_message_type_encoding_binding_request() {
|
||||
let mt = MessageType::new(Method::Binding, Class::Request);
|
||||
let encoded = mt.to_u16();
|
||||
assert_eq!(encoded, 0x0001);
|
||||
let decoded = MessageType::from_u16(encoded).unwrap();
|
||||
assert_eq!(decoded.method, Method::Binding);
|
||||
assert_eq!(decoded.class, Class::Request);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_type_encoding_binding_success() {
|
||||
let mt = MessageType::new(Method::Binding, Class::SuccessResponse);
|
||||
let encoded = mt.to_u16();
|
||||
assert_eq!(encoded, 0x0101);
|
||||
let decoded = MessageType::from_u16(encoded).unwrap();
|
||||
assert_eq!(decoded.method, Method::Binding);
|
||||
assert_eq!(decoded.class, Class::SuccessResponse);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_type_encoding_allocate_request() {
|
||||
let mt = MessageType::new(Method::Allocate, Class::Request);
|
||||
let encoded = mt.to_u16();
|
||||
assert_eq!(encoded, 0x0003);
|
||||
let decoded = MessageType::from_u16(encoded).unwrap();
|
||||
assert_eq!(decoded.method, Method::Allocate);
|
||||
assert_eq!(decoded.class, Class::Request);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_type_encoding_allocate_error() {
|
||||
let mt = MessageType::new(Method::Allocate, Class::ErrorResponse);
|
||||
let encoded = mt.to_u16();
|
||||
assert_eq!(encoded, 0x0113);
|
||||
let decoded = MessageType::from_u16(encoded).unwrap();
|
||||
assert_eq!(decoded.method, Method::Allocate);
|
||||
assert_eq!(decoded.class, Class::ErrorResponse);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_type_roundtrip_all_methods() {
|
||||
let methods = [
|
||||
Method::Binding,
|
||||
Method::Allocate,
|
||||
Method::Refresh,
|
||||
Method::Send,
|
||||
Method::Data,
|
||||
Method::CreatePermission,
|
||||
Method::ChannelBind,
|
||||
];
|
||||
let classes = [
|
||||
Class::Request,
|
||||
Class::Indication,
|
||||
Class::SuccessResponse,
|
||||
Class::ErrorResponse,
|
||||
];
|
||||
|
||||
for method in &methods {
|
||||
for class in &classes {
|
||||
let mt = MessageType::new(*method, *class);
|
||||
let encoded = mt.to_u16();
|
||||
let decoded = MessageType::from_u16(encoded).unwrap();
|
||||
assert_eq!(decoded.method, *method, "method mismatch for {:?}/{:?}", method, class);
|
||||
assert_eq!(decoded.class, *class, "class mismatch for {:?}/{:?}", method, class);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_channel_data() {
|
||||
// STUN message: first byte has top 2 bits = 00
|
||||
assert!(!is_channel_data(&[0x00, 0x01]));
|
||||
// ChannelData: channel 0x4000 -> first byte = 0x40 -> top 2 bits = 01
|
||||
assert!(is_channel_data(&[0x40, 0x00]));
|
||||
// Empty
|
||||
assert!(!is_channel_data(&[]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_stun_message() {
|
||||
// Valid STUN header: type=0x0001, len=0, cookie=0x2112A442, txn_id=zeros
|
||||
let mut buf = [0u8; 20];
|
||||
buf[0] = 0x00;
|
||||
buf[1] = 0x01;
|
||||
// length = 0
|
||||
buf[4] = 0x21;
|
||||
buf[5] = 0x12;
|
||||
buf[6] = 0xA4;
|
||||
buf[7] = 0x42;
|
||||
assert!(is_stun_message(&buf));
|
||||
|
||||
// Wrong cookie
|
||||
buf[4] = 0x00;
|
||||
assert!(!is_stun_message(&buf));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_channel_data_roundtrip() {
|
||||
let cd = ChannelData {
|
||||
channel_number: 0x4001,
|
||||
data: vec![1, 2, 3, 4, 5],
|
||||
};
|
||||
let encoded = cd.encode();
|
||||
let decoded = ChannelData::decode(&encoded).unwrap();
|
||||
assert_eq!(decoded.channel_number, 0x4001);
|
||||
assert_eq!(decoded.data, vec![1, 2, 3, 4, 5]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_channel_data_padding() {
|
||||
let cd = ChannelData {
|
||||
channel_number: 0x4001,
|
||||
data: vec![1, 2, 3], // 3 bytes -> padded to 4
|
||||
};
|
||||
let encoded = cd.encode();
|
||||
// 4 header + 4 padded data = 8 bytes
|
||||
assert_eq!(encoded.len(), 8);
|
||||
assert_eq!(encoded[7], 0); // padding byte
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stun_message_encode_decode_empty() {
|
||||
let msg = StunMessage {
|
||||
msg_type: MessageType::new(Method::Binding, Class::Request),
|
||||
transaction_id: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||
attributes: vec![],
|
||||
};
|
||||
let encoded = msg.encode();
|
||||
assert_eq!(encoded.len(), STUN_HEADER_SIZE);
|
||||
|
||||
let decoded = StunMessage::decode(&encoded).unwrap();
|
||||
assert_eq!(decoded.msg_type.method, Method::Binding);
|
||||
assert_eq!(decoded.msg_type.class, Class::Request);
|
||||
assert_eq!(decoded.transaction_id, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
|
||||
assert!(decoded.attributes.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_crc32_known_value() {
|
||||
// CRC32 of "123456789" is 0xCBF43926
|
||||
let data = b"123456789";
|
||||
let crc = crc32_compute(data);
|
||||
assert_eq!(crc, 0xCBF43926);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_too_short() {
|
||||
let result = StunMessage::decode(&[0; 10]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_bad_cookie() {
|
||||
let mut buf = [0u8; 20];
|
||||
buf[0] = 0x00;
|
||||
buf[1] = 0x01;
|
||||
// Bad cookie
|
||||
buf[4] = 0xFF;
|
||||
let result = StunMessage::decode(&buf);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,329 @@
|
|||
// TCP TURN listener
|
||||
//
|
||||
// Accepts TCP connections from TURN clients and frames STUN/ChannelData
|
||||
// messages over the stream. Per RFC 5766 §2.1, when TURN is used over
|
||||
// TCP, the client connects to the TURN server via TCP but the relay
|
||||
// still uses UDP to communicate with peers.
|
||||
//
|
||||
// TCP framing: STUN messages and ChannelData are self-delimiting on
|
||||
// a TCP stream. The reader peeks at the first two bytes to determine
|
||||
// the message type:
|
||||
// - If the first two bits are 00 → STUN message. Read 20-byte header,
|
||||
// extract message length from bytes 2-3, then read the attribute body.
|
||||
// - Otherwise → ChannelData. The first 2 bytes are the channel number,
|
||||
// next 2 bytes are the data length, then read the data (padded to 4 bytes).
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
use super::allocation::{AllocationManager, TransportProtocol};
|
||||
use super::handler::{HandleResult, MessageContext, TurnHandler};
|
||||
use super::stun::{is_channel_data, ChannelData, StunMessage};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TcpTurnListener
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// TCP listener for the TURN server.
|
||||
///
|
||||
/// Accepts TCP connections from clients and processes STUN/TURN messages
|
||||
/// framed on the TCP stream. Each accepted connection is handled in a
|
||||
/// separate spawned task.
|
||||
pub struct TcpTurnListener {
|
||||
/// The bound TCP listener.
|
||||
listener: TcpListener,
|
||||
/// The shared TURN message handler.
|
||||
handler: Arc<TurnHandler>,
|
||||
/// The shared allocation manager (for relay receiver spawning).
|
||||
allocations: Arc<AllocationManager>,
|
||||
/// The server's listen address.
|
||||
server_addr: SocketAddr,
|
||||
/// The server's public IP (for MessageContext).
|
||||
server_public_ip: std::net::IpAddr,
|
||||
/// The primary UDP socket (for spawning relay receiver tasks).
|
||||
/// TCP clients still use UDP relay sockets for the peer-facing side.
|
||||
udp_socket: Option<Arc<tokio::net::UdpSocket>>,
|
||||
}
|
||||
|
||||
impl TcpTurnListener {
|
||||
/// Bind a TCP listener on the given address.
|
||||
pub async fn bind(
|
||||
addr: SocketAddr,
|
||||
handler: Arc<TurnHandler>,
|
||||
allocations: Arc<AllocationManager>,
|
||||
server_public_ip: std::net::IpAddr,
|
||||
) -> std::io::Result<Self> {
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
let server_addr = listener.local_addr()?;
|
||||
Ok(Self {
|
||||
listener,
|
||||
handler,
|
||||
allocations,
|
||||
server_addr,
|
||||
server_public_ip,
|
||||
udp_socket: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Set the primary UDP socket (used for relay receiver tasks spawned
|
||||
/// from TCP-originated allocations).
|
||||
pub fn set_udp_socket(&mut self, socket: Arc<tokio::net::UdpSocket>) {
|
||||
self.udp_socket = Some(socket);
|
||||
}
|
||||
|
||||
/// Run the TCP listener loop.
|
||||
///
|
||||
/// Accepts connections in a loop and spawns a handler task for each one.
|
||||
pub async fn run(self) {
|
||||
let handler = self.handler;
|
||||
let server_addr = self.server_addr;
|
||||
let server_public_ip = self.server_public_ip;
|
||||
let allocations = self.allocations;
|
||||
let udp_socket = self.udp_socket;
|
||||
|
||||
loop {
|
||||
match self.listener.accept().await {
|
||||
Ok((stream, client_addr)) => {
|
||||
let handler = Arc::clone(&handler);
|
||||
let allocations = Arc::clone(&allocations);
|
||||
let udp_socket = udp_socket.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_tcp_client(
|
||||
stream,
|
||||
client_addr,
|
||||
server_addr,
|
||||
server_public_ip,
|
||||
handler,
|
||||
allocations,
|
||||
udp_socket,
|
||||
)
|
||||
.await
|
||||
{
|
||||
log::debug!("[TURN-TCP] client {} disconnected: {}", client_addr, e);
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("[TURN-TCP] accept error: {}", e);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Per-client TCP handler
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Handle a single TCP client connection.
|
||||
///
|
||||
/// Reads STUN/ChannelData messages from the TCP stream in a loop, dispatches
|
||||
/// them to the handler, and writes responses back.
|
||||
async fn handle_tcp_client(
|
||||
stream: TcpStream,
|
||||
client_addr: SocketAddr,
|
||||
server_addr: SocketAddr,
|
||||
server_public_ip: std::net::IpAddr,
|
||||
handler: Arc<TurnHandler>,
|
||||
allocations: Arc<AllocationManager>,
|
||||
udp_socket: Option<Arc<tokio::net::UdpSocket>>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let ctx = MessageContext {
|
||||
client_addr,
|
||||
server_addr,
|
||||
protocol: TransportProtocol::Tcp,
|
||||
server_public_ip,
|
||||
};
|
||||
|
||||
let (mut reader, mut writer) = stream.into_split();
|
||||
let mut header_buf = [0u8; 4];
|
||||
|
||||
loop {
|
||||
// Read the first 2 bytes to determine message type
|
||||
match reader.read_exact(&mut header_buf[..2]).await {
|
||||
Ok(_) => {}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
|
||||
// Clean disconnect
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
}
|
||||
|
||||
if is_channel_data(&header_buf[..2]) {
|
||||
// ChannelData: first 2 bytes = channel number, next 2 = data length
|
||||
reader.read_exact(&mut header_buf[2..4]).await?;
|
||||
let data_len = u16::from_be_bytes([header_buf[2], header_buf[3]]) as usize;
|
||||
|
||||
// TCP pads ChannelData to 4-byte boundary
|
||||
let padded_len = (data_len + 3) & !3;
|
||||
let mut data_buf = vec![0u8; padded_len];
|
||||
reader.read_exact(&mut data_buf[..padded_len]).await?;
|
||||
|
||||
// Reconstruct the full ChannelData for decoding (header + unpadded data)
|
||||
let mut full_msg = Vec::with_capacity(4 + data_len);
|
||||
full_msg.extend_from_slice(&header_buf);
|
||||
full_msg.extend_from_slice(&data_buf[..data_len]);
|
||||
|
||||
match ChannelData::decode(&full_msg) {
|
||||
Ok(channel_data) => {
|
||||
if let Some(result) =
|
||||
handler.handle_channel_data(&channel_data, &ctx).await
|
||||
{
|
||||
execute_tcp_result(&mut writer, result, client_addr, &allocations, &udp_socket)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::debug!(
|
||||
"[TURN-TCP] failed to parse ChannelData from {}: {}",
|
||||
client_addr,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// STUN message: first 2 bytes are message type, next 2 are message length
|
||||
// We need to read the remaining 18 bytes of the 20-byte header
|
||||
let mut stun_header = [0u8; 20];
|
||||
stun_header[0] = header_buf[0];
|
||||
stun_header[1] = header_buf[1];
|
||||
reader.read_exact(&mut stun_header[2..20]).await?;
|
||||
|
||||
let msg_len = u16::from_be_bytes([stun_header[2], stun_header[3]]) as usize;
|
||||
|
||||
// Read the attribute body
|
||||
let mut msg_buf = Vec::with_capacity(20 + msg_len);
|
||||
msg_buf.extend_from_slice(&stun_header);
|
||||
if msg_len > 0 {
|
||||
let mut body = vec![0u8; msg_len];
|
||||
reader.read_exact(&mut body).await?;
|
||||
msg_buf.extend_from_slice(&body);
|
||||
}
|
||||
|
||||
match StunMessage::decode(&msg_buf) {
|
||||
Ok(msg) => {
|
||||
let results = handler.handle_message(&msg, &ctx).await;
|
||||
for result in results {
|
||||
execute_tcp_result(&mut writer, result, client_addr, &allocations, &udp_socket)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::debug!(
|
||||
"[TURN-TCP] failed to parse STUN message from {}: {}",
|
||||
client_addr,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a single [`HandleResult`] for a TCP client.
|
||||
///
|
||||
/// Responses are written directly to the TCP stream. Relay operations
|
||||
/// use the UDP relay socket (relay is always UDP, even for TCP clients).
|
||||
async fn execute_tcp_result(
|
||||
writer: &mut tokio::net::tcp::OwnedWriteHalf,
|
||||
result: HandleResult,
|
||||
client_addr: SocketAddr,
|
||||
allocations: &Arc<AllocationManager>,
|
||||
udp_socket: &Option<Arc<tokio::net::UdpSocket>>,
|
||||
) {
|
||||
match result {
|
||||
HandleResult::Response(data) => {
|
||||
if let Err(e) = writer.write_all(&data).await {
|
||||
log::error!(
|
||||
"[TURN-TCP] failed to write response to {}: {}",
|
||||
client_addr,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
HandleResult::RelayToPeer {
|
||||
peer_addr,
|
||||
data,
|
||||
relay_socket,
|
||||
} => {
|
||||
// Relay is always UDP even for TCP clients
|
||||
if let Err(e) = relay_socket.send_to(&data, peer_addr).await {
|
||||
log::error!(
|
||||
"[TURN-TCP] failed to relay to peer {}: {}",
|
||||
peer_addr,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
HandleResult::ChannelDataToPeer {
|
||||
peer_addr,
|
||||
data,
|
||||
relay_socket,
|
||||
} => {
|
||||
if let Err(e) = relay_socket.send_to(&data, peer_addr).await {
|
||||
log::error!(
|
||||
"[TURN-TCP] failed to send channel data to peer {}: {}",
|
||||
peer_addr,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
HandleResult::AllocationCreated {
|
||||
response,
|
||||
relay_socket,
|
||||
relay_addr,
|
||||
five_tuple,
|
||||
} => {
|
||||
// Send the success response to the TCP client
|
||||
if let Err(e) = writer.write_all(&response).await {
|
||||
log::error!(
|
||||
"[TURN-TCP] failed to write allocate response to {}: {}",
|
||||
client_addr,
|
||||
e
|
||||
);
|
||||
}
|
||||
|
||||
// For TCP clients, we still need a relay receiver task to handle
|
||||
// peer → relay socket → client. However, for TCP the data needs to
|
||||
// go back over the TCP stream. For simplicity, if a UDP socket is
|
||||
// available we use the UDP-based relay receiver. In a full
|
||||
// implementation, the relay receiver would write to the TCP stream.
|
||||
//
|
||||
// NOTE: In the current architecture, TCP allocations still relay
|
||||
// data via UDP. Peer data arriving on the relay socket will be
|
||||
// forwarded to the client's address via the main UDP socket (if
|
||||
// the client also has a UDP path). For pure TCP clients, a more
|
||||
// sophisticated approach would be needed.
|
||||
if let Some(main_udp) = udp_socket {
|
||||
let main_socket = Arc::clone(main_udp);
|
||||
let allocs = Arc::clone(allocations);
|
||||
let ft = five_tuple.clone();
|
||||
log::debug!(
|
||||
"[TURN-TCP] spawning relay receiver for {} (relay {})",
|
||||
client_addr,
|
||||
relay_addr
|
||||
);
|
||||
tokio::spawn(async move {
|
||||
super::udp_listener::relay_receiver_task(
|
||||
relay_socket,
|
||||
relay_addr,
|
||||
ft,
|
||||
main_socket,
|
||||
allocs,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
} else {
|
||||
log::warn!(
|
||||
"[TURN-TCP] no UDP socket available for relay receiver (allocation {})",
|
||||
relay_addr
|
||||
);
|
||||
}
|
||||
}
|
||||
HandleResult::None => {}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,339 @@
|
|||
// UDP TURN listener
|
||||
//
|
||||
// Listens on a UDP socket for incoming STUN/TURN messages and ChannelData.
|
||||
// Dispatches to the TurnHandler for processing and executes the resulting
|
||||
// I/O actions (send responses, relay data to peers).
|
||||
//
|
||||
// Also manages relay receiver tasks — when a new allocation is created,
|
||||
// a task is spawned to read data arriving on the allocation's relay socket
|
||||
// from peers. That data is wrapped as either ChannelData (if a channel
|
||||
// binding exists) or a STUN Data indication, and sent back to the client
|
||||
// via the main UDP socket.
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::net::UdpSocket;
|
||||
|
||||
use super::allocation::{AllocationManager, FiveTuple, TransportProtocol};
|
||||
use super::attributes::StunAttribute;
|
||||
use super::handler::{HandleResult, MessageContext, TurnHandler};
|
||||
use super::stun::{
|
||||
is_channel_data, is_stun_message, ChannelData, Class, MessageType, Method, StunMessage,
|
||||
compute_fingerprint,
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// UdpTurnListener
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// UDP listener for the TURN server.
|
||||
///
|
||||
/// Reads datagrams from the primary UDP socket, dispatches them to the
|
||||
/// [`TurnHandler`], and executes I/O actions. Also spawns relay receiver
|
||||
/// tasks for each new allocation.
|
||||
pub struct UdpTurnListener {
|
||||
/// The primary UDP socket bound to the TURN server port (e.g., 3478).
|
||||
socket: Arc<UdpSocket>,
|
||||
/// The shared TURN message handler.
|
||||
handler: Arc<TurnHandler>,
|
||||
/// The allocation manager (needed for relay receiver lookups).
|
||||
allocations: Arc<AllocationManager>,
|
||||
/// The server's listen address.
|
||||
server_addr: SocketAddr,
|
||||
/// The server's public IP (for MessageContext).
|
||||
server_public_ip: std::net::IpAddr,
|
||||
}
|
||||
|
||||
impl UdpTurnListener {
|
||||
/// Create a new UDP TURN listener.
|
||||
///
|
||||
/// - `socket`: The primary UDP socket (already bound to the listen address).
|
||||
/// - `handler`: The shared TURN message handler.
|
||||
/// - `allocations`: The shared allocation manager.
|
||||
/// - `server_addr`: The server's listen address.
|
||||
/// - `server_public_ip`: The server's public IP for relay addresses.
|
||||
pub fn new(
|
||||
socket: Arc<UdpSocket>,
|
||||
handler: Arc<TurnHandler>,
|
||||
allocations: Arc<AllocationManager>,
|
||||
server_addr: SocketAddr,
|
||||
server_public_ip: std::net::IpAddr,
|
||||
) -> Self {
|
||||
Self {
|
||||
socket,
|
||||
handler,
|
||||
allocations,
|
||||
server_addr,
|
||||
server_public_ip,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the UDP listener loop.
|
||||
///
|
||||
/// This reads datagrams from the primary UDP socket in a loop. Each
|
||||
/// incoming message is dispatched to a spawned task for processing,
|
||||
/// so the receive loop is never blocked by handler logic.
|
||||
///
|
||||
/// When the handler returns an [`HandleResult::AllocationCreated`], this
|
||||
/// also spawns a relay receiver task for the new allocation's relay socket.
|
||||
pub async fn run(self: Arc<Self>) {
|
||||
let mut buf = vec![0u8; 65536]; // 64KB max UDP datagram
|
||||
|
||||
loop {
|
||||
match self.socket.recv_from(&mut buf).await {
|
||||
Ok((len, client_addr)) => {
|
||||
let data = buf[..len].to_vec();
|
||||
let this = Arc::clone(&self);
|
||||
// Spawn a task per message to avoid blocking the receive loop
|
||||
tokio::spawn(async move {
|
||||
this.process_message(&data, client_addr).await;
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!("[TURN-UDP] recv error: {}", e);
|
||||
// Brief sleep to avoid tight error loop
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a single incoming UDP message.
|
||||
///
|
||||
/// Determines whether the message is ChannelData or a STUN message,
|
||||
/// dispatches to the handler, and executes the resulting actions.
|
||||
async fn process_message(&self, data: &[u8], client_addr: SocketAddr) {
|
||||
let ctx = MessageContext {
|
||||
client_addr,
|
||||
server_addr: self.server_addr,
|
||||
protocol: TransportProtocol::Udp,
|
||||
server_public_ip: self.server_public_ip,
|
||||
};
|
||||
|
||||
if is_channel_data(data) {
|
||||
// Parse ChannelData and handle
|
||||
match ChannelData::decode(data) {
|
||||
Ok(channel_data) => {
|
||||
if let Some(result) = self.handler.handle_channel_data(&channel_data, &ctx).await
|
||||
{
|
||||
self.execute_result(result, client_addr).await;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::debug!(
|
||||
"[TURN-UDP] failed to parse ChannelData from {}: {}",
|
||||
client_addr,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
} else if is_stun_message(data) {
|
||||
// Parse STUN message and handle
|
||||
match StunMessage::decode(data) {
|
||||
Ok(msg) => {
|
||||
let results = self.handler.handle_message(&msg, &ctx).await;
|
||||
for result in results {
|
||||
self.execute_result(result, client_addr).await;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
log::debug!(
|
||||
"[TURN-UDP] failed to parse STUN message from {}: {}",
|
||||
client_addr,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
// else: unknown message format, silently ignored
|
||||
}
|
||||
|
||||
/// Execute a single [`HandleResult`] action.
|
||||
///
|
||||
/// Sends responses to clients, relays data to peers, and spawns relay
|
||||
/// receiver tasks for new allocations.
|
||||
async fn execute_result(&self, result: HandleResult, client_addr: SocketAddr) {
|
||||
match result {
|
||||
HandleResult::Response(data) => {
|
||||
if let Err(e) = self.socket.send_to(&data, client_addr).await {
|
||||
log::error!(
|
||||
"[TURN-UDP] failed to send response to {}: {}",
|
||||
client_addr,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
HandleResult::RelayToPeer {
|
||||
peer_addr,
|
||||
data,
|
||||
relay_socket,
|
||||
} => {
|
||||
if let Err(e) = relay_socket.send_to(&data, peer_addr).await {
|
||||
log::error!(
|
||||
"[TURN-UDP] failed to relay to peer {}: {}",
|
||||
peer_addr,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
HandleResult::ChannelDataToPeer {
|
||||
peer_addr,
|
||||
data,
|
||||
relay_socket,
|
||||
} => {
|
||||
if let Err(e) = relay_socket.send_to(&data, peer_addr).await {
|
||||
log::error!(
|
||||
"[TURN-UDP] failed to send channel data to peer {}: {}",
|
||||
peer_addr,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
HandleResult::AllocationCreated {
|
||||
response,
|
||||
relay_socket,
|
||||
relay_addr,
|
||||
five_tuple,
|
||||
} => {
|
||||
// Send the success response to the client
|
||||
if let Err(e) = self.socket.send_to(&response, client_addr).await {
|
||||
log::error!(
|
||||
"[TURN-UDP] failed to send allocate response to {}: {}",
|
||||
client_addr,
|
||||
e
|
||||
);
|
||||
}
|
||||
|
||||
// Spawn a relay receiver task for this allocation
|
||||
let main_socket = Arc::clone(&self.socket);
|
||||
let allocations = Arc::clone(&self.allocations);
|
||||
let ft = five_tuple.clone();
|
||||
log::debug!(
|
||||
"[TURN-UDP] spawning relay receiver for {} (relay {})",
|
||||
client_addr,
|
||||
relay_addr
|
||||
);
|
||||
tokio::spawn(async move {
|
||||
relay_receiver_task(
|
||||
relay_socket,
|
||||
relay_addr,
|
||||
ft,
|
||||
main_socket,
|
||||
allocations,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
}
|
||||
HandleResult::None => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Relay receiver task
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Relay receiver task for a single allocation.
|
||||
///
|
||||
/// Reads data arriving on the allocation's relay socket from peers and
|
||||
/// forwards it to the client. The data is wrapped as:
|
||||
/// - **ChannelData** if there's a channel binding for the peer address
|
||||
/// - **STUN Data indication** otherwise (with XOR-PEER-ADDRESS and DATA)
|
||||
///
|
||||
/// The task runs until the relay socket encounters an unrecoverable error
|
||||
/// or the allocation is cleaned up (at which point recv_from will fail
|
||||
/// because the socket is dropped).
|
||||
///
|
||||
/// This function is `pub` so it can be reused by the TCP listener for
|
||||
/// TCP-originated allocations that still use UDP relay sockets.
|
||||
pub async fn relay_receiver_task(
|
||||
relay_socket: Arc<UdpSocket>,
|
||||
relay_addr: SocketAddr,
|
||||
five_tuple: FiveTuple,
|
||||
main_socket: Arc<UdpSocket>,
|
||||
allocations: Arc<AllocationManager>,
|
||||
) {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let client_addr = five_tuple.client_addr;
|
||||
|
||||
loop {
|
||||
match relay_socket.recv_from(&mut buf).await {
|
||||
Ok((len, peer_addr)) => {
|
||||
let data = &buf[..len];
|
||||
|
||||
// Check that a permission exists for this peer's IP
|
||||
if !allocations
|
||||
.has_permission(&five_tuple, &peer_addr.ip())
|
||||
.await
|
||||
{
|
||||
log::debug!(
|
||||
"[TURN-RELAY] dropping packet from {} to relay {}: no permission",
|
||||
peer_addr,
|
||||
relay_addr
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if there's a channel binding for this peer
|
||||
let wrapped = if let Some(channel_number) = allocations
|
||||
.get_channel_for_peer(&five_tuple, &peer_addr)
|
||||
.await
|
||||
{
|
||||
// Wrap as ChannelData
|
||||
let cd = ChannelData {
|
||||
channel_number,
|
||||
data: data.to_vec(),
|
||||
};
|
||||
cd.encode()
|
||||
} else {
|
||||
// Wrap as a STUN Data indication
|
||||
build_data_indication(peer_addr, data)
|
||||
};
|
||||
|
||||
// Send to the client via the main UDP socket
|
||||
if let Err(e) = main_socket.send_to(&wrapped, client_addr).await {
|
||||
log::error!(
|
||||
"[TURN-RELAY] failed to send to client {}: {}",
|
||||
client_addr,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// Socket error — likely the allocation was cleaned up and the
|
||||
// socket was dropped. Exit the task.
|
||||
log::debug!(
|
||||
"[TURN-RELAY] relay socket {} recv error (allocation likely expired): {}",
|
||||
relay_addr,
|
||||
e
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log::debug!(
|
||||
"[TURN-RELAY] relay receiver task for {} exiting",
|
||||
relay_addr
|
||||
);
|
||||
}
|
||||
|
||||
/// Build a STUN Data indication message (Method::Data, Class::Indication)
|
||||
/// with XOR-PEER-ADDRESS and DATA attributes.
|
||||
///
|
||||
/// Data indications are used when there's no channel binding for the peer.
|
||||
/// Per RFC 5766 §10.3, they don't require MESSAGE-INTEGRITY.
|
||||
fn build_data_indication(peer_addr: SocketAddr, data: &[u8]) -> Vec<u8> {
|
||||
let mut msg = StunMessage::new_random(MessageType::new(Method::Data, Class::Indication));
|
||||
|
||||
msg.add_attribute(StunAttribute::XorPeerAddress(peer_addr));
|
||||
msg.add_attribute(StunAttribute::Data(data.to_vec()));
|
||||
|
||||
// Add FINGERPRINT for demultiplexing
|
||||
let fp_bytes = msg.encode_for_fingerprint();
|
||||
let fingerprint = compute_fingerprint(&fp_bytes);
|
||||
msg.add_attribute(StunAttribute::Fingerprint(fingerprint));
|
||||
|
||||
msg.encode()
|
||||
}
|
||||
Loading…
Reference in New Issue