7 Commits

Author SHA1 Message Date
Ning Qi (Paul) Sun
cd23bfdf5a Merge pull request #1 from psun256/merge
Merge & Refactor
2025-12-06 16:09:51 -05:00
4cdf2db0c9 feat: improved logging 2025-12-06 02:16:40 -05:00
606880f928 feat: merged repos 2025-12-06 01:31:33 -05:00
19cd5b7f2a feat: modularized proxy 2025-12-06 00:21:53 -05:00
Ning Qi (Paul) Sun
25c3eb9511 gh action
gh action
2025-12-03 22:07:40 -05:00
psun256
e27bd2aaf0 layer 4 load balancing (round robin, hardcoded backends) 2025-11-29 21:46:26 -05:00
Phong Nguyen
1235d3611d Update README with load balancing details
Added a note about load balancing algorithms from a referenced paper.
2025-12-03 12:47:46 -05:00
23 changed files with 209 additions and 2486 deletions

View File

Binary file not shown.

1003
Cargo.lock generated
View File

File diff suppressed because it is too large Load Diff

View File

@@ -6,12 +6,3 @@ edition = "2024"
[dependencies] [dependencies]
anywho = "0.1.2" anywho = "0.1.2"
tokio = { version = "1.48.0", features = ["full"] } tokio = { version = "1.48.0", features = ["full"] }
rand = "0.10.0-rc.5"
serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.145"
rperf3-rs = "0.3.9"
cidr = "0.3.1"
serde-saphyr = "0.0.10"
arc-swap = "1.7.1"
clap = { version = "4.5.53", features = ["derive"] }
notify = "8.2.0"

View File

@@ -33,7 +33,7 @@ RUN --mount=type=cache,target=/usr/local/cargo/registry \
# change to scratch and get comment the apk command for prod, i guess # change to scratch and get comment the apk command for prod, i guess
FROM alpine:latest AS runtime FROM alpine:latest AS runtime
RUN apk add --no-cache ca-certificates curl netcat-openbsd bind-tools strace # RUN apk add --no-cache ca-certificates curl netcat-openbsd bind-tools strace
WORKDIR /enginewhy WORKDIR /enginewhy
COPY --from=builder /enginewhy/target/x86_64-unknown-linux-musl/release/l4lb /usr/bin/l4lb COPY --from=builder /enginewhy/target/x86_64-unknown-linux-musl/release/l4lb /usr/bin/l4lb
ENTRYPOINT ["l4lb"] ENTRYPOINT ["l4lb"]

106
README.md Normal file
View File

