From 5745e1aaf758594f94b898b6a19d94259cffd761 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jalil=20David=20Salam=C3=A9=20Messina?= Date: Sat, 12 Oct 2024 22:40:46 +0200 Subject: [PATCH] feat: refactor and add ip saving - clean up code to my new (and improved) standards - print better miette diagnostigs for the error tests - add an IP saving feature: save the last IP that successfully updated the records and usa that when restarting the service. This allows seamles upgrades of `webnsupdate` without having to manually trigger a DNS update --- .cargo/config | 1 - Cargo.toml | 21 +- default.nix | 2 + src/main.rs | 536 +++++++----------- src/password.rs | 60 ++ src/records.rs | 283 +++++++++ ...bnsupdate__records__test__empty_label.snap | 16 + ...ate__records__test__hostname_too_long.snap | 16 + ...supdate__records__test__invalid_ascii.snap | 16 + ...supdate__records__test__invalid_octet.snap | 16 + ...update__records__test__label_too_long.snap | 16 + .../webnsupdate__records__test__not_fqd.snap | 16 + 12 files changed, 646 insertions(+), 353 deletions(-) delete mode 120000 .cargo/config create mode 100644 src/password.rs create mode 100644 src/records.rs create mode 100644 src/snapshots/webnsupdate__records__test__empty_label.snap create mode 100644 src/snapshots/webnsupdate__records__test__hostname_too_long.snap create mode 100644 src/snapshots/webnsupdate__records__test__invalid_ascii.snap create mode 100644 src/snapshots/webnsupdate__records__test__invalid_octet.snap create mode 100644 src/snapshots/webnsupdate__records__test__label_too_long.snap create mode 100644 src/snapshots/webnsupdate__records__test__not_fqd.snap diff --git a/.cargo/config b/.cargo/config deleted file mode 120000 index ab8b69c..0000000 --- a/.cargo/config +++ /dev/null @@ -1 +0,0 @@ -config.toml \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 29b74b4..1622810 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,5 @@ cargo-features = ["codegen-backend"] + [package] description = "An HTTP server using HTTP basic auth to make secure calls to nsupdate" name = "webnsupdate" @@ -7,24 +8,26 @@ edition = "2021" [dependencies] axum = "0.7.7" +axum-auth = { version = "0.7.0", default-features = false, features = [ + "auth-basic", +] } axum-client-ip = "0.6.1" base64 = "0.22.1" clap = { version = "4.5.20", features = ["derive", "env"] } http = "1.1.0" -insta = "1.40.0" miette = { version = "7.2.0", features = ["fancy"] } ring = { version = "0.17.8", features = ["std"] } +tokio = { version = "1.40.0", features = [ + "macros", + "rt", + "process", + "io-util", +] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } -[dependencies.axum-auth] -version = "0.7.0" -default-features = false -features = ["auth-basic"] - -[dependencies.tokio] -version = "1.40.0" -features = ["macros", "rt", "process", "io-util"] +[dev-dependencies] +insta = "1.40.0" [profile.dev] debug = 0 diff --git a/default.nix b/default.nix index 19889c4..e224dd7 100644 --- a/default.nix +++ b/default.nix @@ -19,6 +19,8 @@ let ".rs" # TOML files are often used to configure cargo based tools (e.g. .cargo/config.toml) ".toml" + # Snapshot tests + ".snap" ]; isCargoLock = base == "Cargo.lock"; in diff --git a/src/main.rs b/src/main.rs index 07b5a4f..19c1e66 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,7 @@ use std::{ ffi::OsStr, - io::Write, + io::ErrorKind, net::{IpAddr, SocketAddr}, - os::unix::fs::OpenOptionsExt, path::{Path, PathBuf}, process::{ExitStatus, Stdio}, time::Duration, @@ -12,14 +11,16 @@ use axum::{extract::State, routing::get, Json, Router}; use axum_auth::AuthBasic; use axum_client_ip::{SecureClientIp, SecureClientIpSource}; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; -use clap::{Args, Parser, Subcommand}; +use clap::{Parser, Subcommand}; use http::StatusCode; -use miette::{ensure, miette, Context, IntoDiagnostic, LabeledSpan, NamedSource, Result}; -use ring::digest::Digest; +use miette::{bail, ensure, Context, IntoDiagnostic, Result}; use tokio::io::AsyncWriteExt; use tracing::{debug, error, info, level_filters::LevelFilter, trace, warn}; use tracing_subscriber::EnvFilter; +mod password; +mod records; + const DEFAULT_TTL: Duration = Duration::from_secs(60); const DEFAULT_SALT: &str = "UpdateMyDNS"; @@ -28,20 +29,29 @@ struct Opts { /// Ip address of the server #[arg(long, default_value = "127.0.0.1")] address: IpAddr, + /// Port of the server #[arg(long, default_value_t = 5353)] port: u16, + /// File containing password to match against /// /// Should be of the format `username:password` and contain a single password #[arg(long)] password_file: Option, + /// Salt to get more unique hashed passwords and prevent table based attacks #[arg(long, default_value = DEFAULT_SALT)] salt: String, + /// Time To Live (in seconds) to set on the DNS records #[arg(long, default_value_t = DEFAULT_TTL.as_secs())] ttl: u64, + + /// Data directory + #[arg(long, default_value = ".")] + data_dir: PathBuf, + /// File containing the records that should be updated when an update request is made /// /// There should be one record per line: @@ -52,70 +62,92 @@ struct Opts { /// ``` #[arg(long)] records: PathBuf, + /// Keyfile `nsupdate` should use /// /// If specified, then `webnsupdate` must have read access to the file #[arg(long)] key_file: Option, + /// Allow not setting a password #[arg(long)] insecure: bool, + /// Set client IP source /// /// see: https://docs.rs/axum-client-ip/latest/axum_client_ip/enum.SecureClientIpSource.html #[clap(long, default_value = "RightmostXForwardedFor")] ip_source: SecureClientIpSource, + #[clap(subcommand)] subcommand: Option, } -#[derive(Debug, Args)] -struct Mkpasswd { - /// The username - username: String, - /// The password - password: String, -} - #[derive(Debug, Subcommand)] enum Cmd { - /// Create a password file - /// - /// If `--password-file` is provided, the password is written to that file - Mkpasswd(Mkpasswd), + Mkpasswd(password::Mkpasswd), /// Verify the records file Verify, } +impl Cmd { + pub fn process(self, args: &Opts) -> Result<()> { + match self { + Cmd::Mkpasswd(mkpasswd) => mkpasswd.process(args), + Cmd::Verify => records::load(&args.records).map(drop), + } + } +} + #[derive(Clone)] struct AppState<'a> { /// TTL set on the Zonefile ttl: Duration, + /// Salt added to the password salt: &'a str, + /// The IN A/AAAA records that should have their IPs updated records: &'a [&'a str], + /// The TSIG key file key_file: Option<&'a Path>, + /// The password hash password_hash: Option<&'a [u8]>, + + /// The file where the last IP is stored + ip_file: &'a Path, } -#[tokio::main(flavor = "current_thread")] -async fn main() -> Result<()> { +fn load_ip(path: &Path) -> Result> { + let data = match std::fs::read_to_string(path) { + Ok(ip) => ip, + Err(err) => { + return match err.kind() { + ErrorKind::NotFound => Ok(None), + _ => Err(err).into_diagnostic().wrap_err_with(|| { + format!("failed to load last ip address from {}", path.display()) + }), + } + } + }; + + Ok(Some( + data.parse() + .into_diagnostic() + .wrap_err("failed to parse last ip address")?, + )) +} + +fn main() -> Result<()> { + // set panic hook to pretty print with miette's formatter miette::set_panic_hook(); - let Opts { - address: ip, - port, - password_file, - key_file, - insecure, - subcommand, - records, - salt, - ttl, - ip_source, - } = Opts::parse(); + + // parse cli arguments + let mut args = Opts::parse(); + + // configure logger let subscriber = tracing_subscriber::FmtSubscriber::builder() .without_time() .with_env_filter( @@ -127,74 +159,126 @@ async fn main() -> Result<()> { tracing::subscriber::set_global_default(subscriber) .into_diagnostic() .wrap_err("setting global tracing subscriber")?; - match subcommand { - Some(Cmd::Mkpasswd(args)) => return mkpasswd(args, password_file.as_deref(), &salt), - Some(Cmd::Verify) => { - let data = std::fs::read_to_string(&records) - .into_diagnostic() - .wrap_err_with(|| format!("trying to read {}", records.display()))?; - return verify_records(&data, &records); - } - None => {} + + // process subcommand + if let Some(cmd) = args.subcommand.take() { + return cmd.process(&args); } + + let Opts { + address: ip, + port, + password_file, + data_dir, + key_file, + insecure, + subcommand: _, + records, + salt, + ttl, + ip_source, + } = args; + info!("checking environment"); + // Set state let ttl = Duration::from_secs(ttl); - let mut state = AppState { + + // Use last registered IP address if available + let ip_file = data_dir.join("last-ip"); + + let state = AppState { ttl, salt: salt.leak(), - records: &[], - key_file: None, - password_hash: None, - }; - if let Some(path) = password_file { - let pass = std::fs::read_to_string(&path).into_diagnostic()?; + // Load DNS records + records: records::load_no_verify(&records)?, + // Load keyfile + key_file: key_file + .map(|key_file| -> miette::Result<_> { + let path = key_file.as_path(); + std::fs::File::open(path) + .into_diagnostic() + .wrap_err_with(|| { + format!("{} is not readable by the current user", path.display()) + })?; + Ok(&*Box::leak(key_file.into_boxed_path())) + }) + .transpose()?, + // Load password hash + password_hash: password_file + .map(|path| -> miette::Result<_> { + let pass = std::fs::read_to_string(path.as_path()).into_diagnostic()?; - let pass: Box<[u8]> = URL_SAFE_NO_PAD - .decode(pass.trim().as_bytes()) - .into_diagnostic() - .wrap_err_with(|| format!("failed to decode password from {}", path.display()))? - .into(); - state.password_hash = Some(Box::leak(pass)); - } else { - ensure!(insecure, "a password must be used"); - } - if let Some(key_file) = key_file { - let path = key_file.as_path(); - std::fs::File::open(path) - .into_diagnostic() - .wrap_err_with(|| format!("{} is not readable by the current user", path.display()))?; - state.key_file = Some(Box::leak(key_file.into_boxed_path())); - } else { - ensure!(insecure, "a key file must be used"); - } - let data = std::fs::read_to_string(&records) + let pass: Box<[u8]> = URL_SAFE_NO_PAD + .decode(pass.trim().as_bytes()) + .into_diagnostic() + .wrap_err_with(|| format!("failed to decode password from {}", path.display()))? + .into(); + + Ok(&*Box::leak(pass)) + }) + .transpose()?, + ip_file: Box::leak(ip_file.into_boxed_path()), + }; + + ensure!( + state.password_hash.is_some() || insecure, + "a password must be used" + ); + + ensure!( + state.key_file.is_some() || insecure, + "a key file must be used" + ); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() .into_diagnostic() - .wrap_err_with(|| format!("loading records from {}", records.display()))?; - if let Err(err) = verify_records(&data, &records) { - warn!("invalid records found: {err}"); - } - state.records = data - .lines() - .map(|s| &*s.to_string().leak()) - .collect::>() - .leak(); - // Start services - let app = Router::new() - .route("/update", get(update_records)) - .layer(ip_source.into_extension()) - .with_state(state); - info!("starting listener on {ip}:{port}"); - let listener = tokio::net::TcpListener::bind(SocketAddr::new(ip, port)) + .wrap_err("failed to start the tokio runtime")?; + + rt.block_on(async { + // Load previous IP and update DNS record to point to it (if available) + match load_ip(state.ip_file) { + Ok(Some(ip)) => match nsupdate(ip, ttl, state.key_file, state.records).await { + Ok(status) => { + if !status.success() { + error!("nsupdate failed: code {status}"); + bail!("nsupdate returned with code {status}"); + } + } + Err(err) => { + error!("Failed to update records with previous IP: {err}"); + return Err(err) + .into_diagnostic() + .wrap_err("failed to update records with previous IP"); + } + }, + Ok(None) => { + info!("No previous IP address set"); + } + Err(err) => { + error!("Failed to load last ip address: {err}") + } + }; + + // Start services + let app = Router::new() + .route("/update", get(update_records)) + .layer(ip_source.into_extension()) + .with_state(state); + info!("starting listener on {ip}:{port}"); + let listener = tokio::net::TcpListener::bind(SocketAddr::new(ip, port)) + .await + .into_diagnostic()?; + info!("listening on {ip}:{port}"); + axum::serve( + listener, + app.into_make_service_with_connect_info::(), + ) .await - .into_diagnostic()?; - info!("listening on {ip}:{port}"); - axum::serve( - listener, - app.into_make_service_with_connect_info::(), - ) - .await - .into_diagnostic() + .into_diagnostic() + }) } #[tracing::instrument(skip(state, pass), level = "trace", ret(level = "info"))] @@ -206,32 +290,38 @@ async fn update_records( let Some(pass) = pass else { return Err((StatusCode::UNAUTHORIZED, Json::from("no password provided")).into()); }; + if let Some(stored_pass) = state.password_hash { let password = pass.trim().to_string(); - let pass_hash = hash_identity(&username, &password, state.salt); + let pass_hash = password::hash_identity(&username, &password, state.salt); if pass_hash.as_ref() != stored_pass { warn!("rejected update"); trace!( "mismatched hashes:\n{}\n{}", URL_SAFE_NO_PAD.encode(pass_hash.as_ref()), - URL_SAFE_NO_PAD.encode(stored_pass.as_ref()), + URL_SAFE_NO_PAD.encode(stored_pass), ); return Err((StatusCode::UNAUTHORIZED, "invalid identity").into()); } } + info!("accepted update"); match nsupdate(ip, state.ttl, state.key_file, state.records).await { + Ok(status) if status.success() => { + tokio::task::spawn_blocking(move || { + if let Err(err) = std::fs::write(state.ip_file, format!("{ip}")) { + error!("Failed to update last IP: {err}"); + } + }); + Ok("successful update") + } Ok(status) => { - if status.success() { - Ok("successful update") - } else { - error!("nsupdate failed"); - Err(( - StatusCode::INTERNAL_SERVER_ERROR, - "nsupdate failed, check server logs", - ) - .into()) - } + error!("nsupdate failed with code {status}"); + Err(( + StatusCode::INTERNAL_SERVER_ERROR, + "nsupdate failed, check server logs", + ) + .into()) } Err(error) => Err(( StatusCode::INTERNAL_SERVER_ERROR, @@ -292,148 +382,11 @@ fn update_ns_records(ip: IpAddr, ttl: Duration, records: &[&str]) -> String { cmds } -fn hash_identity(username: &str, password: &str, salt: &str) -> Digest { - let mut data = Vec::with_capacity(username.len() + password.len() + salt.len() + 1); - write!(data, "{username}:{password}{salt}").unwrap(); - ring::digest::digest(&ring::digest::SHA256, &data) -} - -fn mkpasswd( - Mkpasswd { username, password }: Mkpasswd, - password_file: Option<&Path>, - salt: &str, -) -> miette::Result<()> { - let hash = hash_identity(&username, &password, salt); - let encoded = URL_SAFE_NO_PAD.encode(hash.as_ref()); - let Some(path) = password_file else { - println!("{encoded}"); - return Ok(()); - }; - let err = || format!("trying to save password hash to {}", path.display()); - std::fs::File::options() - .mode(0o600) - .create_new(true) - .open(path) - .into_diagnostic() - .wrap_err_with(err)? - .write_all(encoded.as_bytes()) - .into_diagnostic() - .wrap_err_with(err)?; - - Ok(()) -} - -fn verify_records(data: &str, path: &Path) -> miette::Result<()> { - let source = || NamedSource::new(path.display().to_string(), data.to_string()); - let mut byte_offset = 0usize; - for line in data.lines() { - if line.is_empty() { - continue; - } - ensure!( - line.len() <= 255, - miette!( - labels = [LabeledSpan::new( - Some("this line".to_string()), - byte_offset, - line.len(), - )], - help = "fully qualified domain names can be at most 255 characters long", - url = "https://en.wikipedia.org/wiki/Fully_qualified_domain_name", - "hostname too long ({} octets)", - line.len(), - ) - .with_source_code(source()) - ); - ensure!( - line.ends_with('.'), - miette!( - labels = [LabeledSpan::new( - Some("last character".to_string()), - byte_offset + line.len() - 1, - 1, - )], - help = "hostname should be a fully qualified domain name (end with a '.')", - url = "https://en.wikipedia.org/wiki/Fully_qualified_domain_name", - "not a fully qualified domain name" - ) - .with_source_code(source()) - ); - let mut local_offset = 0usize; - for label in line.strip_suffix('.').unwrap_or(line).split('.') { - ensure!( - !label.is_empty(), - miette!( - labels = [LabeledSpan::new( - Some("label".to_string()), - byte_offset + local_offset, - label.len(), - )], - help = "each label should have at least one character", - url = "https://en.wikipedia.org/wiki/Fully_qualified_domain_name", - "empty label", - ) - .with_source_code(source()) - ); - ensure!( - label.len() <= 63, - miette!( - labels = [LabeledSpan::new( - Some("label".to_string()), - byte_offset + local_offset, - label.len(), - )], - help = "labels should be at most 63 octets", - url = "https://en.wikipedia.org/wiki/Fully_qualified_domain_name", - "label too long ({} octets)", - label.len(), - ) - .with_source_code(source()) - ); - for (offset, octet) in label.bytes().enumerate() { - ensure!( - octet.is_ascii(), - miette!( - labels = [LabeledSpan::new( - Some("octet".to_string()), - byte_offset + local_offset + offset, - 1, - )], - help = "we only accept ascii characters", - url = "https://en.wikipedia.org/wiki/Hostname#Syntax", - "'{}' is not ascii", - octet.escape_ascii(), - ) - .with_source_code(source()) - ); - ensure!( - octet.is_ascii_alphanumeric() || octet == b'-' || octet == b'_', - miette!( - labels = [LabeledSpan::new( - Some("octet".to_string()), - byte_offset + local_offset + offset, - 1, - )], - help = "hostnames are only allowed to contain characters in [a-zA-Z0-9_-]", - url = "https://en.wikipedia.org/wiki/Hostname#Syntax", - "invalid octet: '{}'", - octet.escape_ascii(), - ) - .with_source_code(source()) - ); - } - local_offset += label.len() + 1; - } - byte_offset += line.len() + 1; - } - Ok(()) -} - #[cfg(test)] mod test { use insta::assert_snapshot; - use crate::{update_ns_records, verify_records, DEFAULT_TTL}; + use crate::{update_ns_records, DEFAULT_TTL}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; @@ -476,107 +429,4 @@ mod test { quit "###); } - - #[test] - fn valid_records() -> miette::Result<()> { - verify_records( - "\ - example.com.\n\ - example.org.\n\ - example.net.\n\ - subdomain.example.com.\n\ - ", - std::path::Path::new("test_records_valid"), - ) - } - - #[test] - fn hostname_too_long() { - let err = verify_records( - "\ - example.com.\n\ - example.org.\n\ - example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.net.\n\ - subdomain.example.com.\n\ - ", - std::path::Path::new("test_records_invalid"), - ) - .unwrap_err(); - assert_snapshot!(err, @"hostname too long (260 octets)"); - } - - #[test] - fn not_fqd() { - let err = verify_records( - "\ - example.com.\n\ - example.org.\n\ - example.net\n\ - subdomain.example.com.\n\ - ", - std::path::Path::new("test_records_invalid"), - ) - .unwrap_err(); - assert_snapshot!(err, @"not a fully qualified domain name"); - } - - #[test] - fn empty_label() { - let err = verify_records( - "\ - example.com.\n\ - name..example.org.\n\ - example.net.\n\ - subdomain.example.com.\n\ - ", - std::path::Path::new("test_records_invalid"), - ) - .unwrap_err(); - assert_snapshot!(err, @"empty label"); - } - - #[test] - fn label_too_long() { - let err = verify_records( - "\ - example.com.\n\ - name.an-entremely-long-label-that-should-not-exist-because-it-goes-against-the-spec.example.org.\n\ - example.net.\n\ - subdomain.example.com.\n\ - ", - std::path::Path::new("test_records_invalid"), - ) - .unwrap_err(); - assert_snapshot!(err, @"label too long (78 octets)"); - } - - #[test] - fn invalid_ascii() { - let err = verify_records( - "\ - example.com.\n\ - name.this-is-not-aßcii.example.org.\n\ - example.net.\n\ - subdomain.example.com.\n\ - ", - std::path::Path::new("test_records_invalid"), - ) - .unwrap_err(); - assert_snapshot!(err, @r###"'\xc3' is not ascii"###); - } - - #[test] - fn invalid_octet() { - let err = verify_records( - "\ - example.com.\n\ - name.this-character:-is-not-allowed.example.org.\n\ - example.net.\n\ - subdomain.example.com.\n\ - ", - std::path::Path::new("test_records_invalid"), - ) - .unwrap_err(); - assert_snapshot!(err, @"invalid octet: ':'"); - } } diff --git a/src/password.rs b/src/password.rs new file mode 100644 index 0000000..84574a2 --- /dev/null +++ b/src/password.rs @@ -0,0 +1,60 @@ +//! Make a password for use with webnsupdate +//! +//! You should call this command an give it's output to the app/script that will update the DNS +//! records +use std::io::Write; +use std::os::unix::fs::OpenOptionsExt; +use std::path::Path; + +use base64::prelude::*; +use miette::{Context, IntoDiagnostic, Result}; +use ring::digest::Digest; + +/// Create a password file +/// +/// If `--password-file` is provided, the password is written to that file +#[derive(Debug, clap::Args)] +pub struct Mkpasswd { + /// The username + username: String, + + /// The password + password: String, +} + +impl Mkpasswd { + pub fn process(self, args: &crate::Opts) -> Result<()> { + mkpasswd(self, args.password_file.as_deref(), &args.salt) + } +} + +pub fn hash_identity(username: &str, password: &str, salt: &str) -> Digest { + let mut data = Vec::with_capacity(username.len() + password.len() + salt.len() + 1); + write!(data, "{username}:{password}{salt}").unwrap(); + ring::digest::digest(&ring::digest::SHA256, &data) +} + +pub fn mkpasswd( + Mkpasswd { username, password }: Mkpasswd, + password_file: Option<&Path>, + salt: &str, +) -> miette::Result<()> { + let hash = hash_identity(&username, &password, salt); + let encoded = BASE64_URL_SAFE_NO_PAD.encode(hash.as_ref()); + let Some(path) = password_file else { + println!("{encoded}"); + return Ok(()); + }; + let err = || format!("trying to save password hash to {}", path.display()); + std::fs::File::options() + .mode(0o600) + .create_new(true) + .open(path) + .into_diagnostic() + .wrap_err_with(err)? + .write_all(encoded.as_bytes()) + .into_diagnostic() + .wrap_err_with(err)?; + + Ok(()) +} diff --git a/src/records.rs b/src/records.rs new file mode 100644 index 0000000..5ca6528 --- /dev/null +++ b/src/records.rs @@ -0,0 +1,283 @@ +//! Deal with the DNS records + +use std::path::Path; + +use miette::{ensure, miette, Context, IntoDiagnostic, LabeledSpan, NamedSource, Result}; + +/// Loads and verifies the records from a file +pub fn load(path: &Path) -> Result<()> { + let records = std::fs::read_to_string(path) + .into_diagnostic() + .wrap_err_with(|| format!("failed to read records from {}", path.display()))?; + + verify(&records, path)?; + + Ok(()) +} + +/// Load records without verifying them +pub fn load_no_verify(path: &Path) -> Result<&'static [&'static str]> { + let records = std::fs::read_to_string(path) + .into_diagnostic() + .wrap_err_with(|| format!("failed to read records from {}", path.display()))?; + + if let Err(err) = verify(&records, path) { + tracing::error!("Failed to verify records: {err}"); + } + + // leak memory: we only do this here and it prevents a bunch of allocations + let records: &str = records.leak(); + let records: Box<[&str]> = records.lines().collect(); + + Ok(Box::leak(records)) +} + +/// Verifies that a list of records is valid +pub fn verify(data: &str, path: &Path) -> Result<()> { + let mut offset = 0usize; + for line in data.lines() { + validate_line(offset, line).map_err(|err| { + err.with_source_code(NamedSource::new( + path.display().to_string(), + data.to_string(), + )) + })?; + + offset += line.len() + 1; + } + + Ok(()) +} + +fn validate_line(offset: usize, line: &str) -> Result<()> { + if line.is_empty() { + return Ok(()); + } + + ensure!( + line.len() <= 255, + miette!( + labels = [LabeledSpan::new( + Some("this line".to_string()), + offset, + line.len(), + )], + help = "fully qualified domain names can be at most 255 characters long", + url = "https://en.wikipedia.org/wiki/Fully_qualified_domain_name", + "hostname too long ({} octets)", + line.len(), + ) + ); + ensure!( + line.ends_with('.'), + miette!( + labels = [LabeledSpan::new( + Some("last character".to_string()), + offset + line.len() - 1, + 1, + )], + help = "hostname should be a fully qualified domain name (end with a '.')", + url = "https://en.wikipedia.org/wiki/Fully_qualified_domain_name", + "not a fully qualified domain name" + ) + ); + + let mut label_offset = 0usize; + for label in line.strip_suffix('.').unwrap_or(line).split('.') { + validate_label(offset + label_offset, label)?; + label_offset += label.len() + 1; + } + + Ok(()) +} + +fn validate_label(offset: usize, label: &str) -> Result<()> { + ensure!( + !label.is_empty(), + miette!( + labels = [LabeledSpan::new( + Some("label".to_string()), + offset, + label.len(), + )], + help = "each label should have at least one character", + url = "https://en.wikipedia.org/wiki/Fully_qualified_domain_name", + "empty label", + ) + ); + ensure!( + label.len() <= 63, + miette!( + labels = [LabeledSpan::new( + Some("label".to_string()), + offset, + label.len(), + )], + help = "labels should be at most 63 octets", + url = "https://en.wikipedia.org/wiki/Fully_qualified_domain_name", + "label too long ({} octets)", + label.len(), + ) + ); + + for (octet_offset, octet) in label.bytes().enumerate() { + validate_octet(offset + octet_offset, octet)?; + } + + Ok(()) +} + +fn validate_octet(offset: usize, octet: u8) -> Result<()> { + let spans = || [LabeledSpan::new(Some("octet".to_string()), offset, 1)]; + ensure!( + octet.is_ascii(), + miette!( + labels = spans(), + help = "we only accept ascii characters", + url = "https://en.wikipedia.org/wiki/Hostname#Syntax", + "invalid octet: '{}'", + octet.escape_ascii(), + ) + ); + + ensure!( + octet.is_ascii_alphanumeric() || octet == b'-' || octet == b'_', + miette!( + labels = spans(), + help = "hostnames are only allowed to contain characters in [a-zA-Z0-9_-]", + url = "https://en.wikipedia.org/wiki/Hostname#Syntax", + "invalid octet: '{}'", + octet.escape_ascii(), + ) + ); + + Ok(()) +} + +#[cfg(test)] +mod test { + use crate::records::verify; + + macro_rules! assert_miette_snapshot { + ($diag:expr) => {{ + use std::borrow::Borrow; + + use insta::{with_settings, assert_snapshot}; + use miette::{GraphicalReportHandler, GraphicalTheme}; + + let mut out = String::new(); + GraphicalReportHandler::new_themed(GraphicalTheme::unicode_nocolor()) + .with_width(80) + .render_report(&mut out, $diag.borrow()) + .unwrap(); + with_settings!({ + description => stringify!($diag) + }, { + assert_snapshot!(out); + }); + }}; + } + + #[test] + fn valid_records() -> miette::Result<()> { + verify( + "\ + example.com.\n\ + example.org.\n\ + example.net.\n\ + subdomain.example.com.\n\ + ", + std::path::Path::new("test_records_valid"), + ) + } + + #[test] + fn hostname_too_long() { + let err = verify( + "\ + example.com.\n\ + example.org.\n\ + example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.net.\n\ + subdomain.example.com.\n\ + ", + std::path::Path::new("test_records_invalid"), + ) + .unwrap_err(); + assert_miette_snapshot!(err); + } + + #[test] + fn not_fqd() { + let err = verify( + "\ + example.com.\n\ + example.org.\n\ + example.net\n\ + subdomain.example.com.\n\ + ", + std::path::Path::new("test_records_invalid"), + ) + .unwrap_err(); + assert_miette_snapshot!(err); + } + + #[test] + fn empty_label() { + let err = verify( + "\ + example.com.\n\ + name..example.org.\n\ + example.net.\n\ + subdomain.example.com.\n\ + ", + std::path::Path::new("test_records_invalid"), + ) + .unwrap_err(); + assert_miette_snapshot!(err); + } + + #[test] + fn label_too_long() { + let err = verify( + "\ + example.com.\n\ + name.an-entremely-long-label-that-should-not-exist-because-it-goes-against-the-spec.example.org.\n\ + example.net.\n\ + subdomain.example.com.\n\ + ", + std::path::Path::new("test_records_invalid"), + ) + .unwrap_err(); + assert_miette_snapshot!(err); + } + + #[test] + fn invalid_ascii() { + let err = verify( + "\ + example.com.\n\ + name.this-is-not-aßcii.example.org.\n\ + example.net.\n\ + subdomain.example.com.\n\ + ", + std::path::Path::new("test_records_invalid"), + ) + .unwrap_err(); + assert_miette_snapshot!(err); + } + + #[test] + fn invalid_octet() { + let err = verify( + "\ + example.com.\n\ + name.this-character:-is-not-allowed.example.org.\n\ + example.net.\n\ + subdomain.example.com.\n\ + ", + std::path::Path::new("test_records_invalid"), + ) + .unwrap_err(); + assert_miette_snapshot!(err); + } +} diff --git a/src/snapshots/webnsupdate__records__test__empty_label.snap b/src/snapshots/webnsupdate__records__test__empty_label.snap new file mode 100644 index 0000000..e4d227b --- /dev/null +++ b/src/snapshots/webnsupdate__records__test__empty_label.snap @@ -0,0 +1,16 @@ +--- +source: src/records.rs +description: err +expression: out +--- +]8;;https://en.wikipedia.org/wiki/Fully_qualified_domain_name\(link)]8;;\ + + × empty label + ╭─[test_records_invalid:2:6] + 1 │ example.com. + 2 │ name..example.org. + · ▲ + · ╰── label + 3 │ example.net. + ╰──── + help: each label should have at least one character diff --git a/src/snapshots/webnsupdate__records__test__hostname_too_long.snap b/src/snapshots/webnsupdate__records__test__hostname_too_long.snap new file mode 100644 index 0000000..051d8ce --- /dev/null +++ b/src/snapshots/webnsupdate__records__test__hostname_too_long.snap @@ -0,0 +1,16 @@ +--- +source: src/records.rs +description: err +expression: out +--- +]8;;https://en.wikipedia.org/wiki/Fully_qualified_domain_name\(link)]8;;\ + + × hostname too long (260 octets) + ╭─[test_records_invalid:3:1] + 2 │ example.org. + 3 │ example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.net. + · ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┬───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── + · ╰── this line + 4 │ subdomain.example.com. + ╰──── + help: fully qualified domain names can be at most 255 characters long diff --git a/src/snapshots/webnsupdate__records__test__invalid_ascii.snap b/src/snapshots/webnsupdate__records__test__invalid_ascii.snap new file mode 100644 index 0000000..a777b6d --- /dev/null +++ b/src/snapshots/webnsupdate__records__test__invalid_ascii.snap @@ -0,0 +1,16 @@ +--- +source: src/records.rs +description: err +expression: out +--- +]8;;https://en.wikipedia.org/wiki/Hostname#Syntax\(link)]8;;\ + + × invalid octet: '\xc3' + ╭─[test_records_invalid:2:19] + 1 │ example.com. + 2 │ name.this-is-not-aßcii.example.org. + · ┬ + · ╰── octet + 3 │ example.net. + ╰──── + help: we only accept ascii characters diff --git a/src/snapshots/webnsupdate__records__test__invalid_octet.snap b/src/snapshots/webnsupdate__records__test__invalid_octet.snap new file mode 100644 index 0000000..2da284e --- /dev/null +++ b/src/snapshots/webnsupdate__records__test__invalid_octet.snap @@ -0,0 +1,16 @@ +--- +source: src/records.rs +description: err +expression: out +--- +]8;;https://en.wikipedia.org/wiki/Hostname#Syntax\(link)]8;;\ + + × invalid octet: ':' + ╭─[test_records_invalid:2:20] + 1 │ example.com. + 2 │ name.this-character:-is-not-allowed.example.org. + · ┬ + · ╰── octet + 3 │ example.net. + ╰──── + help: hostnames are only allowed to contain characters in [a-zA-Z0-9_-] diff --git a/src/snapshots/webnsupdate__records__test__label_too_long.snap b/src/snapshots/webnsupdate__records__test__label_too_long.snap new file mode 100644 index 0000000..b529a1f --- /dev/null +++ b/src/snapshots/webnsupdate__records__test__label_too_long.snap @@ -0,0 +1,16 @@ +--- +source: src/records.rs +description: err +expression: out +--- +]8;;https://en.wikipedia.org/wiki/Fully_qualified_domain_name\(link)]8;;\ + + × label too long (78 octets) + ╭─[test_records_invalid:2:6] + 1 │ example.com. + 2 │ name.an-entremely-long-label-that-should-not-exist-because-it-goes-against-the-spec.example.org. + · ───────────────────────────────────────┬────────────────────────────────────── + · ╰── label + 3 │ example.net. + ╰──── + help: labels should be at most 63 octets diff --git a/src/snapshots/webnsupdate__records__test__not_fqd.snap b/src/snapshots/webnsupdate__records__test__not_fqd.snap new file mode 100644 index 0000000..bc8270d --- /dev/null +++ b/src/snapshots/webnsupdate__records__test__not_fqd.snap @@ -0,0 +1,16 @@ +--- +source: src/records.rs +description: err +expression: out +--- +]8;;https://en.wikipedia.org/wiki/Fully_qualified_domain_name\(link)]8;;\ + + × not a fully qualified domain name + ╭─[test_records_invalid:3:11] + 2 │ example.org. + 3 │ example.net + · ┬ + · ╰── last character + 4 │ subdomain.example.com. + ╰──── + help: hostname should be a fully qualified domain name (end with a '.')