Compare commits

...

2 commits

Author SHA1 Message Date
8c8ca6cd60
chore(deps): lock file maintenance
All checks were successful
/ build (push) Successful in 3s
/ check (clippy) (push) Successful in 2s
/ check (module-ipv4-only-test) (push) Successful in 7s
/ check (module-ipv4-test) (push) Successful in 7s
/ check (module-ipv6-only-test) (push) Successful in 7s
/ check (module-ipv6-test) (push) Successful in 7s
/ check (module-nginx-test) (push) Successful in 7s
/ check (nextest) (push) Successful in 2s
/ check (treefmt) (push) Successful in 2s
/ report-size (push) Successful in 7s
2025-02-05 23:20:42 +01:00
b775f8e811
refactor(nsupdate): send all commands at once
All checks were successful
/ build (push) Successful in 1s
/ check (clippy) (push) Successful in 2s
/ check (module-ipv4-only-test) (push) Successful in 6s
/ check (module-ipv4-test) (push) Successful in 6s
/ check (module-ipv6-only-test) (push) Successful in 6s
/ check (module-ipv6-test) (push) Successful in 6s
/ check (module-nginx-test) (push) Successful in 6s
/ check (nextest) (push) Successful in 3s
/ check (treefmt) (push) Successful in 2s
/ report-size (push) Successful in 2s
This ensures `nsupdate` is only called once per IP update (even for both
IPv4 and IPv6 in a single call).
2025-02-05 22:47:13 +01:00
4 changed files with 99 additions and 53 deletions

View file