@@ -0,0 +1,106 @@
# nginy
Production't graden't load balancer.
## Quick links
## Todo
- [ ] architecture astronauting
- [ ] stream / session handling (i think wrapper around tokio TcpStream)
- [ ] basic backend pooling
- [ ] layer 4 load balancing
- [ ] load balancing algorithm from the paper (https://www.wcse.org/WCSE_2018/W110.pdf)
## notes
tcp, for nginx (and haproxy, its similar):
```c
// nginx
struct ngx_connection_s {
void *data;
ngx_event_t *read;
ngx_event_t *write;
ngx_socket_t fd;
ngx_recv_pt recv; // fn pointer to whatever recv fn used (different for dfferent platforms / protocol
ngx_send_pt send; // ditto
ngx_recv_chain_pt recv_chain;
ngx_send_chain_pt send_chain;
ngx_listening_t *listening;
off_t sent;
ngx_log_t *log;
ngx_pool_t *pool;
int type;
struct sockaddr *sockaddr;
socklen_t socklen;
ngx_str_t addr_text;
ngx_proxy_protocol_t *proxy_protocol;
#if (NGX_QUIC || NGX_COMPAT)
ngx_quic_stream_t *quic;
#endif
#if (NGX_SSL || NGX_COMPAT)
ngx_ssl_connection_t *ssl;
#endif
ngx_udp_connection_t *udp; // additional stuff for UDP (which is technically connectionless, but they use timeouts and a rbtree to store "sessions")
struct sockaddr *local_sockaddr;
socklen_t local_socklen;
ngx_buf_t *buffer;
ngx_queue_t queue;
ngx_atomic_uint_t number;
ngx_msec_t start_time;
ngx_uint_t requests;
unsigned buffered:8;
unsigned log_error:3; /* ngx_connection_log_error_e */
unsigned timedout:1;
unsigned error:1;
unsigned destroyed:1;
unsigned pipeline:1;
unsigned idle:1;
unsigned reusable:1;
unsigned close:1;
unsigned shared:1;
unsigned sendfile:1;
unsigned sndlowat:1;
unsigned tcp_nodelay:2; /* ngx_connection_tcp_nodelay_e */
unsigned tcp_nopush:2; /* ngx_connection_tcp_nopush_e */
unsigned need_last_buf:1;
unsigned need_flush_buf:1;
#if (NGX_HAVE_SENDFILE_NODISKIO || NGX_COMPAT)
unsigned busy_count:2;
#endif
#if (NGX_THREADS || NGX_COMPAT)
ngx_thread_task_t *sendfile_task;
#endif
};
```
process to load balance:
- accept incoming connection
- create some kind of stream / session object
- nginx use this to abstract around tcp and udp layers
- for us we probably don't need as detailed as them, since we have tokio::net, so itll be a wrapper around TcpStream
- ask the load balancing algorithm which server in the pool to route to
- connect to the server
- proxy the data (copy_bidirectional? maybe we want some metrics or logging, so might do manually)
- cleanup when smoeone leavesr or something goes wrong (with TCP, OS / tokio will tell us, with UDP probably just timeout based, and a periodic sweep of all sessions)

View File

@@ -1,27 +0,0 @@
healthcheck_addr: "0.0.0.0:8080"
iperf_addr: "0.0.0.0:5001"
backends:
- id: "srv-1"
ip: "192.67.67.2:8080"
- id: "srv-2"
ip: "192.67.67.3:8080"
clusters:
main-api:
- "srv-1"
- "srv-2"
priority-api:
- "srv-1"
rules:
- clients:
- "172.67.67.2/24:80"
targets:
- "main-api"
- "priority-api"
strategy:
type: "Adaptive"
coefficients: [ 1.5, 1.0, 0.5, 0.1 ]
alpha: 0.75

View File

@@ -1,88 +0,0 @@
services:
load-balancer:
image: neoslhp/enginewhy-lb
container_name: load-balancer
tty: true
deploy:
resources:
limits:
cpus: "4.0"
memory: 8G
cap_add:
- NET_ADMIN
- SYS_ADMIN
networks:
internal:
ipv4_address: 172.67.67.67
external:
ipv4_address: 192.67.67.67
server1-high-cpu:
image: neoslhp/enginewhy-server
container_name: server1
tty: true
deploy:
resources:
limits:
cpus: "4.0"
memory: 8G
depends_on:
- load-balancer
cap_add:
- NET_ADMIN
networks:
external:
ipv4_address: 192.67.67.2
server2-low-cpu:
image: neoslhp/enginewhy-server
container_name: server2
tty: true
deploy:
resources:
limits:
cpus: "2.0"
memory: 4G
depends_on:
- load-balancer
cap_add:
- NET_ADMIN
networks:
external:
ipv4_address: 192.67.67.3
client:
image: neoslhp/enginewhy-ubuntu22.04
container_name: client
tty: true
deploy:
resources:
limits:
cpus: "4.0"
memory: 4G
depends_on:
- load-balancer
cap_add:
- NET_ADMIN
networks:
internal:
ipv4_address: 172.67.67.2
networks:
internal:
driver: bridge
ipam:
config:
- subnet: 172.67.67.0/24
external:
driver: bridge
ipam:
config:
- subnet: 192.67.67.0/24
# Resources:
# https://networkgeekstuff.com/networking/basic-load-balancer-scenarios-explained/
# https://hub.docker.com/r/linuxserver/wireshark
# https://www.wcse.org/WCSE_2018/W110.pdf
# Deepseek

View File

@@ -1,68 +0,0 @@
use rperf3::{Server, Config};
use std::net::{TcpListener, TcpStream};
use std::thread;
use std::io::{Read, Write};
use std::env;
use tokio::task;
async fn start_iperf_server() -> Result<(), Box<dyn std::error::Error>> {
let config = Config::server(5001);
let server = Server::new(config);
server.run().await?;
Ok(())
}
fn handle_connection(mut stream: TcpStream) -> std::io::Result<()> {
loop {
let mut buffer = [0u8; 512];
let bytes_read = stream.read(&mut buffer)?;
let received = String::from_utf8_lossy(&buffer[..bytes_read]);
println!("Received: {}", received);
}
Ok(())
}
fn start_tcp_server(addr: &str) -> std::io::Result<()> {
let listener = TcpListener::bind(addr)?;
println!("TCP server listening on {}", addr);
let mut handles = Vec::new();
for stream in listener.incoming() {
match stream {
Ok(stream) => {
let handle = thread::spawn(move || {
if let Err(e) = handle_connection(stream) {
eprintln!("connection handler error: {}", e);
}
});
handles.push(handle);
}
Err(e) => eprintln!("incoming connection failed: {}", e),
}
}
// When the incoming stream iterator ends (listener closed), join all handlers.
for h in handles {
let _ = h.join();
}
Ok(())
}
#[tokio::main]
async fn main() {
// Choose IP based on `--localhost` flag for debugging
let use_localhost = env::args().any(|a| a == "--localhost");
let ip = if use_localhost { "127.0.0.1" } else { "192.67.67.67" };
let tcp_addr = format!("{}:8080", ip);
let iperf_server = task::spawn(async {
start_iperf_server().await;
});
let tcp_ip = tcp_addr.clone();
let tcp_server = thread::spawn(move || {
start_tcp_server(&tcp_ip).unwrap();
});
iperf_server.await.unwrap();
tcp_server.join().unwrap();
}

View File

@@ -1,162 +0,0 @@
use sysinfo::{CpuRefreshKind, RefreshKind, System};
use sysinfo::{Networks};
use sysinfo::{Disks};
use std::thread;
use std::time::Duration;
use std::net::{TcpStream};
use std::env;
use std::collections::HashMap;
use std::io::Write;
use serde_json::Value;
use rperf3::{Client, Config, Protocol};
// Default server addresses
const DEFAULT_REMOTE_IP: &str = "192.67.67.67";
const DEFAULT_LOCAL_IP: &str = "127.0.0.1";
const PORT: u16 = 8080;
const IPERF_PORT: u16 = 5001;
fn get_io_usage_percentage() -> Result<f64, String> {
let mut sys = Disks::new_with_refreshed_list();
// Refresh disk information
sys.refresh(true);
// Get first disk (usually main disk)
if let Some(disk) = sys.list().first() {
let initial_read = disk.usage().total_read_bytes;
let initial_write = disk.usage().total_written_bytes;
thread::sleep(Duration::from_secs(1)); // 1s
sys.refresh(true);
let disk = sys.list().first().ok_or("Disk disappeared")?;
let new_read = disk.usage().total_read_bytes;
let new_write = disk.usage().total_written_bytes;
// Calculate Bps
let read_per_sec = (new_read - initial_read) as f64;
let write_per_sec = (new_write - initial_write) as f64;
// Get disk type to estimate max speed (these are rough estimates)
let max_speed = match disk.kind() {
sysinfo::DiskKind::SSD => 500_000_000.0, // 500 MBps
sysinfo::DiskKind::HDD => 200_000_000.0, // 200 MBps
_ => 300_000_000.0, // Default
};
let io_percentage = f64::min(100.0, ((read_per_sec + write_per_sec) / max_speed) * 100.0);
Ok(io_percentage)
} else {
Err("No disks found".to_string())
}
}
async fn measure_iperf_bandwidth(server_ip: &str, port: u16) -> Result<f64, Box<dyn std::error::Error>> {
// Configure the test (use the provided port)
let config = Config::client(server_ip.to_string(), port)
.with_duration(Duration::from_secs(10));
// Run the test
let client = Client::new(config)?;
client.run().await?;
// Get results
let measurements = client.get_measurements();
let bandwidth_bps = measurements.total_bits_per_second();
println!("iperf3 reported max bandwidth: {:.2} Mbps", bandwidth_bps / 1_000_000.0);
Ok(bandwidth_bps)
}
#[tokio::main]
async fn main() -> std::io::Result<()> {
// Determine server IP from CLI: `--localhost` -> local, otherwise remote
let args: Vec<String> = env::args().collect();
let server_ip = if args.iter().any(|a| a == "--localhost") {
DEFAULT_LOCAL_IP.to_string()
} else {
DEFAULT_REMOTE_IP.to_string()
};
let mut stream = TcpStream::connect(format!("{}:{}", server_ip, PORT))?;
println!("server connected to {}:{}", server_ip, PORT);
// Initialize the system struct
let mut sys = System::new_with_specifics(
RefreshKind::nothing().with_cpu(CpuRefreshKind::everything()),
);
let mut networks = Networks::new();
networks.refresh(true);
// Probe max bandwidth using iperf3
let mut max_bps: f64 = 0.0;
match measure_iperf_bandwidth(&server_ip, IPERF_PORT).await {
Ok(bps) => {
max_bps = bps;
println!("iperf3 reported max bandwidth: {:.2} bits/sec ({:.2} Mbps)", max_bps, max_bps / 1e6);
}
Err(e) => println!("iperf3 failed: {}", e),
}
// Wait a bit because CPU usage is based on diff.
std::thread::sleep(sysinfo::MINIMUM_CPU_UPDATE_INTERVAL);
loop {
sys.refresh_all();
sys.refresh_cpu_usage(); // Refreshing CPU usage.
let mut cpu_usage: f64 = 0.0;
for cpu in sys.cpus() {
cpu_usage += cpu.cpu_usage() as f64;
}
cpu_usage /= sys.cpus().len() as f64;
println!("CPU usage is {}%", cpu_usage);
// Memory usage
let total_mem = sys.total_memory();
let used_mem = sys.used_memory();
let mem_usage = total_mem as f64 / used_mem as f64;
println!("Memory usage is {}%", mem_usage);
// Network bandwidth usage
let mut bandwidth: f64 = 0.0; // Bps
for (interface_name, network) in &networks {
if interface_name == "wlp2s0" {
bandwidth = network.transmitted() as f64;
println!("[{interface_name}] transferred {:?} %", bandwidth / max_bps * 100.0);
}
}
networks.refresh(true);
// Calculate percent usage of measured max bandwidth (if available)
let net_usage_pct: f64 = if max_bps > 0.0 {
f64::min(100.0, (bandwidth / max_bps) * 100.0)
} else { 0.0 };
// IO usage
let mut io_usage = 0.0;
match get_io_usage_percentage() {
Ok(percentage) => {
io_usage = percentage;
println!("I/O usage is {}%", percentage)
},
Err(e) => println!("Error: {}", e)
}
println!();
let mut packet: HashMap<String, Value> = HashMap::new();
packet.insert("cpu".to_string(), Value::from(cpu_usage)); // %
packet.insert("mem".to_string(), Value::from(mem_usage)); // %
packet.insert("net".to_string(), Value::from(net_usage_pct));
packet.insert("io".to_string(), Value::from(io_usage));
let serialized_packet = serde_json::to_string(&packet)?;
serialized_packet.push('\n');
let _ = stream.write(serialized_packet.as_bytes());
thread::sleep(Duration::from_secs(10));
}
}

View File

View File

@@ -1,134 +0,0 @@
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, RwLock};
use serde_json::Value;
use rperf3::{Config, Server};
use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio::io::AsyncBufReadExt;
// Physical server health statistics, used for certain load balancing algorithms
#[derive(Debug, Default)]
pub struct ServerMetrics {
pub cpu: f64,
pub mem: f64,
pub net: f64,
pub io: f64,
}
impl ServerMetrics {
pub fn update(&mut self, cpu: f64, mem: f64, net: f64, io: f64) {
self.cpu = cpu;
self.mem = mem;
self.net = net;
self.io = io;
}
}
pub async fn start_healthcheck_listener(
addr: &str,
healths: HashMap<IpAddr, Arc<RwLock<ServerMetrics>>>,
) -> std::io::Result<()> {
let addrs = tokio::net::lookup_host(addr).await?;
let mut listener = None;
for a in addrs {
let socket = match a {
SocketAddr::V4(_) => TcpSocket::new_v4()?,
SocketAddr::V6(_) => TcpSocket::new_v6()?,
};
socket.set_reuseaddr(true)?;
if socket.bind(a).is_ok() {
listener = Some(socket.listen(1024)?);
break;
}
}
let listener = listener.ok_or_else(|| {
eprintln!("health listener could not bind to port");
std::io::Error::new(std::io::ErrorKind::Other, "health listener failed")
})?;
println!("healthcheck server listening on {}", addr);
loop {
let (stream, remote_addr) = match listener.accept().await {
Ok(v) => v,
Err(e) => {
continue;
}
};
if let Err(e) = handle_metrics_stream(stream, &healths).await {
eprintln!("connection handler error: {}", e);
}
}
}
pub async fn start_iperf_server(addr: &str) -> Result<(), Box<dyn std::error::Error>> {
let sock = addr.parse::<SocketAddr>()?;
let mut config = Config::server(sock.port());
config.bind_addr = Some(sock.ip());
let server = Server::new(config);
println!("iperf server listening on {}", addr);
server.run().await?;
Ok(())
}
async fn handle_metrics_stream(
stream: TcpStream,
healths: &HashMap<IpAddr, Arc<RwLock<ServerMetrics>>>,
) -> std::io::Result<()> {
let server_ip = stream.peer_addr()?.ip();
let mut reader = tokio::io::BufReader::new(stream);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line).await {
Ok(0) => break,
Ok(_) => {
if let Err(e) = process_metrics(server_ip, &line, healths) {
eprintln!("skipping invalid packet: {}", e);
}
}
Err(e) => {
eprintln!("connection error: {}", e);
break;
}
}
}
Ok(())
}
fn process_metrics(
server_ip: IpAddr,
json_str: &str,
healths: &HashMap<IpAddr, Arc<RwLock<ServerMetrics>>>,
) -> Result<(), String> {
let parsed: Value =
serde_json::from_str(json_str).map_err(|e| format!("parse error: {}", e))?;
let metrics_lock = healths
.get(&server_ip)
.ok_or_else(|| format!("unknown server: {}", server_ip))?;
let get_f64 = |key: &str| -> Result<f64, String> {
parsed
.get(key)
.and_then(|v| v.as_f64())
.ok_or_else(|| format!("invalid '{}'", key))
};
if let Ok(mut guard) = metrics_lock.write() {
guard.update(
get_f64("cpu")?,
get_f64("mem")?,
get_f64("net")?,
get_f64("io")?,
);
}
Ok(())
}

