From 022c48a041cf1ea4dbd0dc0f4c41a5611bd2b940 Mon Sep 17 00:00:00 2001 From: psun256 Date: Wed, 10 Dec 2025 18:52:40 -0500 Subject: [PATCH] applied hot reload to health check logic --- Cargo.lock | 67 +++++++++-- config.yaml | 5 + src/backend/health.rs | 117 +++++++++++++++++++ src/backend/mod.rs | 2 +- src/balancer/adaptive_weight.rs | 129 ++++++++++++++++++++- src/balancer/ip_hashing.rs | 7 +- src/balancer/round_robin.rs | 2 +- src/config/loader.rs | 27 +++-- src/config/mod.rs | 16 +++ src/main.rs | 193 +++++++++++++------------------- 10 files changed, 421 insertions(+), 144 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5886925..4811635 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -400,6 +400,26 @@ dependencies = [ "cc", ] +[[package]] +name = "inotify" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f37dccff2791ab604f9babef0ba14fbe0be30bd368dc541e2b08d07c8aa908f3" +dependencies = [ + "bitflags 2.10.0", + "inotify-sys", + "libc", +] + +[[package]] +name = "inotify-sys" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb" +dependencies = [ + "libc", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.2" @@ -446,6 +466,26 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "kqueue" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eac30106d7dce88daf4a3fcb4879ea939476d5074a9b7ddd0fb97fa4bed5596a" +dependencies = [ + "kqueue-sys", + "libc", +] + +[[package]] +name = "kqueue-sys" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed9625ffda8729b85e45cf04090035ac368927b8cebc34898e7c120f52e4838b" +dependencies = [ + "bitflags 1.3.2", + "libc", +] + [[package]] name = "l4lb" version = "0.1.0" @@ -453,6 +493,8 @@ dependencies = [ "anywho", "arc-swap", "cidr", + "clap", + "notify", "rand 0.10.0-rc.5", "rperf3-rs", "serde", @@ -463,9 +505,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.177" +version = "0.2.178" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" +checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" [[package]] name = "lock_api" @@ -496,9 +538,9 @@ checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" [[package]] name = "mio" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69d83b0086dc8ecf3ce9ae2874b2d1290252e2a30720bea58a5c6639b0092873" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" dependencies = [ "libc", "log", @@ -847,9 +889,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "signal-hook-registry" -version = "1.4.6" +version = "1.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2a4719bff48cee6b39d12c020eeb490953ad2443b7055bd0b21fca26bd8c28b" +checksum = "7664a098b8e616bdfcc2dc0e9ac44eb231eedf41db4e9fe95d8d32ec728dedad" dependencies = [ "libc", ] @@ -884,9 +926,9 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" [[package]] name = "syn" -version = "2.0.110" +version = "2.0.111" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a99801b5bd34ede4cf3fc688c5919368fea4e4814a4664359503e6015b280aea" +checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" dependencies = [ "proc-macro2", "quote", @@ -1045,6 +1087,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" diff --git a/config.yaml b/config.yaml index e2e1ce1..47fe30d 100644 --- a/config.yaml +++ b/config.yaml @@ -1,3 +1,7 @@ +healthcheck_addr: "127.0.0.1:9000" + +iperf_addr: "0.0.0.0:5200" + backends: - id: "srv-1" ip: "127.0.0.1" @@ -24,6 +28,7 @@ rules: - clients: - "0.0.0.0/0:6767" + - "0.0.0.0/0:6969" targets: # no issues with duplicate servers or clusters - "priority-api" - "priority-api" diff --git a/src/backend/health.rs b/src/backend/health.rs index 0591438..3801020 100644 --- a/src/backend/health.rs +++ b/src/backend/health.rs @@ -1,3 +1,11 @@ +use std::collections::HashMap; +use std::net::{IpAddr, SocketAddr}; +use std::sync::{Arc, RwLock}; +use serde_json::Value; +use rperf3::{Config, Server}; +use tokio::net::{TcpListener, TcpSocket, TcpStream}; +use tokio::io::AsyncBufReadExt; + // Physical server health statistics, used for certain load balancing algorithms #[derive(Debug, Default)] pub struct ServerMetrics { @@ -15,3 +23,112 @@ impl ServerMetrics { self.io = io; } } + +pub async fn start_healthcheck_listener( + addr: &str, + healths: HashMap>>, +) -> std::io::Result<()> { + let addrs = tokio::net::lookup_host(addr).await?; + let mut listener = None; + + for a in addrs { + let socket = match a { + SocketAddr::V4(_) => TcpSocket::new_v4()?, + SocketAddr::V6(_) => TcpSocket::new_v6()?, + }; + + socket.set_reuseaddr(true)?; + + if socket.bind(a).is_ok() { + listener = Some(socket.listen(1024)?); + break; + } + } + + let listener = listener.ok_or_else(|| { + eprintln!("health listener could not bind to port"); + std::io::Error::new(std::io::ErrorKind::Other, "health listener failed") + })?; + + println!("healthcheck server listening on {}", addr); + loop { + let (stream, remote_addr) = match listener.accept().await { + Ok(v) => v, + Err(e) => { + continue; + } + }; + + if let Err(e) = handle_metrics_stream(stream, &healths).await { + eprintln!("connection handler error: {}", e); + } + } +} + +pub async fn start_iperf_server(addr: &str) -> Result<(), Box> { + let sock = addr.parse::()?; + let mut config = Config::server(sock.port()); + config.bind_addr = Some(sock.ip()); + let server = Server::new(config); + println!("iperf server listening on {}", addr); + server.run().await?; + Ok(()) +} + +async fn handle_metrics_stream( + stream: TcpStream, + healths: &HashMap>>, +) -> std::io::Result<()> { + let server_ip = stream.peer_addr()?.ip(); + let mut reader = tokio::io::BufReader::new(stream); + let mut line = String::new(); + + loop { + line.clear(); + + match reader.read_line(&mut line).await { + Ok(0) => break, + Ok(_) => { + if let Err(e) = process_metrics(server_ip, &line, healths) { + eprintln!("skipping invalid packet: {}", e); + } + } + Err(e) => { + eprintln!("connection error: {}", e); + break; + } + } + } + Ok(()) +} + +fn process_metrics( + server_ip: IpAddr, + json_str: &str, + healths: &HashMap>>, +) -> Result<(), String> { + let parsed: Value = + serde_json::from_str(json_str).map_err(|e| format!("parse error: {}", e))?; + + let metrics_lock = healths + .get(&server_ip) + .ok_or_else(|| format!("unknown server: {}", server_ip))?; + + let get_f64 = |key: &str| -> Result { + parsed + .get(key) + .and_then(|v| v.as_f64()) + .ok_or_else(|| format!("invalid '{}'", key)) + }; + + if let Ok(mut guard) = metrics_lock.write() { + guard.update( + get_f64("cpu")?, + get_f64("mem")?, + get_f64("net")?, + get_f64("io")?, + ); + } + + Ok(()) +} \ No newline at end of file diff --git a/src/backend/mod.rs b/src/backend/mod.rs index d84aa76..445a292 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -1,11 +1,11 @@ pub mod health; +use crate::backend::health::ServerMetrics; use core::fmt; use std::net::SocketAddr; use std::sync::Arc; use std::sync::RwLock; use std::sync::atomic::{AtomicUsize, Ordering}; -use crate::backend::health::ServerMetrics; // A possible endpoint for a proxied connection. // Note that multiple may live on the same server, hence the Arc> diff --git a/src/balancer/adaptive_weight.rs b/src/balancer/adaptive_weight.rs index c23f904..dcef3dc 100644 --- a/src/balancer/adaptive_weight.rs +++ b/src/balancer/adaptive_weight.rs @@ -4,7 +4,8 @@ use crate::balancer::{Balancer, ConnectionInfo}; use rand::prelude::*; use rand::rngs::SmallRng; use std::fmt::Debug; -use std::sync::{Arc}; +use std::fs::Metadata; +use std::sync::{Arc, RwLock}; #[derive(Debug)] struct AdaptiveNode { @@ -27,7 +28,7 @@ impl AdaptiveWeightBalancer { .iter() .map(|b| AdaptiveNode { backend: b.clone(), - weight: 0f64, + weight: 1f64, }) .collect(); @@ -85,7 +86,6 @@ impl Balancer for AdaptiveWeightBalancer { }; let ratio = risk / node.weight; - if ratio <= threshold { return Some(node.backend.clone()); } @@ -142,3 +142,126 @@ impl Balancer for AdaptiveWeightBalancer { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::backend::Backend; + use std::net::SocketAddr; + + fn backend_factory(id: &str, ip: &str, port: u16) -> Arc { + Arc::new(Backend::new( + id.to_string(), + SocketAddr::new(ip.parse().unwrap(), port), + Arc::new(RwLock::new(ServerMetrics::default())), + )) + } + + fn unused_ctx() -> ConnectionInfo { + ConnectionInfo { + client_ip: ("0.0.0.0".parse().unwrap()), + } + } + + #[test] + fn basic_weight_update_and_choose() { + let backends = BackendPool::new(vec![ + backend_factory("server-0", "127.0.0.1", 3000), + backend_factory("server-1", "127.0.0.1", 3001), + ]); + let mut b = AdaptiveWeightBalancer::new(backends.clone(), [0.5, 0.2, 0.2, 0.1], 0.5); + // initially equal weights + // update one backend to be heavily loaded + { + let mut sm0_guard = backends.backends.get(0).unwrap().metrics.write().unwrap(); + sm0_guard.update(90.0, 80.0, 10.0, 5.0); + } + { + let mut sm1_guard = backends.backends.get(1).unwrap().metrics.write().unwrap(); + sm1_guard.update(10.0, 5.0, 1.0, 1.0); + } + + // Choose backend: should pick the less loaded host server1 + let chosen = b + .choose_backend(unused_ctx()) + .expect("should choose a backend"); + + let sm0: &ServerMetrics = &backends.backends.get(0).unwrap().metrics.read().unwrap(); + let sm1: &ServerMetrics = &backends.backends.get(1).unwrap().metrics.read().unwrap(); + println!("{:?}, {:?}", sm0, sm1); + assert_eq!(chosen.id, "server-1"); + } + + #[test] + fn choose_none_when_empty() { + let mut b = + AdaptiveWeightBalancer::new(BackendPool::new(vec![]), [0.5, 0.2, 0.2, 0.1], 0.5); + assert!(b.choose_backend(unused_ctx()).is_none()); + } + + #[test] + fn ratio_triggers_immediate_selection() { + // Arrange two servers where server 1 has composite load 0 and server 2 has composite load 100. + // With alpha = 1.0 and two servers, threshold = 1.0 * (r_sum / w_sum) = 1.0 * (100 / 2) = 50. + // Server 0 ratio = 0 / 1 = 0 <= 50 so it should be chosen immediately. + let backends = BackendPool::new(vec![ + backend_factory("server-0", "127.0.0.1", 3000), + backend_factory("server-1", "127.0.0.1", 3001), + ]); + let mut b = AdaptiveWeightBalancer::new(backends.clone(), [0.25, 0.25, 0.25, 0.25], 1.0); + + { + let mut sm0_guard = backends.backends.get(0).unwrap().metrics.write().unwrap(); + sm0_guard.update(0.0, 0.0, 0.0, 0.0); + } + { + let mut sm1_guard = backends.backends.get(1).unwrap().metrics.write().unwrap(); + sm1_guard.update(100.0, 100.0, 100.0, 100.0); + } + + let chosen = b + .choose_backend(unused_ctx()) + .expect("should choose a backend"); + assert_eq!(chosen.id, "server-0"); + } + + #[test] + fn choose_min_current_load_when_no_ratio() { + // Arrange three servers with identical composite loads so no server satisfies Ri/Wi <= threshold + // (set alpha < 1 so threshold < ratio). The implementation then falls back to picking the + // server with minimum current_load + let backends = BackendPool::new(vec![ + backend_factory("server-0", "127.0.0.1", 3000), + backend_factory("server-1", "127.0.0.1", 3001), + backend_factory("server-2", "127.0.0.1", 3002), + ]); + + // set current_loads (field expected to be public) + + { + let mut sm0_guard = backends.backends.get(0).unwrap().metrics.write().unwrap(); + sm0_guard.update(10.0, 10.0, 10.0, 10.0); + } + { + let mut sm1_guard = backends.backends.get(1).unwrap().metrics.write().unwrap(); + sm1_guard.update(5.0, 5.0, 5.0, 5.0); + } + { + let mut sm2_guard = backends.backends.get(2).unwrap().metrics.write().unwrap(); + sm2_guard.update(20.0, 20.0, 20.0, 20.0); + } + + // Use coeffs that only consider CPU so composite load is easy to reason about. + let mut bal = AdaptiveWeightBalancer::new(backends.clone(), [1.0, 0.0, 0.0, 0.0], 0.5); + + // set identical composite loads > 0 for all so ratio = x and threshold = alpha * x < x + // you will have threshold = 25 for all 3 backend servers and ratio = 50 + // so that forces to choose the smallest current load backend + + let chosen = bal + .choose_backend(unused_ctx()) + .expect("should choose a backend"); + // expect server with smallest current_load server-1 + assert_eq!(chosen.id, "server-1"); + } +} \ No newline at end of file diff --git a/src/balancer/ip_hashing.rs b/src/balancer/ip_hashing.rs index 2cc5318..f48a776 100644 --- a/src/balancer/ip_hashing.rs +++ b/src/balancer/ip_hashing.rs @@ -1,7 +1,7 @@ use crate::backend::{Backend, BackendPool}; use crate::balancer::{Balancer, ConnectionInfo}; use std::hash::{DefaultHasher, Hash, Hasher}; -use std::sync::{Arc}; +use std::sync::Arc; #[derive(Debug)] pub struct SourceIPHash { @@ -30,8 +30,9 @@ impl Balancer for SourceIPHash { #[cfg(test)] mod tests { use super::*; + use crate::backend::health::ServerMetrics; use std::net::IpAddr; - use crate::backend::ServerMetrics; + use std::sync::RwLock; fn create_dummy_backends(count: usize) -> BackendPool { let mut backends = Vec::new(); @@ -105,4 +106,4 @@ mod tests { assert!(distribution[1] > 0, "Backend 1 received no traffic"); assert!(distribution[2] > 0, "Backend 2 received no traffic"); } -} \ No newline at end of file +} diff --git a/src/balancer/round_robin.rs b/src/balancer/round_robin.rs index e0d60d1..83680ea 100644 --- a/src/balancer/round_robin.rs +++ b/src/balancer/round_robin.rs @@ -1,7 +1,7 @@ use crate::backend::{Backend, BackendPool}; use crate::balancer::{Balancer, ConnectionInfo}; use std::fmt::Debug; -use std::sync::{Arc}; +use std::sync::Arc; // only the main thread for receiving connections should be // doing the load balancing. alternatively, each thread diff --git a/src/config/loader.rs b/src/config/loader.rs index 1a6f458..f597d00 100644 --- a/src/config/loader.rs +++ b/src/config/loader.rs @@ -3,8 +3,8 @@ use std::collections::HashMap; use std::net::{IpAddr, SocketAddr}; use std::sync::{Arc, RwLock}; -use crate::backend::*; use crate::backend::health::*; +use crate::backend::*; use crate::balancer::Balancer; use crate::balancer::adaptive_weight::AdaptiveWeightBalancer; use crate::balancer::round_robin::RoundRobinBalancer; @@ -19,23 +19,26 @@ pub type PortListeners = HashMap; fn parse_client(s: &str) -> Result<(IpCidr, u16), String> { // just splits "0.0.0.0/0:80" into ("0.0.0.0/0", 80) - let (ip_part, port_part) = s.rsplit_once(':') + let (ip_part, port_part) = s + .rsplit_once(':') .ok_or_else(|| format!("badly formatted client: {}", s))?; - let port = port_part.parse() - .map_err(|_| format!("bad port: {}", s))?; - let cidr = ip_part.parse() - .map_err(|_| format!("bad ip/mask: {}", s))?; + let port = port_part.parse().map_err(|_| format!("bad port: {}", s))?; + let cidr = ip_part.parse().map_err(|_| format!("bad ip/mask: {}", s))?; Ok((cidr, port)) } -pub fn build_lb(config: AppConfig) -> Result<(PortListeners, HashMap>>), String> { +pub fn build_lb( + config: &AppConfig, +) -> Result<(PortListeners, HashMap>>), String> { let mut healths: HashMap>> = HashMap::new(); let mut backends: HashMap> = HashMap::new(); - for backend_cfg in config.backends { - let ip: IpAddr = backend_cfg.ip.parse() + for backend_cfg in &config.backends { + let ip: IpAddr = backend_cfg + .ip + .parse() .map_err(|_| format!("bad ip: {}", backend_cfg.ip))?; let addr = SocketAddr::new(ip, backend_cfg.port); @@ -46,12 +49,12 @@ pub fn build_lb(config: AppConfig) -> Result<(PortListeners, HashMap Result<(PortListeners, HashMap> = HashMap::new(); - for client_def in rule.clients { + for client_def in &rule.clients { let (cidr, port) = parse_client(&client_def)?; port_groups.entry(port).or_default().push(cidr); } diff --git a/src/config/mod.rs b/src/config/mod.rs index e65f437..4b63927 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -16,11 +16,27 @@ pub mod loader; use serde::Deserialize; use std::collections::HashMap; +fn default_healthcheck_addr() -> String { + "0.0.0.0:8080".to_string() +} + +fn default_iperf_addr() -> String { + "0.0.0.0:5201".to_string() +} + #[derive(Debug, Deserialize)] pub struct AppConfig { + #[serde(default = "default_healthcheck_addr")] + pub healthcheck_addr: String, + + #[serde(default = "default_iperf_addr")] + pub iperf_addr: String, + pub backends: Vec, + #[serde(default)] pub clusters: HashMap>, + pub rules: Vec, } diff --git a/src/main.rs b/src/main.rs index 7c90f86..9f6bc8b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,115 +3,35 @@ mod balancer; mod config; mod proxy; -use std::collections::HashMap; -use crate::balancer::{ConnectionInfo}; +use crate::backend::health::{start_healthcheck_listener, start_iperf_server, ServerMetrics}; +use crate::balancer::ConnectionInfo; +use crate::config::loader::{build_lb, RoutingTable}; use crate::proxy::tcp::proxy_tcp_connection; -use std::fs::File; -use std::path::PathBuf; -use std::net::IpAddr; -use std::sync::atomic::{AtomicU64, Ordering}; -use tokio::net::{TcpListener, TcpStream}; -use tokio::io::{AsyncBufReadExt, AsyncReadExt}; -use serde_json::Value; +use anywho::Error; use std::collections::HashMap; -use std::net::{IpAddr}; -use std::sync::{Arc, RwLock}; -use crate::backend::health::ServerMetrics; -use rperf3::{Server, Config}; -use std::io::Read; -use std::io::{BufRead, BufReader}; - -static NEXT_CONN_ID: AtomicU64 = AtomicU64::new(1); - -async fn start_iperf_server() -> Result<(), Box> { - let config = Config::server(5001); - let server = Server::new(config); - server.run().await?; - Ok(()) -} - -async fn handle_metrics_stream(stream: TcpStream, healths: &HashMap>>) -> std::io::Result<()> { - let server_ip = stream.peer_addr()?.ip(); - let mut reader = tokio::io::BufReader::new(stream); - let mut line = String::new(); - - loop { - line.clear(); - - match reader.read_line(&mut line).await { - Ok(0) => break, - Ok(_) => { - if let Err(e) = process_metrics(server_ip, &line, healths) { - eprintln!("skipping invalid packet: {}", e); - } - } - Err(e) => { - eprintln!("connection error: {}", e); - break; - } - } - } - Ok(()) -} - -fn process_metrics(server_ip: IpAddr, json_str: &str, healths: &HashMap>>) -> Result<(), String> { - let parsed: Value = serde_json::from_str(json_str) - .map_err(|e| format!("parse error: {}", e))?; - - let metrics_lock = healths.get(&server_ip) - .ok_or_else(|| format!("unknown server: {}", server_ip))?; - - let get_f64 = |key: &str| -> Result { - parsed.get(key) - .and_then(|v| v.as_f64()) - .ok_or_else(|| format!("invalid '{}'", key)) - }; - - if let Ok(mut guard) = metrics_lock.write() { - guard.update( - get_f64("cpu")?, - get_f64("mem")?, - get_f64("net")?, - get_f64("io")?, - ); - } - - Ok(()) -} - -async fn start_healthcheck_listener(addr: &str, healths: HashMap>>) -> std::io::Result<()> { - let listener = TcpListener::bind(addr).await?; - println!("TCP server listening on {}", addr); - loop { - let (stream, remote_addr) = match listener.accept().await { - Ok(v) => v, - Err(e) => { - continue; - } - }; - - if let Err(e) = handle_metrics_stream(stream, &healths).await { - eprintln!("connection handler error: {}", e); - } - } - - Ok(()) +use std::fs::File; +use std::hash::Hash; +use std::net::{IpAddr, SocketAddr}; +use std::path::PathBuf; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, Mutex, RwLock}; use std::time::Duration; -use anywho::Error; +use tokio::io::AsyncBufReadExt; use tokio::net::TcpListener; use tokio::sync::mpsc; -use crate::backend::ServerMetrics; -use crate::config::loader::{build_lb, RoutingTable}; - -use notify::{Watcher, RecursiveMode, Event}; use clap::Parser; +use notify::{Event, RecursiveMode, Watcher}; +use std::cmp; static NEXT_CONN_ID: AtomicU64 = AtomicU64::new(1); struct ProgramState { tx_rt_map: HashMap>, healths: HashMap>>, + health_listener: Option>, + iperf_server: Option>, + health_listener_addr: Option, + iperf_server_addr: Option, } #[derive(Parser, Debug)] @@ -135,6 +55,10 @@ async fn main() -> Result<(), Box> { let state = Arc::new(Mutex::new(ProgramState { tx_rt_map: HashMap::new(), healths: HashMap::new(), + health_listener: None, + iperf_server: None, + health_listener_addr: None, + iperf_server_addr: None, })); if let Err(e) = load_config(&args.config, state.clone()).await { @@ -143,48 +67,52 @@ async fn main() -> Result<(), Box> { let config_path = args.config.clone(); let state_clone = state.clone(); - - handles.push( - tokio::spawn(async { - start_healthcheck_listener("0.0.0.0:8080", healths).await.unwrap(); - }) - ); - - handles.push( - tokio::spawn(async { - start_iperf_server().await; - }) - ); tokio::spawn(async move { let (tx, mut rx) = mpsc::channel(1); - + let mut watcher = notify::recommended_watcher(move |res: Result| { if let Ok(event) = res { if event.kind.is_modify() { let _ = tx.blocking_send(()); } } - }).unwrap(); + }) + .unwrap(); - watcher.watch(&config_path, RecursiveMode::NonRecursive).unwrap(); + watcher + .watch(&config_path, RecursiveMode::NonRecursive) + .unwrap(); println!("watching for changes to {:?}", config_path); while rx.recv().await.is_some() { + // for some reason, saving on certain text editors fires several events, + // and this causes us to reload a lot. try to flush some events, add a tiny delay + // to mitigate this + + while rx.try_recv().is_ok() {} + tokio::time::sleep(Duration::from_millis(50)).await; + while rx.try_recv().is_ok() {} + if let Err(e) = load_config(&config_path, state_clone.clone()).await { eprintln!("loading config failed: {}", e); } } }); - loop { tokio::time::sleep(Duration::from_hours(1)).await; } + loop { + tokio::time::sleep(Duration::from_hours(1)).await; + } } async fn load_config(path: &PathBuf, state: Arc>) -> Result<(), Error> { let f = File::open(path)?; let app_config: config::AppConfig = match serde_saphyr::from_reader(f) { Ok(app_config) => app_config, - Err(e) => { eprintln!("error parsing config {}", e); return Ok(()); } + Err(e) => { + eprintln!("error parsing config {}", e); + return Ok(()); + } }; println!( @@ -193,7 +121,7 @@ async fn load_config(path: &PathBuf, state: Arc>) -> Result< app_config.rules.len() ); - let (mut listeners, health_monitors) = match build_lb(app_config) { + let (mut listeners, health_monitors) = match build_lb(&app_config) { Ok(v) => v, Err(e) => { eprintln!("config has logical errors: {}", e); @@ -202,7 +130,8 @@ async fn load_config(path: &PathBuf, state: Arc>) -> Result< }; let mut prog_state = state.lock().unwrap(); - let ports_to_remove: Vec = prog_state.tx_rt_map + let ports_to_remove: Vec = prog_state + .tx_rt_map .keys() .cloned() .filter(|port| !listeners.contains_key(port)) @@ -212,6 +141,38 @@ async fn load_config(path: &PathBuf, state: Arc>) -> Result< prog_state.tx_rt_map.remove(&port); } + if let Some(handle) = prog_state.health_listener.take() { + handle.abort(); + } + let health_map: HashMap>> = health_monitors.clone(); + let health_addr = app_config.healthcheck_addr.clone(); + let health_addr_c = health_addr.clone(); + let health_handle = tokio::spawn(async move { + if let Err(e) = start_healthcheck_listener(&health_addr, health_map).await { + eprintln!("health check listener failed: {}", e); + } + }); + prog_state.health_listener = Some(health_handle); + prog_state.health_listener_addr = Some(health_addr_c); + + // maybe restart iperf server + let iperf_addr = app_config.iperf_addr.clone(); + if prog_state.iperf_server_addr.as_ref() != Some(&iperf_addr) { + if let Some(handle) = prog_state.iperf_server.take() { + handle.abort(); + } + + let iperf_addr_c = iperf_addr.clone(); + let iperf_handle = tokio::spawn(async move { + if let Err(e) = start_iperf_server(iperf_addr.as_str()).await { + eprintln!("iperf server failed: {}", e); + } + }); + + prog_state.iperf_server = Some(iperf_handle); + prog_state.iperf_server_addr = Some(iperf_addr_c); + } + prog_state.healths = health_monitors; for (port, routing_table) in listeners.drain() { if let Some(x) = prog_state.tx_rt_map.get_mut(&port) { @@ -232,7 +193,7 @@ async fn load_config(path: &PathBuf, state: Arc>) -> Result< async fn run_listener( port: u16, mut rx_rt: mpsc::UnboundedReceiver, - mut current_table: RoutingTable + mut current_table: RoutingTable, ) { let addr = format!("0.0.0.0:{}", port); println!("Starting tcp listener on {}", addr);