12 Commits

Author SHA1 Message Date
nnhphong
6d2b8115f3 add some more tests for lb 2025-12-09 22:01:16 -05:00
nnhphong
08cb522f93 the algorithm is working, but will need more test 2025-12-07 23:04:29 -05:00
nnhphong
742827b16f prune some comment 2025-12-07 21:59:43 -05:00
nnhphong
e19efee895 part of the algorithm, waiting for paul s and jeremy to complete refactoring 2025-12-07 21:56:27 -05:00
nnhphong
393c35bdf8 code for docker infra image 2025-12-07 14:09:38 -05:00
Ning Qi (Paul) Sun
cd23bfdf5a Merge pull request #1 from psun256/merge
Merge & Refactor
2025-12-06 16:09:51 -05:00
4cdf2db0c9 feat: improved logging 2025-12-06 02:16:40 -05:00
606880f928 feat: merged repos 2025-12-06 01:31:33 -05:00
19cd5b7f2a feat: modularized proxy 2025-12-06 00:21:53 -05:00
Ning Qi (Paul) Sun
25c3eb9511 gh action
gh action
2025-12-03 22:07:40 -05:00
psun256
e27bd2aaf0 layer 4 load balancing (round robin, hardcoded backends) 2025-11-29 21:46:26 -05:00
Phong Nguyen
1235d3611d Update README with load balancing details
Added a note about load balancing algorithms from a referenced paper.
2025-12-03 12:47:46 -05:00
29 changed files with 371 additions and 2265 deletions

View File

Binary file not shown.

940
Cargo.lock generated
View File

File diff suppressed because it is too large Load Diff

View File

@@ -6,12 +6,4 @@ edition = "2024"
[dependencies] [dependencies]
anywho = "0.1.2" anywho = "0.1.2"
tokio = { version = "1.48.0", features = ["full"] } tokio = { version = "1.48.0", features = ["full"] }
rand = "0.10.0-rc.5" rand = { version = "0.8", features = ["small_rng"] }
serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.145"
rperf3-rs = "0.3.9"
cidr = "0.3.1"
serde-saphyr = "0.0.10"
arc-swap = "1.7.1"
clap = { version = "4.5.53", features = ["derive"] }
notify = "8.2.0"

View File

@@ -33,7 +33,7 @@ RUN --mount=type=cache,target=/usr/local/cargo/registry \
# change to scratch and get comment the apk command for prod, i guess # change to scratch and get comment the apk command for prod, i guess
FROM alpine:latest AS runtime FROM alpine:latest AS runtime
RUN apk add --no-cache ca-certificates curl netcat-openbsd bind-tools strace # RUN apk add --no-cache ca-certificates curl netcat-openbsd bind-tools strace
WORKDIR /enginewhy WORKDIR /enginewhy
COPY --from=builder /enginewhy/target/x86_64-unknown-linux-musl/release/l4lb /usr/bin/l4lb COPY --from=builder /enginewhy/target/x86_64-unknown-linux-musl/release/l4lb /usr/bin/l4lb
ENTRYPOINT ["l4lb"] ENTRYPOINT ["l4lb"]

115
README.md Normal file
View File