View File

@@ -1,76 +0,0 @@
pub mod health;
use crate::backend::health::ServerMetrics;
use core::fmt;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::RwLock;
use std::sync::atomic::{AtomicUsize, Ordering};
// A possible endpoint for a proxied connection.
// Note that multiple may live on the same server, hence the Arc<RwLock<ServerMetric>>
#[derive(Debug)]
pub struct Backend {
pub id: String,
pub address: SocketAddr,
pub active_connections: AtomicUsize,
pub metrics: Arc<RwLock<ServerMetrics>>,
}
impl Backend {
pub fn new(
id: String,
address: SocketAddr,
server_metrics: Arc<RwLock<ServerMetrics>>,
) -> Self {
Self {
id: id.to_string(),
address,
active_connections: AtomicUsize::new(0),
metrics: server_metrics,
}
}
// Ordering::Relaxed means the ops could be in any order, but since this
// is just a metric, and we assume the underlying system is sane
// enough not to behave poorly, so SeqCst is probably overkill.
pub fn inc_connections(&self) {
self.active_connections.fetch_add(1, Ordering::Relaxed);
println!(
"{} has {} connections open",
self.id,
self.active_connections.load(Ordering::Relaxed)
);
}
pub fn dec_connections(&self) {
self.active_connections.fetch_sub(1, Ordering::Relaxed);
println!(
"{} has {} connections open",
self.id,
self.active_connections.load(Ordering::Relaxed)
);
}
}
impl fmt::Display for Backend {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} ({})", self.address, self.id)
}
}
// A set of endpoints that can be load balanced around.
// Each Balancer owns one of these. Backend instances may be shared
// with other Balancer instances, hence Arc<Backend>.
#[derive(Clone, Debug)]
pub struct BackendPool {
pub backends: Arc<Vec<Arc<Backend>>>,
}
impl BackendPool {
pub fn new(backends: Vec<Arc<Backend>>) -> Self {
BackendPool {
backends: Arc::new(backends),
}
}
}

