webnsupdate/src/main.rs

542 lines
15 KiB
Rust
Raw Normal View History

2024-05-03 20:29:10 +02:00
use std::{
io::ErrorKind,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
2024-05-03 20:29:10 +02:00
path::{Path, PathBuf},
time::Duration,
};
use axum::{
extract::{Query, State},
routing::get,
Router,
};
use axum_client_ip::{SecureClientIp, SecureClientIpSource};
2024-05-03 20:29:10 +02:00
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use clap::{Parser, Subcommand};
2024-10-26 13:23:22 +02:00
use clap_verbosity_flag::Verbosity;
2024-05-03 20:29:10 +02:00
use http::StatusCode;
use miette::{bail, ensure, Context, IntoDiagnostic, Result};
2024-11-23 20:56:03 +01:00
use tracing::{debug, error, info};
2024-05-03 20:29:10 +02:00
use tracing_subscriber::EnvFilter;
mod auth;
2024-11-23 20:56:03 +01:00
mod nsupdate;
mod password;
mod records;
2024-05-03 20:29:10 +02:00
const DEFAULT_TTL: Duration = Duration::from_secs(60);
const DEFAULT_SALT: &str = "UpdateMyDNS";
#[derive(Debug, Parser)]
struct Opts {
2024-10-26 13:23:22 +02:00
#[command(flatten)]
verbosity: Verbosity<clap_verbosity_flag::InfoLevel>,
2024-10-26 13:23:22 +02:00
2024-05-03 20:29:10 +02:00
/// Ip address of the server
#[arg(long, default_value = "127.0.0.1")]
address: IpAddr,
2024-05-03 20:29:10 +02:00
/// Port of the server
#[arg(long, default_value_t = 5353)]
port: u16,
2024-05-03 20:29:10 +02:00
/// File containing password to match against
///
/// Should be of the format `username:password` and contain a single password
#[arg(long)]
password_file: Option<PathBuf>,
2024-05-03 20:29:10 +02:00
/// Salt to get more unique hashed passwords and prevent table based attacks
#[arg(long, default_value = DEFAULT_SALT)]
salt: String,
2024-05-03 20:29:10 +02:00
/// Time To Live (in seconds) to set on the DNS records
#[arg(long, default_value_t = DEFAULT_TTL.as_secs())]
ttl: u64,
/// Data directory
#[arg(long, default_value = ".")]
data_dir: PathBuf,
2024-05-03 20:29:10 +02:00
/// File containing the records that should be updated when an update request is made
///
/// There should be one record per line:
///
/// ```text
/// example.com.
/// mail.example.com.
/// ```
#[arg(long)]
records: PathBuf,
2024-05-03 20:29:10 +02:00
/// Keyfile `nsupdate` should use
///
/// If specified, then `webnsupdate` must have read access to the file
#[arg(long)]
key_file: Option<PathBuf>,
/// Allow not setting a password
2024-05-03 20:29:10 +02:00
#[arg(long)]
insecure: bool,
/// Set client IP source
///
/// see: <https://docs.rs/axum-client-ip/latest/axum_client_ip/enum.SecureClientIpSource.html>
#[clap(long, default_value = "RightmostXForwardedFor")]
ip_source: SecureClientIpSource,
/// Set which IPs to allow updating
#[clap(long, default_value_t = IpType::Both)]
ip_type: IpType,
2024-05-03 20:29:10 +02:00
#[clap(subcommand)]
subcommand: Option<Cmd>,
}
#[derive(Debug, Default, Clone, Copy)]
enum IpType {
#[default]
Both,
IPv4Only,
IPv6Only,
}
impl IpType {
fn valid_for_type(self, ip: IpAddr) -> bool {
match self {
IpType::Both => true,
IpType::IPv4Only => ip.is_ipv4(),
IpType::IPv6Only => ip.is_ipv6(),
}
}
}
impl std::fmt::Display for IpType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
IpType::Both => f.write_str("both"),
IpType::IPv4Only => f.write_str("ipv4-only"),
IpType::IPv6Only => f.write_str("ipv6-only"),
}
}
}
impl std::str::FromStr for IpType {
type Err = miette::Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s {
"both" => Ok(Self::Both),
"ipv4-only" => Ok(Self::IPv4Only),
"ipv6-only" => Ok(Self::IPv6Only),
_ => bail!("expected one of 'ipv4-only', 'ipv6-only' or 'both', got '{s}'"),
}
}
}
2024-05-03 20:29:10 +02:00
#[derive(Debug, Subcommand)]
enum Cmd {
Mkpasswd(password::Mkpasswd),
2024-05-03 20:29:10 +02:00
/// Verify the records file
Verify,
}
impl Cmd {
pub fn process(self, args: &Opts) -> Result<()> {
match self {
Cmd::Mkpasswd(mkpasswd) => mkpasswd.process(args),
Cmd::Verify => records::load(&args.records).map(drop),
}
}
}
2024-05-03 20:29:10 +02:00
#[derive(Clone)]
struct AppState<'a> {
/// TTL set on the Zonefile
ttl: Duration,
2024-05-03 20:29:10 +02:00
/// The IN A/AAAA records that should have their IPs updated
records: &'a [&'a str],
2024-05-03 20:29:10 +02:00
/// The TSIG key file
key_file: Option<&'a Path>,
/// The file where the last IP is stored
ip_file: &'a Path,
/// Last recorded IPs
last_ips: std::sync::Arc<tokio::sync::Mutex<SavedIPs>>,
/// The IP type for which to allow updates
ip_type: IpType,
}
#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)]
struct SavedIPs {
#[serde(skip_serializing_if = "Option::is_none")]
ipv4: Option<Ipv4Addr>,
#[serde(skip_serializing_if = "Option::is_none")]
ipv6: Option<Ipv6Addr>,
}
impl SavedIPs {
fn update(&mut self, ip: IpAddr) {
match ip {
IpAddr::V4(ipv4_addr) => self.ipv4 = Some(ipv4_addr),
IpAddr::V6(ipv6_addr) => self.ipv6 = Some(ipv6_addr),
}
}
fn ips(&self) -> impl Iterator<Item = IpAddr> {
self.ipv4
.map(IpAddr::V4)
.into_iter()
.chain(self.ipv6.map(IpAddr::V6))
}
fn from_str(data: &str) -> miette::Result<Self> {
match data.parse::<IpAddr>() {
// Old format
Ok(IpAddr::V4(ipv4)) => Ok(Self {
ipv4: Some(ipv4),
ipv6: None,
}),
Ok(IpAddr::V6(ipv6)) => Ok(Self {
ipv4: None,
ipv6: Some(ipv6),
}),
Err(_) => serde_json::from_str(data).into_diagnostic(),
}
}
}
2024-11-23 20:56:03 +01:00
impl AppState<'static> {
fn from_args(args: &Opts) -> miette::Result<Self> {
let Opts {
verbosity: _,
address: _,
port: _,
password_file: _,
data_dir,
key_file,
insecure,
subcommand: _,
records,
salt: _,
ttl,
ip_source: _,
ip_type,
2024-11-23 20:56:03 +01:00
} = args;
// Set state
let ttl = Duration::from_secs(*ttl);
// Use last registered IP address if available
let ip_file = Box::leak(data_dir.join("last-ip.json").into_boxed_path());
2024-11-23 20:56:03 +01:00
let state = AppState {
ttl,
// Load DNS records
records: records::load_no_verify(records)?,
// Load keyfile
key_file: key_file
.as_deref()
.map(|path| -> miette::Result<_> {
std::fs::File::open(path)
.into_diagnostic()
.wrap_err_with(|| {
format!("{} is not readable by the current user", path.display())
})?;
Ok(&*Box::leak(path.into()))
})
.transpose()?,
ip_file,
ip_type: *ip_type,
last_ips: std::sync::Arc::new(tokio::sync::Mutex::new(
load_ip(ip_file)?.unwrap_or_default(),
)),
2024-11-23 20:56:03 +01:00
};
ensure!(
state.key_file.is_some() || *insecure,
"a key file must be used"
);
Ok(state)
}
}
fn load_ip(path: &Path) -> Result<Option<SavedIPs>> {
2024-10-26 13:23:22 +02:00
debug!("loading last IP from {}", path.display());
let data = match std::fs::read_to_string(path) {
Ok(ip) => ip,
Err(err) => {
return match err.kind() {
ErrorKind::NotFound => Ok(None),
_ => Err(err).into_diagnostic().wrap_err_with(|| {
format!("failed to load last ip address from {}", path.display())
}),
}
}
};
SavedIPs::from_str(&data)
.wrap_err_with(|| format!("failed to load last ip address from {}", path.display()))
.map(Some)
2024-05-03 20:29:10 +02:00
}
#[tracing::instrument(err)]
fn main() -> Result<()> {
// set panic hook to pretty print with miette's formatter
2024-05-03 20:29:10 +02:00
miette::set_panic_hook();
// parse cli arguments
let mut args = Opts::parse();
// configure logger
2024-05-03 20:29:10 +02:00
let subscriber = tracing_subscriber::FmtSubscriber::builder()
.without_time()
.with_env_filter(
EnvFilter::builder()
.with_default_directive(args.verbosity.tracing_level_filter().into())
2024-05-03 20:29:10 +02:00
.from_env_lossy(),
)
.finish();
2024-05-03 20:29:10 +02:00
tracing::subscriber::set_global_default(subscriber)
.into_diagnostic()
.wrap_err("failed to set global tracing subscriber")?;
debug!("{args:?}");
// process subcommand
if let Some(cmd) = args.subcommand.take() {
return cmd.process(&args);
2024-05-03 20:29:10 +02:00
}
2024-11-23 20:56:03 +01:00
// Initialize state
let state = AppState::from_args(&args)?;
let Opts {
2024-10-26 13:23:22 +02:00
verbosity: _,
address: ip,
port,
password_file,
2024-11-23 20:56:03 +01:00
data_dir: _,
key_file: _,
insecure,
subcommand: _,
2024-11-23 20:56:03 +01:00
records: _,
salt,
2024-11-23 20:56:03 +01:00
ttl: _,
ip_source,
ip_type,
} = args;
2024-05-03 20:29:10 +02:00
info!("checking environment");
// Load password hash
let password_hash = password_file
.map(|path| -> miette::Result<_> {
2024-11-23 20:56:03 +01:00
let path = path.as_path();
let pass = std::fs::read_to_string(path).into_diagnostic()?;
let pass: Box<[u8]> = URL_SAFE_NO_PAD
.decode(pass.trim().as_bytes())
.into_diagnostic()
.wrap_err_with(|| format!("failed to decode password from {}", path.display()))?
.into();
Ok(pass)
})
.transpose()
.wrap_err("failed to load password hash")?;
ensure!(
password_hash.is_some() || insecure,
"a password must be used"
);
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
2024-05-03 20:29:10 +02:00
.into_diagnostic()
.wrap_err("failed to start the tokio runtime")?;
rt.block_on(async {
// Update DNS record with previous IPs (if available)
let ips = state.last_ips.lock().await.clone();
for ip in ips.ips() {
if !ip_type.valid_for_type(ip) {
continue;
}
match nsupdate::nsupdate(ip, state.ttl, state.key_file, state.records).await {
Ok(status) => {
if !status.success() {
error!("nsupdate failed: code {status}");
bail!("nsupdate returned with code {status}");
}
}
Err(err) => {
error!("Failed to update records with previous IP: {err}");
return Err(err)
.into_diagnostic()
.wrap_err("failed to update records with previous IP");
}
2024-11-23 20:56:03 +01:00
}
}
// Create services
let app = Router::new()
.route("/update", get(update_records))
.route("/fritzbox-dyn-dns", get(fritzbox_dyn_dns));
// if a password is provided, validate it
let app = if let Some(pass) = password_hash {
app.layer(auth::layer(Box::leak(pass), String::leak(salt)))
} else {
app
}
.layer(ip_source.into_extension())
.with_state(state);
// Start services
info!("starting listener on {ip}:{port}");
let listener = tokio::net::TcpListener::bind(SocketAddr::new(ip, port))
.await
.into_diagnostic()?;
info!("listening on {ip}:{port}");
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
2024-05-03 20:29:10 +02:00
.await
.into_diagnostic()
})
.wrap_err("failed to run main loop")
2024-05-03 20:29:10 +02:00
}
#[derive(Debug, serde::Deserialize)]
struct FritzBoxUpdateParams {
/// The domain that should be updated
#[allow(unused)]
domain: Option<String>,
/// IPv4 address for the domain
ipv4: Option<Ipv4Addr>,
/// IPv6 address for the domain
ipv6: Option<Ipv6Addr>,
/// IPv6 prefix for the home network
#[allow(unused)]
ipv6prefix: Option<String>,
/// Whether the networks uses both IPv4 and IPv6
#[allow(unused)]
dualstack: Option<String>,
}
#[tracing::instrument(skip(state), level = "trace", ret(level = "info"))]
async fn fritzbox_dyn_dns(
State(state): State<AppState<'static>>,
update_params: Query<FritzBoxUpdateParams>,
) -> axum::response::Result<&'static str> {
info!("received params: {update_params:#?}");
let FritzBoxUpdateParams {
domain: _,
ipv4,
ipv6,
ipv6prefix: _,
dualstack: _,
} = update_params.0;
if ipv4.is_none() && ipv6.is_none() {
return Err((
StatusCode::BAD_REQUEST,
"failed to provide an IP for the update",
)
.into());
}
if let Some(ip) = ipv4 {
let ip = IpAddr::V4(ip);
if !state.ip_type.valid_for_type(ip) {
tracing::warn!("requested update of IPv4 but we are {}", state.ip_type);
}
_ = trigger_update(ip, &state).await?;
}
if let Some(ip) = ipv6 {
let ip = IpAddr::V6(ip);
if !state.ip_type.valid_for_type(ip) {
tracing::warn!("requested update of IPv6 but we are {}", state.ip_type);
}
_ = trigger_update(ip, &state).await?;
}
Ok("Successfully updated IP of records!\n")
}
#[tracing::instrument(skip(state), level = "trace", ret(level = "info"))]
2024-05-03 20:29:10 +02:00
async fn update_records(
State(state): State<AppState<'static>>,
SecureClientIp(ip): SecureClientIp,
2024-05-03 20:29:10 +02:00
) -> axum::response::Result<&'static str> {
2024-11-23 20:56:03 +01:00
info!("accepted update from {ip}");
if !state.ip_type.valid_for_type(ip) {
tracing::warn!(
"rejecting update from {ip} as we are running a {} filter",
state.ip_type
);
return Err((
StatusCode::CONFLICT,
format!("running in {} mode", state.ip_type),
)
.into());
}
trigger_update(ip, &state).await
}
#[tracing::instrument(skip(state), level = "trace", ret(level = "info"))]
async fn trigger_update(
ip: IpAddr,
state: &AppState<'static>,
) -> axum::response::Result<&'static str> {
2024-11-23 20:56:03 +01:00
match nsupdate::nsupdate(ip, state.ttl, state.key_file, state.records).await {
Ok(status) if status.success() => {
let ips = {
// Update state
let mut ips = state.last_ips.lock().await;
ips.update(ip);
ips.clone()
};
let ip_file = state.ip_file;
tokio::task::spawn_blocking(move || {
info!("updating last ips to {ips:?}");
let data = serde_json::to_vec(&ips).expect("invalid serialization impl");
if let Err(err) = std::fs::write(ip_file, data) {
error!("Failed to update last IP: {err}");
}
info!("updated last ips to {ips:?}");
});
Ok("Successfully updated IP of records!\n")
}
2024-05-03 20:29:10 +02:00
Ok(status) => {
error!("nsupdate failed with code {status}");
Err((
StatusCode::INTERNAL_SERVER_ERROR,
"nsupdate failed, check server logs",
)
.into())
2024-05-03 20:29:10 +02:00
}
Err(error) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("failed to update records: {error}"),
)
.into()),
}
}