@@ -0,0 +1,115 @@
# nginy
Production't graden't load balancer.
## Quick links
## Todo
- [ ] architecture astronauting
- [ ] stream / session handling (i think wrapper around tokio TcpStream)
- [ ] basic backend pooling
- [ ] layer 4 load balancing
- [ ] load balancing algorithm from the paper (https://www.wcse.org/WCSE_2018/W110.pdf)
## notes
tcp, for nginx (and haproxy, its similar):
```c
// nginx
struct ngx_connection_s {
void *data;
ngx_event_t *read;
ngx_event_t *write;
ngx_socket_t fd;
ngx_recv_pt recv; // fn pointer to whatever recv fn used (different for dfferent platforms / protocol
ngx_send_pt send; // ditto
ngx_recv_chain_pt recv_chain;
ngx_send_chain_pt send_chain;
ngx_listening_t *listening;
off_t sent;
ngx_log_t *log;
ngx_pool_t *pool;
int type;
struct sockaddr *sockaddr;
socklen_t socklen;
ngx_str_t addr_text;
ngx_proxy_protocol_t *proxy_protocol;
#if (NGX_QUIC || NGX_COMPAT)
ngx_quic_stream_t *quic;
#endif
#if (NGX_SSL || NGX_COMPAT)
ngx_ssl_connection_t *ssl;
#endif
ngx_udp_connection_t *udp; // additional stuff for UDP (which is technically connectionless, but they use timeouts and a rbtree to store "sessions")
struct sockaddr *local_sockaddr;
socklen_t local_socklen;
ngx_buf_t *buffer;
ngx_queue_t queue;
ngx_atomic_uint_t number;
ngx_msec_t start_time;
ngx_uint_t requests;
unsigned buffered:8;
unsigned log_error:3; /* ngx_connection_log_error_e */
unsigned timedout:1;
unsigned error:1;
unsigned destroyed:1;
unsigned pipeline:1;
unsigned idle:1;
unsigned reusable:1;
unsigned close:1;
unsigned shared:1;
unsigned sendfile:1;
unsigned sndlowat:1;
unsigned tcp_nodelay:2; /* ngx_connection_tcp_nodelay_e */
unsigned tcp_nopush:2; /* ngx_connection_tcp_nopush_e */
unsigned need_last_buf:1;
unsigned need_flush_buf:1;
#if (NGX_HAVE_SENDFILE_NODISKIO || NGX_COMPAT)
unsigned busy_count:2;
#endif
#if (NGX_THREADS || NGX_COMPAT)
ngx_thread_task_t *sendfile_task;
#endif
};
```
process to load balance:
- accept incoming connection
- create some kind of stream / session object
- nginx use this to abstract around tcp and udp layers
- for us we probably don't need as detailed as them, since we have tokio::net, so itll be a wrapper around TcpStream
- ask the load balancing algorithm which server in the pool to route to
- connect to the server
- proxy the data (copy_bidirectional? maybe we want some metrics or logging, so might do manually)
- cleanup when smoeone leavesr or something goes wrong (with TCP, OS / tokio will tell us, with UDP probably just timeout based, and a periodic sweep of all sessions)
## Load balancer algorithm
- Choose a fixed weight coefficient for the resource parameter
- Spawn a thread on a load balancer to host the iperf server, used for new onboarding server connecting to the load balancer to measure their maximum bandwidth
- Spawn another thread for listening to resource update from connected server
- Update the comprehensive load sum from eq (1), update the formula in eq (2) to (5)
- Choose alpha for eq (8), and run the algorithm to choose which server
- Extract the server from the server id using ```get_backend()```
- Use ```tunnel()``` to proxy the packet

BIN
W110.pdf Normal file
View File

Binary file not shown.

View File

@@ -1,27 +0,0 @@
healthcheck_addr: "0.0.0.0:8080"
iperf_addr: "0.0.0.0:5001"
backends:
- id: "srv-1"
ip: "192.67.67.2:8080"
- id: "srv-2"
ip: "192.67.67.3:8080"
clusters:
main-api:
- "srv-1"
- "srv-2"
priority-api:
- "srv-1"
rules:
- clients:
- "172.67.67.2/24:80"
targets:
- "main-api"
- "priority-api"
strategy:
type: "Adaptive"
coefficients: [ 1.5, 1.0, 0.5, 0.1 ]
alpha: 0.75

View File

@@ -1,88 +0,0 @@
services:
load-balancer:
image: neoslhp/enginewhy-lb
container_name: load-balancer
tty: true
deploy:
resources:
limits:
cpus: "4.0"
memory: 8G
cap_add:
- NET_ADMIN
- SYS_ADMIN
networks:
internal:
ipv4_address: 172.67.67.67
external:
ipv4_address: 192.67.67.67
server1-high-cpu:
image: neoslhp/enginewhy-server
container_name: server1
tty: true
deploy:
resources:
limits:
cpus: "4.0"
memory: 8G
depends_on:
- load-balancer
cap_add:
- NET_ADMIN
networks:
external:
ipv4_address: 192.67.67.2
server2-low-cpu:
image: neoslhp/enginewhy-server
container_name: server2
tty: true
deploy:
resources:
limits:
cpus: "2.0"
memory: 4G
depends_on:
- load-balancer
cap_add:
- NET_ADMIN
networks:
external:
ipv4_address: 192.67.67.3
client:
image: neoslhp/enginewhy-ubuntu22.04
container_name: client
tty: true
deploy:
resources:
limits:
cpus: "4.0"
memory: 4G
depends_on:
- load-balancer
cap_add:
- NET_ADMIN
networks:
internal:
ipv4_address: 172.67.67.2
networks:
internal:
driver: bridge
ipam:
config:
- subnet: 172.67.67.0/24
external:
driver: bridge
ipam:
config:
- subnet: 192.67.67.0/24
# Resources:
# https://networkgeekstuff.com/networking/basic-load-balancer-scenarios-explained/
# https://hub.docker.com/r/linuxserver/wireshark
# https://www.wcse.org/WCSE_2018/W110.pdf
# Deepseek

View File

@@ -1,38 +0,0 @@
healthcheck_addr: "10.0.1.10:9000"
iperf_addr: "10.0.1.10:5201"
backends:
- id: "srv-1"
ip: "10.0.1.11:8081"
- id: "srv-2"
ip: "10.0.1.12:8082"
- id: "srv-3"
ip: "10.0.1.13:8083"
- id: "srv-4"
ip: "10.0.1.14:8084"
clusters:
main-api:
- "srv-1"
- "srv-2"
priority-api:
- "srv-3"
- "srv-4"
rules:
- clients:
- "0.0.0.0/0:8080"
targets:
- "main-api"
strategy:
type: "RoundRobin"
- clients:
- "10.0.0.0/24:8080"
- "10.0.0.0/24:25565"
targets:
- "main-api"
- "priority-api"
strategy:
type: "RoundRobin"

View File

@@ -1,110 +0,0 @@
services:
load-balancer:
image: enginewhy
container_name: load-balancer
tty: true
cap_add:
- NET_ADMIN
- SYS_ADMIN
volumes:
- ./config.yaml:/enginewhy/config.yaml
networks:
net_1:
ipv4_address: 10.0.1.10
net_2:
ipv4_address: 10.0.0.10
net_3:
ipv4_address: 10.0.2.10
srv-1:
image: nicolaka/netshoot
container_name: srv-1
tty: true
command: ["python3", "-m", "http.server", "8081", "--directory", "/root/www"]
networks:
net_1:
ipv4_address: 10.0.1.11
ports:
- "8081:8081"
volumes:
- ./srv1:/root/www
cap_add: [ NET_ADMIN ]
srv-2:
image: nicolaka/netshoot
container_name: srv-2
tty: true
command: ["python3", "-m", "http.server", "8082", "--directory", "/root/www"]
networks:
net_1:
ipv4_address: 10.0.1.12
ports:
- "8082:8082"
volumes:
- ./srv2:/root/www
cap_add: [ NET_ADMIN ]
srv-3:
image: nicolaka/netshoot
container_name: srv-3
tty: true
command: ["python3", "-m", "http.server", "8083", "--directory", "/root/www"]
networks:
net_1:
ipv4_address: 10.0.1.13
ports:
- "8083:8083"
volumes:
- ./srv3:/root/www
cap_add: [ NET_ADMIN ]
srv-4:
image: nicolaka/netshoot
container_name: srv-4
tty: true
command: ["python3", "-m", "http.server", "8084", "--directory", "/root/www"]
networks:
net_1:
ipv4_address: 10.0.1.14
ports:
- "8084:8084"
volumes:
- ./srv4:/root/www
cap_add: [ NET_ADMIN ]
client-net2:
image: nicolaka/netshoot
container_name: client-net2
tty: true
networks:
net_2:
ipv4_address: 10.0.0.11
cap_add: [ NET_ADMIN ]
client-net3:
image: nicolaka/netshoot
container_name: client-net3
tty: true
networks:
net_3:
ipv4_address: 10.0.2.11
cap_add: [ NET_ADMIN ]
networks:
net_1:
driver: bridge
ipam:
config:
- subnet: 10.0.1.0/24
net_2:
driver: bridge
ipam:
config:
- subnet: 10.0.0.0/24
net_3:
driver: bridge
ipam:
config:
- subnet: 10.0.2.0/24

View File

@@ -1 +0,0 @@
Hello from server 1!

View File

@@ -1 +0,0 @@
Hello from server 2!

View File

@@ -1 +0,0 @@
Hello from server 3!

View File

@@ -1 +0,0 @@
Hello from server 4!

View File

@@ -146,14 +146,20 @@ async fn main() -> std::io::Result<()> {
} }
println!(); println!();
// Identify this process (client) by the local socket address used to connect
let server_identifier = match stream.local_addr() {
Ok(addr) => addr.to_string(),
Err(_) => format!("localhost:{}", PORT),
};
let mut packet: HashMap<String, Value> = HashMap::new(); let mut packet: HashMap<String, Value> = HashMap::new();
packet.insert("server_ip".to_string(), Value::String(server_identifier));
packet.insert("cpu".to_string(), Value::from(cpu_usage)); // % packet.insert("cpu".to_string(), Value::from(cpu_usage)); // %
packet.insert("mem".to_string(), Value::from(mem_usage)); // % packet.insert("mem".to_string(), Value::from(mem_usage)); // %
packet.insert("net".to_string(), Value::from(net_usage_pct)); packet.insert("net".to_string(), Value::from(net_usage_pct));
packet.insert("io".to_string(), Value::from(io_usage)); packet.insert("io".to_string(), Value::from(io_usage));
let serialized_packet = serde_json::to_string(&packet)?; let serialized_packet = serde_json::to_string(&packet)?;
serialized_packet.push('\n');
let _ = stream.write(serialized_packet.as_bytes()); let _ = stream.write(serialized_packet.as_bytes());
thread::sleep(Duration::from_secs(10)); thread::sleep(Duration::from_secs(10));

View File

