diff --git a/Cargo.lock b/Cargo.lock index 97a263a..ecbb0a3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -304,7 +304,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9" dependencies = [ "nonempty", - "thiserror", + "thiserror 1.0.69", ] [[package]] @@ -454,6 +454,7 @@ dependencies = [ "linked-hash-map", "once_cell", "pin-project", + "serde", "similar", ] @@ -542,7 +543,7 @@ dependencies = [ "supports-unicode", "terminal_size", "textwrap", - "thiserror", + "thiserror 1.0.69", "unicode-width", ] @@ -948,7 +949,16 @@ version = "1.0.69" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" dependencies = [ - "thiserror-impl", + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" +dependencies = [ + "thiserror-impl 2.0.11", ] [[package]] @@ -962,6 +972,17 @@ dependencies = [ "syn", ] +[[package]] +name = "thiserror-impl" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "thread_local" version = "1.1.8" @@ -1162,6 +1183,7 @@ dependencies = [ "ring", "serde", "serde_json", + "thiserror 2.0.11", "tokio", "tower-http", "tracing", diff --git a/Cargo.toml b/Cargo.toml index dd4bde5..1022360 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,3 @@ -cargo-features = ["codegen-backend"] - [package] description = "An HTTP server using HTTP basic auth to make secure calls to nsupdate" name = "webnsupdate" @@ -29,13 +27,14 @@ miette = { version = "7", features = ["fancy"] } ring = { version = "0.17", features = ["std"] } serde = { version = "1", features = ["derive"] } serde_json = "1" +thiserror = "2" tokio = { version = "1", features = ["macros", "rt", "process", "io-util"] } tower-http = { version = "0.6", features = ["validate-request"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } [dev-dependencies] -insta = "1" +insta = { version = "1", features = ["json"] } [profile.release] opt-level = "s" @@ -46,4 +45,3 @@ codegen-units = 1 [profile.dev] debug = 0 -codegen-backend = "cranelift" diff --git a/flake-modules/module.nix b/flake-modules/module.nix index e70ae35..ddfa934 100644 --- a/flake-modules/module.nix +++ b/flake-modules/module.nix @@ -14,8 +14,38 @@ let mkPackageOption types ; + format = pkgs.formats.json { }; in { + # imports = [ + # (lib.mkRenamedOptionModule + # [ "services" "webnsupdate" "passwordFile" ] + # [ "services" "webnsupdate" "settings" "password_file" ] + # ) + # (lib.mkRenamedOptionModule + # [ "services" "webnsupdate" "keyFile" ] + # [ "services" "webnsupdate" "settings" "key_file" ] + # ) + # (lib.mkRemovedOptionModule [ "services" "webnsupdate" "allowedIPVersion" ] '' + # This option was replaced with 'services.webnsupdate.settings.ip_type' which defaults to Both. + # '') + # (lib.mkRemovedOptionModule [ "services" "webnsupdate" "bindIp" ] '' + # This option was replaced with 'services.webnsupdate.settings.address' which defaults to 127.0.0.1:5353. + # '') + # (lib.mkRemovedOptionModule [ "services" "webnsupdate" "bindPort" ] '' + # This option was replaced with 'services.webnsupdate.settings.address' which defaults to 127.0.0.1:5353. + # '') + # (lib.mkRemovedOptionModule [ "services" "webnsupdate" "records" ] '' + # This option was replaced with 'services.webnsupdate.settings.records' which defaults to []. + # '') + # (lib.mkRemovedOptionModule [ "services" "webnsupdate" "recordsFile" ] '' + # This option was replaced with 'services.webnsupdate.settings.records' which defaults to []. + # '') + # (lib.mkRemovedOptionModule [ "services" "webnsupdate" "ttl" ] '' + # This option was replaced with 'services.webnsupdate.settings.ttl' which defaults to 600s. + # '') + # ]; + options.services.webnsupdate = mkOption { description = "An HTTP server for nsupdate."; default = { }; @@ -31,82 +61,92 @@ let example = [ "--ip-source" ]; }; package = mkPackageOption pkgs "webnsupdate" { }; - bindIp = mkOption { - description = '' - IP address to bind to. + settings = mkOption { + description = "The webnsupdate JSON configuration"; + default = { }; + type = types.submodule { + freeformType = format.type; + options = { + address = mkOption { + description = '' + IP address and port to bind to. - Setting it to anything other than localhost is very insecure as - `webnsupdate` only supports plain HTTP and should always be behind a - reverse proxy. - ''; - type = types.str; - default = "localhost"; - example = "0.0.0.0"; - }; - bindPort = mkOption { - description = "Port to bind to."; - type = types.port; - default = 5353; - }; - allowedIPVersion = mkOption { - description = ''The allowed IP versions to accept updates from.''; - type = types.enum [ - "both" - "ipv4-only" - "ipv6-only" - ]; - default = "both"; - example = "ipv4-only"; - }; - passwordFile = mkOption { - description = '' - The file where the password is stored. + Setting it to anything other than localhost is very + insecure as `webnsupdate` only supports plain HTTP and + should always be behind a reverse proxy. + ''; + type = types.str; + default = "127.0.0.1:5353"; + example = "[::1]:5353"; + }; + ip_type = mkOption { + description = ''The allowed IP versions to accept updates from.''; + type = types.enum [ + "Both" + "Ipv4Only" + "Ipv6Only" + ]; + default = "Both"; + example = "Ipv4Only"; + }; + password_file = mkOption { + description = '' + The file where the password is stored. - This file can be created by running `webnsupdate mkpasswd $USERNAME $PASSWORD`. - ''; - type = types.path; - example = "/secrets/webnsupdate.pass"; - }; - keyFile = mkOption { - description = '' - The TSIG key that `nsupdate` should use. + This file can be created by running `webnsupdate mkpasswd $USERNAME $PASSWORD`. + ''; + type = types.path; + example = "/secrets/webnsupdate.pass"; + }; + key_file = mkOption { + description = '' + The TSIG key that `nsupdate` should use. - This file will be passed to `nsupdate` through the `-k` option, so look - at `man 8 nsupdate` for information on the key's format. - ''; - type = types.path; - example = "/secrets/webnsupdate.key"; - }; - ttl = mkOption { - description = "The TTL that should be set on the zone records created by `nsupdate`."; - type = types.ints.positive; - default = 60; - example = 3600; - }; - records = mkOption { - description = '' - The fqdn of records that should be updated. + This file will be passed to `nsupdate` through the `-k` option, so look + at `man 8 nsupdate` for information on the key's format. + ''; + type = types.path; + example = "/secrets/webnsupdate.key"; + }; + ttl = mkOption { + description = "The TTL that should be set on the zone records created by `nsupdate`."; + default = { + secs = 600; + }; + example = { + secs = 600; + nanos = 50000; + }; + type = types.submodule { + options = { + secs = mkOption { + description = "The TTL (in seconds) that should be set on the zone records created by `nsupdate`."; + example = 3600; + }; + nanos = mkOption { + description = "The TTL (in nanoseconds) that should be set on the zone records created by `nsupdate`."; + default = 0; + example = 50000; + }; + }; + }; + }; + records = mkOption { + description = '' + The fqdn of records that should be updated. - Empty lines will be ignored, but whitespace will not be. - ''; - type = types.nullOr types.lines; - default = null; - example = '' - example.com. - - example.org. - ci.example.org. - ''; - }; - recordsFile = mkOption { - description = '' - The fqdn of records that should be updated. - - Empty lines will be ignored, but whitespace will not be. - ''; - type = types.nullOr types.path; - default = null; - example = "/secrets/webnsupdate.records"; + Empty lines will be ignored, but whitespace will not be. + ''; + type = types.listOf types.str; + default = [ ]; + example = [ + "example.com." + "example.org." + "ci.example.org." + ]; + }; + }; + }; }; user = mkOption { description = "The user to run as."; @@ -124,41 +164,14 @@ let config = let - recordsFile = - if cfg.recordsFile != null then cfg.recordsFile else pkgs.writeText "webnsrecords" cfg.records; - args = lib.strings.escapeShellArgs ( - [ - "--records" - recordsFile - "--key-file" - cfg.keyFile - "--password-file" - cfg.passwordFile - "--address" - cfg.bindIp - "--ip-type" - cfg.allowedIPVersion - "--port" - (builtins.toString cfg.bindPort) - "--ttl" - (builtins.toString cfg.ttl) - "--data-dir=%S/webnsupdate" - ] - ++ cfg.extraArgs - ); + configFile = format.generate "webnsupdate.json" cfg.settings; + args = lib.strings.escapeShellArgs ([ "--config=${configFile}" ] ++ cfg.extraArgs); cmd = "${lib.getExe cfg.package} ${args}"; in lib.mkIf cfg.enable { + # FIXME: re-enable once I stop using the patched version of bind # warnings = # lib.optional (!config.services.bind.enable) "`webnsupdate` is expected to be used alongside `bind`. This is an unsupported configuration."; - assertions = [ - { - assertion = - (cfg.records != null || cfg.recordsFile != null) - && !(cfg.records != null && cfg.recordsFile != null); - message = "Exactly one of `services.webnsupdate.records` and `services.webnsupdate.recordsFile` must be set."; - } - ]; systemd.services.webnsupdate = { description = "Web interface for nsupdate."; @@ -167,9 +180,10 @@ let "network.target" "bind.service" ]; - preStart = "${cmd} verify"; + preStart = "${lib.getExe cfg.package} verify ${configFile}"; path = [ pkgs.dig ]; startLimitIntervalSec = 60; + environment.DATA_DIR = "%S/webnsupdate"; serviceConfig = { ExecStart = [ cmd ]; Type = "exec"; diff --git a/flake-modules/tests.nix b/flake-modules/tests.nix index 45fc5ff..8257326 100644 --- a/flake-modules/tests.nix +++ b/flake-modules/tests.nix @@ -9,7 +9,7 @@ lastIPPath = "/var/lib/webnsupdate/last-ip.json"; zoneFile = pkgs.writeText "${testDomain}.zoneinfo" '' - $TTL 60 ; 1 minute + $TTL 600 ; 10 minutes $ORIGIN ${testDomain}. @ IN SOA ns1.${testDomain}. admin.${testDomain}. ( 1 ; serial @@ -73,20 +73,19 @@ webnsupdate = { enable = true; - bindIp = lib.mkDefault "127.0.0.1"; - keyFile = "/etc/bind/rndc.key"; - # test:test (user:password) - passwordFile = pkgs.writeText "webnsupdate.pass" "FQoNmuU1BKfg8qsU96F6bK5ykp2b0SLe3ZpB3nbtfZA"; package = self'.packages.webnsupdate; - extraArgs = [ - "-vvv" # debug messages - "--ip-source=ConnectInfo" - ]; - records = '' - test1.${testDomain}. - test2.${testDomain}. - test3.${testDomain}. - ''; + extraArgs = [ "-vvv" ]; # debug messages + settings = { + address = lib.mkDefault "127.0.0.1:5353"; + key_file = "/etc/bind/rndc.key"; + password_file = pkgs.writeText "webnsupdate.pass" "FQoNmuU1BKfg8qsU96F6bK5ykp2b0SLe3ZpB3nbtfZA"; # test:test + ip_source = lib.mkDefault "ConnectInfo"; + records = [ + "test1.${testDomain}." + "test2.${testDomain}." + "test3.${testDomain}." + ]; + }; }; }; }; @@ -97,7 +96,7 @@ webnsupdate-ipv4-machine ]; - config.services.webnsupdate.bindIp = "::1"; + config.services.webnsupdate.settings.address = "[::1]:5353"; }; webnsupdate-nginx-machine = @@ -109,26 +108,26 @@ config.services = { # Use default IP Source - webnsupdate.extraArgs = lib.mkForce [ "-vvv" ]; # debug messages + webnsupdate.settings.ip_source = "RightmostXForwardedFor"; nginx = { enable = true; recommendedProxySettings = true; virtualHosts.webnsupdate.locations."/".proxyPass = - "http://${config.services.webnsupdate.bindIp}:${builtins.toString config.services.webnsupdate.bindPort}"; + "http://${config.services.webnsupdate.settings.address}"; }; }; }; webnsupdate-ipv4-only-machine = { imports = [ webnsupdate-nginx-machine ]; - config.services.webnsupdate.allowedIPVersion = "ipv4-only"; + config.services.webnsupdate.settings.ip_type = "Ipv4Only"; }; webnsupdate-ipv6-only-machine = { imports = [ webnsupdate-nginx-machine ]; - config.services.webnsupdate.allowedIPVersion = "ipv6-only"; + config.services.webnsupdate.settings.ip_type = "Ipv6Only"; }; # "A" for IPv4, "AAAA" for IPv6, "ANY" for any @@ -158,9 +157,9 @@ STATIC_DOMAINS: list[str] = ["${testDomain}", "ns1.${testDomain}", "nsupdate.${testDomain}"] DYNAMIC_DOMAINS: list[str] = ["test1.${testDomain}", "test2.${testDomain}", "test3.${testDomain}"] - def dig_cmd(domain: str, record: str, ip: str | None) -> str: - match_ip = "" if ip is None else f"\\s\\+60\\s\\+IN\\s\\+{record}\\s\\+{ip}$" - return f"dig @localhost {record} {domain} +noall +answer | grep '^{domain}.{match_ip}'" + def dig_cmd(domain: str, record: str, ip: str | None) -> tuple[str, str]: + match_ip = "" if ip is None else f"\\s\\+600\\s\\+IN\\s\\+{record}\\s\\+{ip}$" + return f"dig @localhost {record} {domain} +noall +answer", f"grep '^{domain}.{match_ip}'" def curl_cmd(domain: str, identity: str, path: str, query: dict[str, str]) -> str: from urllib.parse import urlencode @@ -168,10 +167,16 @@ return f"{CURL} -u {identity} -X GET 'http://{domain}{"" if NGINX else ":5353"}/{path}{q}'" def domain_available(domain: str, record: str, ip: str | None=None): - machine.succeed(dig_cmd(domain, record, ip)) + dig, grep = dig_cmd(domain, record, ip) + rc, output = machine.execute(dig) + print(f"{dig}[{rc}]: {output}") + machine.succeed(f"{dig} | {grep}") def domain_missing(domain: str, record: str, ip: str | None=None): - machine.fail(dig_cmd(domain, record, ip)) + dig, grep = dig_cmd(domain, record, ip) + rc, output = machine.execute(dig) + print(f"{dig}[{rc}]: {output}") + machine.fail(f"{dig} | {grep}") def update_records(domain: str="localhost", /, *, path: str="update", **kwargs): machine.succeed(curl_cmd(domain, "test:test", path, kwargs)) diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..e798661 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,217 @@ +use std::{ + fs::File, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + path::PathBuf, + time::Duration, +}; + +use axum_client_ip::SecureClientIpSource; +use miette::{Context, IntoDiagnostic}; + +#[derive(Debug, Default, Clone, Copy, serde::Deserialize, serde::Serialize)] +pub enum IpType { + #[default] + Both, + Ipv4Only, + Ipv6Only, +} + +impl IpType { + pub fn valid_for_type(self, ip: IpAddr) -> bool { + match self { + IpType::Both => true, + IpType::Ipv4Only => ip.is_ipv4(), + IpType::Ipv6Only => ip.is_ipv6(), + } + } +} + +impl std::fmt::Display for IpType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + IpType::Both => f.write_str("both"), + IpType::Ipv4Only => f.write_str("ipv4-only"), + IpType::Ipv6Only => f.write_str("ipv6-only"), + } + } +} + +impl std::str::FromStr for IpType { + type Err = miette::Error; + + fn from_str(s: &str) -> std::result::Result { + match s { + "both" => Ok(Self::Both), + "ipv4-only" => Ok(Self::Ipv4Only), + "ipv6-only" => Ok(Self::Ipv6Only), + _ => miette::bail!("expected one of 'ipv4-only', 'ipv6-only' or 'both', got '{s}'"), + } + } +} + +/// Webserver settings +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub struct Server { + /// Ip address and port of the server + #[serde(default = "default_address")] + pub address: SocketAddr, +} + +/// Password settings +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub struct Password { + /// File containing password to match against + /// + /// Should be of the format `username:password` and contain a single password + #[serde(default, skip_serializing_if = "Option::is_none")] + pub password_file: Option, + + /// Salt to get more unique hashed passwords and prevent table based attacks + #[serde(default = "default_salt")] + pub salt: Box, +} + +/// Records settings +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub struct Records { + /// Time To Live (in seconds) to set on the DNS records + #[serde(default = "default_ttl")] + pub ttl: Duration, + + /// List of domain names for which to update the IP when an update is requested + #[serde(default, skip_serializing_if = "Vec::is_empty")] + #[allow(clippy::struct_field_names)] + pub records: Vec>, + + /// If provided, when an IPv6 prefix is provided with an update, this will be used to derive + /// the full IPv6 address of the client + #[serde(default, skip_serializing_if = "Option::is_none")] + pub client_id: Option, + + /// If a client id is provided the ipv6 update will be ignored (only the prefix will be used). + /// This domain will point to the ipv6 address instead of the address derived from the client + /// id (usually this is the router). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub router_domain: Option>, + + /// Set client IP source + /// + /// see: + #[serde(default = "default_ip_source")] + pub ip_source: SecureClientIpSource, + + /// Set which IPs to allow updating (ipv4, ipv6 or both) + #[serde(default = "default_ip_type")] + pub ip_type: IpType, + + /// Keyfile `nsupdate` should use + /// + /// If specified, then `webnsupdate` must have read access to the file + #[serde(default, skip_serializing_if = "Option::is_none")] + pub key_file: Option, +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub struct Config { + /// Server Configuration + #[serde(flatten)] + pub server: Server, + + /// Password Configuration + #[serde(flatten)] + pub password: Password, + + /// Records Configuration + #[serde(flatten)] + pub records: Records, +} + +impl Config { + /// Load the configuration without verifying it + pub fn load(path: &std::path::Path) -> miette::Result { + serde_json::from_reader::( + File::open(path) + .into_diagnostic() + .wrap_err_with(|| format!("failed open {}", path.display()))?, + ) + .into_diagnostic() + .wrap_err_with(|| format!("failed to load configuration from {}", path.display())) + } + + /// Ensure only a verified configuration is returned + pub fn verified(self) -> miette::Result { + self.verify()?; + Ok(self) + } + + /// Verify the configuration + pub fn verify(&self) -> Result<(), Invalid> { + let mut invalid_records: Vec = self + .records + .records + .iter() + .filter_map(|record| crate::records::validate_record_str(record).err()) + .collect(); + + invalid_records.extend( + self.records + .router_domain + .as_ref() + .and_then(|domain| crate::records::validate_record_str(domain).err()), + ); + + let err = Invalid { invalid_records }; + + if err.invalid_records.is_empty() { + Ok(()) + } else { + Err(err) + } + } +} + +#[derive(Debug, miette::Diagnostic, thiserror::Error)] +#[error("the configuration was invalid")] +pub struct Invalid { + #[related] + pub invalid_records: Vec, +} + +// --- Default Values (sadly serde doesn't have a way to specify a constant as a default value) --- + +fn default_ttl() -> Duration { + super::DEFAULT_TTL +} + +fn default_salt() -> Box { + super::DEFAULT_SALT.into() +} + +fn default_address() -> SocketAddr { + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 5353) +} + +fn default_ip_source() -> SecureClientIpSource { + SecureClientIpSource::RightmostXForwardedFor +} + +fn default_ip_type() -> IpType { + IpType::Both +} + +#[test] +fn default_values_config_snapshot() { + let config: Config = serde_json::from_str("{}").unwrap(); + insta::assert_json_snapshot!(config, @r#" + { + "address": "127.0.0.1:5353", + "salt": "UpdateMyDNS", + "ttl": { + "secs": 60, + "nanos": 0 + }, + "ip_source": "RightmostXForwardedFor", + "ip_type": "Both" + } + "#); +} diff --git a/src/main.rs b/src/main.rs index cf37707..9b24e99 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,16 +10,18 @@ use axum::{ routing::get, Router, }; -use axum_client_ip::{SecureClientIp, SecureClientIpSource}; +use axum_client_ip::SecureClientIp; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use clap::{Parser, Subcommand}; use clap_verbosity_flag::Verbosity; +use config::Config; use http::StatusCode; use miette::{bail, ensure, Context, IntoDiagnostic, Result}; use tracing::{debug, error, info}; use tracing_subscriber::EnvFilter; mod auth; +mod config; mod nsupdate; mod password; mod records; @@ -32,120 +34,52 @@ struct Opts { #[command(flatten)] verbosity: Verbosity, - /// Ip address of the server - #[arg(long, default_value = "127.0.0.1")] - address: IpAddr, - - /// Port of the server - #[arg(long, default_value_t = 5353)] - port: u16, - - /// File containing password to match against - /// - /// Should be of the format `username:password` and contain a single password - #[arg(long)] - password_file: Option, - - /// Salt to get more unique hashed passwords and prevent table based attacks - #[arg(long, default_value = DEFAULT_SALT)] - salt: String, - - /// Time To Live (in seconds) to set on the DNS records - #[arg(long, default_value_t = DEFAULT_TTL.as_secs())] - ttl: u64, - /// Data directory - #[arg(long, default_value = ".")] + #[arg(long, env, default_value = ".")] data_dir: PathBuf, - /// File containing the records that should be updated when an update request is made - /// - /// There should be one record per line: - /// - /// ```text - /// example.com. - /// mail.example.com. - /// ``` - #[arg(long)] - records: PathBuf, - - /// Keyfile `nsupdate` should use - /// - /// If specified, then `webnsupdate` must have read access to the file - #[arg(long)] - key_file: Option, - /// Allow not setting a password #[arg(long)] insecure: bool, - /// Set client IP source - /// - /// see: - #[clap(long, default_value = "RightmostXForwardedFor")] - ip_source: SecureClientIpSource, + #[clap(flatten)] + config_or_command: ConfigOrCommand, +} - /// Set which IPs to allow updating - #[clap(long, default_value_t = IpType::Both)] - ip_type: IpType, +#[derive(clap::Args, Debug)] +#[group(multiple = false)] +struct ConfigOrCommand { + /// Path to the configuration file + #[arg(long, short)] + config: Option, #[clap(subcommand)] subcommand: Option, } -#[derive(Debug, Default, Clone, Copy)] -enum IpType { - #[default] - Both, - IPv4Only, - IPv6Only, -} - -impl IpType { - fn valid_for_type(self, ip: IpAddr) -> bool { - match self { - IpType::Both => true, - IpType::IPv4Only => ip.is_ipv4(), - IpType::IPv6Only => ip.is_ipv6(), - } - } -} - -impl std::fmt::Display for IpType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - IpType::Both => f.write_str("both"), - IpType::IPv4Only => f.write_str("ipv4-only"), - IpType::IPv6Only => f.write_str("ipv6-only"), - } - } -} - -impl std::str::FromStr for IpType { - type Err = miette::Error; - - fn from_str(s: &str) -> std::result::Result { - match s { - "both" => Ok(Self::Both), - "ipv4-only" => Ok(Self::IPv4Only), - "ipv6-only" => Ok(Self::IPv6Only), - _ => bail!("expected one of 'ipv4-only', 'ipv6-only' or 'both', got '{s}'"), - } +impl ConfigOrCommand { + pub fn take(&mut self) -> (Option, Option) { + (self.config.take(), self.subcommand.take()) } } #[derive(Debug, Subcommand)] enum Cmd { Mkpasswd(password::Mkpasswd), - /// Verify the records file - Verify, + /// Verify the configuration file + Verify { + /// Path to the configuration file + config: PathBuf, + }, } impl Cmd { pub fn process(self, args: &Opts) -> Result<()> { match self { Cmd::Mkpasswd(mkpasswd) => mkpasswd.process(args), - Cmd::Verify => records::load(&args.records).map(drop), + Cmd::Verify { config } => config::Config::load(&config) // load config + .and_then(Config::verified) // verify config + .map(drop), // ignore config data } } } @@ -168,7 +102,7 @@ struct AppState<'a> { last_ips: std::sync::Arc>, /// The IP type for which to allow updates - ip_type: IpType, + ip_type: config::IpType, } #[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)] @@ -211,33 +145,38 @@ impl SavedIPs { } impl AppState<'static> { - fn from_args(args: &Opts) -> miette::Result { + fn from_args(args: &Opts, config: &config::Config) -> miette::Result { let Opts { verbosity: _, - address: _, - port: _, - password_file: _, data_dir, - key_file, insecure, - subcommand: _, - records, - salt: _, - ttl, - ip_source: _, - ip_type, + config_or_command: _, } = args; - // Set state - let ttl = Duration::from_secs(*ttl); + let config::Records { + ttl, + records, + client_id: _, + router_domain: _, + ip_source: _, + ip_type, + key_file, + } = &config.records; // Use last registered IP address if available let ip_file = Box::leak(data_dir.join("last-ip.json").into_boxed_path()); + // Leak DNS records + let records: &[&str] = &*Vec::leak( + records + .iter() + .map(|record| &*Box::leak(record.clone())) + .collect(), + ); + let state = AppState { - ttl, - // Load DNS records - records: records::load_no_verify(records)?, + ttl: *ttl, + records, // Load keyfile key_file: key_file .as_deref() @@ -340,34 +279,37 @@ fn main() -> Result<()> { debug!("{args:?}"); - // process subcommand - if let Some(cmd) = args.subcommand.take() { - return cmd.process(&args); - } + let config = match args.config_or_command.take() { + // process subcommand + (None, Some(cmd)) => return cmd.process(&args), + (Some(path), None) => { + let config = config::Config::load(&path)?; + if let Err(err) = config.verify() { + error!("failed to verify configuration: {err}"); + } + config + } + (None, None) | (Some(_), Some(_)) => unreachable!( + "bad state, one of config or subcommand should be available (clap should enforce this)" + ), + }; // Initialize state - let state = AppState::from_args(&args)?; + let state = AppState::from_args(&args, &config)?; let Opts { verbosity: _, - address: ip, - port, - password_file, data_dir: _, - key_file: _, insecure, - subcommand: _, - records: _, - salt, - ttl: _, - ip_source, - ip_type, + config_or_command: _, } = args; info!("checking environment"); // Load password hash - let password_hash = password_file + let password_hash = config + .password + .password_file .map(|path| -> miette::Result<_> { let path = path.as_path(); let pass = std::fs::read_to_string(path).into_diagnostic()?; @@ -398,23 +340,26 @@ fn main() -> Result<()> { // Update DNS record with previous IPs (if available) let ips = state.last_ips.lock().await.clone(); - let actions = ips + let mut actions = ips .ips() - .filter(|ip| ip_type.valid_for_type(*ip)) - .flat_map(|ip| nsupdate::Action::from_records(ip, state.ttl, state.records)); + .filter(|ip| config.records.ip_type.valid_for_type(*ip)) + .flat_map(|ip| nsupdate::Action::from_records(ip, state.ttl, state.records)) + .peekable(); - match nsupdate::nsupdate(state.key_file, actions).await { - Ok(status) => { - if !status.success() { - error!("nsupdate failed: code {status}"); - bail!("nsupdate returned with code {status}"); + if actions.peek().is_some() { + match nsupdate::nsupdate(state.key_file, actions).await { + Ok(status) => { + if !status.success() { + error!("nsupdate failed: code {status}"); + bail!("nsupdate returned with code {status}"); + } + } + Err(err) => { + error!("Failed to update records with previous IP: {err}"); + return Err(err) + .into_diagnostic() + .wrap_err("failed to update records with previous IP"); } - } - Err(err) => { - error!("Failed to update records with previous IP: {err}"); - return Err(err) - .into_diagnostic() - .wrap_err("failed to update records with previous IP"); } } @@ -422,19 +367,24 @@ fn main() -> Result<()> { let app = Router::new().route("/update", get(update_records)); // if a password is provided, validate it let app = if let Some(pass) = password_hash { - app.layer(auth::layer(Box::leak(pass), String::leak(salt))) + app.layer(auth::layer( + Box::leak(pass), + Box::leak(config.password.salt), + )) } else { app } - .layer(ip_source.into_extension()) + .layer(config.records.ip_source.into_extension()) .with_state(state); + let config::Server { address } = config.server; + // Start services - info!("starting listener on {ip}:{port}"); - let listener = tokio::net::TcpListener::bind(SocketAddr::new(ip, port)) + info!("starting listener on {address}"); + let listener = tokio::net::TcpListener::bind(address) .await .into_diagnostic()?; - info!("listening on {ip}:{port}"); + info!("listening on {address}"); axum::serve( listener, app.into_make_service_with_connect_info::(), @@ -573,6 +523,15 @@ async fn trigger_update( state: &AppState<'static>, ) -> axum::response::Result<&'static str> { let actions = nsupdate::Action::from_records(ip, state.ttl, state.records); + + if actions.len() == 0 { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + "Nothing to do (e.g. we are ipv4-only but an ipv6 update was requested)", + ) + .into()); + } + match nsupdate::nsupdate(state.key_file, actions).await { Ok(status) if status.success() => { let ips = { diff --git a/src/nsupdate.rs b/src/nsupdate.rs index 74397fa..debd6ff 100644 --- a/src/nsupdate.rs +++ b/src/nsupdate.rs @@ -25,7 +25,7 @@ impl<'a> Action<'a> { to: IpAddr, ttl: Duration, records: &'a [&'a str], - ) -> impl IntoIterator + 'a { + ) -> impl IntoIterator + std::iter::ExactSizeIterator + 'a { records .iter() .map(move |&domain| Action::Reassign { domain, to, ttl }) @@ -91,7 +91,7 @@ fn update_ns_records<'a>( ) -> std::io::Result<()> { writeln!(buf, "server 127.0.0.1")?; for action in actions { - writeln!(buf, "{action}")?; + write!(buf, "{action}")?; } writeln!(buf, "send")?; writeln!(buf, "quit") diff --git a/src/password.rs b/src/password.rs index 8d965ba..d99a93b 100644 --- a/src/password.rs +++ b/src/password.rs @@ -4,7 +4,7 @@ //! records use std::io::Write; use std::os::unix::fs::OpenOptionsExt; -use std::path::Path; +use std::path::PathBuf; use base64::prelude::*; use miette::{Context, IntoDiagnostic, Result}; @@ -20,11 +20,18 @@ pub struct Mkpasswd { /// The password password: String, + + /// An application specific value + #[arg(long, default_value = crate::DEFAULT_SALT)] + salt: String, + + /// The file to write the password to + password_file: Option, } impl Mkpasswd { - pub fn process(self, args: &crate::Opts) -> Result<()> { - mkpasswd(self, args.password_file.as_deref(), &args.salt) + pub fn process(self, _args: &crate::Opts) -> Result<()> { + mkpasswd(self) } } @@ -45,13 +52,16 @@ pub fn hash_identity(username: &str, password: &str, salt: &str) -> Digest { } pub fn mkpasswd( - Mkpasswd { username, password }: Mkpasswd, - password_file: Option<&Path>, - salt: &str, + Mkpasswd { + username, + password, + salt, + password_file, + }: Mkpasswd, ) -> miette::Result<()> { - let hash = hash_identity(&username, &password, salt); + let hash = hash_identity(&username, &password, &salt); let encoded = BASE64_URL_SAFE_NO_PAD.encode(hash.as_ref()); - let Some(path) = password_file else { + let Some(path) = password_file.as_deref() else { println!("{encoded}"); return Ok(()); }; diff --git a/src/records.rs b/src/records.rs index 860f719..9c5158c 100644 --- a/src/records.rs +++ b/src/records.rs @@ -1,52 +1,9 @@ //! Deal with the DNS records -use std::path::Path; +use miette::{ensure, miette, LabeledSpan, Result}; -use miette::{ensure, miette, Context, IntoDiagnostic, LabeledSpan, NamedSource, Result}; - -/// Loads and verifies the records from a file -pub fn load(path: &Path) -> Result<()> { - let records = std::fs::read_to_string(path) - .into_diagnostic() - .wrap_err_with(|| format!("failed to read records from {}", path.display()))?; - - verify(&records, path)?; - - Ok(()) -} - -/// Load records without verifying them -pub fn load_no_verify(path: &Path) -> Result<&'static [&'static str]> { - let records = std::fs::read_to_string(path) - .into_diagnostic() - .wrap_err_with(|| format!("failed to read records from {}", path.display()))?; - - if let Err(err) = verify(&records, path) { - tracing::error!("Failed to verify records: {err}"); - } - - // leak memory: we only do this here and it prevents a bunch of allocations - let records: &str = records.leak(); - let records: Box<[&str]> = records.lines().collect(); - - Ok(Box::leak(records)) -} - -/// Verifies that a list of records is valid -pub fn verify(data: &str, path: &Path) -> Result<()> { - let mut offset = 0usize; - for line in data.lines() { - validate_line(offset, line).map_err(|err| { - err.with_source_code(NamedSource::new( - path.display().to_string(), - data.to_string(), - )) - })?; - - offset += line.len() + 1; - } - - Ok(()) +pub fn validate_record_str(record: &str) -> Result<()> { + validate_line(0, record).map_err(|err| err.with_source_code(String::from(record))) } fn validate_line(offset: usize, line: &str) -> Result<()> { @@ -156,7 +113,7 @@ fn validate_octet(offset: usize, octet: u8) -> Result<()> { #[cfg(test)] mod test { - use crate::records::verify; + use crate::records::validate_record_str; macro_rules! assert_miette_snapshot { ($diag:expr) => {{ @@ -180,104 +137,51 @@ mod test { #[test] fn valid_records() -> miette::Result<()> { - verify( - "\ - example.com.\n\ - example.org.\n\ - example.net.\n\ - subdomain.example.com.\n\ - ", - std::path::Path::new("test_records_valid"), - ) + for record in [ + "example.com.", + "example.org.", + "example.net.", + "subdomain.example.com.", + ] { + validate_record_str(record)?; + } + Ok(()) } #[test] fn hostname_too_long() { - let err = verify( - "\ - example.com.\n\ - example.org.\n\ - example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.net.\n\ - subdomain.example.com.\n\ - ", - std::path::Path::new("test_records_invalid"), - ) - .unwrap_err(); + let err = validate_record_str("example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.net.").unwrap_err(); assert_miette_snapshot!(err); } #[test] fn not_fqd() { - let err = verify( - "\ - example.com.\n\ - example.org.\n\ - example.net\n\ - subdomain.example.com.\n\ - ", - std::path::Path::new("test_records_invalid"), - ) - .unwrap_err(); + let err = validate_record_str("example.net").unwrap_err(); assert_miette_snapshot!(err); } #[test] fn empty_label() { - let err = verify( - "\ - example.com.\n\ - name..example.org.\n\ - example.net.\n\ - subdomain.example.com.\n\ - ", - std::path::Path::new("test_records_invalid"), - ) - .unwrap_err(); + let err = validate_record_str("name..example.org.").unwrap_err(); assert_miette_snapshot!(err); } #[test] fn label_too_long() { - let err = verify( - "\ - example.com.\n\ - name.an-entremely-long-label-that-should-not-exist-because-it-goes-against-the-spec.example.org.\n\ - example.net.\n\ - subdomain.example.com.\n\ - ", - std::path::Path::new("test_records_invalid"), - ) - .unwrap_err(); + let err = validate_record_str("name.an-entremely-long-label-that-should-not-exist-because-it-goes-against-the-spec.example.org.").unwrap_err(); assert_miette_snapshot!(err); } #[test] fn invalid_ascii() { - let err = verify( - "\ - example.com.\n\ - name.this-is-not-ascii-ß.example.org.\n\ - example.net.\n\ - subdomain.example.com.\n\ - ", - std::path::Path::new("test_records_invalid"), - ) - .unwrap_err(); + let err = validate_record_str("name.this-is-not-ascii-ß.example.org.").unwrap_err(); assert_miette_snapshot!(err); } #[test] fn invalid_octet() { - let err = verify( - "\ - example.com.\n\ - name.this-character:-is-not-allowed.example.org.\n\ - example.net.\n\ - subdomain.example.com.\n\ - ", - std::path::Path::new("test_records_invalid"), - ) - .unwrap_err(); + let err = + validate_record_str("name.this-character:-is-not-allowed.example.org.").unwrap_err(); assert_miette_snapshot!(err); } } diff --git a/src/snapshots/webnsupdate__records__test__empty_label.snap b/src/snapshots/webnsupdate__records__test__empty_label.snap index e4d227b..d6fb7fa 100644 --- a/src/snapshots/webnsupdate__records__test__empty_label.snap +++ b/src/snapshots/webnsupdate__records__test__empty_label.snap @@ -6,11 +6,9 @@ expression: out ]8;;https://en.wikipedia.org/wiki/Fully_qualified_domain_name\(link)]8;;\ × empty label - ╭─[test_records_invalid:2:6] - 1 │ example.com. - 2 │ name..example.org. + ╭──── + 1 │ name..example.org. · ▲ · ╰── label - 3 │ example.net. ╰──── help: each label should have at least one character diff --git a/src/snapshots/webnsupdate__records__test__hostname_too_long.snap b/src/snapshots/webnsupdate__records__test__hostname_too_long.snap index 051d8ce..5c48b16 100644 --- a/src/snapshots/webnsupdate__records__test__hostname_too_long.snap +++ b/src/snapshots/webnsupdate__records__test__hostname_too_long.snap @@ -6,11 +6,9 @@ expression: out ]8;;https://en.wikipedia.org/wiki/Fully_qualified_domain_name\(link)]8;;\ × hostname too long (260 octets) - ╭─[test_records_invalid:3:1] - 2 │ example.org. - 3 │ example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.net. + ╭──── + 1 │ example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.example.net. · ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┬───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── · ╰── this line - 4 │ subdomain.example.com. ╰──── help: fully qualified domain names can be at most 255 characters long diff --git a/src/snapshots/webnsupdate__records__test__invalid_ascii.snap b/src/snapshots/webnsupdate__records__test__invalid_ascii.snap index eb8102b..6ef64e3 100644 --- a/src/snapshots/webnsupdate__records__test__invalid_ascii.snap +++ b/src/snapshots/webnsupdate__records__test__invalid_ascii.snap @@ -6,11 +6,9 @@ expression: out ]8;;https://en.wikipedia.org/wiki/Hostname#Syntax\(link)]8;;\ × invalid octet: '\xc3' - ╭─[test_records_invalid:2:24] - 1 │ example.com. - 2 │ name.this-is-not-ascii-ß.example.org. + ╭──── + 1 │ name.this-is-not-ascii-ß.example.org. · ┬ · ╰── octet - 3 │ example.net. ╰──── help: we only accept ascii characters diff --git a/src/snapshots/webnsupdate__records__test__invalid_octet.snap b/src/snapshots/webnsupdate__records__test__invalid_octet.snap index 2da284e..ed8a44c 100644 --- a/src/snapshots/webnsupdate__records__test__invalid_octet.snap +++ b/src/snapshots/webnsupdate__records__test__invalid_octet.snap @@ -6,11 +6,9 @@ expression: out ]8;;https://en.wikipedia.org/wiki/Hostname#Syntax\(link)]8;;\ × invalid octet: ':' - ╭─[test_records_invalid:2:20] - 1 │ example.com. - 2 │ name.this-character:-is-not-allowed.example.org. + ╭──── + 1 │ name.this-character:-is-not-allowed.example.org. · ┬ · ╰── octet - 3 │ example.net. ╰──── help: hostnames are only allowed to contain characters in [a-zA-Z0-9_-] diff --git a/src/snapshots/webnsupdate__records__test__label_too_long.snap b/src/snapshots/webnsupdate__records__test__label_too_long.snap index b529a1f..f1561ae 100644 --- a/src/snapshots/webnsupdate__records__test__label_too_long.snap +++ b/src/snapshots/webnsupdate__records__test__label_too_long.snap @@ -6,11 +6,9 @@ expression: out ]8;;https://en.wikipedia.org/wiki/Fully_qualified_domain_name\(link)]8;;\ × label too long (78 octets) - ╭─[test_records_invalid:2:6] - 1 │ example.com. - 2 │ name.an-entremely-long-label-that-should-not-exist-because-it-goes-against-the-spec.example.org. + ╭──── + 1 │ name.an-entremely-long-label-that-should-not-exist-because-it-goes-against-the-spec.example.org. · ───────────────────────────────────────┬────────────────────────────────────── · ╰── label - 3 │ example.net. ╰──── help: labels should be at most 63 octets diff --git a/src/snapshots/webnsupdate__records__test__not_fqd.snap b/src/snapshots/webnsupdate__records__test__not_fqd.snap index bc8270d..ccf6746 100644 --- a/src/snapshots/webnsupdate__records__test__not_fqd.snap +++ b/src/snapshots/webnsupdate__records__test__not_fqd.snap @@ -6,11 +6,9 @@ expression: out ]8;;https://en.wikipedia.org/wiki/Fully_qualified_domain_name\(link)]8;;\ × not a fully qualified domain name - ╭─[test_records_invalid:3:11] - 2 │ example.org. - 3 │ example.net + ╭──── + 1 │ example.net · ┬ · ╰── last character - 4 │ subdomain.example.com. ╰──── help: hostname should be a fully qualified domain name (end with a '.')