feat: refactor and add ip saving #8

Merged
jalil merged 1 commit from save-last-ip into main 2024-10-13 00:59:30 +02:00
12 changed files with 646 additions and 353 deletions

View file

@ -1 +0,0 @@
config.toml

View file

@ -1,4 +1,5 @@
cargo-features = ["codegen-backend"] cargo-features = ["codegen-backend"]
[package] [package]
description = "An HTTP server using HTTP basic auth to make secure calls to nsupdate" description = "An HTTP server using HTTP basic auth to make secure calls to nsupdate"
name = "webnsupdate" name = "webnsupdate"
@ -7,24 +8,26 @@ edition = "2021"
[dependencies] [dependencies]
axum = "0.7.7" axum = "0.7.7"
axum-auth = { version = "0.7.0", default-features = false, features = [
"auth-basic",
] }
axum-client-ip = "0.6.1" axum-client-ip = "0.6.1"
base64 = "0.22.1" base64 = "0.22.1"
clap = { version = "4.5.20", features = ["derive", "env"] } clap = { version = "4.5.20", features = ["derive", "env"] }
http = "1.1.0" http = "1.1.0"
insta = "1.40.0"
miette = { version = "7.2.0", features = ["fancy"] } miette = { version = "7.2.0", features = ["fancy"] }
ring = { version = "0.17.8", features = ["std"] } ring = { version = "0.17.8", features = ["std"] }
tokio = { version = "1.40.0", features = [
"macros",
"rt",
"process",
"io-util",
] }
tracing = "0.1.40" tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
[dependencies.axum-auth] [dev-dependencies]
version = "0.7.0" insta = "1.40.0"
default-features = false
features = ["auth-basic"]
[dependencies.tokio]
version = "1.40.0"
features = ["macros", "rt", "process", "io-util"]
[profile.dev] [profile.dev]
debug = 0 debug = 0

View file

@ -19,6 +19,8 @@ let
".rs" ".rs"
# TOML files are often used to configure cargo based tools (e.g. .cargo/config.toml) # TOML files are often used to configure cargo based tools (e.g. .cargo/config.toml)
".toml" ".toml"
# Snapshot tests
".snap"
]; ];
isCargoLock = base == "Cargo.lock"; isCargoLock = base == "Cargo.lock";
in in

View file