@@ -1,134 +0,0 @@
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, RwLock};
use serde_json::Value;
use rperf3::{Config, Server};
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::io::AsyncBufReadExt;
// Physical server health statistics, used for certain load balancing algorithms
#[derive(Debug, Default)]
pub struct ServerMetrics {
pub cpu: f64,
pub mem: f64,
pub net: f64,
pub io: f64,
}
impl ServerMetrics {
pub fn update(&mut self, cpu: f64, mem: f64, net: f64, io: f64) {
self.cpu = cpu;
self.mem = mem;
self.net = net;
self.io = io;
}
}
pub async fn start_healthcheck_listener(
addr: &str,
healths: HashMap<IpAddr, Arc<RwLock<ServerMetrics>>>,
) -> std::io::Result<()> {
let addrs = tokio::net::lookup_host(addr).await?;
let mut listener = None;
for a in addrs {
let socket = match a {
SocketAddr::V4(_) => TcpSocket::new_v4()?,
SocketAddr::V6(_) => TcpSocket::new_v6()?,
};
socket.set_reuseaddr(true)?;
if socket.bind(a).is_ok() {
listener = Some(socket.listen(1024)?);
break;
}
}
let listener = listener.ok_or_else(|| {
eprintln!("health listener could not bind to port");
std::io::Error::new(std::io::ErrorKind::Other, "health listener failed")
})?;
println!("healthcheck server listening on {}", addr);
loop {
let (stream, remote_addr) = match listener.accept().await {
Ok(v) => v,
Err(e) => {
continue;
}
};
if let Err(e) = handle_metrics_stream(stream, &healths).await {
eprintln!("connection handler error: {}", e);
}
}
}
pub async fn start_iperf_server(addr: &str) -> Result<(), Box<dyn std::error::Error>> {
let sock = addr.parse::<SocketAddr>()?;
let mut config = Config::server(sock.port());
config.bind_addr = Some(sock.ip());
let server = Server::new(config);
println!("iperf server listening on {}", addr);
server.run().await?;
Ok(())
}
async fn handle_metrics_stream(
stream: TcpStream,
healths: &HashMap<IpAddr, Arc<RwLock<ServerMetrics>>>,
) -> std::io::Result<()> {
let server_ip = stream.peer_addr()?.ip();
let mut reader = tokio::io::BufReader::new(stream);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => break,
Ok(_) => {
if let Err(e) = process_metrics(server_ip, &line, healths) {
eprintln!("skipping invalid packet: {}", e);
}
}
Err(e) => {
eprintln!("connection error: {}", e);
break;
}
}
}
Ok(())
}
fn process_metrics(
server_ip: IpAddr,
json_str: &str,
healths: &HashMap<IpAddr, Arc<RwLock<ServerMetrics>>>,
) -> Result<(), String> {
let parsed: Value =
serde_json::from_str(json_str).map_err(|e| format!("parse error: {}", e))?;
let metrics_lock = healths
.get(&server_ip)
.ok_or_else(|| format!("unknown server: {}", server_ip))?;
let get_f64 = |key: &str| -> Result<f64, String> {
parsed
.get(key)
.and_then(|v| v.as_f64())
.ok_or_else(|| format!("invalid '{}'", key))
};
if let Ok(mut guard) = metrics_lock.write() {
guard.update(
get_f64("cpu")?,
get_f64("mem")?,
get_f64("net")?,
get_f64("io")?,
);
}
Ok(())
}

View File

@@ -1,76 +0,0 @@
pub mod health;
use crate::backend::health::ServerMetrics;
use core::fmt;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::RwLock;
use std::sync::atomic::{AtomicUsize, Ordering};
// A possible endpoint for a proxied connection.
// Note that multiple may live on the same server, hence the Arc<RwLock<ServerMetric>>
#[derive(Debug)]
pub struct Backend {
pub id: String,
pub address: SocketAddr,
pub active_connections: AtomicUsize,
pub metrics: Arc<RwLock<ServerMetrics>>,
}
impl Backend {
pub fn new(
id: String,
address: SocketAddr,
server_metrics: Arc<RwLock<ServerMetrics>>,
) -> Self {
Self {
id: id.to_string(),
address,
active_connections: AtomicUsize::new(0),
metrics: server_metrics,
}
}
// Ordering::Relaxed means the ops could be in any order, but since this
// is just a metric, and we assume the underlying system is sane
// enough not to behave poorly, so SeqCst is probably overkill.
pub fn inc_connections(&self) {
self.active_connections.fetch_add(1, Ordering::Relaxed);
println!(
"{} has {} connections open",
self.id,
self.active_connections.load(Ordering::Relaxed)
);
}
pub fn dec_connections(&self) {
self.active_connections.fetch_sub(1, Ordering::Relaxed);
println!(
"{} has {} connections open",
self.id,
self.active_connections.load(Ordering::Relaxed)
);
}
}
impl fmt::Display for Backend {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} ({})", self.address, self.id)
}
}
// A set of endpoints that can be load balanced around.
// Each Balancer owns one of these. Backend instances may be shared
// with other Balancer instances, hence Arc<Backend>.
#[derive(Clone, Debug)]
pub struct BackendPool {
pub backends: Arc<Vec<Arc<Backend>>>,
}
impl BackendPool {
pub fn new(backends: Vec<Arc<Backend>>) -> Self {
BackendPool {
backends: Arc::new(backends),
}
}
}

View File