View File

@@ -1,267 +0,0 @@
use crate::backend::{Backend, BackendPool};
use crate::backend::health::ServerMetrics;
use crate::balancer::{Balancer, ConnectionInfo};
use rand::prelude::*;
use rand::rngs::SmallRng;
use std::fmt::Debug;
use std::fs::Metadata;
use std::sync::{Arc, RwLock};
#[derive(Debug)]
struct AdaptiveNode {
backend: Arc<Backend>,
weight: f64,
}
#[derive(Debug)]
pub struct AdaptiveWeightBalancer {
pool: Vec<AdaptiveNode>,
coefficients: [f64; 4],
alpha: f64,
rng: SmallRng,
}
impl AdaptiveWeightBalancer {
pub fn new(pool: BackendPool, coefficients: [f64; 4], alpha: f64) -> Self {
let nodes = pool
.backends
.iter()
.map(|b| AdaptiveNode {
backend: b.clone(),
weight: 1f64,
})
.collect();
AdaptiveWeightBalancer {
pool: nodes,
coefficients,
alpha,
rng: SmallRng::from_rng(&mut rand::rng()),
}
}
pub fn metrics_to_weight(&self, metrics: &ServerMetrics) -> f64 {
self.coefficients[0] * metrics.cpu
+ self.coefficients[1] * metrics.mem
+ self.coefficients[2] * metrics.net
+ self.coefficients[3] * metrics.io
}
}
impl Balancer for AdaptiveWeightBalancer {
fn choose_backend(&mut self, ctx: ConnectionInfo) -> Option<Arc<Backend>> {
if self.pool.is_empty() {
return None;
}
// Compute remaining capacity R_i = 100 - composite_load
let mut r_sum = 0.0;
let mut w_sum = 0.0;
let mut l_sum = 0;
for node in &self.pool {
if let Ok(health) = node.backend.metrics.read() {
r_sum += self.metrics_to_weight(&health);
}
w_sum += node.weight;
l_sum += node
.backend
.active_connections
.load(std::sync::atomic::Ordering::Relaxed);
}
let safe_w_sum = w_sum.max(1e-12);
let threshold = self.alpha * (r_sum / safe_w_sum);
for idx in 0..self.pool.len() {
let node = &self.pool[idx];
if node.weight <= 0.001 {
continue;
}
let risk = match node.backend.metrics.read() {
Ok(h) => self.metrics_to_weight(&h),
Err(_) => f64::MAX,
};
let ratio = risk / node.weight;
if ratio <= threshold {
return Some(node.backend.clone());
}
}
// If any server satisfies Ri/Wi <= threshold, it means the server
// is relatively overloaded, and we must adjust its weight using
// formula (6).
let mut total_lwi = 0.0;
let l_sum_f64 = l_sum as f64;
for node in &self.pool {
let load = node
.backend
.active_connections
.load(std::sync::atomic::Ordering::Relaxed) as f64;
let weight = node.weight.max(1e-12);
let lwi = load * (safe_w_sum / weight) * l_sum_f64;
total_lwi += lwi;
}
let avg_lwi = (total_lwi / self.pool.len() as f64).max(1e-12);
// Compute Li = Wi / Ri and choose server minimizing Li.
let mut best_backend: Option<Arc<Backend>> = None;
let mut min_load = usize::MAX;
for node in &mut self.pool {
let load = node
.backend
.active_connections
.load(std::sync::atomic::Ordering::Relaxed);
let load_f64 = load as f64;
let weight = node.weight.max(1e-12);
let lwi = load_f64 * (safe_w_sum / weight) * l_sum_f64;
let adj = 1.0 - (lwi / avg_lwi);
node.weight += adj;
node.weight = node.weight.clamp(0.1, 100.0);
if load < min_load {
min_load = load;
best_backend = Some(node.backend.clone());
}
}
match best_backend {
Some(backend) => Some(backend),
None => {
let i = (self.rng.next_u32() as usize) % self.pool.len();
Some(self.pool[i].backend.clone())
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::Backend;
use std::net::SocketAddr;
fn backend_factory(id: &str, ip: &str, port: u16) -> Arc<Backend> {
Arc::new(Backend::new(
id.to_string(),
SocketAddr::new(ip.parse().unwrap(), port),
Arc::new(RwLock::new(ServerMetrics::default())),
))
}
fn unused_ctx() -> ConnectionInfo {
ConnectionInfo {
client_ip: ("0.0.0.0".parse().unwrap()),
}
}
#[test]
fn basic_weight_update_and_choose() {
let backends = BackendPool::new(vec![
backend_factory("server-0", "127.0.0.1", 3000),
backend_factory("server-1", "127.0.0.1", 3001),
]);
let mut b = AdaptiveWeightBalancer::new(backends.clone(), [0.5, 0.2, 0.2, 0.1], 0.5);
// initially equal weights
// update one backend to be heavily loaded
{
let mut sm0_guard = backends.backends.get(0).unwrap().metrics.write().unwrap();
sm0_guard.update(90.0, 80.0, 10.0, 5.0);
}
{
let mut sm1_guard = backends.backends.get(1).unwrap().metrics.write().unwrap();
sm1_guard.update(10.0, 5.0, 1.0, 1.0);
}
// Choose backend: should pick the less loaded host server1
let chosen = b
.choose_backend(unused_ctx())
.expect("should choose a backend");
let sm0: &ServerMetrics = &backends.backends.get(0).unwrap().metrics.read().unwrap();
let sm1: &ServerMetrics = &backends.backends.get(1).unwrap().metrics.read().unwrap();
println!("{:?}, {:?}", sm0, sm1);
assert_eq!(chosen.id, "server-1");
}
#[test]
fn choose_none_when_empty() {
let mut b =
AdaptiveWeightBalancer::new(BackendPool::new(vec![]), [0.5, 0.2, 0.2, 0.1], 0.5);
assert!(b.choose_backend(unused_ctx()).is_none());
}
#[test]
fn ratio_triggers_immediate_selection() {
// Arrange two servers where server 1 has composite load 0 and server 2 has composite load 100.
// With alpha = 1.0 and two servers, threshold = 1.0 * (r_sum / w_sum) = 1.0 * (100 / 2) = 50.
// Server 0 ratio = 0 / 1 = 0 <= 50 so it should be chosen immediately.
let backends = BackendPool::new(vec![
backend_factory("server-0", "127.0.0.1", 3000),
backend_factory("server-1", "127.0.0.1", 3001),
]);
let mut b = AdaptiveWeightBalancer::new(backends.clone(), [0.25, 0.25, 0.25, 0.25], 1.0);
{
let mut sm0_guard = backends.backends.get(0).unwrap().metrics.write().unwrap();
sm0_guard.update(0.0, 0.0, 0.0, 0.0);
}
{
let mut sm1_guard = backends.backends.get(1).unwrap().metrics.write().unwrap();
sm1_guard.update(100.0, 100.0, 100.0, 100.0);
}
let chosen = b
.choose_backend(unused_ctx())
.expect("should choose a backend");
assert_eq!(chosen.id, "server-0");
}
#[test]
fn choose_min_current_load_when_no_ratio() {
// Arrange three servers with identical composite loads so no server satisfies Ri/Wi <= threshold
// (set alpha < 1 so threshold < ratio). The implementation then falls back to picking the
// server with minimum current_load
let backends = BackendPool::new(vec![
backend_factory("server-0", "127.0.0.1", 3000),
backend_factory("server-1", "127.0.0.1", 3001),
backend_factory("server-2", "127.0.0.1", 3002),
]);
// set current_loads (field expected to be public)
{
let mut sm0_guard = backends.backends.get(0).unwrap().metrics.write().unwrap();
sm0_guard.update(10.0, 10.0, 10.0, 10.0);
}
{
let mut sm1_guard = backends.backends.get(1).unwrap().metrics.write().unwrap();
sm1_guard.update(5.0, 5.0, 5.0, 5.0);
}
{
let mut sm2_guard = backends.backends.get(2).unwrap().metrics.write().unwrap();
sm2_guard.update(20.0, 20.0, 20.0, 20.0);
}
// Use coeffs that only consider CPU so composite load is easy to reason about.
let mut bal = AdaptiveWeightBalancer::new(backends.clone(), [1.0, 0.0, 0.0, 0.0], 0.5);
// set identical composite loads > 0 for all so ratio = x and threshold = alpha * x < x
// you will have threshold = 25 for all 3 backend servers and ratio = 50
// so that forces to choose the smallest current load backend
let chosen = bal
.choose_backend(unused_ctx())
.expect("should choose a backend");
// expect server with smallest current_load server-1
assert_eq!(chosen.id, "server-1");
}
}

View File

@@ -1,109 +0,0 @@
use crate::backend::{Backend, BackendPool};
use crate::balancer::{Balancer, ConnectionInfo};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
#[derive(Debug)]
pub struct SourceIPHash {
pool: BackendPool,
}
impl SourceIPHash {
pub fn new(pool: BackendPool) -> SourceIPHash {
Self { pool }
}
}
impl Balancer for SourceIPHash {
fn choose_backend(&mut self, ctx: ConnectionInfo) -> Option<Arc<Backend>> {
let client_ip = ctx.client_ip;
let mut hasher = DefaultHasher::new();
client_ip.hash(&mut hasher);
let hash = hasher.finish();
let idx = (hash as usize) % self.pool.backends.len();
Some(self.pool.backends[idx].clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::health::ServerMetrics;
use std::net::IpAddr;
use std::sync::RwLock;
fn create_dummy_backends(count: usize) -> BackendPool {
let mut backends = Vec::new();
for i in 1..=count {
backends.push(Arc::new(Backend::new(
format!("backend {}", i),
format!("127.0.0.1:808{}", i).parse().unwrap(),
Arc::new(RwLock::new(ServerMetrics::default())),
)));
}
BackendPool::new(backends)
}
#[test]
fn test_same_ip_always_selects_same_backend() {
let backends = create_dummy_backends(3);
let mut balancer = SourceIPHash::new(backends);
let client_ip: IpAddr = "192.168.1.100".parse().unwrap();
let first_choice = balancer.choose_backend(ConnectionInfo { client_ip });
let second_choice = balancer.choose_backend(ConnectionInfo { client_ip });
assert!(first_choice.is_some());
assert!(second_choice.is_some());
let first = first_choice.unwrap();
let second = second_choice.unwrap();
assert_eq!(first.id, second.id);
}
#[test]
fn test_different_ips_may_select_different_backends() {
let backends = create_dummy_backends(2);
let mut balancer = SourceIPHash::new(backends);
let ip1: IpAddr = "192.168.1.100".parse().unwrap();
let choice1 = balancer.choose_backend(ConnectionInfo { client_ip: ip1 });
let ip2: IpAddr = "192.168.1.101".parse().unwrap();
let choice2 = balancer.choose_backend(ConnectionInfo { client_ip: ip2 });
assert!(choice1.is_some());
assert!(choice2.is_some());
}
#[test]
fn test_hash_distribution_across_backends() {
let pool = create_dummy_backends(3);
let backends_ref = pool.backends.clone();
let mut balancer = SourceIPHash::new(pool);
let mut distribution = [0, 0, 0];
// Test 30 different IPs
for i in 0..30 {
let client_ip: IpAddr = format!("192.168.1.{}", 100 + i).parse().unwrap();
if let Some(backend) = balancer.choose_backend(ConnectionInfo { client_ip }) {
for (idx, b) in backends_ref.iter().enumerate() {
if backend.id == b.id && backend.address == b.address {
distribution[idx] += 1;
break;
}
}
}
}
assert!(distribution[0] > 0, "Backend 0 received no traffic");
assert!(distribution[1] > 0, "Backend 1 received no traffic");
assert!(distribution[2] > 0, "Backend 2 received no traffic");
}
}

View File

@@ -1 +0,0 @@
// use super::*;

View File

@@ -1,18 +0,0 @@
pub mod adaptive_weight;
pub mod ip_hashing;
pub mod least_connections;
pub mod round_robin;
use crate::backend::Backend;
use std::fmt::Debug;
use std::net::IpAddr;
use std::sync::Arc;
#[derive(Clone, Debug)]
pub struct ConnectionInfo {
pub client_ip: IpAddr,
}
pub trait Balancer: Debug + Send + Sync + 'static {
fn choose_backend(&mut self, ctx: ConnectionInfo) -> Option<Arc<Backend>>;
}

View File

@@ -1,32 +0,0 @@
use crate::backend::{Backend, BackendPool};
use crate::balancer::{Balancer, ConnectionInfo};
use std::fmt::Debug;
use std::sync::Arc;
// only the main thread for receiving connections should be
// doing the load balancing. alternatively, each thread
// that handles load balancing should get their own instance.
#[derive(Debug)]
pub struct RoundRobinBalancer {
pool: BackendPool,
index: usize,
}
impl RoundRobinBalancer {
pub fn new(pool: BackendPool) -> RoundRobinBalancer {
Self { pool, index: 0 }
}
}
impl Balancer for RoundRobinBalancer {
fn choose_backend(&mut self, ctx: ConnectionInfo) -> Option<Arc<Backend>> {
let backends = self.pool.backends.clone();
if backends.is_empty() {
return None;
}
let backend = backends[self.index % backends.len()].clone();
self.index = self.index.wrapping_add(1);
Some(backend)
}
}

View File

@@ -1,128 +0,0 @@
use cidr::IpCidr;
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, RwLock};
use crate::backend::health::*;
use crate::backend::*;
use crate::balancer::Balancer;
use crate::balancer::adaptive_weight::AdaptiveWeightBalancer;
use crate::balancer::round_robin::RoundRobinBalancer;
use crate::balancer::ip_hashing::SourceIPHash;
use crate::config::*;
pub struct RoutingTable {
pub balancers: Vec<Box<dyn Balancer + Send>>,
pub entries: Vec<(IpCidr, usize)>,
}
pub type PortListeners = HashMap<u16, RoutingTable>;
fn parse_client(s: &str) -> Result<(IpCidr, u16), String> {
// just splits "0.0.0.0/0:80" into ("0.0.0.0/0", 80)
let (ip_part, port_part) = s
.rsplit_once(':')
.ok_or_else(|| format!("badly formatted client: {}", s))?;
let port = port_part.parse().map_err(|_| format!("bad port: {}", s))?;
let cidr = ip_part.parse().map_err(|_| format!("bad ip/mask: {}", s))?;
Ok((cidr, port))
}
pub fn build_lb(
config: &AppConfig,
) -> Result<(PortListeners, HashMap<IpAddr, Arc<RwLock<ServerMetrics>>>), String> {
let mut healths: HashMap<IpAddr, Arc<RwLock<ServerMetrics>>> = HashMap::new();
let mut backends: HashMap<String, Arc<Backend>> = HashMap::new();
for backend_cfg in &config.backends {
let addr: SocketAddr = backend_cfg.ip.parse()
.map_err(|_| format!("bad ip: {}", backend_cfg.ip))?;
let ip = addr.ip();
let health = healths
.entry(ip)
.or_insert_with(|| Arc::new(RwLock::new(ServerMetrics::default())))
.clone();
let backend = Arc::new(Backend::new(backend_cfg.id.clone(), addr, health));
backends.insert(backend_cfg.id.clone(), backend);
}
let mut listeners: PortListeners = HashMap::new();
for rule in &config.rules {
let mut target_backends = Vec::new();
for target_name in &rule.targets {
if let Some(members) = config.clusters.get(target_name) {
for member_id in members {
if let Some(backend) = backends.get(member_id) {
target_backends.push(backend.clone());
}
}
} else if let Some(backend) = backends.get(target_name) {
target_backends.push(backend.clone());
} else {
eprintln!("warning: target {} not found", target_name);
}
}
// possible for multiple targets of the same rule to have common backends.
target_backends.sort_by(|a, b| a.id.cmp(&b.id));
target_backends.dedup_by(|a, b| a.id == b.id);
if target_backends.is_empty() {
eprintln!("warning: rule has no valid targets, skipping.");
continue;
}
// for each different client port on this rule, we unfortunately need to make a new
// Balancer, since Balancer is not thread safe, requires &mut self for the backend
// selection.
// a good enough compromise to make a new one for each port, avoids using Mutex, at the
// cost of minor penalty to load balancing "quality" when you have several client ports.
let mut port_groups: HashMap<u16, Vec<IpCidr>> = HashMap::new();
for client_def in &rule.clients {
let (cidr, port) = parse_client(&client_def)?;
port_groups.entry(port).or_default().push(cidr);
}
for (port, cidrs) in port_groups {
let table = listeners.entry(port).or_insert_with(|| RoutingTable {
balancers: Vec::new(),
entries: Vec::new(),
});
let pool = BackendPool::new(target_backends.clone());
let balancer: Box<dyn Balancer + Send> = match &rule.strategy {
LoadBalancerStrategy::RoundRobin => Box::new(RoundRobinBalancer::new(pool)),
LoadBalancerStrategy::SourceIPHash => Box::new(SourceIPHash::new(pool)),
LoadBalancerStrategy::Adaptive {
coefficients,
alpha,
} => Box::new(AdaptiveWeightBalancer::new(pool, *coefficients, *alpha)),
};
let balancer_idx = table.balancers.len();
table.balancers.push(balancer);
for cidr in cidrs {
table.entries.push((cidr, balancer_idx));
}
}
}
// sort to make most specific first, so that first match == longest prefix match
for table in listeners.values_mut() {
table
.entries
.sort_by(|(a, _), (b, _)| b.network_length().cmp(&a.network_length()));
}
Ok((listeners, healths))
}

View File

@@ -1,62 +0,0 @@
// config is written as a YAML file, the path will be passed to the program.
//
// the high level structure of the config is that we
// first define the individual backends (ip + port) we are going
// to load balance around.
//
// next we define some clusters, which are really more like a short
// alias for a group of backends.
//
// next we define the rules. these are written as a list of
// "ip/subnet:port" for the clients, and then a list of clusters
// for which backends these are balanced around. and of course
// specify which algorithm to use.
pub mod loader;
use serde::Deserialize;
use std::collections::HashMap;
fn default_healthcheck_addr() -> String {
"0.0.0.0:8080".to_string()
}
fn default_iperf_addr() -> String {
"0.0.0.0:5201".to_string()
}
#[derive(Debug, Deserialize)]
pub struct AppConfig {
#[serde(default = "default_healthcheck_addr")]
pub healthcheck_addr: String,
#[serde(default = "default_iperf_addr")]
pub iperf_addr: String,
pub backends: Vec<BackendConfig>,
#[serde(default)]
pub clusters: HashMap<String, Vec<String>>,
pub rules: Vec<RuleConfig>,
}
#[derive(Debug, Deserialize)]
pub struct BackendConfig {
pub id: String,
pub ip: String,
}
#[derive(Debug, Deserialize)]
pub struct RuleConfig {
pub clients: Vec<String>,
pub targets: Vec<String>,
pub strategy: LoadBalancerStrategy,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
pub enum LoadBalancerStrategy {
RoundRobin,
SourceIPHash,
Adaptive { coefficients: [f64; 4], alpha: f64 },
}

View File

@@ -1,252 +1,55 @@
mod backend; macro_rules! info {
mod balancer; ($($arg:tt)*) => {{
mod config; print!("info: ");
mod proxy; println!($($arg)*);
}};
}
macro_rules! error {
($($arg:tt)*) => {
eprint!("error: ");
eprintln!($($arg)*);
};
}
mod netutils;
use crate::backend::health::{start_healthcheck_listener, start_iperf_server, ServerMetrics};
use crate::balancer::ConnectionInfo;
use crate::config::loader::{build_lb, RoutingTable};
use crate::proxy::tcp::proxy_tcp_connection;
use anywho::Error; use anywho::Error;
use std::collections::HashMap; use netutils::{Backend, tunnel};
use std::fs::File; use std::sync::Arc;
use std::hash::Hash;
use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::time::Duration;
use tokio::io::AsyncBufReadExt;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::sync::mpsc; use tokio::sync::Mutex;
use clap::Parser;
use notify::{Event, RecursiveMode, Watcher};
use std::cmp;
static NEXT_CONN_ID: AtomicU64 = AtomicU64::new(1);
struct ProgramState {
tx_rt_map: HashMap<u16, mpsc::UnboundedSender<RoutingTable>>,
healths: HashMap<IpAddr, Arc<RwLock<ServerMetrics>>>,
health_listener: Option<tokio::task::JoinHandle<()>>,
iperf_server: Option<tokio::task::JoinHandle<()>>,
health_listener_addr: Option<String>,
iperf_server_addr: Option<String>,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(short, long, default_value = "config.yaml")]
config: PathBuf,
}
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Error> {
let args = Args::parse(); let backends = Arc::new(vec![
Backend::new("127.0.0.1:8081".to_string()),
Backend::new("127.0.0.1:8082".to_string()),
]);
if !args.config.is_file() { let current_index = Arc::new(Mutex::new(0));
eprintln!("config file not found or not accessible");
std::process::exit(1);
}
println!("reading config from {:?}", args.config); info!("enginewhy starting on 0.0.0.0:8080");
info!("backends: {:?}", backends);
let state = Arc::new(Mutex::new(ProgramState { let listener = TcpListener::bind("0.0.0.0:8080").await?;
tx_rt_map: HashMap::new(),
healths: HashMap::new(),
health_listener: None,
iperf_server: None,
health_listener_addr: None,
iperf_server_addr: None,
}));
if let Err(e) = load_config(&args.config, state.clone()).await {
eprintln!("config file loading failed: {}", e);
}
let config_path = args.config.clone();
let state_clone = state.clone();
tokio::spawn(async move {
let (tx, mut rx) = mpsc::channel(1);
let mut watcher = notify::recommended_watcher(move |res: Result<Event, notify::Error>| {
if let Ok(event) = res {
if event.kind.is_modify() {
let _ = tx.blocking_send(());
}
}
})
.unwrap();
watcher
.watch(&config_path, RecursiveMode::NonRecursive)
.unwrap();
println!("watching for changes to {:?}", config_path);
while rx.recv().await.is_some() {
// for some reason, saving on certain text editors fires several events,
// and this causes us to reload a lot. try to flush some events, add a tiny delay
// to mitigate this
while rx.try_recv().is_ok() {}
tokio::time::sleep(Duration::from_millis(50)).await;
while rx.try_recv().is_ok() {}
if let Err(e) = load_config(&config_path, state_clone.clone()).await {
eprintln!("loading config failed: {}", e);
}
}
});
loop { loop {
tokio::time::sleep(Duration::from_hours(1)).await; let (client, addr) = listener.accept().await?;
} info!("new connection from {}", addr);
}
async fn load_config(path: &PathBuf, state: Arc<Mutex<ProgramState>>) -> Result<(), Error> { let backend = {
let f = File::open(path)?; let mut index = current_index.lock().await;
let app_config: config::AppConfig = match serde_saphyr::from_reader(f) { let selected_backend = backends[*index].clone();
Ok(app_config) => app_config, *index = (*index + 1) % backends.len();
Err(e) => { selected_backend
eprintln!("error parsing config {}", e); };
return Ok(());
}
};
println!( info!("routing client {} to backend {}", addr, backend);
"Loaded config, with {} backends, {} rules.",
app_config.backends.len(),
app_config.rules.len()
);
let (mut listeners, health_monitors) = match build_lb(&app_config) { if let Err(e) = tunnel(client, backend).await {
Ok(v) => v, error!("proxy failed for {}: {}", addr, e);
Err(e) => {
eprintln!("config has logical errors: {}", e);
return Ok(());
}
};
let mut prog_state = state.lock().unwrap();
let ports_to_remove: Vec<u16> = prog_state
.tx_rt_map
.keys()
.cloned()
.filter(|port| !listeners.contains_key(port))
.collect();
for port in ports_to_remove {
prog_state.tx_rt_map.remove(&port);
}
if let Some(handle) = prog_state.health_listener.take() {
handle.abort();
}
let health_map: HashMap<IpAddr, Arc<RwLock<ServerMetrics>>> = health_monitors.clone();
let health_addr = app_config.healthcheck_addr.clone();
let health_addr_c = health_addr.clone();
let health_handle = tokio::spawn(async move {
if let Err(e) = start_healthcheck_listener(&health_addr, health_map).await {
eprintln!("health check listener failed: {}", e);
}
});
prog_state.health_listener = Some(health_handle);
prog_state.health_listener_addr = Some(health_addr_c);
// maybe restart iperf server
let iperf_addr = app_config.iperf_addr.clone();
if prog_state.iperf_server_addr.as_ref() != Some(&iperf_addr) {
if let Some(handle) = prog_state.iperf_server.take() {
handle.abort();
}
let iperf_addr_c = iperf_addr.clone();
let iperf_handle = tokio::spawn(async move {
if let Err(e) = start_iperf_server(iperf_addr.as_str()).await {
eprintln!("iperf server failed: {}", e);
}
});
prog_state.iperf_server = Some(iperf_handle);
prog_state.iperf_server_addr = Some(iperf_addr_c);
}
prog_state.healths = health_monitors;
for (port, routing_table) in listeners.drain() {
if let Some(x) = prog_state.tx_rt_map.get_mut(&port) {
x.send(routing_table)?;
println!("updated rules on port {}", port);
} else {
let (tx_rt, rx_rt) = mpsc::unbounded_channel();
prog_state.tx_rt_map.insert(port, tx_rt);
tokio::spawn(run_listener(port, rx_rt, routing_table));
}
}
println!("reload complete");
Ok(())
}
async fn run_listener(
port: u16,
mut rx_rt: mpsc::UnboundedReceiver<RoutingTable>,
mut current_table: RoutingTable,
) {
let addr = format!("0.0.0.0:{}", port);
println!("Starting tcp listener on {}", addr);
let listener = TcpListener::bind(&addr).await.expect("Failed to bind port");
loop {
tokio::select! {
msg = rx_rt.recv() => {
match msg {
Some(new_table) => {
current_table = new_table;
}
None => {
println!("Unbinding listener on port {}", port);
break;
}
}
}
accept_result = listener.accept() => {
match accept_result {
Ok((socket, remote_addr)) => {
let remote_ip = remote_addr.ip();
let conn_id = NEXT_CONN_ID.fetch_add(1, Ordering::Relaxed);
let mut chosen_backend = None;
for (cidr, balancer_idx) in &mut current_table.entries {
if cidr.contains(&remote_ip) {
let balancer = &mut current_table.balancers[*balancer_idx];
chosen_backend = balancer.choose_backend(ConnectionInfo {
client_ip: remote_ip,
});
break;
}
}
if let Some(backend) = chosen_backend {
tokio::spawn(async move {
if let Err(e) = proxy_tcp_connection(conn_id, socket, backend).await {
eprintln!("error: conn_id={} proxy failed: {}", conn_id, e);
}
});
} else {
println!("error: no matching rule for {} on port {}", remote_ip, port);
}
}
Err(e) => {
eprintln!("error: listener port {}: {}", port, e);
continue;
}
}
}
} }
} }
} }

52
src/netutils.rs Normal file
View File

@@ -0,0 +1,52 @@
use std::fmt;
use tokio::io;
use tokio::net::TcpStream;
use std::error::Error;
#[derive(Clone, Debug)]
pub struct Backend {
address: String,
}
impl Backend {
pub fn new(address: String) -> Self {
Backend { address }
}
}
impl fmt::Display for Backend {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.address)
}
}
pub async fn tunnel(client_stream: TcpStream, backend: Backend) -> Result<(), Box<dyn Error>> {
let backend_address: String = backend.address.clone();
tokio::spawn(async move {
let backend_stream: TcpStream = match TcpStream::connect(&backend_address).await {
Ok(s) => {
info!("connected to backend {backend_address}");
s
}
Err(e) => {
error!("failed connecting to backend {backend_address}: {e}");
return;
}
};
let (mut read_client, mut write_client) = client_stream.into_split();
let (mut read_backend, mut write_backend) = backend_stream.into_split();
let client_to_backend =
tokio::spawn(async move { io::copy(&mut read_client, &mut write_backend).await });
let backend_to_client =
tokio::spawn(async move { io::copy(&mut read_backend, &mut write_client).await });
let _ = tokio::join!(client_to_backend, backend_to_client);
});
Ok(())
}

