feat: replace axum-auth with tower_http
Slightly more involde in the auth code, but it makes the rest of the application more straight forward. Fixes #10
This commit is contained in:
parent
60aed649b1
commit
750cbbff93
5 changed files with 166 additions and 75 deletions
37
Cargo.lock
generated
37
Cargo.lock
generated
|
@ -120,18 +120,6 @@ dependencies = [
|
|||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-auth"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8169113a185f54f68614fcfc3581df585d30bf8542bcb99496990e1025e4120a"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core",
|
||||
"base64 0.21.7",
|
||||
"http",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-client-ip"
|
||||
version = "0.6.1"
|
||||
|
@ -188,12 +176,6 @@ dependencies = [
|
|||
"backtrace",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.21.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.22.1"
|
||||
|
@ -1044,6 +1026,21 @@ dependencies = [
|
|||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-http"
|
||||
version = "0.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "403fa3b783d4b626a8ad51d766ab03cb6d2dbfc46b1c5d4448395e6628dc9697"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"bytes",
|
||||
"http",
|
||||
"mime",
|
||||
"pin-project-lite",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-layer"
|
||||
version = "0.3.3"
|
||||
|
@ -1165,9 +1162,8 @@ name = "webnsupdate"
|
|||
version = "0.3.2-dev"
|
||||
dependencies = [
|
||||
"axum",
|
||||
"axum-auth",
|
||||
"axum-client-ip",
|
||||
"base64 0.22.1",
|
||||
"base64",
|
||||
"clap",
|
||||
"clap-verbosity-flag",
|
||||
"http",
|
||||
|
@ -1175,6 +1171,7 @@ dependencies = [
|
|||
"miette",
|
||||
"ring",
|
||||
"tokio",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
|
|
@ -8,9 +8,6 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
axum = "0.7"
|
||||
axum-auth = { version = "0.7", default-features = false, features = [
|
||||
"auth-basic",
|
||||
] }
|
||||
axum-client-ip = "0.6"
|
||||
base64 = "0.22"
|
||||
clap = { version = "4", features = ["derive", "env"] }
|
||||
|
@ -21,6 +18,7 @@ http = "1"
|
|||
miette = { version = "7", features = ["fancy"] }
|
||||
ring = { version = "0.17", features = ["std"] }
|
||||
tokio = { version = "1", features = ["macros", "rt", "process", "io-util"] }
|
||||
tower-http = { version = "0.6.2", features = ["validate-request"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
|
|
104
src/auth.rs
Normal file
104
src/auth.rs
Normal file
|
@ -0,0 +1,104 @@
|
|||
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
|
||||
use base64::Engine;
|
||||
use tower_http::validate_request::ValidateRequestHeaderLayer;
|
||||
use tracing::{trace, warn};
|
||||
|
||||
use crate::password;
|
||||
|
||||
pub fn auth_layer<'a, ResBody>(
|
||||
user_pass_hash: &'a [u8],
|
||||
salt: &'a str,
|
||||
) -> ValidateRequestHeaderLayer<BasicAuth<'a, ResBody>> {
|
||||
ValidateRequestHeaderLayer::custom(BasicAuth::new(user_pass_hash, salt))
|
||||
}
|
||||
|
||||
#[derive(Copy)]
|
||||
pub struct BasicAuth<'a, ResBody> {
|
||||
pass: &'a [u8],
|
||||
salt: &'a str,
|
||||
_ty: std::marker::PhantomData<fn() -> ResBody>,
|
||||
}
|
||||
|
||||
impl<ResBody> std::fmt::Debug for BasicAuth<'_, ResBody> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("BasicAuth")
|
||||
.field("pass", &self.pass)
|
||||
.field("salt", &self.salt)
|
||||
.field("_ty", &self._ty)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<ResBody> Clone for BasicAuth<'_, ResBody> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
pass: self.pass,
|
||||
salt: self.salt,
|
||||
_ty: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, ResBody> BasicAuth<'a, ResBody> {
|
||||
pub fn new(pass: &'a [u8], salt: &'a str) -> Self {
|
||||
Self {
|
||||
pass,
|
||||
salt,
|
||||
_ty: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
fn check_headers(&self, headers: &http::HeaderMap<http::HeaderValue>) -> bool {
|
||||
let Some(auth) = headers.get(http::header::AUTHORIZATION) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
// Poor man's split once: https://doc.rust-lang.org/std/primitive.slice.html#method.split_once
|
||||
let Some(index) = auth.as_bytes().iter().position(|&c| c == b' ') else {
|
||||
return false;
|
||||
};
|
||||
let user_pass = &auth.as_bytes()[index + 1..];
|
||||
|
||||
match base64::engine::general_purpose::URL_SAFE.decode(user_pass) {
|
||||
Ok(user_pass) => {
|
||||
let hashed = password::hash_basic_auth(&user_pass, self.salt);
|
||||
if hashed.as_ref() == self.pass {
|
||||
return true;
|
||||
}
|
||||
warn!("rejected update");
|
||||
trace!(
|
||||
"mismatched hashes:\nprovided: {}\nstored: {}",
|
||||
URL_SAFE_NO_PAD.encode(hashed.as_ref()),
|
||||
URL_SAFE_NO_PAD.encode(self.pass),
|
||||
);
|
||||
false
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("received invalid base64 when decoding Basic header: {err}");
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, ResBody> tower_http::validate_request::ValidateRequest<B> for BasicAuth<'_, ResBody>
|
||||
where
|
||||
ResBody: Default,
|
||||
{
|
||||
type ResponseBody = ResBody;
|
||||
|
||||
fn validate(
|
||||
&mut self,
|
||||
request: &mut http::Request<B>,
|
||||
) -> std::result::Result<(), http::Response<Self::ResponseBody>> {
|
||||
if self.check_headers(request.headers()) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut res = http::Response::new(ResBody::default());
|
||||
*res.status_mut() = http::status::StatusCode::UNAUTHORIZED;
|
||||
res.headers_mut()
|
||||
.insert(http::header::WWW_AUTHENTICATE, "Basic".parse().unwrap());
|
||||
Err(res)
|
||||
}
|
||||
}
|
78
src/main.rs
78
src/main.rs
|
@ -7,8 +7,7 @@ use std::{
|
|||
time::Duration,
|
||||
};
|
||||
|
||||
use axum::{extract::State, routing::get, Json, Router};
|
||||
use axum_auth::AuthBasic;
|
||||
use axum::{extract::State, routing::get, Router};
|
||||
use axum_client_ip::{SecureClientIp, SecureClientIpSource};
|
||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
|
||||
use clap::{Parser, Subcommand};
|
||||
|
@ -16,9 +15,10 @@ use clap_verbosity_flag::Verbosity;
|
|||
use http::StatusCode;
|
||||
use miette::{bail, ensure, Context, IntoDiagnostic, Result};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tracing::{debug, error, info, trace, warn};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
mod auth;
|
||||
mod password;
|
||||
mod records;
|
||||
|
||||
|
@ -108,18 +108,12 @@ struct AppState<'a> {
|
|||
/// TTL set on the Zonefile
|
||||
ttl: Duration,
|
||||
|
||||
/// Salt added to the password
|
||||
salt: &'a str,
|
||||
|
||||
/// The IN A/AAAA records that should have their IPs updated
|
||||
records: &'a [&'a str],
|
||||
|
||||
/// The TSIG key file
|
||||
key_file: Option<&'a Path>,
|
||||
|
||||
/// The password hash
|
||||
password_hash: Option<&'a [u8]>,
|
||||
|
||||
/// The file where the last IP is stored
|
||||
ip_file: &'a Path,
|
||||
}
|
||||
|
@ -195,9 +189,23 @@ fn main() -> Result<()> {
|
|||
// Use last registered IP address if available
|
||||
let ip_file = data_dir.join("last-ip");
|
||||
|
||||
// Load password hash
|
||||
let password_hash = password_file
|
||||
.map(|path| -> miette::Result<_> {
|
||||
let pass = std::fs::read_to_string(path.as_path()).into_diagnostic()?;
|
||||
|
||||
let pass: Box<[u8]> = URL_SAFE_NO_PAD
|
||||
.decode(pass.trim().as_bytes())
|
||||
.into_diagnostic()
|
||||
.wrap_err_with(|| format!("failed to decode password from {}", path.display()))?
|
||||
.into();
|
||||
|
||||
Ok(pass)
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let state = AppState {
|
||||
ttl,
|
||||
salt: salt.leak(),
|
||||
// Load DNS records
|
||||
records: records::load_no_verify(&records)?,
|
||||
// Load keyfile
|
||||
|
@ -212,25 +220,11 @@ fn main() -> Result<()> {
|
|||
Ok(&*Box::leak(key_file.into_boxed_path()))
|
||||
})
|
||||
.transpose()?,
|
||||
// Load password hash
|
||||
password_hash: password_file
|
||||
.map(|path| -> miette::Result<_> {
|
||||
let pass = std::fs::read_to_string(path.as_path()).into_diagnostic()?;
|
||||
|
||||
let pass: Box<[u8]> = URL_SAFE_NO_PAD
|
||||
.decode(pass.trim().as_bytes())
|
||||
.into_diagnostic()
|
||||
.wrap_err_with(|| format!("failed to decode password from {}", path.display()))?
|
||||
.into();
|
||||
|
||||
Ok(&*Box::leak(pass))
|
||||
})
|
||||
.transpose()?,
|
||||
ip_file: Box::leak(ip_file.into_boxed_path()),
|
||||
};
|
||||
|
||||
ensure!(
|
||||
state.password_hash.is_some() || insecure,
|
||||
password_hash.is_some() || insecure,
|
||||
"a password must be used"
|
||||
);
|
||||
|
||||
|
@ -270,11 +264,18 @@ fn main() -> Result<()> {
|
|||
}
|
||||
};
|
||||
|
||||
// Start services
|
||||
let app = Router::new()
|
||||
.route("/update", get(update_records))
|
||||
// Create services
|
||||
let app = Router::new().route("/update", get(update_records));
|
||||
// if a password is provided, validate it
|
||||
let app = if let Some(pass) = password_hash {
|
||||
app.layer(auth::auth_layer(Box::leak(pass), String::leak(salt)))
|
||||
} else {
|
||||
app
|
||||
}
|
||||
.layer(ip_source.into_extension())
|
||||
.with_state(state);
|
||||
|
||||
// Start services
|
||||
info!("starting listener on {ip}:{port}");
|
||||
let listener = tokio::net::TcpListener::bind(SocketAddr::new(ip, port))
|
||||
.await
|
||||
|
@ -289,31 +290,12 @@ fn main() -> Result<()> {
|
|||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(state, pass), level = "trace", ret(level = "info"))]
|
||||
#[tracing::instrument(skip(state), level = "trace", ret(level = "info"))]
|
||||
async fn update_records(
|
||||
State(state): State<AppState<'static>>,
|
||||
AuthBasic((username, pass)): AuthBasic,
|
||||
SecureClientIp(ip): SecureClientIp,
|
||||
) -> axum::response::Result<&'static str> {
|
||||
debug!("received update request from {ip}");
|
||||
let Some(pass) = pass else {
|
||||
return Err((StatusCode::UNAUTHORIZED, Json::from("no password provided")).into());
|
||||
};
|
||||
|
||||
if let Some(stored_pass) = state.password_hash {
|
||||
let password = pass.trim().to_string();
|
||||
let pass_hash = password::hash_identity(&username, &password, state.salt);
|
||||
if pass_hash.as_ref() != stored_pass {
|
||||
warn!("rejected update");
|
||||
trace!(
|
||||
"mismatched hashes:\n{}\n{}",
|
||||
URL_SAFE_NO_PAD.encode(pass_hash.as_ref()),
|
||||
URL_SAFE_NO_PAD.encode(stored_pass),
|
||||
);
|
||||
return Err((StatusCode::UNAUTHORIZED, "invalid identity").into());
|
||||
}
|
||||
}
|
||||
|
||||
info!("accepted update");
|
||||
match nsupdate(ip, state.ttl, state.key_file, state.records).await {
|
||||
Ok(status) if status.success() => {
|
||||
|
|
|
@ -28,10 +28,20 @@ impl Mkpasswd {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn hash_basic_auth(user_pass: &[u8], salt: &str) -> Digest {
|
||||
let mut context = ring::digest::Context::new(&ring::digest::SHA256);
|
||||
context.update(user_pass);
|
||||
context.update(salt.as_bytes());
|
||||
context.finish()
|
||||
}
|
||||
|
||||
pub fn hash_identity(username: &str, password: &str, salt: &str) -> Digest {
|
||||
let mut data = Vec::with_capacity(username.len() + password.len() + salt.len() + 1);
|
||||
write!(data, "{username}:{password}{salt}").unwrap();
|
||||
ring::digest::digest(&ring::digest::SHA256, &data)
|
||||
let mut context = ring::digest::Context::new(&ring::digest::SHA256);
|
||||
context.update(username.as_bytes());
|
||||
context.update(b":");
|
||||
context.update(password.as_bytes());
|
||||
context.update(salt.as_bytes());
|
||||
context.finish()
|
||||
}
|
||||
|
||||
pub fn mkpasswd(
|
||||
|
|
Loading…
Reference in a new issue