@@ -1,267 +1,211 @@
use crate::backend::{Backend, BackendPool}; use crate::netutils::Backend;
use crate::backend::health::ServerMetrics;
use crate::balancer::{Balancer, ConnectionInfo};
use rand::prelude::*; use rand::prelude::*;
use rand::rngs::SmallRng; use rand::rngs::SmallRng;
use std::fmt::Debug; use std::sync::Arc;
use std::fs::Metadata;
use std::sync::{Arc, RwLock};
#[derive(Debug)]
struct AdaptiveNode { #[derive(Debug, Clone)]
backend: Arc<Backend>, pub struct ServerMetrics {
weight: f64, // metrics are percents (0..100)
pub cpu: f64,
pub mem: f64,
pub net: f64,
pub io: f64,
} }
#[derive(Debug)] impl ServerMetrics {
pub struct AdaptiveWeightBalancer { pub fn new() -> Self {
pool: Vec<AdaptiveNode>, ServerMetrics { cpu: 0.0, mem: 0.0, net: 0.0, io: 0.0 }
coefficients: [f64; 4], }
pub fn update(&mut self, cpu: f64, mem: f64, net: f64, io: f64) {
self.cpu = cpu;
self.mem = mem;
self.net = net;
self.io = io;
}
}
#[derive(Debug, Clone)]
pub struct ServerState {
pub backend: Backend,
pub metrics: ServerMetrics,
pub weight: f64,
}
impl ServerState {
pub fn new(backend: Backend) -> Self {
ServerState { backend, metrics: ServerMetrics::new(), weight: 1.0 }
}
}
pub struct AdaptiveBalancer {
servers: Vec<ServerState>,
// resource coefficients (cpu, mem, net, io) - sum to 1.0
coeffs: [f64; 4],
alpha: f64, alpha: f64,
rng: SmallRng, rng: SmallRng,
} }
impl AdaptiveWeightBalancer { impl AdaptiveBalancer {
pub fn new(pool: BackendPool, coefficients: [f64; 4], alpha: f64) -> Self { pub fn new(backends: Vec<Backend>, coeffs: [f64; 4], alpha: f64) -> Self {
let nodes = pool let servers = backends.into_iter().map(ServerState::new).collect();
.backends let rng = SmallRng::from_entropy();
.iter() AdaptiveBalancer { servers, coeffs, alpha, rng }
.map(|b| AdaptiveNode { }
backend: b.clone(),
weight: 1f64,
})
.collect();
AdaptiveWeightBalancer { pub fn add_backend(&mut self, backend: Backend) {
pool: nodes, self.servers.push(ServerState::new(backend));
coefficients, }
alpha,
rng: SmallRng::from_rng(&mut rand::rng()), /// Update metrics reported by a backend identified by its display/address.
/// If the backend isn't found this is a no-op.
pub fn update_metrics(&mut self, backend_addr: &str, cpu: f64, mem: f64, net: f64, io: f64) {
for s in &mut self.servers {
if s.backend.to_string() == backend_addr {
s.metrics.update(cpu, mem, net, io);
return;
}
} }
} }
pub fn metrics_to_weight(&self, metrics: &ServerMetrics) -> f64 { fn metrics_to_weight(metrics: &ServerMetrics, coeffs: &[f64; 4]) -> f64 {
self.coefficients[0] * metrics.cpu coeffs[0] * metrics.cpu + coeffs[1] * metrics.mem + coeffs[2] * metrics.net + coeffs[3] * metrics.io
+ self.coefficients[1] * metrics.mem
+ self.coefficients[2] * metrics.net
+ self.coefficients[3] * metrics.io
} }
}
impl Balancer for AdaptiveWeightBalancer { /// Choose a backend using weighted random selection based on current weights.
fn choose_backend(&mut self, ctx: ConnectionInfo) -> Option<Arc<Backend>> { /// Returns an Arc-wrapped Backend clone so callers can cheaply clone it.
if self.pool.is_empty() { pub fn choose_backend(&mut self) -> Option<Arc<Backend>> {
if self.servers.is_empty() {
return None; return None;
} }
// Compute remaining capacity R_i = 100 - composite_load // Compute remaining capacity R_i = 100 - composite_load
let mut r_sum = 0.0; let rs: Vec<f64> = self.servers.iter().map(|s| {
let mut w_sum = 0.0; Self::metrics_to_weight(&s.metrics, &self.coeffs)
let mut l_sum = 0; }).collect();
let ws: Vec<f64> = self.servers.iter().map(|s| s.weight).collect();
let ls: Vec<u32> = self.servers.iter().map(|s| s.backend.current_load).collect();
for node in &self.pool { let r_sum: f64 = rs.iter().copied().sum::<f64>();
if let Ok(health) = node.backend.metrics.read() { let w_sum: f64 = ws.iter().copied().sum::<f64>().max(1e-12);
r_sum += self.metrics_to_weight(&health); let l_sum: u32 = ls.iter().copied().sum::<u32>();
} let threshold = self.alpha * (r_sum / w_sum);
w_sum += node.weight;
l_sum += node
.backend
.active_connections
.load(std::sync::atomic::Ordering::Relaxed);
}
let safe_w_sum = w_sum.max(1e-12); for (i, s) in self.servers.iter_mut().enumerate() {
let threshold = self.alpha * (r_sum / safe_w_sum); let ratio = if s.weight <= 0.0 { f64::INFINITY } else { rs[i] / s.weight };
for idx in 0..self.pool.len() {
let node = &self.pool[idx];
if node.weight <= 0.001 {
continue;
}
let risk = match node.backend.metrics.read() {
Ok(h) => self.metrics_to_weight(&h),
Err(_) => f64::MAX,
};
let ratio = risk / node.weight;
if ratio <= threshold { if ratio <= threshold {
return Some(node.backend.clone()); return Some(Arc::new(s.backend.clone()));
} }
} }
// If any server satisfies Ri/Wi <= threshold, it means the server // If any server satisfies Ri/Wi <= threshold, it means the server
// is relatively overloaded, and we must adjust its weight using // is relatively overloaded and we must adjust its weight using
// formula (6). // formula (6).
let mut total_lwi = 0.0;
let l_sum_f64 = l_sum as f64; let lwi: Vec<f64> = self.servers.iter().enumerate().map(|(i, s)| {
s.backend.current_load as f64 * w_sum / ws[i] * l_sum as f64
for node in &self.pool { }).collect();
let load = node let a_lwi: f64 = lwi.iter().copied().sum::<f64>() / lwi.len() as f64;
.backend for (i, s) in self.servers.iter_mut().enumerate() {
.active_connections s.weight += 1 as f64 - lwi[i] / a_lwi;
.load(std::sync::atomic::Ordering::Relaxed) as f64;
let weight = node.weight.max(1e-12);
let lwi = load * (safe_w_sum / weight) * l_sum_f64;
total_lwi += lwi;
} }
let avg_lwi = (total_lwi / self.pool.len() as f64).max(1e-12);
// Compute Li = Wi / Ri and choose server minimizing Li. // Compute Li = Wi / Ri and choose server minimizing Li.
let mut best_backend: Option<Arc<Backend>> = None; let mut best_idx: Option<usize> = None;
let mut min_load = usize::MAX; let mut best_li = u32::MAX;
for (i, s) in self.servers.iter().enumerate() {
for node in &mut self.pool { let li = s.backend.current_load;
let load = node if li < best_li {
.backend best_li = li;
.active_connections best_idx = Some(i);
.load(std::sync::atomic::Ordering::Relaxed);
let load_f64 = load as f64;
let weight = node.weight.max(1e-12);
let lwi = load_f64 * (safe_w_sum / weight) * l_sum_f64;
let adj = 1.0 - (lwi / avg_lwi);
node.weight += adj;
node.weight = node.weight.clamp(0.1, 100.0);
if load < min_load {
min_load = load;
best_backend = Some(node.backend.clone());
} }
} }
match best_backend { // If nothing chosen, fall back to random selection
Some(backend) => Some(backend), if best_idx.is_none() {
None => { let i = (self.rng.next_u32() as usize) % self.servers.len();
let i = (self.rng.next_u32() as usize) % self.pool.len(); return Some(Arc::new(self.servers[i].backend.clone()));
Some(self.pool[i].backend.clone())
}
} }
Some(Arc::new(self.servers[best_idx.unwrap()].backend.clone()))
}
// Expose a snapshot of server weights (for monitoring/testing)
pub fn snapshot_weights(&self) -> Vec<(String, f64)> {
self.servers.iter().map(|s| (s.backend.to_string(), s.weight)).collect()
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::backend::Backend;
use std::net::SocketAddr;
fn backend_factory(id: &str, ip: &str, port: u16) -> Arc<Backend> {
Arc::new(Backend::new(
id.to_string(),
SocketAddr::new(ip.parse().unwrap(), port),
Arc::new(RwLock::new(ServerMetrics::default())),
))
}
fn unused_ctx() -> ConnectionInfo {
ConnectionInfo {
client_ip: ("0.0.0.0".parse().unwrap()),
}
}
#[test] #[test]
fn basic_weight_update_and_choose() { fn basic_weight_update_and_choose() {
let backends = BackendPool::new(vec![ let backends = vec![Backend::new("127.0.0.1:1".to_string()), Backend::new("127.0.0.1:2".to_string())];
backend_factory("server-0", "127.0.0.1", 3000), let mut b = AdaptiveBalancer::new(backends, [0.5, 0.2, 0.2, 0.1], 0.5);
backend_factory("server-1", "127.0.0.1", 3001),
]);
let mut b = AdaptiveWeightBalancer::new(backends.clone(), [0.5, 0.2, 0.2, 0.1], 0.5);
// initially equal weights // initially equal weights
let snaps = b.snapshot_weights();
assert_eq!(snaps.len(), 2);
// update one backend to be heavily loaded // update one backend to be heavily loaded
{ b.update_metrics("127.0.0.1:1", 90.0, 80.0, 10.0, 5.0);
let mut sm0_guard = backends.backends.get(0).unwrap().metrics.write().unwrap(); b.update_metrics("127.0.0.1:2", 10.0, 5.0, 1.0, 1.0);
sm0_guard.update(90.0, 80.0, 10.0, 5.0);
}
{
let mut sm1_guard = backends.backends.get(1).unwrap().metrics.write().unwrap();
sm1_guard.update(10.0, 5.0, 1.0, 1.0);
}
// Choose backend: should pick the less loaded host server1 // Choose backend: should pick the less loaded host (127.0.0.1:2)
let chosen = b let chosen = b.choose_backend().expect("should choose a backend");
.choose_backend(unused_ctx()) let snaps2 = b.snapshot_weights();
.expect("should choose a backend"); println!("{:?}, {:?}", snaps, snaps2);
assert_eq!(chosen.to_string(), "127.0.0.1:2");
let sm0: &ServerMetrics = &backends.backends.get(0).unwrap().metrics.read().unwrap();
let sm1: &ServerMetrics = &backends.backends.get(1).unwrap().metrics.read().unwrap();
println!("{:?}, {:?}", sm0, sm1);
assert_eq!(chosen.id, "server-1");
} }
#[test] #[test]
fn choose_none_when_empty() { fn choose_none_when_empty() {
let mut b = let mut b = AdaptiveBalancer::new(vec![], [0.5, 0.2, 0.2, 0.1], 0.5);
AdaptiveWeightBalancer::new(BackendPool::new(vec![]), [0.5, 0.2, 0.2, 0.1], 0.5); assert!(b.choose_backend().is_none());
assert!(b.choose_backend(unused_ctx()).is_none());
} }
#[test] #[test]
fn ratio_triggers_immediate_selection() { fn ratio_triggers_immediate_selection() {
// Arrange two servers where server 1 has composite load 0 and server 2 has composite load 100. // Arrange two servers where server 1 has composite load 0 and server 2 has composite load 100.
// With alpha = 1.0 and two servers, threshold = 1.0 * (r_sum / w_sum) = 1.0 * (100 / 2) = 50. // With alpha = 1.0 and two servers, threshold = 1.0 * (r_sum / w_sum) = 1.0 * (100 / 2) = 50.
// Server 0 ratio = 0 / 1 = 0 <= 50 so it should be chosen immediately. // Server 1 ratio = 0 / 1 = 0 <= 50 so it should be chosen immediately.
let backends = BackendPool::new(vec![ let backends = vec![Backend::new("127.0.0.1:1".to_string()), Backend::new("127.0.0.1:2".to_string())];
backend_factory("server-0", "127.0.0.1", 3000), let mut b = AdaptiveBalancer::new(backends, [0.25, 0.25, 0.25, 0.25], 1.0);
backend_factory("server-1", "127.0.0.1", 3001),
]);
let mut b = AdaptiveWeightBalancer::new(backends.clone(), [0.25, 0.25, 0.25, 0.25], 1.0);
{ b.update_metrics("127.0.0.1:1", 0.0, 0.0, 0.0, 0.0);
let mut sm0_guard = backends.backends.get(0).unwrap().metrics.write().unwrap(); b.update_metrics("127.0.0.1:2", 100.0, 100.0, 100.0, 100.0);
sm0_guard.update(0.0, 0.0, 0.0, 0.0);
}
{
let mut sm1_guard = backends.backends.get(1).unwrap().metrics.write().unwrap();
sm1_guard.update(100.0, 100.0, 100.0, 100.0);
}
let chosen = b let chosen = b.choose_backend().expect("should choose a backend");
.choose_backend(unused_ctx()) assert_eq!(chosen.to_string(), "127.0.0.1:1");
.expect("should choose a backend");
assert_eq!(chosen.id, "server-0");
} }
#[test] #[test]
fn choose_min_current_load_when_no_ratio() { fn choose_min_current_load_when_no_ratio() {
// Arrange three servers with identical composite loads so no server satisfies Ri/Wi <= threshold // Arrange three servers with identical composite loads so no server satisfies Ri/Wi <= threshold
// (set alpha < 1 so threshold < ratio). The implementation then falls back to picking the // (set alpha < 1 so threshold < ratio). The implementation then falls back to picking the
// server with minimum current_load // server with minimum current_load
let backends = BackendPool::new(vec![ let mut s1 = Backend::new("127.0.0.1:1".to_string());
backend_factory("server-0", "127.0.0.1", 3000), let mut s2 = Backend::new("127.0.0.1:2".to_string());
backend_factory("server-1", "127.0.0.1", 3001), let mut s3 = Backend::new("127.0.0.1:3".to_string());
backend_factory("server-2", "127.0.0.1", 3002),
]);
// set current_loads (field expected to be public) // set current_loads (field expected to be public)
s1.current_load = 10;
{ s2.current_load = 5;
let mut sm0_guard = backends.backends.get(0).unwrap().metrics.write().unwrap(); s3.current_load = 20;
sm0_guard.update(10.0, 10.0, 10.0, 10.0);
}
{
let mut sm1_guard = backends.backends.get(1).unwrap().metrics.write().unwrap();
sm1_guard.update(5.0, 5.0, 5.0, 5.0);
}
{
let mut sm2_guard = backends.backends.get(2).unwrap().metrics.write().unwrap();
sm2_guard.update(20.0, 20.0, 20.0, 20.0);
}
// Use coeffs that only consider CPU so composite load is easy to reason about. // Use coeffs that only consider CPU so composite load is easy to reason about.
let mut bal = AdaptiveWeightBalancer::new(backends.clone(), [1.0, 0.0, 0.0, 0.0], 0.5); let mut bal = AdaptiveBalancer::new(vec![s1, s2, s3], [1.0, 0.0, 0.0, 0.0], 0.5);
// set identical composite loads > 0 for all so ratio = x and threshold = alpha * x < x // set identical composite loads > 0 for all so ratio = x and threshold = alpha * x < x
// you will have threshold = 25 for all 3 backend servers and ratio = 50 // you will have threshold = 25 for all 3 backend servers and ratio = 50
// so that forces to choose the smallest current load backend // so that forces to choose the smallest current load backend
bal.update_metrics("127.0.0.1:1", 50.0, 0.0, 0.0, 0.0);
bal.update_metrics("127.0.0.1:2", 50.0, 0.0, 0.0, 0.0);
bal.update_metrics("127.0.0.1:3", 50.0, 0.0, 0.0, 0.0);
let chosen = bal let chosen = bal.choose_backend().expect("should choose a backend");
.choose_backend(unused_ctx()) // expect server with smallest current_load (127.0.0.1:2)
.expect("should choose a backend"); assert_eq!(chosen.to_string(), "127.0.0.1:2");
// expect server with smallest current_load server-1
assert_eq!(chosen.id, "server-1");
} }
} }

View File

@@ -1,109 +0,0 @@
use crate::backend::{Backend, BackendPool};
use crate::balancer::{Balancer, ConnectionInfo};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
#[derive(Debug)]
pub struct SourceIPHash {
pool: BackendPool,
}
impl SourceIPHash {
pub fn new(pool: BackendPool) -> SourceIPHash {
Self { pool }
}
}
impl Balancer for SourceIPHash {
fn choose_backend(&mut self, ctx: ConnectionInfo) -> Option<Arc<Backend>> {
let client_ip = ctx.client_ip;
let mut hasher = DefaultHasher::new();
client_ip.hash(&mut hasher);
let hash = hasher.finish();
let idx = (hash as usize) % self.pool.backends.len();
Some(self.pool.backends[idx].clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::health::ServerMetrics;
use std::net::IpAddr;
use std::sync::RwLock;
fn create_dummy_backends(count: usize) -> BackendPool {
let mut backends = Vec::new();
for i in 1..=count {
backends.push(Arc::new(Backend::new(
format!("backend {}", i),
format!("127.0.0.1:808{}", i).parse().unwrap(),
Arc::new(RwLock::new(ServerMetrics::default())),
)));
}
BackendPool::new(backends)
}
#[test]
fn test_same_ip_always_selects_same_backend() {
let backends = create_dummy_backends(3);
let mut balancer = SourceIPHash::new(backends);
let client_ip: IpAddr = "192.168.1.100".parse().unwrap();
let first_choice = balancer.choose_backend(ConnectionInfo { client_ip });
let second_choice = balancer.choose_backend(ConnectionInfo { client_ip });
assert!(first_choice.is_some());
assert!(second_choice.is_some());
let first = first_choice.unwrap();
let second = second_choice.unwrap();
assert_eq!(first.id, second.id);
}
#[test]
fn test_different_ips_may_select_different_backends() {
let backends = create_dummy_backends(2);
let mut balancer = SourceIPHash::new(backends);
let ip1: IpAddr = "192.168.1.100".parse().unwrap();
let choice1 = balancer.choose_backend(ConnectionInfo { client_ip: ip1 });
let ip2: IpAddr = "192.168.1.101".parse().unwrap();
let choice2 = balancer.choose_backend(ConnectionInfo { client_ip: ip2 });
assert!(choice1.is_some());
assert!(choice2.is_some());
}
#[test]
fn test_hash_distribution_across_backends() {
let pool = create_dummy_backends(3);
let backends_ref = pool.backends.clone();
let mut balancer = SourceIPHash::new(pool);
let mut distribution = [0, 0, 0];
// Test 30 different IPs
for i in 0..30 {
let client_ip: IpAddr = format!("192.168.1.{}", 100 + i).parse().unwrap();
if let Some(backend) = balancer.choose_backend(ConnectionInfo { client_ip }) {
for (idx, b) in backends_ref.iter().enumerate() {
if backend.id == b.id && backend.address == b.address {
distribution[idx] += 1;
break;
}
}
}
}
assert!(distribution[0] > 0, "Backend 0 received no traffic");
assert!(distribution[1] > 0, "Backend 1 received no traffic");
assert!(distribution[2] > 0, "Backend 2 received no traffic");
}
}

View File

@@ -1 +0,0 @@
// use super::*;

View File

@@ -1,18 +1,2 @@
pub mod adaptive_weight; pub mod adaptive_weight;
pub mod ip_hashing; pub use adaptive_weight::AdaptiveBalancer;
pub mod least_connections;
pub mod round_robin;
use crate::backend::Backend;
use std::fmt::Debug;
use std::net::IpAddr;
use std::sync::Arc;
#[derive(Clone, Debug)]
pub struct ConnectionInfo {
pub client_ip: IpAddr,
}
pub trait Balancer: Debug + Send + Sync + 'static {
fn choose_backend(&mut self, ctx: ConnectionInfo) -> Option<Arc<Backend>>;
}

View File

View File

@@ -1,32 +0,0 @@
use crate::backend::{Backend, BackendPool};
use crate::balancer::{Balancer, ConnectionInfo};
use std::fmt::Debug;
use std::sync::Arc;
// only the main thread for receiving connections should be
// doing the load balancing. alternatively, each thread
// that handles load balancing should get their own instance.
#[derive(Debug)]
pub struct RoundRobinBalancer {
pool: BackendPool,
index: usize,
}
impl RoundRobinBalancer {
pub fn new(pool: BackendPool) -> RoundRobinBalancer {
Self { pool, index: 0 }
}
}
impl Balancer for RoundRobinBalancer {
fn choose_backend(&mut self, ctx: ConnectionInfo) -> Option<Arc<Backend>> {
let backends = self.pool.backends.clone();
if backends.is_empty() {
return None;
}
let backend = backends[self.index % backends.len()].clone();
self.index = self.index.wrapping_add(1);
Some(backend)
}
}

View File

@@ -1,128 +0,0 @@
use cidr::IpCidr;
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, RwLock};
use crate::backend::health::*;
use crate::backend::*;
use crate::balancer::Balancer;
use crate::balancer::adaptive_weight::AdaptiveWeightBalancer;
use crate::balancer::round_robin::RoundRobinBalancer;
use crate::balancer::ip_hashing::SourceIPHash;
use crate::config::*;
pub struct RoutingTable {
pub balancers: Vec<Box<dyn Balancer + Send>>,
pub entries: Vec<(IpCidr, usize)>,
}
pub type PortListeners = HashMap<u16, RoutingTable>;
fn parse_client(s: &str) -> Result<(IpCidr, u16), String> {
// just splits "0.0.0.0/0:80" into ("0.0.0.0/0", 80)
let (ip_part, port_part) = s
.rsplit_once(':')
.ok_or_else(|| format!("badly formatted client: {}", s))?;
let port = port_part.parse().map_err(|_| format!("bad port: {}", s))?;
let cidr = ip_part.parse().map_err(|_| format!("bad ip/mask: {}", s))?;
Ok((cidr, port))
}
pub fn build_lb(
config: &AppConfig,
) -> Result<(PortListeners, HashMap<IpAddr, Arc<RwLock<ServerMetrics>>>), String> {
let mut healths: HashMap<IpAddr, Arc<RwLock<ServerMetrics>>> = HashMap::new();
let mut backends: HashMap<String, Arc<Backend>> = HashMap::new();
for backend_cfg in &config.backends {
let addr: SocketAddr = backend_cfg.ip.parse()
.map_err(|_| format!("bad ip: {}", backend_cfg.ip))?;
let ip = addr.ip();
let health = healths
.entry(ip)
.or_insert_with(|| Arc::new(RwLock::new(ServerMetrics::default())))
.clone();
let backend = Arc::new(Backend::new(backend_cfg.id.clone(), addr, health));
backends.insert(backend_cfg.id.clone(), backend);
}
let mut listeners: PortListeners = HashMap::new();
for rule in &config.rules {
let mut target_backends = Vec::new();
for target_name in &rule.targets {
if let Some(members) = config.clusters.get(target_name) {
for member_id in members {
if let Some(backend) = backends.get(member_id) {
target_backends.push(backend.clone());
}
}
} else if let Some(backend) = backends.get(target_name) {
target_backends.push(backend.clone());
} else {
eprintln!("warning: target {} not found", target_name);
}
}
// possible for multiple targets of the same rule to have common backends.
target_backends.sort_by(|a, b| a.id.cmp(&b.id));
target_backends.dedup_by(|a, b| a.id == b.id);
if target_backends.is_empty() {
eprintln!("warning: rule has no valid targets, skipping.");
continue;
}
// for each different client port on this rule, we unfortunately need to make a new
// Balancer, since Balancer is not thread safe, requires &mut self for the backend
// selection.
// a good enough compromise to make a new one for each port, avoids using Mutex, at the
// cost of minor penalty to load balancing "quality" when you have several client ports.
let mut port_groups: HashMap<u16, Vec<IpCidr>> = HashMap::new();
for client_def in &rule.clients {
let (cidr, port) = parse_client(&client_def)?;
port_groups.entry(port).or_default().push(cidr);
}
for (port, cidrs) in port_groups {
let table = listeners.entry(port).or_insert_with(|| RoutingTable {
balancers: Vec::new(),
entries: Vec::new(),
});
let pool = BackendPool::new(target_backends.clone());
let balancer: Box<dyn Balancer + Send> = match &rule.strategy {
LoadBalancerStrategy::RoundRobin => Box::new(RoundRobinBalancer::new(pool)),
LoadBalancerStrategy::SourceIPHash => Box::new(SourceIPHash::new(pool)),
LoadBalancerStrategy::Adaptive {
coefficients,
alpha,
} => Box::new(AdaptiveWeightBalancer::new(pool, *coefficients, *alpha)),
};
let balancer_idx = table.balancers.len();
table.balancers.push(balancer);
for cidr in cidrs {
table.entries.push((cidr, balancer_idx));
}
}
}
// sort to make most specific first, so that first match == longest prefix match
for table in listeners.values_mut() {
table
.entries
.sort_by(|(a, _), (b, _)| b.network_length().cmp(&a.network_length()));
}
Ok((listeners, healths))
}

View File

@@ -1,62 +0,0 @@
// config is written as a YAML file, the path will be passed to the program.
//
// the high level structure of the config is that we
// first define the individual backends (ip + port) we are going
// to load balance around.
//
// next we define some clusters, which are really more like a short
// alias for a group of backends.
//
// next we define the rules. these are written as a list of
// "ip/subnet:port" for the clients, and then a list of clusters
// for which backends these are balanced around. and of course
// specify which algorithm to use.
pub mod loader;
use serde::Deserialize;
use std::collections::HashMap;
fn default_healthcheck_addr() -> String {
"0.0.0.0:8080".to_string()
}
fn default_iperf_addr() -> String {
"0.0.0.0:5201".to_string()
}
#[derive(Debug, Deserialize)]
pub struct AppConfig {
#[serde(default = "default_healthcheck_addr")]
pub healthcheck_addr: String,
#[serde(default = "default_iperf_addr")]
pub iperf_addr: String,
pub backends: Vec<BackendConfig>,
#[serde(default)]
pub clusters: HashMap<String, Vec<String>>,
pub rules: Vec<RuleConfig>,
}
#[derive(Debug, Deserialize)]
pub struct BackendConfig {
pub id: String,
pub ip: String,
}
#[derive(Debug, Deserialize)]
pub struct RuleConfig {
pub clients: Vec<String>,
pub targets: Vec<String>,
pub strategy: LoadBalancerStrategy,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
pub enum LoadBalancerStrategy {
RoundRobin,
SourceIPHash,
Adaptive { coefficients: [f64; 4], alpha: f64 },
}

View File

@@ -1,252 +1,56 @@
mod backend; macro_rules! info {
($($arg:tt)*) => {{
print!("info: ");
println!($($arg)*);
}};
}
macro_rules! error {
($($arg:tt)*) => {
eprint!("error: ");
eprintln!($($arg)*);
};
}
mod netutils;
mod balancer; mod balancer;
mod config;
mod proxy;
use crate::backend::health::{start_healthcheck_listener, start_iperf_server, ServerMetrics};
use crate::balancer::ConnectionInfo;
use crate::config::loader::{build_lb, RoutingTable};
use crate::proxy::tcp::proxy_tcp_connection;
use anywho::Error; use anywho::Error;
use std::collections::HashMap; use netutils::{Backend, tunnel};
use std::fs::File; use std::sync::Arc;
use std::hash::Hash;
use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::time::Duration;
use tokio::io::AsyncBufReadExt;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::sync::mpsc; use tokio::sync::Mutex;
use clap::Parser;
use notify::{Event, RecursiveMode, Watcher};
use std::cmp;
static NEXT_CONN_ID: AtomicU64 = AtomicU64::new(1);
struct ProgramState {
tx_rt_map: HashMap<u16, mpsc::UnboundedSender<RoutingTable>>,
healths: HashMap<IpAddr, Arc<RwLock<ServerMetrics>>>,
health_listener: Option<tokio::task::JoinHandle<()>>,
iperf_server: Option<tokio::task::JoinHandle<()>>,
health_listener_addr: Option<String>,
iperf_server_addr: Option<String>,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(short, long, default_value = "config.yaml")]
config: PathBuf,
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Error> {
let args = Args::parse(); let backends = Arc::new(vec![
Backend::new("127.0.0.1:8081".to_string()),
Backend::new("127.0.0.1:8082".to_string()),
]);
if !args.config.is_file() { let current_index = Arc::new(Mutex::new(0));
eprintln!("config file not found or not accessible");
std::process::exit(1);
}
println!("reading config from {:?}", args.config); info!("enginewhy starting on 0.0.0.0:8080");
info!("backends: {:?}", backends);
let state = Arc::new(Mutex::new(ProgramState { let listener = TcpListener::bind("0.0.0.0:8080").await?;
tx_rt_map: HashMap::new(),
healths: HashMap::new(),
health_listener: None,
iperf_server: None,
health_listener_addr: None,
iperf_server_addr: None,
}));
if let Err(e) = load_config(&args.config, state.clone()).await {
eprintln!("config file loading failed: {}", e);
}
let config_path = args.config.clone();
let state_clone = state.clone();
tokio::spawn(async move {
let (tx, mut rx) = mpsc::channel(1);
let mut watcher = notify::recommended_watcher(move |res: Result<Event, notify::Error>| {
if let Ok(event) = res {
if event.kind.is_modify() {
let _ = tx.blocking_send(());
}
}
})
.unwrap();
watcher
.watch(&config_path, RecursiveMode::NonRecursive)
.unwrap();
println!("watching for changes to {:?}", config_path);
while rx.recv().await.is_some() {
// for some reason, saving on certain text editors fires several events,
// and this causes us to reload a lot. try to flush some events, add a tiny delay
// to mitigate this
while rx.try_recv().is_ok() {}
tokio::time::sleep(Duration::from_millis(50)).await;
while rx.try_recv().is_ok() {}
if let Err(e) = load_config(&config_path, state_clone.clone()).await {
eprintln!("loading config failed: {}", e);
}
}
});
loop { loop {
tokio::time::sleep(Duration::from_hours(1)).await; let (client, addr) = listener.accept().await?;
} info!("new connection from {}", addr);
}
async fn load_config(path: &PathBuf, state: Arc<Mutex<ProgramState>>) -> Result<(), Error> { let backend = {
let f = File::open(path)?; let mut index = current_index.lock().await;
let app_config: config::AppConfig = match serde_saphyr::from_reader(f) { let selected_backend = backends[*index].clone();
Ok(app_config) => app_config, *index = (*index + 1) % backends.len();
Err(e) => { selected_backend
eprintln!("error parsing config {}", e); };
return Ok(());
}
};
println!( info!("routing client {} to backend {}", addr, backend);
"Loaded config, with {} backends, {} rules.",
app_config.backends.len(),
app_config.rules.len()
);
let (mut listeners, health_monitors) = match build_lb(&app_config) { if let Err(e) = tunnel(client, backend).await {
Ok(v) => v, error!("proxy failed for {}: {}", addr, e);
Err(e) => {
eprintln!("config has logical errors: {}", e);
return Ok(());
}
};
let mut prog_state = state.lock().unwrap();
let ports_to_remove: Vec<u16> = prog_state
.tx_rt_map
.keys()
.cloned()
.filter(|port| !listeners.contains_key(port))
.collect();
for port in ports_to_remove {
prog_state.tx_rt_map.remove(&port);
}
if let Some(handle) = prog_state.health_listener.take() {
handle.abort();
}
let health_map: HashMap<IpAddr, Arc<RwLock<ServerMetrics>>> = health_monitors.clone();
let health_addr = app_config.healthcheck_addr.clone();
let health_addr_c = health_addr.clone();
let health_handle = tokio::spawn(async move {
if let Err(e) = start_healthcheck_listener(&health_addr, health_map).await {
eprintln!("health check listener failed: {}", e);
}
});
prog_state.health_listener = Some(health_handle);
prog_state.health_listener_addr = Some(health_addr_c);
// maybe restart iperf server
let iperf_addr = app_config.iperf_addr.clone();
if prog_state.iperf_server_addr.as_ref() != Some(&iperf_addr) {
if let Some(handle) = prog_state.iperf_server.take() {
handle.abort();
}
let iperf_addr_c = iperf_addr.clone();
let iperf_handle = tokio::spawn(async move {
if let Err(e) = start_iperf_server(iperf_addr.as_str()).await {
eprintln!("iperf server failed: {}", e);
}
});
prog_state.iperf_server = Some(iperf_handle);
prog_state.iperf_server_addr = Some(iperf_addr_c);
}
prog_state.healths = health_monitors;
for (port, routing_table) in listeners.drain() {
if let Some(x) = prog_state.tx_rt_map.get_mut(&port) {
x.send(routing_table)?;
println!("updated rules on port {}", port);
} else {
let (tx_rt, rx_rt) = mpsc::unbounded_channel();
prog_state.tx_rt_map.insert(port, tx_rt);
tokio::spawn(run_listener(port, rx_rt, routing_table));
}
}
println!("reload complete");
Ok(())
}
async fn run_listener(
port: u16,
mut rx_rt: mpsc::UnboundedReceiver<RoutingTable>,
mut current_table: RoutingTable,
) {
let addr = format!("0.0.0.0:{}", port);
println!("Starting tcp listener on {}", addr);
let listener = TcpListener::bind(&addr).await.expect("Failed to bind port");
loop {
tokio::select! {
msg = rx_rt.recv() => {
match msg {
Some(new_table) => {
current_table = new_table;
}
None => {
println!("Unbinding listener on port {}", port);
break;
}
}
}
accept_result = listener.accept() => {
match accept_result {
Ok((socket, remote_addr)) => {
let remote_ip = remote_addr.ip();
let conn_id = NEXT_CONN_ID.fetch_add(1, Ordering::Relaxed);
let mut chosen_backend = None;
for (cidr, balancer_idx) in &mut current_table.entries {
if cidr.contains(&remote_ip) {
let balancer = &mut current_table.balancers[*balancer_idx];
chosen_backend = balancer.choose_backend(ConnectionInfo {
client_ip: remote_ip,
});
break;
}
}
if let Some(backend) = chosen_backend {
tokio::spawn(async move {
if let Err(e) = proxy_tcp_connection(conn_id, socket, backend).await {
eprintln!("error: conn_id={} proxy failed: {}", conn_id, e);
}
});
} else {
println!("error: no matching rule for {} on port {}", remote_ip, port);
}
}
Err(e) => {
eprintln!("error: listener port {}: {}", port, e);
continue;
}
}
}
} }
} }
} }

56
src/netutils.rs Normal file
View File

@@ -0,0 +1,56 @@
use std::fmt;
use tokio::io;
use tokio::net::TcpStream;
use std::error::Error;
#[derive(Clone, Debug)]
pub struct Backend {
address: String,
pub current_load : u32
}
impl Backend {
pub fn new(address: String) -> Self {
Backend {
address,
current_load : 0
}
}
}
impl fmt::Display for Backend {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.address)
}
}
pub async fn tunnel(client_stream: TcpStream, backend: Backend) -> Result<(), Box<dyn Error>> {
let backend_address: String = backend.address.clone();
tokio::spawn(async move {
let backend_stream: TcpStream = match TcpStream::connect(&backend_address).await {
Ok(s) => {
info!("connected to backend {backend_address}");
s
}
Err(e) => {
error!("failed connecting to backend {backend_address}: {e}");
return;
}
};
let (mut read_client, mut write_client) = client_stream.into_split();
let (mut read_backend, mut write_backend) = backend_stream.into_split();
let client_to_backend =
tokio::spawn(async move { io::copy(&mut read_client, &mut write_backend).await });
let backend_to_client =
tokio::spawn(async move { io::copy(&mut read_backend, &mut write_client).await });
let _ = tokio::join!(client_to_backend, backend_to_client);
});
Ok(())
}

