diff --git a/Cargo.lock b/Cargo.lock index 2b8f45e..333b083 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -120,18 +120,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "axum-auth" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8169113a185f54f68614fcfc3581df585d30bf8542bcb99496990e1025e4120a" -dependencies = [ - "async-trait", - "axum-core", - "base64 0.21.7", - "http", -] - [[package]] name = "axum-client-ip" version = "0.6.1" @@ -188,12 +176,6 @@ dependencies = [ "backtrace", ] -[[package]] -name = "base64" -version = "0.21.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" - [[package]] name = "base64" version = "0.22.1" @@ -1044,6 +1026,21 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower-http" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "403fa3b783d4b626a8ad51d766ab03cb6d2dbfc46b1c5d4448395e6628dc9697" +dependencies = [ + "bitflags", + "bytes", + "http", + "mime", + "pin-project-lite", + "tower-layer", + "tower-service", +] + [[package]] name = "tower-layer" version = "0.3.3" @@ -1165,9 +1162,8 @@ name = "webnsupdate" version = "0.3.2-dev" dependencies = [ "axum", - "axum-auth", "axum-client-ip", - "base64 0.22.1", + "base64", "clap", "clap-verbosity-flag", "http", @@ -1175,6 +1171,7 @@ dependencies = [ "miette", "ring", "tokio", + "tower-http", "tracing", "tracing-subscriber", ] diff --git a/Cargo.toml b/Cargo.toml index d977c4e..cf93cf8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,9 +8,6 @@ edition = "2021" [dependencies] axum = "0.7" -axum-auth = { version = "0.7", default-features = false, features = [ - "auth-basic", -] } axum-client-ip = "0.6" base64 = "0.22" clap = { version = "4", features = ["derive", "env"] } @@ -21,6 +18,7 @@ http = "1" miette = { version = "7", features = ["fancy"] } ring = { version = "0.17", features = ["std"] } tokio = { version = "1", features = ["macros", "rt", "process", "io-util"] } +tower-http = { version = "0.6.2", features = ["validate-request"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..846b3ff --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,104 @@ +use base64::engine::general_purpose::URL_SAFE_NO_PAD; +use base64::Engine; +use tower_http::validate_request::ValidateRequestHeaderLayer; +use tracing::{trace, warn}; + +use crate::password; + +pub fn auth_layer<'a, ResBody>( + user_pass_hash: &'a [u8], + salt: &'a str, +) -> ValidateRequestHeaderLayer> { + ValidateRequestHeaderLayer::custom(BasicAuth::new(user_pass_hash, salt)) +} + +#[derive(Copy)] +pub struct BasicAuth<'a, ResBody> { + pass: &'a [u8], + salt: &'a str, + _ty: std::marker::PhantomData ResBody>, +} + +impl std::fmt::Debug for BasicAuth<'_, ResBody> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BasicAuth") + .field("pass", &self.pass) + .field("salt", &self.salt) + .field("_ty", &self._ty) + .finish() + } +} + +impl Clone for BasicAuth<'_, ResBody> { + fn clone(&self) -> Self { + Self { + pass: self.pass, + salt: self.salt, + _ty: std::marker::PhantomData, + } + } +} + +impl<'a, ResBody> BasicAuth<'a, ResBody> { + pub fn new(pass: &'a [u8], salt: &'a str) -> Self { + Self { + pass, + salt, + _ty: std::marker::PhantomData, + } + } + + fn check_headers(&self, headers: &http::HeaderMap) -> bool { + let Some(auth) = headers.get(http::header::AUTHORIZATION) else { + return false; + }; + + // Poor man's split once: https://doc.rust-lang.org/std/primitive.slice.html#method.split_once + let Some(index) = auth.as_bytes().iter().position(|&c| c == b' ') else { + return false; + }; + let user_pass = &auth.as_bytes()[index + 1..]; + + match base64::engine::general_purpose::URL_SAFE.decode(user_pass) { + Ok(user_pass) => { + let hashed = password::hash_basic_auth(&user_pass, self.salt); + if hashed.as_ref() == self.pass { + return true; + } + warn!("rejected update"); + trace!( + "mismatched hashes:\nprovided: {}\nstored: {}", + URL_SAFE_NO_PAD.encode(hashed.as_ref()), + URL_SAFE_NO_PAD.encode(self.pass), + ); + false + } + Err(err) => { + warn!("received invalid base64 when decoding Basic header: {err}"); + false + } + } + } +} + +impl tower_http::validate_request::ValidateRequest for BasicAuth<'_, ResBody> +where + ResBody: Default, +{ + type ResponseBody = ResBody; + + fn validate( + &mut self, + request: &mut http::Request, + ) -> std::result::Result<(), http::Response> { + if self.check_headers(request.headers()) { + return Ok(()); + } + + let mut res = http::Response::new(ResBody::default()); + *res.status_mut() = http::status::StatusCode::UNAUTHORIZED; + res.headers_mut() + .insert(http::header::WWW_AUTHENTICATE, "Basic".parse().unwrap()); + Err(res) + } +} diff --git a/src/main.rs b/src/main.rs index e216b58..b850e02 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,8 +7,7 @@ use std::{ time::Duration, }; -use axum::{extract::State, routing::get, Json, Router}; -use axum_auth::AuthBasic; +use axum::{extract::State, routing::get, Router}; use axum_client_ip::{SecureClientIp, SecureClientIpSource}; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use clap::{Parser, Subcommand}; @@ -16,9 +15,10 @@ use clap_verbosity_flag::Verbosity; use http::StatusCode; use miette::{bail, ensure, Context, IntoDiagnostic, Result}; use tokio::io::AsyncWriteExt; -use tracing::{debug, error, info, trace, warn}; +use tracing::{debug, error, info, warn}; use tracing_subscriber::EnvFilter; +mod auth; mod password; mod records; @@ -108,18 +108,12 @@ struct AppState<'a> { /// TTL set on the Zonefile ttl: Duration, - /// Salt added to the password - salt: &'a str, - /// The IN A/AAAA records that should have their IPs updated records: &'a [&'a str], /// The TSIG key file key_file: Option<&'a Path>, - /// The password hash - password_hash: Option<&'a [u8]>, - /// The file where the last IP is stored ip_file: &'a Path, } @@ -195,9 +189,23 @@ fn main() -> Result<()> { // Use last registered IP address if available let ip_file = data_dir.join("last-ip"); + // Load password hash + let password_hash = password_file + .map(|path| -> miette::Result<_> { + let pass = std::fs::read_to_string(path.as_path()).into_diagnostic()?; + + let pass: Box<[u8]> = URL_SAFE_NO_PAD + .decode(pass.trim().as_bytes()) + .into_diagnostic() + .wrap_err_with(|| format!("failed to decode password from {}", path.display()))? + .into(); + + Ok(pass) + }) + .transpose()?; + let state = AppState { ttl, - salt: salt.leak(), // Load DNS records records: records::load_no_verify(&records)?, // Load keyfile @@ -212,25 +220,11 @@ fn main() -> Result<()> { Ok(&*Box::leak(key_file.into_boxed_path())) }) .transpose()?, - // Load password hash - password_hash: password_file - .map(|path| -> miette::Result<_> { - let pass = std::fs::read_to_string(path.as_path()).into_diagnostic()?; - - let pass: Box<[u8]> = URL_SAFE_NO_PAD - .decode(pass.trim().as_bytes()) - .into_diagnostic() - .wrap_err_with(|| format!("failed to decode password from {}", path.display()))? - .into(); - - Ok(&*Box::leak(pass)) - }) - .transpose()?, ip_file: Box::leak(ip_file.into_boxed_path()), }; ensure!( - state.password_hash.is_some() || insecure, + password_hash.is_some() || insecure, "a password must be used" ); @@ -270,11 +264,18 @@ fn main() -> Result<()> { } }; + // Create services + let app = Router::new().route("/update", get(update_records)); + // if a password is provided, validate it + let app = if let Some(pass) = password_hash { + app.layer(auth::auth_layer(Box::leak(pass), String::leak(salt))) + } else { + app + } + .layer(ip_source.into_extension()) + .with_state(state); + // Start services - let app = Router::new() - .route("/update", get(update_records)) - .layer(ip_source.into_extension()) - .with_state(state); info!("starting listener on {ip}:{port}"); let listener = tokio::net::TcpListener::bind(SocketAddr::new(ip, port)) .await @@ -289,31 +290,12 @@ fn main() -> Result<()> { }) } -#[tracing::instrument(skip(state, pass), level = "trace", ret(level = "info"))] +#[tracing::instrument(skip(state), level = "trace", ret(level = "info"))] async fn update_records( State(state): State>, - AuthBasic((username, pass)): AuthBasic, SecureClientIp(ip): SecureClientIp, ) -> axum::response::Result<&'static str> { debug!("received update request from {ip}"); - let Some(pass) = pass else { - return Err((StatusCode::UNAUTHORIZED, Json::from("no password provided")).into()); - }; - - if let Some(stored_pass) = state.password_hash { - let password = pass.trim().to_string(); - let pass_hash = password::hash_identity(&username, &password, state.salt); - if pass_hash.as_ref() != stored_pass { - warn!("rejected update"); - trace!( - "mismatched hashes:\n{}\n{}", - URL_SAFE_NO_PAD.encode(pass_hash.as_ref()), - URL_SAFE_NO_PAD.encode(stored_pass), - ); - return Err((StatusCode::UNAUTHORIZED, "invalid identity").into()); - } - } - info!("accepted update"); match nsupdate(ip, state.ttl, state.key_file, state.records).await { Ok(status) if status.success() => { diff --git a/src/password.rs b/src/password.rs index 84574a2..8d965ba 100644 --- a/src/password.rs +++ b/src/password.rs @@ -28,10 +28,20 @@ impl Mkpasswd { } } +pub fn hash_basic_auth(user_pass: &[u8], salt: &str) -> Digest { + let mut context = ring::digest::Context::new(&ring::digest::SHA256); + context.update(user_pass); + context.update(salt.as_bytes()); + context.finish() +} + pub fn hash_identity(username: &str, password: &str, salt: &str) -> Digest { - let mut data = Vec::with_capacity(username.len() + password.len() + salt.len() + 1); - write!(data, "{username}:{password}{salt}").unwrap(); - ring::digest::digest(&ring::digest::SHA256, &data) + let mut context = ring::digest::Context::new(&ring::digest::SHA256); + context.update(username.as_bytes()); + context.update(b":"); + context.update(password.as_bytes()); + context.update(salt.as_bytes()); + context.finish() } pub fn mkpasswd(