View File

@@ -1,44 +0,0 @@
use crate::backend::Backend;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Instant;
pub mod tcp;
pub struct ConnectionContext {
pub id: u64,
pub client_addr: SocketAddr,
pub start_time: Instant,
pub backend: Arc<Backend>,
pub bytes_transferred: u64,
}
impl ConnectionContext {
pub fn new(id: u64, client_addr: SocketAddr, backend: Arc<Backend>) -> Self {
backend.inc_connections();
Self {
id,
client_addr,
start_time: Instant::now(),
backend,
bytes_transferred: 0,
}
}
}
impl Drop for ConnectionContext {
fn drop(&mut self) {
self.backend.dec_connections();
let duration = self.start_time.elapsed();
println!(
"info: conn_id={} closed. client={} backend={} bytes={} duration={:.2?}",
self.id,
self.client_addr,
self.backend.address,
self.bytes_transferred,
duration.as_secs_f64()
);
}
}

View File

@@ -1,30 +0,0 @@
use crate::backend::Backend;
use crate::proxy::ConnectionContext;
use anywho::Error;
use std::sync::Arc;
use tokio::io;
use tokio::net::TcpStream;
pub async fn proxy_tcp_connection(
connection_id: u64,
mut client_stream: TcpStream,
backend: Arc<Backend>,
) -> Result<(), Error> {
let client_addr = client_stream.peer_addr()?;
let mut ctx = ConnectionContext::new(connection_id, client_addr, backend.clone());
#[cfg(debug_assertions)]
println!(
"info: conn_id={} connecting to {}",
connection_id, ctx.backend.id
);
let mut backend_stream = TcpStream::connect(&backend.address).await?;
let (tx, rx) = io::copy_bidirectional(&mut client_stream, &mut backend_stream).await?;
ctx.bytes_transferred = tx + rx;
Ok(())
}