View File

@@ -1,44 +0,0 @@
use crate::backend::Backend;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Instant;
pub mod tcp;
pub struct ConnectionContext {
pub id: u64,
pub client_addr: SocketAddr,
pub start_time: Instant,
pub backend: Arc<Backend>,
pub bytes_transferred: u64,
}
impl ConnectionContext {
pub fn new(id: u64, client_addr: SocketAddr, backend: Arc<Backend>) -> Self {
backend.inc_connections();
Self {
id,
client_addr,
start_time: Instant::now(),
backend,
bytes_transferred: 0,
}
}
}
impl Drop for ConnectionContext {
fn drop(&mut self) {
self.backend.dec_connections();
let duration = self.start_time.elapsed();
println!(
"info: conn_id={} closed. client={} backend={} bytes={} duration={:.2?}",
self.id,
self.client_addr,
self.backend.address,
self.bytes_transferred,
duration.as_secs_f64()
);
}
}

View File

@@ -1,30 +0,0 @@
use crate::backend::Backend;
use crate::proxy::ConnectionContext;
use anywho::Error;
use std::sync::Arc;
use tokio::io;
use tokio::net::TcpStream;
pub async fn proxy_tcp_connection(
connection_id: u64,
mut client_stream: TcpStream,
backend: Arc<Backend>,
) -> Result<(), Error> {
let client_addr = client_stream.peer_addr()?;
let mut ctx = ConnectionContext::new(connection_id, client_addr, backend.clone());
#[cfg(debug_assertions)]
println!(
"info: conn_id={} connecting to {}",
connection_id, ctx.backend.id
);
let mut backend_stream = TcpStream::connect(&backend.address).await?;
let (tx, rx) = io::copy_bidirectional(&mut client_stream, &mut backend_stream).await?;
ctx.bytes_transferred = tx + rx;
Ok(())
}