@ -1,8 +1,7 @@
use std::{ use std::{
ffi::OsStr, ffi::OsStr,
io::Write, io::ErrorKind,
net::{IpAddr, SocketAddr}, net::{IpAddr, SocketAddr},
os::unix::fs::OpenOptionsExt,
path::{Path, PathBuf}, path::{Path, PathBuf},
process::{ExitStatus, Stdio}, process::{ExitStatus, Stdio},
time::Duration, time::Duration,
@ -12,14 +11,16 @@ use axum::{extract::State, routing::get, Json, Router};
use axum_auth::AuthBasic; use axum_auth::AuthBasic;
use axum_client_ip::{SecureClientIp, SecureClientIpSource}; use axum_client_ip::{SecureClientIp, SecureClientIpSource};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use clap::{Args, Parser, Subcommand}; use clap::{Parser, Subcommand};
use http::StatusCode; use http::StatusCode;
use miette::{ensure, miette, Context, IntoDiagnostic, LabeledSpan, NamedSource, Result}; use miette::{bail, ensure, Context, IntoDiagnostic, Result};
use ring::digest::Digest;
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use tracing::{debug, error, info, level_filters::LevelFilter, trace, warn}; use tracing::{debug, error, info, level_filters::LevelFilter, trace, warn};
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
mod password;
mod records;
const DEFAULT_TTL: Duration = Duration::from_secs(60); const DEFAULT_TTL: Duration = Duration::from_secs(60);
const DEFAULT_SALT: &str = "UpdateMyDNS"; const DEFAULT_SALT: &str = "UpdateMyDNS";
@ -28,20 +29,29 @@ struct Opts {
/// Ip address of the server /// Ip address of the server
#[arg(long, default_value = "127.0.0.1")] #[arg(long, default_value = "127.0.0.1")]
address: IpAddr, address: IpAddr,
/// Port of the server /// Port of the server
#[arg(long, default_value_t = 5353)] #[arg(long, default_value_t = 5353)]
port: u16, port: u16,
/// File containing password to match against /// File containing password to match against
/// ///
/// Should be of the format `username:password` and contain a single password /// Should be of the format `username:password` and contain a single password
#[arg(long)] #[arg(long)]
password_file: Option<PathBuf>, password_file: Option<PathBuf>,
/// Salt to get more unique hashed passwords and prevent table based attacks /// Salt to get more unique hashed passwords and prevent table based attacks
#[arg(long, default_value = DEFAULT_SALT)] #[arg(long, default_value = DEFAULT_SALT)]
salt: String, salt: String,
/// Time To Live (in seconds) to set on the DNS records /// Time To Live (in seconds) to set on the DNS records
#[arg(long, default_value_t = DEFAULT_TTL.as_secs())] #[arg(long, default_value_t = DEFAULT_TTL.as_secs())]
ttl: u64, ttl: u64,
/// Data directory
#[arg(long, default_value = ".")]
data_dir: PathBuf,
/// File containing the records that should be updated when an update request is made /// File containing the records that should be updated when an update request is made
/// ///
/// There should be one record per line: /// There should be one record per line:
@ -52,70 +62,92 @@ struct Opts {
/// ``` /// ```
#[arg(long)] #[arg(long)]
records: PathBuf, records: PathBuf,
/// Keyfile `nsupdate` should use /// Keyfile `nsupdate` should use
/// ///
/// If specified, then `webnsupdate` must have read access to the file /// If specified, then `webnsupdate` must have read access to the file
#[arg(long)] #[arg(long)]
key_file: Option<PathBuf>, key_file: Option<PathBuf>,
/// Allow not setting a password /// Allow not setting a password
#[arg(long)] #[arg(long)]
insecure: bool, insecure: bool,
/// Set client IP source /// Set client IP source
/// ///
/// see: https://docs.rs/axum-client-ip/latest/axum_client_ip/enum.SecureClientIpSource.html /// see: https://docs.rs/axum-client-ip/latest/axum_client_ip/enum.SecureClientIpSource.html
#[clap(long, default_value = "RightmostXForwardedFor")] #[clap(long, default_value = "RightmostXForwardedFor")]
ip_source: SecureClientIpSource, ip_source: SecureClientIpSource,
#[clap(subcommand)] #[clap(subcommand)]
subcommand: Option<Cmd>, subcommand: Option<Cmd>,
} }
#[derive(Debug, Args)]
struct Mkpasswd {
/// The username
username: String,
/// The password
password: String,
}
#[derive(Debug, Subcommand)] #[derive(Debug, Subcommand)]
enum Cmd { enum Cmd {
/// Create a password file Mkpasswd(password::Mkpasswd),
///
/// If `--password-file` is provided, the password is written to that file
Mkpasswd(Mkpasswd),
/// Verify the records file /// Verify the records file
Verify, Verify,
} }
impl Cmd {
pub fn process(self, args: &Opts) -> Result<()> {
match self {
Cmd::Mkpasswd(mkpasswd) => mkpasswd.process(args),
Cmd::Verify => records::load(&args.records).map(drop),
}
}
}
#[derive(Clone)] #[derive(Clone)]
struct AppState<'a> { struct AppState<'a> {
/// TTL set on the Zonefile /// TTL set on the Zonefile
ttl: Duration, ttl: Duration,
/// Salt added to the password /// Salt added to the password
salt: &'a str, salt: &'a str,
/// The IN A/AAAA records that should have their IPs updated /// The IN A/AAAA records that should have their IPs updated
records: &'a [&'a str], records: &'a [&'a str],
/// The TSIG key file /// The TSIG key file
key_file: Option<&'a Path>, key_file: Option<&'a Path>,
/// The password hash /// The password hash
password_hash: Option<&'a [u8]>, password_hash: Option<&'a [u8]>,
/// The file where the last IP is stored
ip_file: &'a Path,
} }
#[tokio::main(flavor = "current_thread")] fn load_ip(path: &Path) -> Result<Option<IpAddr>> {
async fn main() -> Result<()> { let data = match std::fs::read_to_string(path) {
Ok(ip) => ip,
Err(err) => {
return match err.kind() {
ErrorKind::NotFound => Ok(None),
_ => Err(err).into_diagnostic().wrap_err_with(|| {
format!("failed to load last ip address from {}", path.display())
}),
}
}
};
Ok(Some(
data.parse()
.into_diagnostic()
.wrap_err("failed to parse last ip address")?,
))
}
fn main() -> Result<()> {
// set panic hook to pretty print with miette's formatter
miette::set_panic_hook(); miette::set_panic_hook();
let Opts {
address: ip, // parse cli arguments
port, let mut args = Opts::parse();
password_file,
key_file, // configure logger
insecure,
subcommand,
records,
salt,
ttl,
ip_source,
} = Opts::parse();
let subscriber = tracing_subscriber::FmtSubscriber::builder() let subscriber = tracing_subscriber::FmtSubscriber::builder()
.without_time() .without_time()
.with_env_filter( .with_env_filter(
@ -127,58 +159,109 @@ async fn main() -> Result<()> {
tracing::subscriber::set_global_default(subscriber) tracing::subscriber::set_global_default(subscriber)
.into_diagnostic() .into_diagnostic()
.wrap_err("setting global tracing subscriber")?; .wrap_err("setting global tracing subscriber")?;
match subcommand {
Some(Cmd::Mkpasswd(args)) => return mkpasswd(args, password_file.as_deref(), &salt), // process subcommand
Some(Cmd::Verify) => { if let Some(cmd) = args.subcommand.take() {
let data = std::fs::read_to_string(&records) return cmd.process(&args);
.into_diagnostic()
.wrap_err_with(|| format!("trying to read {}", records.display()))?;
return verify_records(&data, &records);
}
None => {}
} }
let Opts {
address: ip,
port,
password_file,
data_dir,
key_file,
insecure,
subcommand: _,
records,
salt,
ttl,
ip_source,
} = args;
info!("checking environment"); info!("checking environment");
// Set state // Set state
let ttl = Duration::from_secs(ttl); let ttl = Duration::from_secs(ttl);
let mut state = AppState {
// Use last registered IP address if available
let ip_file = data_dir.join("last-ip");
let state = AppState {
ttl, ttl,
salt: salt.leak(), salt: salt.leak(),
records: &[], // Load DNS records
key_file: None, records: records::load_no_verify(&records)?,
password_hash: None, // Load keyfile
}; key_file: key_file
if let Some(path) = password_file { .map(|key_file| -> miette::Result<_> {
let pass = std::fs::read_to_string(&path).into_diagnostic()?; let path = key_file.as_path();
std::fs::File::open(path)
.into_diagnostic()
.wrap_err_with(|| {
format!("{} is not readable by the current user", path.display())
})?;
Ok(&*Box::leak(key_file.into_boxed_path()))
})
.transpose()?,
// Load password hash
password_hash: password_file
.map(|path| -> miette::Result<_> {
let pass = std::fs::read_to_string(path.as_path()).into_diagnostic()?;
let pass: Box<[u8]> = URL_SAFE_NO_PAD let pass: Box<[u8]> = URL_SAFE_NO_PAD
.decode(pass.trim().as_bytes()) .decode(pass.trim().as_bytes())
.into_diagnostic() .into_diagnostic()
.wrap_err_with(|| format!("failed to decode password from {}", path.display()))? .wrap_err_with(|| format!("failed to decode password from {}", path.display()))?
.into(); .into();
state.password_hash = Some(Box::leak(pass));
} else { Ok(&*Box::leak(pass))
ensure!(insecure, "a password must be used"); })
} .transpose()?,
if let Some(key_file) = key_file { ip_file: Box::leak(ip_file.into_boxed_path()),
let path = key_file.as_path(); };
std::fs::File::open(path)
ensure!(
state.password_hash.is_some() || insecure,
"a password must be used"
);
ensure!(
state.key_file.is_some() || insecure,
"a key file must be used"
);
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.into_diagnostic() .into_diagnostic()
.wrap_err_with(|| format!("{} is not readable by the current user", path.display()))?; .wrap_err("failed to start the tokio runtime")?;
state.key_file = Some(Box::leak(key_file.into_boxed_path()));
} else { rt.block_on(async {
ensure!(insecure, "a key file must be used"); // Load previous IP and update DNS record to point to it (if available)
match load_ip(state.ip_file) {
Ok(Some(ip)) => match nsupdate(ip, ttl, state.key_file, state.records).await {
Ok(status) => {
if !status.success() {
error!("nsupdate failed: code {status}");
bail!("nsupdate returned with code {status}");
} }
let data = std::fs::read_to_string(&records) }
Err(err) => {
error!("Failed to update records with previous IP: {err}");
return Err(err)
.into_diagnostic() .into_diagnostic()
.wrap_err_with(|| format!("loading records from {}", records.display()))?; .wrap_err("failed to update records with previous IP");
if let Err(err) = verify_records(&data, &records) {
warn!("invalid records found: {err}");
} }
state.records = data },
.lines() Ok(None) => {
.map(|s| &*s.to_string().leak()) info!("No previous IP address set");
.collect::<Vec<&'static str>>() }
.leak(); Err(err) => {
error!("Failed to load last ip address: {err}")
}
};
// Start services // Start services
let app = Router::new() let app = Router::new()
.route("/update", get(update_records)) .route("/update", get(update_records))
@ -195,6 +278,7 @@ async fn main() -> Result<()> {
) )
.await .await
.into_diagnostic() .into_diagnostic()
})
} }
#[tracing::instrument(skip(state, pass), level = "trace", ret(level = "info"))] #[tracing::instrument(skip(state, pass), level = "trace", ret(level = "info"))]
@ -206,33 +290,39 @@ async fn update_records(
let Some(pass) = pass else { let Some(pass) = pass else {
return Err((StatusCode::UNAUTHORIZED, Json::from("no password provided")).into()); return Err((StatusCode::UNAUTHORIZED, Json::from("no password provided")).into());
}; };
if let Some(stored_pass) = state.password_hash { if let Some(stored_pass) = state.password_hash {
let password = pass.trim().to_string(); let password = pass.trim().to_string();
let pass_hash = hash_identity(&username, &password, state.salt); let pass_hash = password::hash_identity(&username, &password, state.salt);
if pass_hash.as_ref() != stored_pass { if pass_hash.as_ref() != stored_pass {
warn!("rejected update"); warn!("rejected update");
trace!( trace!(
"mismatched hashes:\n{}\n{}", "mismatched hashes:\n{}\n{}",
URL_SAFE_NO_PAD.encode(pass_hash.as_ref()), URL_SAFE_NO_PAD.encode(pass_hash.as_ref()),
URL_SAFE_NO_PAD.encode(stored_pass.as_ref()), URL_SAFE_NO_PAD.encode(stored_pass),
); );
return Err((StatusCode::UNAUTHORIZED, "invalid identity").into()); return Err((StatusCode::UNAUTHORIZED, "invalid identity").into());
} }
} }
info!("accepted update"); info!("accepted update");
match nsupdate(ip, state.ttl, state.key_file, state.records).await { match nsupdate(ip, state.ttl, state.key_file, state.records).await {
Ok(status) => { Ok(status) if status.success() => {
if status.success() { tokio::task::spawn_blocking(move || {
if let Err(err) = std::fs::write(state.ip_file, format!("{ip}")) {
error!("Failed to update last IP: {err}");
}
});
Ok("successful update") Ok("successful update")
} else { }
error!("nsupdate failed"); Ok(status) => {
error!("nsupdate failed with code {status}");
Err(( Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
"nsupdate failed, check server logs", "nsupdate failed, check server logs",
) )
.into()) .into())
} }
}
Err(error) => Err(( Err(error) => Err((
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
format!("failed to update records: {error}"), format!("failed to update records: {error}"),
@ -292,148 +382,11 @@ fn update_ns_records(ip: IpAddr, ttl: Duration, records: &[&str]) -> String {
cmds cmds
} }
fn hash_identity(username: &str, password: &str, salt: &str) -> Digest {
let mut data = Vec::with_capacity(username.len() + password.len() + salt.len() + 1);
write!(data, "{username}:{password}{salt}").unwrap();
ring::digest::digest(&ring::digest::SHA256, &data)
}
fn mkpasswd(
Mkpasswd { username, password }: Mkpasswd,
password_file: Option<&Path>,
salt: &str,
) -> miette::Result<()> {
let hash = hash_identity(&username, &password, salt);
let encoded = URL_SAFE_NO_PAD.encode(hash.as_ref());
let Some(path) = password_file else {
println!("{encoded}");
return Ok(());
};
let err = || format!("trying to save password hash to {}", path.display());
std::fs::File::options()
.mode(0o600)
.create_new(true)
.open(path)
.into_diagnostic()
.wrap_err_with(err)?
.write_all(encoded.as_bytes())
.into_diagnostic()
.wrap_err_with(err)?;
Ok(())
}
fn verify_records(data: &str, path: &Path) -> miette::Result<()> {
let source = || NamedSource::new(path.display().to_string(), data.to_string());
let mut byte_offset = 0usize;
for line in data.lines() {
if line.is_empty() {
continue;
}
ensure!(
line.len() <= 255,
miette!(
labels = [LabeledSpan::new(
Some("this line".to_string()),
byte_offset,
line.len(),
)],
help = "fully qualified domain names can be at most 255 characters long",
url = "https://en.wikipedia.org/wiki/Fully_qualified_domain_name",
"hostname too long ({} octets)",
line.len(),
)
.with_source_code(source())
);
ensure!(
line.ends_with('.'),
miette!(
labels = [LabeledSpan::new(
Some("last character".to_string()),
byte_offset + line.len() - 1,
1,
)],
help = "hostname should be a fully qualified domain name (end with a '.')",
url = "https://en.wikipedia.org/wiki/Fully_qualified_domain_name",
"not a fully qualified domain name"
)
.with_source_code(source())
);
let mut local_offset = 0usize;
for label in line.strip_suffix('.').unwrap_or(line).split('.') {
ensure!(
!label.is_empty(),
miette!(
labels = [LabeledSpan::new(
Some("label".to_string()),
byte_offset + local_offset,
label.len(),
)],
help = "each label should have at least one character",
url = "https://en.wikipedia.org/wiki/Fully_qualified_domain_name",
"empty label",
)
.with_source_code(source())
);
ensure!(
label.len() <= 63,
miette!(
labels = [LabeledSpan::new(
Some("label".to_string()),
byte_offset + local_offset,
label.len(),
)],
help = "labels should be at most 63 octets",
url = "https://en.wikipedia.org/wiki/Fully_qualified_domain_name",
"label too long ({} octets)",
label.len(),
)
.with_source_code(source())
);
for (offset, octet) in label.bytes().enumerate() {
ensure!(
octet.is_ascii(),
miette!(
labels = [LabeledSpan::new(
Some("octet".to_string()),
byte_offset + local_offset + offset,
1,
)],
help = "we only accept ascii characters",
url = "https://en.wikipedia.org/wiki/Hostname#Syntax",
"'{}' is not ascii",
octet.escape_ascii(),
)
.with_source_code(source())
);
ensure!(
octet.is_ascii_alphanumeric() || octet == b'-' || octet == b'_',
miette!(
labels = [LabeledSpan::new(
Some("octet".to_string()),
byte_offset + local_offset + offset,
1,
)],
help = "hostnames are only allowed to contain characters in [a-zA-Z0-9_-]",
url = "https://en.wikipedia.org/wiki/Hostname#Syntax",
"invalid octet: '{}'",
octet.escape_ascii(),
)
.with_source_code(source())
);
}
local_offset += label.len() + 1;
}
byte_offset += line.len() + 1;
}
Ok(())
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use insta::assert_snapshot; use insta::assert_snapshot;
use crate::{update_ns_records, verify_records, DEFAULT_TTL}; use crate::{update_ns_records, DEFAULT_TTL};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
@ -476,107 +429,4 @@ mod test {
quit quit
"###); "###);
} }
#[test]
fn valid_records() -> miette::Result<()> {
verify_records(
"\
example.com.\n\
example.org.\n\
example.net.\n\
subdomain.example.com.\n\
",
std::path::Path::new("test_records_valid"),
)
}
#[test]
fn hostname_too_long() {
let err = verify_records(
"\
example.com.\n\
example.org.\n\
example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.net.\n\
subdomain.example.com.\n\
",
std::path::Path::new("test_records_invalid"),
)
.unwrap_err();
assert_snapshot!(err, @"hostname too long (260 octets)");
}
#[test]
fn not_fqd() {
let err = verify_records(
"\
example.com.\n\
example.org.\n\
example.net\n\
subdomain.example.com.\n\
",
std::path::Path::new("test_records_invalid"),
)
.unwrap_err();
assert_snapshot!(err, @"not a fully qualified domain name");
}
#[test]
fn empty_label() {
let err = verify_records(
"\
example.com.\n\
name..example.org.\n\
example.net.\n\
subdomain.example.com.\n\
",
std::path::Path::new("test_records_invalid"),
)
.unwrap_err();
assert_snapshot!(err, @"empty label");
}
#[test]
fn label_too_long() {
let err = verify_records(
"\
example.com.\n\
name.an-entremely-long-label-that-should-not-exist-because-it-goes-against-the-spec.example.org.\n\
example.net.\n\
subdomain.example.com.\n\
",
std::path::Path::new("test_records_invalid"),
)
.unwrap_err();
assert_snapshot!(err, @"label too long (78 octets)");
}
#[test]
fn invalid_ascii() {
let err = verify_records(
"\
example.com.\n\
name.this-is-not-aßcii.example.org.\n\
example.net.\n\
subdomain.example.com.\n\
",
std::path::Path::new("test_records_invalid"),
)
.unwrap_err();
assert_snapshot!(err, @r###"'\xc3' is not ascii"###);
}
#[test]
fn invalid_octet() {
let err = verify_records(
"\
example.com.\n\
name.this-character:-is-not-allowed.example.org.\n\
example.net.\n\
subdomain.example.com.\n\
",
std::path::Path::new("test_records_invalid"),
)
.unwrap_err();
assert_snapshot!(err, @"invalid octet: ':'");
}
} }

60
src/password.rs Normal file
View file

@ -0,0 +1,60 @@
//! Make a password for use with webnsupdate
//!
//! You should call this command an give it's output to the app/script that will update the DNS
//! records
use std::io::Write;
use std::os::unix::fs::OpenOptionsExt;
use std::path::Path;
use base64::prelude::*;
use miette::{Context, IntoDiagnostic, Result};
use ring::digest::Digest;
/// Create a password file
///
/// If `--password-file` is provided, the password is written to that file
#[derive(Debug, clap::Args)]
pub struct Mkpasswd {
/// The username
username: String,
/// The password
password: String,
}
impl Mkpasswd {
pub fn process(self, args: &crate::Opts) -> Result<()> {
mkpasswd(self, args.password_file.as_deref(), &args.salt)
}
}
pub fn hash_identity(username: &str, password: &str, salt: &str) -> Digest {
let mut data = Vec::with_capacity(username.len() + password.len() + salt.len() + 1);
write!(data, "{username}:{password}{salt}").unwrap();
ring::digest::digest(&ring::digest::SHA256, &data)
}
pub fn mkpasswd(
Mkpasswd { username, password }: Mkpasswd,
password_file: Option<&Path>,
salt: &str,
) -> miette::Result<()> {
let hash = hash_identity(&username, &password, salt);
let encoded = BASE64_URL_SAFE_NO_PAD.encode(hash.as_ref());
let Some(path) = password_file else {
println!("{encoded}");
return Ok(());
};
let err = || format!("trying to save password hash to {}", path.display());
std::fs::File::options()
.mode(0o600)
.create_new(true)
.open(path)
.into_diagnostic()
.wrap_err_with(err)?
.write_all(encoded.as_bytes())
.into_diagnostic()
.wrap_err_with(err)?;
Ok(())
}

283
src/records.rs Normal file
View file

@ -0,0 +1,283 @@
//! Deal with the DNS records
use std::path::Path;
use miette::{ensure, miette, Context, IntoDiagnostic, LabeledSpan, NamedSource, Result};
/// Loads and verifies the records from a file
pub fn load(path: &Path) -> Result<()> {
let records = std::fs::read_to_string(path)
.into_diagnostic()
.wrap_err_with(|| format!("failed to read records from {}", path.display()))?;
verify(&records, path)?;
Ok(())
}
/// Load records without verifying them
pub fn load_no_verify(path: &Path) -> Result<&'static [&'static str]> {
let records = std::fs::read_to_string(path)
.into_diagnostic()
.wrap_err_with(|| format!("failed to read records from {}", path.display()))?;
if let Err(err) = verify(&records, path) {
tracing::error!("Failed to verify records: {err}");
}
// leak memory: we only do this here and it prevents a bunch of allocations
let records: &str = records.leak();
let records: Box<[&str]> = records.lines().collect();
Ok(Box::leak(records))
}
/// Verifies that a list of records is valid
pub fn verify(data: &str, path: &Path) -> Result<()> {
let mut offset = 0usize;
for line in data.lines() {
validate_line(offset, line).map_err(|err| {
err.with_source_code(NamedSource::new(
path.display().to_string(),
data.to_string(),
))
})?;
offset += line.len() + 1;
}
Ok(())
}
fn validate_line(offset: usize, line: &str) -> Result<()> {
if line.is_empty() {
return Ok(());
}
ensure!(
line.len() <= 255,
miette!(
labels = [LabeledSpan::new(
Some("this line".to_string()),
offset,
line.len(),
)],
help = "fully qualified domain names can be at most 255 characters long",
url = "https://en.wikipedia.org/wiki/Fully_qualified_domain_name",
"hostname too long ({} octets)",
line.len(),
)
);
ensure!(
line.ends_with('.'),
miette!(
labels = [LabeledSpan::new(
Some("last character".to_string()),
offset + line.len() - 1,
1,
)],
help = "hostname should be a fully qualified domain name (end with a '.')",
url = "https://en.wikipedia.org/wiki/Fully_qualified_domain_name",
"not a fully qualified domain name"
)
);
let mut label_offset = 0usize;
for label in line.strip_suffix('.').unwrap_or(line).split('.') {
validate_label(offset + label_offset, label)?;
label_offset += label.len() + 1;
}
Ok(())
}
fn validate_label(offset: usize, label: &str) -> Result<()> {
ensure!(
!label.is_empty(),
miette!(
labels = [LabeledSpan::new(
Some("label".to_string()),
offset,
label.len(),
)],
help = "each label should have at least one character",
url = "https://en.wikipedia.org/wiki/Fully_qualified_domain_name",
"empty label",
)
);
ensure!(
label.len() <= 63,
miette!(
labels = [LabeledSpan::new(
Some("label".to_string()),
offset,
label.len(),
)],
help = "labels should be at most 63 octets",
url = "https://en.wikipedia.org/wiki/Fully_qualified_domain_name",
"label too long ({} octets)",
label.len(),
)
);
for (octet_offset, octet) in label.bytes().enumerate() {
validate_octet(offset + octet_offset, octet)?;
}
Ok(())
}
fn validate_octet(offset: usize, octet: u8) -> Result<()> {
let spans = || [LabeledSpan::new(Some("octet".to_string()), offset, 1)];
ensure!(
octet.is_ascii(),
miette!(
labels = spans(),
help = "we only accept ascii characters",
url = "https://en.wikipedia.org/wiki/Hostname#Syntax",
"invalid octet: '{}'",
octet.escape_ascii(),
)
);
ensure!(
octet.is_ascii_alphanumeric() || octet == b'-' || octet == b'_',
miette!(
labels = spans(),
help = "hostnames are only allowed to contain characters in [a-zA-Z0-9_-]",
url = "https://en.wikipedia.org/wiki/Hostname#Syntax",
"invalid octet: '{}'",
octet.escape_ascii(),
)
);
Ok(())
}
#[cfg(test)]
mod test {
use crate::records::verify;
macro_rules! assert_miette_snapshot {
($diag:expr) => {{
use std::borrow::Borrow;
use insta::{with_settings, assert_snapshot};
use miette::{GraphicalReportHandler, GraphicalTheme};
let mut out = String::new();
GraphicalReportHandler::new_themed(GraphicalTheme::unicode_nocolor())
.with_width(80)
.render_report(&mut out, $diag.borrow())
.unwrap();
with_settings!({
description => stringify!($diag)
}, {
assert_snapshot!(out);
});
}};
}
#[test]
fn valid_records() -> miette::Result<()> {
verify(
"\
example.com.\n\
example.org.\n\
example.net.\n\
subdomain.example.com.\n\
",
std::path::Path::new("test_records_valid"),
)
}
#[test]
fn hostname_too_long() {
let err = verify(
"\
example.com.\n\
example.org.\n\
example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.net.\n\
subdomain.example.com.\n\
",
std::path::Path::new("test_records_invalid"),
)
.unwrap_err();
assert_miette_snapshot!(err);
}
#[test]
fn not_fqd() {
let err = verify(
"\
example.com.\n\
example.org.\n\
example.net\n\
subdomain.example.com.\n\
",
std::path::Path::new("test_records_invalid"),
)
.unwrap_err();
assert_miette_snapshot!(err);
}
#[test]
fn empty_label() {
let err = verify(
"\
example.com.\n\
name..example.org.\n\
example.net.\n\
subdomain.example.com.\n\
",
std::path::Path::new("test_records_invalid"),
)
.unwrap_err();
assert_miette_snapshot!(err);
}
#[test]
fn label_too_long() {
let err = verify(
"\
example.com.\n\
name.an-entremely-long-label-that-should-not-exist-because-it-goes-against-the-spec.example.org.\n\
example.net.\n\
subdomain.example.com.\n\
",
std::path::Path::new("test_records_invalid"),
)
.unwrap_err();
assert_miette_snapshot!(err);
}
#[test]
fn invalid_ascii() {
let err = verify(
"\
example.com.\n\
name.this-is-not-aßcii.example.org.\n\
example.net.\n\
subdomain.example.com.\n\
",
std::path::Path::new("test_records_invalid"),
)
.unwrap_err();
assert_miette_snapshot!(err);
}
#[test]
fn invalid_octet() {
let err = verify(
"\
example.com.\n\
name.this-character:-is-not-allowed.example.org.\n\
example.net.\n\
subdomain.example.com.\n\
",
std::path::Path::new("test_records_invalid"),
)
.unwrap_err();
assert_miette_snapshot!(err);
}
}

View file

@ -0,0 +1,16 @@
---
source: src/records.rs
description: err
expression: out
---
]8;;https://en.wikipedia.org/wiki/Fully_qualified_domain_name\(link)]8;;\
× empty label
╭─[test_records_invalid:2:6]
1 │ example.com.
2 │ name..example.org.
· ▲
· ╰── label
3 │ example.net.
╰────
help: each label should have at least one character

View file

@ -0,0 +1,16 @@
---
source: src/records.rs
description: err
expression: out
---
]8;;https://en.wikipedia.org/wiki/Fully_qualified_domain_name\(link)]8;;\
× hostname too long (260 octets)
╭─[test_records_invalid:3:1]
2 │ example.org.
3 │ example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.net.
· ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┬─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
· ╰── this line
4 │ subdomain.example.com.
╰────
help: fully qualified domain names can be at most 255 characters long

View file

@ -0,0 +1,16 @@
---
source: src/records.rs
description: err
expression: out
---
]8;;https://en.wikipedia.org/wiki/Hostname#Syntax\(link)]8;;\
× invalid octet: '\xc3'
╭─[test_records_invalid:2:19]
1 │ example.com.
2 │ name.this-is-not-aßcii.example.org.
· ┬
· ╰── octet
3 │ example.net.
╰────
help: we only accept ascii characters

View file

@ -0,0 +1,16 @@
---
source: src/records.rs
description: err
expression: out
---
]8;;https://en.wikipedia.org/wiki/Hostname#Syntax\(link)]8;;\
× invalid octet: ':'
╭─[test_records_invalid:2:20]
1 │ example.com.
2 │ name.this-character:-is-not-allowed.example.org.
· ┬
· ╰── octet
3 │ example.net.
╰────
help: hostnames are only allowed to contain characters in [a-zA-Z0-9_-]

View file

@ -0,0 +1,16 @@
---
source: src/records.rs
description: err
expression: out
---
]8;;https://en.wikipedia.org/wiki/Fully_qualified_domain_name\(link)]8;;\
× label too long (78 octets)
╭─[test_records_invalid:2:6]
1 │ example.com.
2 │ name.an-entremely-long-label-that-should-not-exist-because-it-goes-against-the-spec.example.org.
· ───────────────────────────────────────┬──────────────────────────────────────
· ╰── label
3 │ example.net.
╰────
help: labels should be at most 63 octets

View file

@ -0,0 +1,16 @@
---
source: src/records.rs
description: err
expression: out
---
]8;;https://en.wikipedia.org/wiki/Fully_qualified_domain_name\(link)]8;;\
× not a fully qualified domain name
╭─[test_records_invalid:3:11]
2 │ example.org.
3 │ example.net
· ┬
· ╰── last character
4 │ subdomain.example.com.
╰────
help: hostname should be a fully qualified domain name (end with a '.')