12 Commits

Author SHA1 Message Date
nnhphong
6d2b8115f3 add some more tests for lb 2025-12-09 22:01:16 -05:00
nnhphong
08cb522f93 the algorithm is working, but will need more test 2025-12-07 23:04:29 -05:00
nnhphong
742827b16f prune some comment 2025-12-07 21:59:43 -05:00
nnhphong
e19efee895 part of the algorithm, waiting for paul s and jeremy to complete refactoring 2025-12-07 21:56:27 -05:00
nnhphong
393c35bdf8 code for docker infra image 2025-12-07 14:09:38 -05:00
Ning Qi (Paul) Sun
cd23bfdf5a Merge pull request #1 from psun256/merge
Merge & Refactor
2025-12-06 16:09:51 -05:00
4cdf2db0c9 feat: improved logging 2025-12-06 02:16:40 -05:00
606880f928 feat: merged repos 2025-12-06 01:31:33 -05:00
19cd5b7f2a feat: modularized proxy 2025-12-06 00:21:53 -05:00
Ning Qi (Paul) Sun
25c3eb9511 gh action
gh action
2025-12-03 22:07:40 -05:00
psun256
e27bd2aaf0 layer 4 load balancing (round robin, hardcoded backends) 2025-11-29 21:46:26 -05:00
Phong Nguyen
1235d3611d Update README with load balancing details
Added a note about load balancing algorithms from a referenced paper.
2025-12-03 12:47:46 -05:00
17 changed files with 650 additions and 235 deletions

22
.github/workflows/rust.yml vendored Normal file
View File

@@ -0,0 +1,22 @@
name: Rust
on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
env:
CARGO_TERM_COLOR: always
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Build
run: cargo build --verbose
- name: Run tests
run: cargo test --verbose

71
Cargo.lock generated
View File

@@ -26,11 +26,23 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
[[package]]
name = "getrandom"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592"
dependencies = [
"cfg-if",
"libc",
"wasi",
]
[[package]] [[package]]
name = "l4lb" name = "l4lb"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anywho", "anywho",
"rand",
"tokio", "tokio",
] ]
@@ -89,6 +101,15 @@ version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
[[package]]
name = "ppv-lite86"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9"
dependencies = [
"zerocopy",
]
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.103" version = "1.0.103"
@@ -107,6 +128,36 @@ dependencies = [
"proc-macro2", "proc-macro2",
] ]
[[package]]
name = "rand"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
]
[[package]]
name = "rand_chacha"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
]
[[package]]
name = "rand_core"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
"getrandom",
]
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.18" version = "0.5.18"
@@ -286,3 +337,23 @@ name = "windows_x86_64_msvc"
version = "0.53.1" version = "0.53.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650"
[[package]]
name = "zerocopy"
version = "0.8.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fd74ec98b9250adb3ca554bdde269adf631549f51d8a8f8f0a10b50f1cb298c3"
dependencies = [
"zerocopy-derive",
]
[[package]]
name = "zerocopy-derive"
version = "0.8.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8a8d209fdf45cf5138cbb5a506f6b52522a25afccc534d1475dad8e31105c6a"
dependencies = [
"proc-macro2",
"quote",
"syn",
]

View File

@@ -6,3 +6,4 @@ edition = "2024"
[dependencies] [dependencies]
anywho = "0.1.2" anywho = "0.1.2"
tokio = { version = "1.48.0", features = ["full"] } tokio = { version = "1.48.0", features = ["full"] }
rand = { version = "0.8", features = ["small_rng"] }

View File

@@ -5,25 +5,10 @@ Production't graden't load balancer.
## Todo ## Todo
- [ ] architecture astronauting - [ ] architecture astronauting
- balancer module
- just the algorithms i guess
-
- backend module
- manages the backend pool
- deals with health / load check
- BackendPool for all the backends stored together
- Backend for individual backends
- has some methods used by balancer module to pick a suitable backend
- proxy module
- all the different supported protocols to handle
- will create a session / stream context structure (ConnectionContext)
- not globally tracked (this might change for UDP!)
- mainly some metadata
- config module
- set up all the stuff or something
- [ ] stream / session handling (i think wrapper around tokio TcpStream) - [ ] stream / session handling (i think wrapper around tokio TcpStream)
- [ ] basic backend pooling - [ ] basic backend pooling
- [ ] layer 4 load balancing - [ ] layer 4 load balancing
- [ ] load balancing algorithm from the paper (https://www.wcse.org/WCSE_2018/W110.pdf)
## notes ## notes
tcp, for nginx (and haproxy, its similar): tcp, for nginx (and haproxy, its similar):
@@ -36,7 +21,7 @@ struct ngx_connection_s {
ngx_socket_t fd; ngx_socket_t fd;
ngx_recv_pt recv; // fn pointer to whatever recv fn used (different for idfferent platforms / protocol ngx_recv_pt recv; // fn pointer to whatever recv fn used (different for dfferent platforms / protocol
ngx_send_pt send; // ditto ngx_send_pt send; // ditto
ngx_recv_chain_pt recv_chain; ngx_recv_chain_pt recv_chain;
ngx_send_chain_pt send_chain; ngx_send_chain_pt send_chain;
@@ -120,9 +105,11 @@ process to load balance:
- proxy the data (copy_bidirectional? maybe we want some metrics or logging, so might do manually) - proxy the data (copy_bidirectional? maybe we want some metrics or logging, so might do manually)
- cleanup when smoeone leavesr or something goes wrong (with TCP, OS / tokio will tell us, with UDP probably just timeout based, and a periodic sweep of all sessions) - cleanup when smoeone leavesr or something goes wrong (with TCP, OS / tokio will tell us, with UDP probably just timeout based, and a periodic sweep of all sessions)
## Load balancer algorithm
### UDP - Choose a fixed weight coefficient for the resource parameter
UDP is connectionless, and i don't think UdpSocket or UdpFramed implement the traits required for tokio copy_bidirectional - Spawn a thread on a load balancer to host the iperf server, used for new onboarding server connecting to the load balancer to measure their maximum bandwidth
but async write and read don't work on just regular datagrams, so probably not possible. - Spawn another thread for listening to resource update from connected server
- Update the comprehensive load sum from eq (1), update the formula in eq (2) to (5)
Would require us to implement our own bidirectional copying / proxying, as well as tracking "active" connections. - Choose alpha for eq (8), and run the algorithm to choose which server
- Extract the server from the server id using ```get_backend()```
- Use ```tunnel()``` to proxy the packet

BIN
W110.pdf Normal file
View File

Binary file not shown.

68
infra/enginewhy-lb.rs Normal file
View File

@@ -0,0 +1,68 @@
use rperf3::{Server, Config};
use std::net::{TcpListener, TcpStream};
use std::thread;
use std::io::{Read, Write};
use std::env;
use tokio::task;
async fn start_iperf_server() -> Result<(), Box<dyn std::error::Error>> {
let config = Config::server(5001);
let server = Server::new(config);
server.run().await?;
Ok(())
}
fn handle_connection(mut stream: TcpStream) -> std::io::Result<()> {
loop {
let mut buffer = [0u8; 512];
let bytes_read = stream.read(&mut buffer)?;
let received = String::from_utf8_lossy(&buffer[..bytes_read]);
println!("Received: {}", received);
}
Ok(())
}
fn start_tcp_server(addr: &str) -> std::io::Result<()> {
let listener = TcpListener::bind(addr)?;
println!("TCP server listening on {}", addr);
let mut handles = Vec::new();
for stream in listener.incoming() {
match stream {
Ok(stream) => {
let handle = thread::spawn(move || {
if let Err(e) = handle_connection(stream) {
eprintln!("connection handler error: {}", e);
}
});
handles.push(handle);
}
Err(e) => eprintln!("incoming connection failed: {}", e),
}
}
// When the incoming stream iterator ends (listener closed), join all handlers.
for h in handles {
let _ = h.join();
}
Ok(())
}
#[tokio::main]
async fn main() {
// Choose IP based on `--localhost` flag for debugging
let use_localhost = env::args().any(|a| a == "--localhost");
let ip = if use_localhost { "127.0.0.1" } else { "192.67.67.67" };
let tcp_addr = format!("{}:8080", ip);
let iperf_server = task::spawn(async {
start_iperf_server().await;
});
let tcp_ip = tcp_addr.clone();
let tcp_server = thread::spawn(move || {
start_tcp_server(&tcp_ip).unwrap();
});
iperf_server.await.unwrap();
tcp_server.join().unwrap();
}

168
infra/enginewhy-server.rs Normal file
View File

@@ -0,0 +1,168 @@
use sysinfo::{CpuRefreshKind, RefreshKind, System};
use sysinfo::{Networks};
use sysinfo::{Disks};
use std::thread;
use std::time::Duration;
use std::net::{TcpStream};
use std::env;
use std::collections::HashMap;
use std::io::Write;
use serde_json::Value;
use rperf3::{Client, Config, Protocol};
// Default server addresses
const DEFAULT_REMOTE_IP: &str = "192.67.67.67";
const DEFAULT_LOCAL_IP: &str = "127.0.0.1";
const PORT: u16 = 8080;
const IPERF_PORT: u16 = 5001;
fn get_io_usage_percentage() -> Result<f64, String> {
let mut sys = Disks::new_with_refreshed_list();
// Refresh disk information
sys.refresh(true);
// Get first disk (usually main disk)
if let Some(disk) = sys.list().first() {
let initial_read = disk.usage().total_read_bytes;
let initial_write = disk.usage().total_written_bytes;
thread::sleep(Duration::from_secs(1)); // 1s
sys.refresh(true);
let disk = sys.list().first().ok_or("Disk disappeared")?;
let new_read = disk.usage().total_read_bytes;
let new_write = disk.usage().total_written_bytes;
// Calculate Bps
let read_per_sec = (new_read - initial_read) as f64;
let write_per_sec = (new_write - initial_write) as f64;
// Get disk type to estimate max speed (these are rough estimates)
let max_speed = match disk.kind() {
sysinfo::DiskKind::SSD => 500_000_000.0, // 500 MBps
sysinfo::DiskKind::HDD => 200_000_000.0, // 200 MBps
_ => 300_000_000.0, // Default
};
let io_percentage = f64::min(100.0, ((read_per_sec + write_per_sec) / max_speed) * 100.0);
Ok(io_percentage)
} else {
Err("No disks found".to_string())
}
}
async fn measure_iperf_bandwidth(server_ip: &str, port: u16) -> Result<f64, Box<dyn std::error::Error>> {
// Configure the test (use the provided port)
let config = Config::client(server_ip.to_string(), port)
.with_duration(Duration::from_secs(10));
// Run the test
let client = Client::new(config)?;
client.run().await?;
// Get results
let measurements = client.get_measurements();
let bandwidth_bps = measurements.total_bits_per_second();
println!("iperf3 reported max bandwidth: {:.2} Mbps", bandwidth_bps / 1_000_000.0);
Ok(bandwidth_bps)
}
#[tokio::main]
async fn main() -> std::io::Result<()> {
// Determine server IP from CLI: `--localhost` -> local, otherwise remote
let args: Vec<String> = env::args().collect();
let server_ip = if args.iter().any(|a| a == "--localhost") {
DEFAULT_LOCAL_IP.to_string()
} else {
DEFAULT_REMOTE_IP.to_string()
};
let mut stream = TcpStream::connect(format!("{}:{}", server_ip, PORT))?;
println!("server connected to {}:{}", server_ip, PORT);
// Initialize the system struct
let mut sys = System::new_with_specifics(
RefreshKind::nothing().with_cpu(CpuRefreshKind::everything()),
);
let mut networks = Networks::new();
networks.refresh(true);
// Probe max bandwidth using iperf3
let mut max_bps: f64 = 0.0;
match measure_iperf_bandwidth(&server_ip, IPERF_PORT).await {
Ok(bps) => {
max_bps = bps;
println!("iperf3 reported max bandwidth: {:.2} bits/sec ({:.2} Mbps)", max_bps, max_bps / 1e6);
}
Err(e) => println!("iperf3 failed: {}", e),
}
// Wait a bit because CPU usage is based on diff.
std::thread::sleep(sysinfo::MINIMUM_CPU_UPDATE_INTERVAL);
loop {
sys.refresh_all();
sys.refresh_cpu_usage(); // Refreshing CPU usage.
let mut cpu_usage: f64 = 0.0;
for cpu in sys.cpus() {
cpu_usage += cpu.cpu_usage() as f64;
}
cpu_usage /= sys.cpus().len() as f64;
println!("CPU usage is {}%", cpu_usage);
// Memory usage
let total_mem = sys.total_memory();
let used_mem = sys.used_memory();
let mem_usage = total_mem as f64 / used_mem as f64;
println!("Memory usage is {}%", mem_usage);
// Network bandwidth usage
let mut bandwidth: f64 = 0.0; // Bps
for (interface_name, network) in &networks {
if interface_name == "wlp2s0" {
bandwidth = network.transmitted() as f64;
println!("[{interface_name}] transferred {:?} %", bandwidth / max_bps * 100.0);
}
}
networks.refresh(true);
// Calculate percent usage of measured max bandwidth (if available)
let net_usage_pct: f64 = if max_bps > 0.0 {
f64::min(100.0, (bandwidth / max_bps) * 100.0)
} else { 0.0 };
// IO usage
let mut io_usage = 0.0;
match get_io_usage_percentage() {
Ok(percentage) => {
io_usage = percentage;
println!("I/O usage is {}%", percentage)
},
Err(e) => println!("Error: {}", e)
}
println!();
// Identify this process (client) by the local socket address used to connect
let server_identifier = match stream.local_addr() {
Ok(addr) => addr.to_string(),
Err(_) => format!("localhost:{}", PORT),
};
let mut packet: HashMap<String, Value> = HashMap::new();
packet.insert("server_ip".to_string(), Value::String(server_identifier));
packet.insert("cpu".to_string(), Value::from(cpu_usage)); // %
packet.insert("mem".to_string(), Value::from(mem_usage)); // %
packet.insert("net".to_string(), Value::from(net_usage_pct));
packet.insert("io".to_string(), Value::from(io_usage));
let serialized_packet = serde_json::to_string(&packet)?;
let _ = stream.write(serialized_packet.as_bytes());
thread::sleep(Duration::from_secs(10));
}
}

View File

@@ -1,58 +0,0 @@
use core::fmt;
use std::net::SocketAddr;
use std::sync::RwLock;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug)]
pub struct Backend {
pub id: String,
pub address: SocketAddr,
pub active_connections: AtomicUsize,
}
impl Backend {
pub fn new(id: String, address: SocketAddr) -> Self {
Self {
id: id.to_string(),
address,
active_connections: AtomicUsize::new(0),
}
}
// Ordering::Relaxed means the ops could be in any order, but since this
// is just a metric, and we assume the underlying system is sane
// enough not to behave poorly, so SeqCst is probably overkill.
pub fn inc_connections(&self) {
self.active_connections.fetch_add(1, Ordering::Relaxed);
println!("{} has {} connections open", self.id, self.active_connections.load(Ordering::Relaxed));
}
pub fn dec_connections(&self) {
self.active_connections.fetch_sub(1, Ordering::Relaxed);
println!("{} has {} connections open", self.id, self.active_connections.load(Ordering::Relaxed));
}
}
impl fmt::Display for Backend {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} ({})", self.address, self.id)
}
}
#[derive(Clone, Debug)]
pub struct BackendPool {
pub backends: Arc<RwLock<Vec<Arc<Backend>>>>,
}
impl BackendPool {
pub fn new() -> Self {
BackendPool {
backends: Arc::new(RwLock::new(Vec::new())),
}
}
pub fn add(&self, backend: Backend) {
self.backends.write().unwrap().push(Arc::new(backend));
}
}

View File

@@ -0,0 +1,211 @@
use crate::netutils::Backend;
use rand::prelude::*;
use rand::rngs::SmallRng;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct ServerMetrics {
// metrics are percents (0..100)
pub cpu: f64,
pub mem: f64,
pub net: f64,
pub io: f64,
}
impl ServerMetrics {
pub fn new() -> Self {
ServerMetrics { cpu: 0.0, mem: 0.0, net: 0.0, io: 0.0 }
}
pub fn update(&mut self, cpu: f64, mem: f64, net: f64, io: f64) {
self.cpu = cpu;
self.mem = mem;
self.net = net;
self.io = io;
}
}
#[derive(Debug, Clone)]
pub struct ServerState {
pub backend: Backend,
pub metrics: ServerMetrics,
pub weight: f64,
}
impl ServerState {
pub fn new(backend: Backend) -> Self {
ServerState { backend, metrics: ServerMetrics::new(), weight: 1.0 }
}
}
pub struct AdaptiveBalancer {
servers: Vec<ServerState>,
// resource coefficients (cpu, mem, net, io) - sum to 1.0
coeffs: [f64; 4],
alpha: f64,
rng: SmallRng,
}
impl AdaptiveBalancer {
pub fn new(backends: Vec<Backend>, coeffs: [f64; 4], alpha: f64) -> Self {
let servers = backends.into_iter().map(ServerState::new).collect();
let rng = SmallRng::from_entropy();
AdaptiveBalancer { servers, coeffs, alpha, rng }
}
pub fn add_backend(&mut self, backend: Backend) {
self.servers.push(ServerState::new(backend));
}
/// Update metrics reported by a backend identified by its display/address.
/// If the backend isn't found this is a no-op.
pub fn update_metrics(&mut self, backend_addr: &str, cpu: f64, mem: f64, net: f64, io: f64) {
for s in &mut self.servers {
if s.backend.to_string() == backend_addr {
s.metrics.update(cpu, mem, net, io);
return;
}
}
}
fn metrics_to_weight(metrics: &ServerMetrics, coeffs: &[f64; 4]) -> f64 {
coeffs[0] * metrics.cpu + coeffs[1] * metrics.mem + coeffs[2] * metrics.net + coeffs[3] * metrics.io
}
/// Choose a backend using weighted random selection based on current weights.
/// Returns an Arc-wrapped Backend clone so callers can cheaply clone it.
pub fn choose_backend(&mut self) -> Option<Arc<Backend>> {
if self.servers.is_empty() {
return None;
}
// Compute remaining capacity R_i = 100 - composite_load
let rs: Vec<f64> = self.servers.iter().map(|s| {
Self::metrics_to_weight(&s.metrics, &self.coeffs)
}).collect();
let ws: Vec<f64> = self.servers.iter().map(|s| s.weight).collect();
let ls: Vec<u32> = self.servers.iter().map(|s| s.backend.current_load).collect();
let r_sum: f64 = rs.iter().copied().sum::<f64>();
let w_sum: f64 = ws.iter().copied().sum::<f64>().max(1e-12);
let l_sum: u32 = ls.iter().copied().sum::<u32>();
let threshold = self.alpha * (r_sum / w_sum);
for (i, s) in self.servers.iter_mut().enumerate() {
let ratio = if s.weight <= 0.0 { f64::INFINITY } else { rs[i] / s.weight };
if ratio <= threshold {
return Some(Arc::new(s.backend.clone()));
}
}
// If any server satisfies Ri/Wi <= threshold, it means the server
// is relatively overloaded and we must adjust its weight using
// formula (6).
let lwi: Vec<f64> = self.servers.iter().enumerate().map(|(i, s)| {
s.backend.current_load as f64 * w_sum / ws[i] * l_sum as f64
}).collect();
let a_lwi: f64 = lwi.iter().copied().sum::<f64>() / lwi.len() as f64;
for (i, s) in self.servers.iter_mut().enumerate() {
s.weight += 1 as f64 - lwi[i] / a_lwi;
}
// Compute Li = Wi / Ri and choose server minimizing Li.
let mut best_idx: Option<usize> = None;
let mut best_li = u32::MAX;
for (i, s) in self.servers.iter().enumerate() {
let li = s.backend.current_load;
if li < best_li {
best_li = li;
best_idx = Some(i);
}
}
// If nothing chosen, fall back to random selection
if best_idx.is_none() {
let i = (self.rng.next_u32() as usize) % self.servers.len();
return Some(Arc::new(self.servers[i].backend.clone()));
}
Some(Arc::new(self.servers[best_idx.unwrap()].backend.clone()))
}
// Expose a snapshot of server weights (for monitoring/testing)
pub fn snapshot_weights(&self) -> Vec<(String, f64)> {
self.servers.iter().map(|s| (s.backend.to_string(), s.weight)).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_weight_update_and_choose() {
let backends = vec![Backend::new("127.0.0.1:1".to_string()), Backend::new("127.0.0.1:2".to_string())];
let mut b = AdaptiveBalancer::new(backends, [0.5, 0.2, 0.2, 0.1], 0.5);
// initially equal weights
let snaps = b.snapshot_weights();
assert_eq!(snaps.len(), 2);
// update one backend to be heavily loaded
b.update_metrics("127.0.0.1:1", 90.0, 80.0, 10.0, 5.0);
b.update_metrics("127.0.0.1:2", 10.0, 5.0, 1.0, 1.0);
// Choose backend: should pick the less loaded host (127.0.0.1:2)
let chosen = b.choose_backend().expect("should choose a backend");
let snaps2 = b.snapshot_weights();
println!("{:?}, {:?}", snaps, snaps2);
assert_eq!(chosen.to_string(), "127.0.0.1:2");
}
#[test]
fn choose_none_when_empty() {
let mut b = AdaptiveBalancer::new(vec![], [0.5, 0.2, 0.2, 0.1], 0.5);
assert!(b.choose_backend().is_none());
}
#[test]
fn ratio_triggers_immediate_selection() {
// Arrange two servers where server 1 has composite load 0 and server 2 has composite load 100.
// With alpha = 1.0 and two servers, threshold = 1.0 * (r_sum / w_sum) = 1.0 * (100 / 2) = 50.
// Server 1 ratio = 0 / 1 = 0 <= 50 so it should be chosen immediately.
let backends = vec![Backend::new("127.0.0.1:1".to_string()), Backend::new("127.0.0.1:2".to_string())];
let mut b = AdaptiveBalancer::new(backends, [0.25, 0.25, 0.25, 0.25], 1.0);
b.update_metrics("127.0.0.1:1", 0.0, 0.0, 0.0, 0.0);
b.update_metrics("127.0.0.1:2", 100.0, 100.0, 100.0, 100.0);
let chosen = b.choose_backend().expect("should choose a backend");
assert_eq!(chosen.to_string(), "127.0.0.1:1");
}
#[test]
fn choose_min_current_load_when_no_ratio() {
// Arrange three servers with identical composite loads so no server satisfies Ri/Wi <= threshold
// (set alpha < 1 so threshold < ratio). The implementation then falls back to picking the
// server with minimum current_load
let mut s1 = Backend::new("127.0.0.1:1".to_string());
let mut s2 = Backend::new("127.0.0.1:2".to_string());
let mut s3 = Backend::new("127.0.0.1:3".to_string());
// set current_loads (field expected to be public)
s1.current_load = 10;
s2.current_load = 5;
s3.current_load = 20;
// Use coeffs that only consider CPU so composite load is easy to reason about.
let mut bal = AdaptiveBalancer::new(vec![s1, s2, s3], [1.0, 0.0, 0.0, 0.0], 0.5);
// set identical composite loads > 0 for all so ratio = x and threshold = alpha * x < x
// you will have threshold = 25 for all 3 backend servers and ratio = 50
// so that forces to choose the smallest current load backend
bal.update_metrics("127.0.0.1:1", 50.0, 0.0, 0.0, 0.0);
bal.update_metrics("127.0.0.1:2", 50.0, 0.0, 0.0, 0.0);
bal.update_metrics("127.0.0.1:3", 50.0, 0.0, 0.0, 0.0);
let chosen = bal.choose_backend().expect("should choose a backend");
// expect server with smallest current_load (127.0.0.1:2)
assert_eq!(chosen.to_string(), "127.0.0.1:2");
}
}

View File

@@ -1,9 +1,2 @@
pub mod round_robin; pub mod adaptive_weight;
pub use adaptive_weight::AdaptiveBalancer;
use std::fmt::Debug;
use std::sync::Arc;
use crate::backend::Backend;
pub trait Balancer: Debug + Send + Sync + 'static {
fn choose_backend(&mut self) -> Option<Arc<Backend>>;
}

0
src/balancer/random.rs Normal file
View File

View File

@@ -1,33 +0,0 @@
use std::sync::{Arc, RwLock};
use std::fmt::Debug;
use crate::backend::{Backend, BackendPool};
use crate::balancer::Balancer;
// only the main thread for receiving connections should be
// doing the load balancing. alternatively, each thread
// that handles load balancing should get their own instance.
#[derive(Debug)]
pub struct RoundRobinBalancer {
pool: BackendPool,
index: usize,
}
impl RoundRobinBalancer {
pub fn new(pool: BackendPool) -> RoundRobinBalancer {
Self {
pool,
index: 0,
}
}
}
impl Balancer for RoundRobinBalancer {
fn choose_backend(&mut self) -> Option<Arc<Backend>> {
let backends = self.pool.backends.read().unwrap();
if backends.is_empty() { return None; }
let backend = backends[self.index % backends.len()].clone();
self.index = self.index.wrapping_add(1);
Some(backend)
}
}

View File

@@ -1,6 +0,0 @@
// TODO: "routing" rules
// backends defined as ip + port
// define sets of backends
// allowed set operations for now is just union
// rules are ip + mask and ports, maps to some of the sets
// defined earlier, along with a routing strategy

View File

@@ -1,52 +1,56 @@
extern crate core; macro_rules! info {
($($arg:tt)*) => {{
print!("info: ");
println!($($arg)*);
}};
}
macro_rules! error {
($($arg:tt)*) => {
eprint!("error: ");
eprintln!($($arg)*);
};
}
mod netutils;
mod balancer; mod balancer;
mod config;
mod backend; use anywho::Error;
mod proxy; use netutils::{Backend, tunnel};
use std::sync::Arc;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::Mutex;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use crate::backend::{Backend, BackendPool};
use crate::balancer::Balancer;
use crate::balancer::round_robin::RoundRobinBalancer;
use crate::proxy::tcp::proxy_tcp_connection;
static NEXT_CONN_ID: AtomicU64 = AtomicU64::new(1);
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Error> {
let pool = BackendPool::new(); let backends = Arc::new(vec![
Backend::new("127.0.0.1:8081".to_string()),
Backend::new("127.0.0.1:8082".to_string()),
]);
pool.add(Backend::new( let current_index = Arc::new(Mutex::new(0));
"backend 1".into(),
"127.0.0.1:8081".parse().unwrap(),
));
pool.add(Backend::new( info!("enginewhy starting on 0.0.0.0:8080");
"backend 2".into(), info!("backends: {:?}", backends);
"127.0.0.1:8082".parse().unwrap(),
));
let mut balancer = RoundRobinBalancer::new(pool.clone()); let listener = TcpListener::bind("0.0.0.0:8080").await?;
let listener = TcpListener::bind("127.0.0.1:8080").await?;
loop { loop {
let (socket, _) = listener.accept().await?; let (client, addr) = listener.accept().await?;
info!("new connection from {}", addr);
let conn_id = NEXT_CONN_ID.fetch_add(1, std::sync::atomic::Ordering::SeqCst); let backend = {
let mut index = current_index.lock().await;
let selected_backend = backends[*index].clone();
*index = (*index + 1) % backends.len();
selected_backend
};
if let Some(backend) = balancer.choose_backend() { info!("routing client {} to backend {}", addr, backend);
tokio::spawn(async move {
if let Err(e) = proxy_tcp_connection(conn_id, socket, backend).await { if let Err(e) = tunnel(client, backend).await {
eprintln!("error: conn_id={} proxy failed: {}", conn_id, e); error!("proxy failed for {}: {}", addr, e);
}
});
} else {
eprintln!("error: no backendsd for conn_id={}", conn_id);
} }
} }
} }

56
src/netutils.rs Normal file
View File

@@ -0,0 +1,56 @@
use std::fmt;
use tokio::io;
use tokio::net::TcpStream;
use std::error::Error;
#[derive(Clone, Debug)]
pub struct Backend {
address: String,
pub current_load : u32
}
impl Backend {
pub fn new(address: String) -> Self {
Backend {
address,
current_load : 0
}
}
}
impl fmt::Display for Backend {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.address)
}
}
pub async fn tunnel(client_stream: TcpStream, backend: Backend) -> Result<(), Box<dyn Error>> {
let backend_address: String = backend.address.clone();
tokio::spawn(async move {
let backend_stream: TcpStream = match TcpStream::connect(&backend_address).await {
Ok(s) => {
info!("connected to backend {backend_address}");
s
}
Err(e) => {
error!("failed connecting to backend {backend_address}: {e}");
return;
}
};
let (mut read_client, mut write_client) = client_stream.into_split();
let (mut read_backend, mut write_backend) = backend_stream.into_split();
let client_to_backend =
tokio::spawn(async move { io::copy(&mut read_client, &mut write_backend).await });
let backend_to_client =
tokio::spawn(async move { io::copy(&mut read_backend, &mut write_client).await });
let _ = tokio::join!(client_to_backend, backend_to_client);
});
Ok(())
}

View File

@@ -1,43 +0,0 @@
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Instant;
use crate::backend::Backend;
pub mod tcp;
pub struct ConnectionContext {
pub id: u64,
pub client_addr: SocketAddr,
pub start_time: Instant,
pub backend: Arc<Backend>,
pub bytes_transferred: u64,
}
impl ConnectionContext {
pub fn new(id: u64, client_addr: SocketAddr, backend: Arc<Backend>) -> Self {
backend.inc_connections();
Self {
id,
client_addr,
start_time: Instant::now(),
backend,
bytes_transferred: 0,
}
}
}
impl Drop for ConnectionContext {
fn drop(&mut self) {
self.backend.dec_connections();
let duration = self.start_time.elapsed();
println!("info: conn_id={} closed. client={} backend={} bytes={} duration={:.2?}",
self.id,
self.client_addr,
self.backend.address,
self.bytes_transferred,
duration.as_secs_f64()
);
}
}

View File

@@ -1,26 +0,0 @@
use std::sync::Arc;
use tokio::io;
use tokio::net::TcpStream;
use anywho::Error;
use crate::backend::Backend;
use crate::proxy::ConnectionContext;
pub async fn proxy_tcp_connection(connection_id: u64, mut client_stream: TcpStream, backend: Arc<Backend>) -> Result<(), Error> {
let client_addr = client_stream.peer_addr()?;
let mut ctx = ConnectionContext::new(connection_id, client_addr, backend.clone());
#[cfg(debug_assertions)]
println!("info: conn_id={} connecting to {}", connection_id, ctx.backend.id);
let mut backend_stream = TcpStream::connect(&backend.address).await?;
let (tx, rx) = io::copy_bidirectional(
&mut client_stream,
&mut backend_stream,
).await?;
ctx.bytes_transferred = tx + rx;
Ok(())
}