diff --git a/Cargo.toml b/Cargo.toml index f435dbf..dd4bde5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,10 +27,10 @@ clap-verbosity-flag = { version = "3", default-features = false, features = [ http = "1" miette = { version = "7", features = ["fancy"] } ring = { version = "0.17", features = ["std"] } -serde = { version = "1.0.217", features = ["derive"] } -serde_json = "1.0.137" +serde = { version = "1", features = ["derive"] } +serde_json = "1" tokio = { version = "1", features = ["macros", "rt", "process", "io-util"] } -tower-http = { version = "0.6.2", features = ["validate-request"] } +tower-http = { version = "0.6", features = ["validate-request"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/src/main.rs b/src/main.rs index bea4268..7718175 100644 --- a/src/main.rs +++ b/src/main.rs @@ -366,25 +366,25 @@ fn main() -> Result<()> { rt.block_on(async { // Update DNS record with previous IPs (if available) let ips = state.last_ips.lock().await.clone(); - for ip in ips.ips() { - if !ip_type.valid_for_type(ip) { - continue; - } - match nsupdate::nsupdate(ip, state.ttl, state.key_file, state.records).await { - Ok(status) => { - if !status.success() { - error!("nsupdate failed: code {status}"); - bail!("nsupdate returned with code {status}"); - } - } - Err(err) => { - error!("Failed to update records with previous IP: {err}"); - return Err(err) - .into_diagnostic() - .wrap_err("failed to update records with previous IP"); + let actions = ips + .ips() + .filter(|ip| ip_type.valid_for_type(*ip)) + .flat_map(|ip| nsupdate::Action::from_records(ip, state.ttl, state.records)); + + match nsupdate::nsupdate(state.key_file, actions).await { + Ok(status) => { + if !status.success() { + error!("nsupdate failed: code {status}"); + bail!("nsupdate returned with code {status}"); } } + Err(err) => { + error!("Failed to update records with previous IP: {err}"); + return Err(err) + .into_diagnostic() + .wrap_err("failed to update records with previous IP"); + } } // Create services @@ -541,7 +541,8 @@ async fn trigger_update( ip: IpAddr, state: &AppState<'static>, ) -> axum::response::Result<&'static str> { - match nsupdate::nsupdate(ip, state.ttl, state.key_file, state.records).await { + let actions = nsupdate::Action::from_records(ip, state.ttl, state.records); + match nsupdate::nsupdate(state.key_file, actions).await { Ok(status) if status.success() => { let ips = { // Update state diff --git a/src/nsupdate.rs b/src/nsupdate.rs index 62395b7..74397fa 100644 --- a/src/nsupdate.rs +++ b/src/nsupdate.rs @@ -9,12 +9,51 @@ use std::{ use tokio::io::AsyncWriteExt; use tracing::{debug, warn}; -#[tracing::instrument(level = "trace", ret(level = "warn"))] +pub enum Action<'a> { + // Reassign a domain to a different IP + Reassign { + domain: &'a str, + to: IpAddr, + ttl: Duration, + }, +} + +impl<'a> Action<'a> { + /// Create a set of [`Action`]s reassigning the domains in `records` to the specified + /// [`IpAddr`] + pub fn from_records( + to: IpAddr, + ttl: Duration, + records: &'a [&'a str], + ) -> impl IntoIterator + 'a { + records + .iter() + .map(move |&domain| Action::Reassign { domain, to, ttl }) + } +} + +impl std::fmt::Display for Action<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Action::Reassign { domain, to, ttl } => { + let ttl = ttl.as_secs(); + let typ = match to { + IpAddr::V4(_) => "A", + IpAddr::V6(_) => "AAAA", + }; + // Delete previous record of type `typ` + writeln!(f, "update delete {domain} {ttl} IN {typ}")?; + // Add record with new IP + writeln!(f, "update add {domain} {ttl} IN {typ} {to}") + } + } + } +} + +#[tracing::instrument(level = "trace", skip(actions), ret(level = "warn"))] pub async fn nsupdate( - ip: IpAddr, - ttl: Duration, key_file: Option<&Path>, - records: &[&str], + actions: impl IntoIterator>, ) -> std::io::Result { let mut cmd = tokio::process::Command::new("nsupdate"); if let Some(key_file) = key_file { @@ -27,10 +66,13 @@ pub async fn nsupdate( .inspect_err(|err| warn!("failed to spawn child: {err}"))?; let mut stdin = child.stdin.take().expect("stdin not present"); debug!("sending update request"); + let mut buf = Vec::new(); + update_ns_records(&mut buf, actions).unwrap(); stdin - .write_all(update_ns_records(ip, ttl, records).as_bytes()) + .write_all(&buf) .await .inspect_err(|err| warn!("failed to write to the stdin of nsupdate: {err}"))?; + debug!("closing stdin"); stdin .shutdown() @@ -43,21 +85,16 @@ pub async fn nsupdate( .inspect_err(|err| warn!("failed to wait for child: {err}")) } -fn update_ns_records(ip: IpAddr, ttl: Duration, records: &[&str]) -> String { - use std::fmt::Write; - let ttl_s: u64 = ttl.as_secs(); - - let rec_type = match ip { - IpAddr::V4(_) => "A", - IpAddr::V6(_) => "AAAA", - }; - let mut cmds = String::from("server 127.0.0.1\n"); - for &record in records { - writeln!(cmds, "update delete {record} {ttl_s} IN {rec_type}").unwrap(); - writeln!(cmds, "update add {record} {ttl_s} IN {rec_type} {ip}").unwrap(); +fn update_ns_records<'a>( + mut buf: impl std::io::Write, + actions: impl IntoIterator>, +) -> std::io::Result<()> { + writeln!(buf, "server 127.0.0.1")?; + for action in actions { + writeln!(buf, "{action}")?; } - writeln!(cmds, "send\nquit").unwrap(); - cmds + writeln!(buf, "send")?; + writeln!(buf, "quit") } #[cfg(test)] @@ -66,17 +103,21 @@ mod test { use insta::assert_snapshot; - use super::update_ns_records; + use super::{update_ns_records, Action}; use crate::DEFAULT_TTL; #[test] #[allow(non_snake_case)] fn expected_update_string_A() { - assert_snapshot!(update_ns_records( - IpAddr::V4(Ipv4Addr::LOCALHOST), - DEFAULT_TTL, - &["example.com.", "example.org.", "example.net."], - ), @r###" + let mut buf = Vec::new(); + let actions = Action::from_records( + IpAddr::V4(Ipv4Addr::LOCALHOST), + DEFAULT_TTL, + &["example.com.", "example.org.", "example.net."], + ); + update_ns_records(&mut buf, actions).unwrap(); + + assert_snapshot!(String::from_utf8(buf).unwrap(), @r###" server 127.0.0.1 update delete example.com. 60 IN A update add example.com. 60 IN A 127.0.0.1 @@ -92,11 +133,15 @@ mod test { #[test] #[allow(non_snake_case)] fn expected_update_string_AAAA() { - assert_snapshot!(update_ns_records( - IpAddr::V6(Ipv6Addr::LOCALHOST), - DEFAULT_TTL, - &["example.com.", "example.org.", "example.net."], - ), @r###" + let mut buf = Vec::new(); + let actions = Action::from_records( + IpAddr::V6(Ipv6Addr::LOCALHOST), + DEFAULT_TTL, + &["example.com.", "example.org.", "example.net."], + ); + update_ns_records(&mut buf, actions).unwrap(); + + assert_snapshot!(String::from_utf8(buf).unwrap(), @r###" server 127.0.0.1 update delete example.com. 60 IN AAAA update add example.com. 60 IN AAAA ::1