@ -27,10 +27,10 @@ clap-verbosity-flag = { version = "3", default-features = false, features = [
http = "1" http = "1"
miette = { version = "7", features = ["fancy"] } miette = { version = "7", features = ["fancy"] }
ring = { version = "0.17", features = ["std"] } ring = { version = "0.17", features = ["std"] }
serde = { version = "1.0.217", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = "1.0.137" serde_json = "1"
tokio = { version = "1", features = ["macros", "rt", "process", "io-util"] } tokio = { version = "1", features = ["macros", "rt", "process", "io-util"] }
tower-http = { version = "0.6.2", features = ["validate-request"] } tower-http = { version = "0.6", features = ["validate-request"] }
tracing = "0.1" tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] }

6
flake.lock generated
View file

@ -37,11 +37,11 @@
}, },
"nixpkgs": { "nixpkgs": {
"locked": { "locked": {
"lastModified": 1738546358, "lastModified": 1738680400,
"narHash": "sha256-nLivjIygCiqLp5QcL7l56Tca/elVqM9FG1hGd9ZSsrg=", "narHash": "sha256-ooLh+XW8jfa+91F1nhf9OF7qhuA/y1ChLx6lXDNeY5U=",
"owner": "NixOS", "owner": "NixOS",
"repo": "nixpkgs", "repo": "nixpkgs",
"rev": "c6e957d81b96751a3d5967a0fd73694f303cc914", "rev": "799ba5bffed04ced7067a91798353d360788b30d",
"type": "github" "type": "github"
}, },
"original": { "original": {

View file

@ -366,12 +366,13 @@ fn main() -> Result<()> {
rt.block_on(async { rt.block_on(async {
// Update DNS record with previous IPs (if available) // Update DNS record with previous IPs (if available)
let ips = state.last_ips.lock().await.clone(); let ips = state.last_ips.lock().await.clone();
for ip in ips.ips() {
if !ip_type.valid_for_type(ip) {
continue;
}
match nsupdate::nsupdate(ip, state.ttl, state.key_file, state.records).await { let actions = ips
.ips()
.filter(|ip| ip_type.valid_for_type(*ip))
.flat_map(|ip| nsupdate::Action::from_records(ip, state.ttl, state.records));
match nsupdate::nsupdate(state.key_file, actions).await {
Ok(status) => { Ok(status) => {
if !status.success() { if !status.success() {
error!("nsupdate failed: code {status}"); error!("nsupdate failed: code {status}");
@ -385,7 +386,6 @@ fn main() -> Result<()> {
.wrap_err("failed to update records with previous IP"); .wrap_err("failed to update records with previous IP");
} }
} }
}
// Create services // Create services
let app = Router::new().route("/update", get(update_records)); let app = Router::new().route("/update", get(update_records));
@ -541,7 +541,8 @@ async fn trigger_update(
ip: IpAddr, ip: IpAddr,
state: &AppState<'static>, state: &AppState<'static>,
) -> axum::response::Result<&'static str> { ) -> axum::response::Result<&'static str> {
match nsupdate::nsupdate(ip, state.ttl, state.key_file, state.records).await { let actions = nsupdate::Action::from_records(ip, state.ttl, state.records);
match nsupdate::nsupdate(state.key_file, actions).await {
Ok(status) if status.success() => { Ok(status) if status.success() => {
let ips = { let ips = {
// Update state // Update state

View file

@ -9,12 +9,51 @@ use std::{
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use tracing::{debug, warn}; use tracing::{debug, warn};
#[tracing::instrument(level = "trace", ret(level = "warn"))] pub enum Action<'a> {
pub async fn nsupdate( // Reassign a domain to a different IP
ip: IpAddr, Reassign {
domain: &'a str,
to: IpAddr,
ttl: Duration, ttl: Duration,
},
}
impl<'a> Action<'a> {
/// Create a set of [`Action`]s reassigning the domains in `records` to the specified
/// [`IpAddr`]
pub fn from_records(
to: IpAddr,
ttl: Duration,
records: &'a [&'a str],
) -> impl IntoIterator<Item = Self> + 'a {
records
.iter()
.map(move |&domain| Action::Reassign { domain, to, ttl })
}
}
impl std::fmt::Display for Action<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Action::Reassign { domain, to, ttl } => {
let ttl = ttl.as_secs();
let typ = match to {
IpAddr::V4(_) => "A",
IpAddr::V6(_) => "AAAA",
};
// Delete previous record of type `typ`
writeln!(f, "update delete {domain} {ttl} IN {typ}")?;
// Add record with new IP
writeln!(f, "update add {domain} {ttl} IN {typ} {to}")
}
}
}
}
#[tracing::instrument(level = "trace", skip(actions), ret(level = "warn"))]
pub async fn nsupdate(
key_file: Option<&Path>, key_file: Option<&Path>,
records: &[&str], actions: impl IntoIterator<Item = Action<'_>>,
) -> std::io::Result<ExitStatus> { ) -> std::io::Result<ExitStatus> {
let mut cmd = tokio::process::Command::new("nsupdate"); let mut cmd = tokio::process::Command::new("nsupdate");
if let Some(key_file) = key_file { if let Some(key_file) = key_file {
@ -27,10 +66,13 @@ pub async fn nsupdate(
.inspect_err(|err| warn!("failed to spawn child: {err}"))?; .inspect_err(|err| warn!("failed to spawn child: {err}"))?;
let mut stdin = child.stdin.take().expect("stdin not present"); let mut stdin = child.stdin.take().expect("stdin not present");
debug!("sending update request"); debug!("sending update request");
let mut buf = Vec::new();
update_ns_records(&mut buf, actions).unwrap();
stdin stdin
.write_all(update_ns_records(ip, ttl, records).as_bytes()) .write_all(&buf)
.await .await
.inspect_err(|err| warn!("failed to write to the stdin of nsupdate: {err}"))?; .inspect_err(|err| warn!("failed to write to the stdin of nsupdate: {err}"))?;
debug!("closing stdin"); debug!("closing stdin");
stdin stdin
.shutdown() .shutdown()
@ -43,21 +85,16 @@ pub async fn nsupdate(
.inspect_err(|err| warn!("failed to wait for child: {err}")) .inspect_err(|err| warn!("failed to wait for child: {err}"))
} }
fn update_ns_records(ip: IpAddr, ttl: Duration, records: &[&str]) -> String { fn update_ns_records<'a>(
use std::fmt::Write; mut buf: impl std::io::Write,
let ttl_s: u64 = ttl.as_secs(); actions: impl IntoIterator<Item = Action<'a>>,
) -> std::io::Result<()> {
let rec_type = match ip { writeln!(buf, "server 127.0.0.1")?;
IpAddr::V4(_) => "A", for action in actions {
IpAddr::V6(_) => "AAAA", writeln!(buf, "{action}")?;
};
let mut cmds = String::from("server 127.0.0.1\n");
for &record in records {
writeln!(cmds, "update delete {record} {ttl_s} IN {rec_type}").unwrap();
writeln!(cmds, "update add {record} {ttl_s} IN {rec_type} {ip}").unwrap();
} }
writeln!(cmds, "send\nquit").unwrap(); writeln!(buf, "send")?;
cmds writeln!(buf, "quit")
} }
#[cfg(test)] #[cfg(test)]
@ -66,17 +103,21 @@ mod test {
use insta::assert_snapshot; use insta::assert_snapshot;
use super::update_ns_records; use super::{update_ns_records, Action};
use crate::DEFAULT_TTL; use crate::DEFAULT_TTL;
#[test] #[test]
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn expected_update_string_A() { fn expected_update_string_A() {
assert_snapshot!(update_ns_records( let mut buf = Vec::new();
let actions = Action::from_records(
IpAddr::V4(Ipv4Addr::LOCALHOST), IpAddr::V4(Ipv4Addr::LOCALHOST),
DEFAULT_TTL, DEFAULT_TTL,
&["example.com.", "example.org.", "example.net."], &["example.com.", "example.org.", "example.net."],
), @r###" );
update_ns_records(&mut buf, actions).unwrap();
assert_snapshot!(String::from_utf8(buf).unwrap(), @r###"
server 127.0.0.1 server 127.0.0.1
update delete example.com. 60 IN A update delete example.com. 60 IN A
update add example.com. 60 IN A 127.0.0.1 update add example.com. 60 IN A 127.0.0.1
@ -92,11 +133,15 @@ mod test {
#[test] #[test]
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn expected_update_string_AAAA() { fn expected_update_string_AAAA() {
assert_snapshot!(update_ns_records( let mut buf = Vec::new();
let actions = Action::from_records(
IpAddr::V6(Ipv6Addr::LOCALHOST), IpAddr::V6(Ipv6Addr::LOCALHOST),
DEFAULT_TTL, DEFAULT_TTL,
&["example.com.", "example.org.", "example.net."], &["example.com.", "example.org.", "example.net."],
), @r###" );
update_ns_records(&mut buf, actions).unwrap();
assert_snapshot!(String::from_utf8(buf).unwrap(), @r###"
server 127.0.0.1 server 127.0.0.1
update delete example.com. 60 IN AAAA update delete example.com. 60 IN AAAA
update add example.com. 60 IN AAAA ::1 update add example.com. 60 IN AAAA ::1