refactor: reorganize main.rs #20
2 changed files with 189 additions and 154 deletions
212
src/main.rs
212
src/main.rs
|
@ -1,9 +1,7 @@
|
||||||
use std::{
|
use std::{
|
||||||
ffi::OsStr,
|
|
||||||
io::ErrorKind,
|
io::ErrorKind,
|
||||||
net::{IpAddr, SocketAddr},
|
net::{IpAddr, SocketAddr},
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
process::{ExitStatus, Stdio},
|
|
||||||
time::Duration,
|
time::Duration,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -14,11 +12,11 @@ use clap::{Parser, Subcommand};
|
||||||
use clap_verbosity_flag::Verbosity;
|
use clap_verbosity_flag::Verbosity;
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use miette::{bail, ensure, Context, IntoDiagnostic, Result};
|
use miette::{bail, ensure, Context, IntoDiagnostic, Result};
|
||||||
use tokio::io::AsyncWriteExt;
|
use tracing::{debug, error, info};
|
||||||
use tracing::{debug, error, info, warn};
|
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
mod auth;
|
mod auth;
|
||||||
|
mod nsupdate;
|
||||||
mod password;
|
mod password;
|
||||||
mod records;
|
mod records;
|
||||||
|
|
||||||
|
@ -118,6 +116,57 @@ struct AppState<'a> {
|
||||||
ip_file: &'a Path,
|
ip_file: &'a Path,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl AppState<'static> {
|
||||||
|
fn from_args(args: &Opts) -> miette::Result<Self> {
|
||||||
|
let Opts {
|
||||||
|
verbosity: _,
|
||||||
|
address: _,
|
||||||
|
port: _,
|
||||||
|
password_file: _,
|
||||||
|
data_dir,
|
||||||
|
key_file,
|
||||||
|
insecure,
|
||||||
|
subcommand: _,
|
||||||
|
records,
|
||||||
|
salt: _,
|
||||||
|
ttl,
|
||||||
|
ip_source: _,
|
||||||
|
} = args;
|
||||||
|
|
||||||
|
// Set state
|
||||||
|
let ttl = Duration::from_secs(*ttl);
|
||||||
|
|
||||||
|
// Use last registered IP address if available
|
||||||
|
let ip_file = data_dir.join("last-ip");
|
||||||
|
|
||||||
|
let state = AppState {
|
||||||
|
ttl,
|
||||||
|
// Load DNS records
|
||||||
|
records: records::load_no_verify(records)?,
|
||||||
|
// Load keyfile
|
||||||
|
key_file: key_file
|
||||||
|
.as_deref()
|
||||||
|
.map(|path| -> miette::Result<_> {
|
||||||
|
std::fs::File::open(path)
|
||||||
|
.into_diagnostic()
|
||||||
|
.wrap_err_with(|| {
|
||||||
|
format!("{} is not readable by the current user", path.display())
|
||||||
|
})?;
|
||||||
|
Ok(&*Box::leak(path.into()))
|
||||||
|
})
|
||||||
|
.transpose()?,
|
||||||
|
ip_file: Box::leak(ip_file.into_boxed_path()),
|
||||||
|
};
|
||||||
|
|
||||||
|
ensure!(
|
||||||
|
state.key_file.is_some() || *insecure,
|
||||||
|
"a key file must be used"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn load_ip(path: &Path) -> Result<Option<IpAddr>> {
|
fn load_ip(path: &Path) -> Result<Option<IpAddr>> {
|
||||||
debug!("loading last IP from {}", path.display());
|
debug!("loading last IP from {}", path.display());
|
||||||
let data = match std::fs::read_to_string(path) {
|
let data = match std::fs::read_to_string(path) {
|
||||||
|
@ -166,33 +215,31 @@ fn main() -> Result<()> {
|
||||||
return cmd.process(&args);
|
return cmd.process(&args);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize state
|
||||||
|
let state = AppState::from_args(&args)?;
|
||||||
|
|
||||||
let Opts {
|
let Opts {
|
||||||
verbosity: _,
|
verbosity: _,
|
||||||
address: ip,
|
address: ip,
|
||||||
port,
|
port,
|
||||||
password_file,
|
password_file,
|
||||||
data_dir,
|
data_dir: _,
|
||||||
key_file,
|
key_file: _,
|
||||||
insecure,
|
insecure,
|
||||||
subcommand: _,
|
subcommand: _,
|
||||||
records,
|
records: _,
|
||||||
salt,
|
salt,
|
||||||
ttl,
|
ttl: _,
|
||||||
ip_source,
|
ip_source,
|
||||||
} = args;
|
} = args;
|
||||||
|
|
||||||
info!("checking environment");
|
info!("checking environment");
|
||||||
|
|
||||||
// Set state
|
|
||||||
let ttl = Duration::from_secs(ttl);
|
|
||||||
|
|
||||||
// Use last registered IP address if available
|
|
||||||
let ip_file = data_dir.join("last-ip");
|
|
||||||
|
|
||||||
// Load password hash
|
// Load password hash
|
||||||
let password_hash = password_file
|
let password_hash = password_file
|
||||||
.map(|path| -> miette::Result<_> {
|
.map(|path| -> miette::Result<_> {
|
||||||
let pass = std::fs::read_to_string(path.as_path()).into_diagnostic()?;
|
let path = path.as_path();
|
||||||
|
let pass = std::fs::read_to_string(path).into_diagnostic()?;
|
||||||
|
|
||||||
let pass: Box<[u8]> = URL_SAFE_NO_PAD
|
let pass: Box<[u8]> = URL_SAFE_NO_PAD
|
||||||
.decode(pass.trim().as_bytes())
|
.decode(pass.trim().as_bytes())
|
||||||
|
@ -204,35 +251,11 @@ fn main() -> Result<()> {
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
|
|
||||||
let state = AppState {
|
|
||||||
ttl,
|
|
||||||
// Load DNS records
|
|
||||||
records: records::load_no_verify(&records)?,
|
|
||||||
// Load keyfile
|
|
||||||
key_file: key_file
|
|
||||||
.map(|key_file| -> miette::Result<_> {
|
|
||||||
let path = key_file.as_path();
|
|
||||||
std::fs::File::open(path)
|
|
||||||
.into_diagnostic()
|
|
||||||
.wrap_err_with(|| {
|
|
||||||
format!("{} is not readable by the current user", path.display())
|
|
||||||
})?;
|
|
||||||
Ok(&*Box::leak(key_file.into_boxed_path()))
|
|
||||||
})
|
|
||||||
.transpose()?,
|
|
||||||
ip_file: Box::leak(ip_file.into_boxed_path()),
|
|
||||||
};
|
|
||||||
|
|
||||||
ensure!(
|
ensure!(
|
||||||
password_hash.is_some() || insecure,
|
password_hash.is_some() || insecure,
|
||||||
"a password must be used"
|
"a password must be used"
|
||||||
);
|
);
|
||||||
|
|
||||||
ensure!(
|
|
||||||
state.key_file.is_some() || insecure,
|
|
||||||
"a key file must be used"
|
|
||||||
);
|
|
||||||
|
|
||||||
let rt = tokio::runtime::Builder::new_current_thread()
|
let rt = tokio::runtime::Builder::new_current_thread()
|
||||||
.enable_all()
|
.enable_all()
|
||||||
.build()
|
.build()
|
||||||
|
@ -242,7 +265,8 @@ fn main() -> Result<()> {
|
||||||
rt.block_on(async {
|
rt.block_on(async {
|
||||||
// Load previous IP and update DNS record to point to it (if available)
|
// Load previous IP and update DNS record to point to it (if available)
|
||||||
match load_ip(state.ip_file) {
|
match load_ip(state.ip_file) {
|
||||||
Ok(Some(ip)) => match nsupdate(ip, ttl, state.key_file, state.records).await {
|
Ok(Some(ip)) => {
|
||||||
|
match nsupdate::nsupdate(ip, state.ttl, state.key_file, state.records).await {
|
||||||
Ok(status) => {
|
Ok(status) => {
|
||||||
if !status.success() {
|
if !status.success() {
|
||||||
error!("nsupdate failed: code {status}");
|
error!("nsupdate failed: code {status}");
|
||||||
|
@ -255,7 +279,8 @@ fn main() -> Result<()> {
|
||||||
.into_diagnostic()
|
.into_diagnostic()
|
||||||
.wrap_err("failed to update records with previous IP");
|
.wrap_err("failed to update records with previous IP");
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
|
}
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
info!("No previous IP address set");
|
info!("No previous IP address set");
|
||||||
}
|
}
|
||||||
|
@ -295,9 +320,8 @@ async fn update_records(
|
||||||
State(state): State<AppState<'static>>,
|
State(state): State<AppState<'static>>,
|
||||||
SecureClientIp(ip): SecureClientIp,
|
SecureClientIp(ip): SecureClientIp,
|
||||||
) -> axum::response::Result<&'static str> {
|
) -> axum::response::Result<&'static str> {
|
||||||
debug!("received update request from {ip}");
|
info!("accepted update from {ip}");
|
||||||
info!("accepted update");
|
match nsupdate::nsupdate(ip, state.ttl, state.key_file, state.records).await {
|
||||||
match nsupdate(ip, state.ttl, state.key_file, state.records).await {
|
|
||||||
Ok(status) if status.success() => {
|
Ok(status) if status.success() => {
|
||||||
tokio::task::spawn_blocking(move || {
|
tokio::task::spawn_blocking(move || {
|
||||||
info!("updating last ip to {ip}");
|
info!("updating last ip to {ip}");
|
||||||
|
@ -323,103 +347,3 @@ async fn update_records(
|
||||||
.into()),
|
.into()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(level = "trace", ret(level = "warn"))]
|
|
||||||
async fn nsupdate(
|
|
||||||
ip: IpAddr,
|
|
||||||
ttl: Duration,
|
|
||||||
key_file: Option<&Path>,
|
|
||||||
records: &[&str],
|
|
||||||
) -> std::io::Result<ExitStatus> {
|
|
||||||
let mut cmd = tokio::process::Command::new("nsupdate");
|
|
||||||
if let Some(key_file) = key_file {
|
|
||||||
cmd.args([OsStr::new("-k"), key_file.as_os_str()]);
|
|
||||||
}
|
|
||||||
debug!("spawning new process");
|
|
||||||
let mut child = cmd
|
|
||||||
.stdin(Stdio::piped())
|
|
||||||
.spawn()
|
|
||||||
.inspect_err(|err| warn!("failed to spawn child: {err}"))?;
|
|
||||||
let mut stdin = child.stdin.take().expect("stdin not present");
|
|
||||||
debug!("sending update request");
|
|
||||||
stdin
|
|
||||||
.write_all(update_ns_records(ip, ttl, records).as_bytes())
|
|
||||||
.await
|
|
||||||
.inspect_err(|err| warn!("failed to write to the stdin of nsupdate: {err}"))?;
|
|
||||||
debug!("closing stdin");
|
|
||||||
stdin
|
|
||||||
.shutdown()
|
|
||||||
.await
|
|
||||||
.inspect_err(|err| warn!("failed to close stdin to nsupdate: {err}"))?;
|
|
||||||
debug!("waiting for nsupdate to exit");
|
|
||||||
child
|
|
||||||
.wait()
|
|
||||||
.await
|
|
||||||
.inspect_err(|err| warn!("failed to wait for child: {err}"))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn update_ns_records(ip: IpAddr, ttl: Duration, records: &[&str]) -> String {
|
|
||||||
use std::fmt::Write;
|
|
||||||
let ttl_s: u64 = ttl.as_secs();
|
|
||||||
|
|
||||||
let rec_type = match ip {
|
|
||||||
IpAddr::V4(_) => "A",
|
|
||||||
IpAddr::V6(_) => "AAAA",
|
|
||||||
};
|
|
||||||
let mut cmds = String::from("server 127.0.0.1\n");
|
|
||||||
for &record in records {
|
|
||||||
writeln!(cmds, "update delete {record} {ttl_s} IN {rec_type}").unwrap();
|
|
||||||
writeln!(cmds, "update add {record} {ttl_s} IN {rec_type} {ip}").unwrap();
|
|
||||||
}
|
|
||||||
writeln!(cmds, "send\nquit").unwrap();
|
|
||||||
cmds
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod test {
|
|
||||||
use insta::assert_snapshot;
|
|
||||||
|
|
||||||
use crate::{update_ns_records, DEFAULT_TTL};
|
|
||||||
|
|
||||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
#[allow(non_snake_case)]
|
|
||||||
fn expected_update_string_A() {
|
|
||||||
assert_snapshot!(update_ns_records(
|
|
||||||
IpAddr::V4(Ipv4Addr::LOCALHOST),
|
|
||||||
DEFAULT_TTL,
|
|
||||||
&["example.com.", "example.org.", "example.net."],
|
|
||||||
), @r###"
|
|
||||||
server 127.0.0.1
|
|
||||||
update delete example.com. 60 IN A
|
|
||||||
update add example.com. 60 IN A 127.0.0.1
|
|
||||||
update delete example.org. 60 IN A
|
|
||||||
update add example.org. 60 IN A 127.0.0.1
|
|
||||||
update delete example.net. 60 IN A
|
|
||||||
update add example.net. 60 IN A 127.0.0.1
|
|
||||||
send
|
|
||||||
quit
|
|
||||||
"###);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
#[allow(non_snake_case)]
|
|
||||||
fn expected_update_string_AAAA() {
|
|
||||||
assert_snapshot!(update_ns_records(
|
|
||||||
IpAddr::V6(Ipv6Addr::LOCALHOST),
|
|
||||||
DEFAULT_TTL,
|
|
||||||
&["example.com.", "example.org.", "example.net."],
|
|
||||||
), @r###"
|
|
||||||
server 127.0.0.1
|
|
||||||
update delete example.com. 60 IN AAAA
|
|
||||||
update add example.com. 60 IN AAAA ::1
|
|
||||||
update delete example.org. 60 IN AAAA
|
|
||||||
update add example.org. 60 IN AAAA ::1
|
|
||||||
update delete example.net. 60 IN AAAA
|
|
||||||
update add example.net. 60 IN AAAA ::1
|
|
||||||
send
|
|
||||||
quit
|
|
||||||
"###);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
111
src/nsupdate.rs
Normal file
111
src/nsupdate.rs
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
use std::{
|
||||||
|
ffi::OsStr,
|
||||||
|
net::IpAddr,
|
||||||
|
path::Path,
|
||||||
|
process::{ExitStatus, Stdio},
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
use tracing::{debug, warn};
|
||||||
|
|
||||||
|
#[tracing::instrument(level = "trace", ret(level = "warn"))]
|
||||||
|
pub async fn nsupdate(
|
||||||
|
ip: IpAddr,
|
||||||
|
ttl: Duration,
|
||||||
|
key_file: Option<&Path>,
|
||||||
|
records: &[&str],
|
||||||
|
) -> std::io::Result<ExitStatus> {
|
||||||
|
let mut cmd = tokio::process::Command::new("nsupdate");
|
||||||
|
if let Some(key_file) = key_file {
|
||||||
|
cmd.args([OsStr::new("-k"), key_file.as_os_str()]);
|
||||||
|
}
|
||||||
|
debug!("spawning new process");
|
||||||
|
let mut child = cmd
|
||||||
|
.stdin(Stdio::piped())
|
||||||
|
.spawn()
|
||||||
|
.inspect_err(|err| warn!("failed to spawn child: {err}"))?;
|
||||||
|
let mut stdin = child.stdin.take().expect("stdin not present");
|
||||||
|
debug!("sending update request");
|
||||||
|
stdin
|
||||||
|
.write_all(update_ns_records(ip, ttl, records).as_bytes())
|
||||||
|
.await
|
||||||
|
.inspect_err(|err| warn!("failed to write to the stdin of nsupdate: {err}"))?;
|
||||||
|
debug!("closing stdin");
|
||||||
|
stdin
|
||||||
|
.shutdown()
|
||||||
|
.await
|
||||||
|
.inspect_err(|err| warn!("failed to close stdin to nsupdate: {err}"))?;
|
||||||
|
debug!("waiting for nsupdate to exit");
|
||||||
|
child
|
||||||
|
.wait()
|
||||||
|
.await
|
||||||
|
.inspect_err(|err| warn!("failed to wait for child: {err}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update_ns_records(ip: IpAddr, ttl: Duration, records: &[&str]) -> String {
|
||||||
|
use std::fmt::Write;
|
||||||
|
let ttl_s: u64 = ttl.as_secs();
|
||||||
|
|
||||||
|
let rec_type = match ip {
|
||||||
|
IpAddr::V4(_) => "A",
|
||||||
|
IpAddr::V6(_) => "AAAA",
|
||||||
|
};
|
||||||
|
let mut cmds = String::from("server 127.0.0.1\n");
|
||||||
|
for &record in records {
|
||||||
|
writeln!(cmds, "update delete {record} {ttl_s} IN {rec_type}").unwrap();
|
||||||
|
writeln!(cmds, "update add {record} {ttl_s} IN {rec_type} {ip}").unwrap();
|
||||||
|
}
|
||||||
|
writeln!(cmds, "send\nquit").unwrap();
|
||||||
|
cmds
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||||
|
|
||||||
|
use insta::assert_snapshot;
|
||||||
|
|
||||||
|
use super::update_ns_records;
|
||||||
|
use crate::DEFAULT_TTL;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
fn expected_update_string_A() {
|
||||||
|
assert_snapshot!(update_ns_records(
|
||||||
|
IpAddr::V4(Ipv4Addr::LOCALHOST),
|
||||||
|
DEFAULT_TTL,
|
||||||
|
&["example.com.", "example.org.", "example.net."],
|
||||||
|
), @r###"
|
||||||
|
server 127.0.0.1
|
||||||
|
update delete example.com. 60 IN A
|
||||||
|
update add example.com. 60 IN A 127.0.0.1
|
||||||
|
update delete example.org. 60 IN A
|
||||||
|
update add example.org. 60 IN A 127.0.0.1
|
||||||
|
update delete example.net. 60 IN A
|
||||||
|
update add example.net. 60 IN A 127.0.0.1
|
||||||
|
send
|
||||||
|
quit
|
||||||
|
"###);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
fn expected_update_string_AAAA() {
|
||||||
|
assert_snapshot!(update_ns_records(
|
||||||
|
IpAddr::V6(Ipv6Addr::LOCALHOST),
|
||||||
|
DEFAULT_TTL,
|
||||||
|
&["example.com.", "example.org.", "example.net."],
|
||||||
|
), @r###"
|
||||||
|
server 127.0.0.1
|
||||||
|
update delete example.com. 60 IN AAAA
|
||||||
|
update add example.com. 60 IN AAAA ::1
|
||||||
|
update delete example.org. 60 IN AAAA
|
||||||
|
update add example.org. 60 IN AAAA ::1
|
||||||
|
update delete example.net. 60 IN AAAA
|
||||||
|
update add example.net. 60 IN AAAA ::1
|
||||||
|
send
|
||||||
|
quit
|
||||||
|
"###);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue