rustfmt, fixed ip hash issue

This commit is contained in:
psun256
2025-12-10 03:03:10 -05:00
parent 90d326ba33
commit 9fb423b949
11 changed files with 150 additions and 239 deletions

View File

@@ -0,0 +1 @@

View File

@@ -2,27 +2,20 @@ pub mod health;
use core::fmt; use core::fmt;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::RwLock;
use std::sync::Arc; use std::sync::Arc;
use std::sync::RwLock;
use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::atomic::{AtomicUsize, Ordering};
// Physical server information
#[derive(Debug)]
pub struct Server {
pub endpoints: Arc<Vec<Arc<Backend>>>,
pub metrics: Arc<RwLock<ServerHealth>>,
}
// Physical server health statistics, used for certain load balancing algorithms // Physical server health statistics, used for certain load balancing algorithms
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct ServerHealth { pub struct ServerMetrics {
pub cpu: f64, pub cpu: f64,
pub mem: f64, pub mem: f64,
pub net: f64, pub net: f64,
pub io: f64, pub io: f64,
} }
impl ServerHealth { impl ServerMetrics {
pub fn update(&mut self, cpu: f64, mem: f64, net: f64, io: f64) { pub fn update(&mut self, cpu: f64, mem: f64, net: f64, io: f64) {
self.cpu = cpu; self.cpu = cpu;
self.mem = mem; self.mem = mem;
@@ -38,11 +31,15 @@ pub struct Backend {
pub id: String, pub id: String,
pub address: SocketAddr, pub address: SocketAddr,
pub active_connections: AtomicUsize, pub active_connections: AtomicUsize,
pub metrics: Arc<RwLock<ServerHealth>>, pub metrics: Arc<RwLock<ServerMetrics>>,
} }
impl Backend { impl Backend {
pub fn new(id: String, address: SocketAddr, server_metrics: Arc<RwLock<ServerHealth>>) -> Self { pub fn new(
id: String,
address: SocketAddr,
server_metrics: Arc<RwLock<ServerMetrics>>,
) -> Self {
Self { Self {
id: id.to_string(), id: id.to_string(),
address, address,
@@ -56,12 +53,20 @@ impl Backend {
// enough not to behave poorly, so SeqCst is probably overkill. // enough not to behave poorly, so SeqCst is probably overkill.
pub fn inc_connections(&self) { pub fn inc_connections(&self) {
self.active_connections.fetch_add(1, Ordering::Relaxed); self.active_connections.fetch_add(1, Ordering::Relaxed);
println!("{} has {} connections open", self.id, self.active_connections.load(Ordering::Relaxed)); println!(
"{} has {} connections open",
self.id,
self.active_connections.load(Ordering::Relaxed)
);
} }
pub fn dec_connections(&self) { pub fn dec_connections(&self) {
self.active_connections.fetch_sub(1, Ordering::Relaxed); self.active_connections.fetch_sub(1, Ordering::Relaxed);
println!("{} has {} connections open", self.id, self.active_connections.load(Ordering::Relaxed)); println!(
"{} has {} connections open",
self.id,
self.active_connections.load(Ordering::Relaxed)
);
} }
} }

View File

@@ -1,10 +1,10 @@
use std::sync::{Arc, RwLock}; use crate::backend::{Backend, BackendPool, ServerMetrics};
use std::fmt::Debug; use crate::balancer::{Balancer, ConnectionInfo};
use std::fs::Metadata;
use crate::backend::{Backend, BackendPool, ServerHealth};
use crate::balancer::Balancer;
use rand::prelude::*; use rand::prelude::*;
use rand::rngs::SmallRng; use rand::rngs::SmallRng;
use std::fmt::Debug;
use std::fs::Metadata;
use std::sync::{Arc, RwLock};
#[derive(Debug)] #[derive(Debug)]
struct AdaptiveNode { struct AdaptiveNode {
@@ -22,7 +22,8 @@ pub struct AdaptiveWeightBalancer {
impl AdaptiveWeightBalancer { impl AdaptiveWeightBalancer {
pub fn new(pool: BackendPool, coefficients: [f64; 4], alpha: f64) -> Self { pub fn new(pool: BackendPool, coefficients: [f64; 4], alpha: f64) -> Self {
let nodes = pool.backends let nodes = pool
.backends
.iter() .iter()
.map(|b| AdaptiveNode { .map(|b| AdaptiveNode {
backend: b.clone(), backend: b.clone(),
@@ -34,20 +35,20 @@ impl AdaptiveWeightBalancer {
pool: nodes, pool: nodes,
coefficients, coefficients,
alpha, alpha,
rng: SmallRng::from_rng(&mut rand::rng()) rng: SmallRng::from_rng(&mut rand::rng()),
} }
} }
pub fn metrics_to_weight(&self, metrics: &ServerHealth) -> f64 { pub fn metrics_to_weight(&self, metrics: &ServerMetrics) -> f64 {
self.coefficients[0] * metrics.cpu + self.coefficients[0] * metrics.cpu
self.coefficients[1] * metrics.mem + + self.coefficients[1] * metrics.mem
self.coefficients[2] * metrics.net + + self.coefficients[2] * metrics.net
self.coefficients[3] * metrics.io + self.coefficients[3] * metrics.io
} }
} }
impl Balancer for AdaptiveWeightBalancer { impl Balancer for AdaptiveWeightBalancer {
fn choose_backend(&mut self) -> Option<Arc<Backend>> { fn choose_backend(&mut self, ctx: ConnectionInfo) -> Option<Arc<Backend>> {
if self.pool.is_empty() { if self.pool.is_empty() {
return None; return None;
} }
@@ -62,7 +63,9 @@ impl Balancer for AdaptiveWeightBalancer {
r_sum += self.metrics_to_weight(&health); r_sum += self.metrics_to_weight(&health);
} }
w_sum += node.weight; w_sum += node.weight;
l_sum += node.backend.active_connections l_sum += node
.backend
.active_connections
.load(std::sync::atomic::Ordering::Relaxed); .load(std::sync::atomic::Ordering::Relaxed);
} }
@@ -72,7 +75,9 @@ impl Balancer for AdaptiveWeightBalancer {
for idx in 0..self.pool.len() { for idx in 0..self.pool.len() {
let node = &self.pool[idx]; let node = &self.pool[idx];
if node.weight <= 0.001 { continue; } if node.weight <= 0.001 {
continue;
}
let risk = match node.backend.metrics.read() { let risk = match node.backend.metrics.read() {
Ok(h) => self.metrics_to_weight(&h), Ok(h) => self.metrics_to_weight(&h),
@@ -93,7 +98,9 @@ impl Balancer for AdaptiveWeightBalancer {
let l_sum_f64 = l_sum as f64; let l_sum_f64 = l_sum as f64;
for node in &self.pool { for node in &self.pool {
let load = node.backend.active_connections let load = node
.backend
.active_connections
.load(std::sync::atomic::Ordering::Relaxed) as f64; .load(std::sync::atomic::Ordering::Relaxed) as f64;
let weight = node.weight.max(1e-12); let weight = node.weight.max(1e-12);
let lwi = load * (safe_w_sum / weight) * l_sum_f64; let lwi = load * (safe_w_sum / weight) * l_sum_f64;
@@ -107,7 +114,9 @@ impl Balancer for AdaptiveWeightBalancer {
let mut min_load = usize::MAX; let mut min_load = usize::MAX;
for node in &mut self.pool { for node in &mut self.pool {
let load = node.backend.active_connections let load = node
.backend
.active_connections
.load(std::sync::atomic::Ordering::Relaxed); .load(std::sync::atomic::Ordering::Relaxed);
let load_f64 = load as f64; let load_f64 = load as f64;
let weight = node.weight.max(1e-12); let weight = node.weight.max(1e-12);

View File

@@ -1,11 +1,11 @@
use crate::backend::{Backend, BackendPool, ServerHealth}; use crate::backend::{Backend, BackendPool};
use crate::balancer::{Balancer, CURRENT_CONNECTION_INFO, ConnectionInfo}; use crate::balancer::{Balancer, ConnectionInfo};
use std::hash::{Hasher, DefaultHasher, Hash}; use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
#[derive(Debug)] #[derive(Debug)]
pub struct SourceIPHash { pub struct SourceIPHash {
pool : BackendPool, pool: BackendPool,
} }
impl SourceIPHash { impl SourceIPHash {
@@ -15,165 +15,84 @@ impl SourceIPHash {
} }
impl Balancer for SourceIPHash { impl Balancer for SourceIPHash {
fn choose_backend(&mut self) -> Option<Arc<Backend>>{ fn choose_backend(&mut self, ctx: ConnectionInfo) -> Option<Arc<Backend>> {
let client_ip = CURRENT_CONNECTION_INFO.with(|info| { let client_ip = ctx.client_ip;
info.borrow().as_ref().map(|c| c.client_ip.clone())
});
let client_ip = match client_ip {
Some(ip) => ip,
None => return None, // no client info available
};
let mut hasher = DefaultHasher::new(); let mut hasher = DefaultHasher::new();
client_ip.hash(&mut hasher); client_ip.hash(&mut hasher);
let hash = hasher.finish(); let hash = hasher.finish();
let idx = (hash as usize) % self.pool.backends.len(); let idx = (hash as usize) % self.pool.backends.len();
return Some(self.pool.backends[idx].clone()); Some(self.pool.backends[idx].clone())
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use std::net::IpAddr;
use crate::backend::ServerMetrics;
fn create_dummy_backends(count: usize) -> BackendPool {
let mut backends = Vec::new();
for i in 1..=count {
backends.push(Arc::new(Backend::new(
format!("backend {}", i),
format!("127.0.0.1:808{}", i).parse().unwrap(),
Arc::new(RwLock::new(ServerMetrics::default())),
)));
}
BackendPool::new(backends)
}
#[test] #[test]
fn test_same_ip_always_selects_same_backend() { fn test_same_ip_always_selects_same_backend() {
let backends = vec![ let backends = create_dummy_backends(3);
Arc::new(Backend::new( let mut balancer = SourceIPHash::new(backends);
"backend 1".into(),
"127.0.0.1:8081".parse().unwrap(),
Arc::new(RwLock::new(ServerHealth::default())),
)),
Arc::new(Backend::new(
"backend 2".into(),
"127.0.0.1:8082".parse().unwrap(),
Arc::new(RwLock::new(ServerHealth::default())),
)),
Arc::new(Backend::new(
"backend 3".into(),
"127.0.0.1:8083".parse().unwrap(),
Arc::new(RwLock::new(ServerHealth::default())),
)),
];
let mut balancer = SourceIPHash::new(BackendPool::new(backends)); let client_ip: IpAddr = "192.168.1.100".parse().unwrap();
let client_ip = "192.168.1.100:54321".parse().unwrap();
CURRENT_CONNECTION_INFO.with(|info| { let first_choice = balancer.choose_backend(ConnectionInfo { client_ip });
*info.borrow_mut() = Some(ConnectionInfo { client_ip }); let second_choice = balancer.choose_backend(ConnectionInfo { client_ip });
});
let first_choice = balancer.choose_backend();
CURRENT_CONNECTION_INFO.with(|info| {
*info.borrow_mut() = Some(ConnectionInfo { client_ip });
});
let second_choice = balancer.choose_backend();
assert!(first_choice.is_some()); assert!(first_choice.is_some());
assert!(second_choice.is_some()); assert!(second_choice.is_some());
let first = first_choice.unwrap(); let first = first_choice.unwrap();
let second = second_choice.unwrap(); let second = second_choice.unwrap();
assert_eq!(first.id, second.id);
// Cleanup assert_eq!(first.id, second.id);
CURRENT_CONNECTION_INFO.with(|info| {
*info.borrow_mut() = None;
});
} }
#[test] #[test]
fn test_different_ips_may_select_different_backends() { fn test_different_ips_may_select_different_backends() {
let backends = vec![ let backends = create_dummy_backends(2);
Arc::new(Backend::new( let mut balancer = SourceIPHash::new(backends);
"backend 1".into(),
"127.0.0.1:8081".parse().unwrap(),
Arc::new(RwLock::new(ServerHealth::default())),
)),
Arc::new(Backend::new(
"backend 2".into(),
"127.0.0.1:8082".parse().unwrap(),
Arc::new(RwLock::new(ServerHealth::default())),
)),
];
let mut balancer = SourceIPHash::new(BackendPool::new(backends)); let ip1: IpAddr = "192.168.1.100".parse().unwrap();
let choice1 = balancer.choose_backend(ConnectionInfo { client_ip: ip1 });
let ip1 = "192.168.1.100:54321".parse().unwrap(); let ip2: IpAddr = "192.168.1.101".parse().unwrap();
CURRENT_CONNECTION_INFO.with(|info| { let choice2 = balancer.choose_backend(ConnectionInfo { client_ip: ip2 });
*info.borrow_mut() = Some(ConnectionInfo { client_ip: ip1 });
});
let choice1 = balancer.choose_backend();
let ip2 = "192.168.1.101:54322".parse().unwrap();
CURRENT_CONNECTION_INFO.with(|info| {
*info.borrow_mut() = Some(ConnectionInfo { client_ip: ip2 });
});
let choice2 = balancer.choose_backend();
assert!(choice1.is_some()); assert!(choice1.is_some());
assert!(choice2.is_some()); assert!(choice2.is_some());
// Note: choice1 and choice2 might be equal by chance, but statistically should differ
CURRENT_CONNECTION_INFO.with(|info| {
*info.borrow_mut() = None;
});
}
#[test]
fn test_returns_none_when_no_connection_info() {
let backends = vec![Arc::new(Backend::new(
"backend 1".into(),
"127.0.0.1:8081".parse().unwrap(),
Arc::new(RwLock::new(ServerHealth::default())),
))];
let mut balancer = SourceIPHash::new(BackendPool::new(backends));
// Don't set any connection info
CURRENT_CONNECTION_INFO.with(|info| {
*info.borrow_mut() = None;
});
let choice = balancer.choose_backend();
assert!(choice.is_none());
} }
#[test] #[test]
fn test_hash_distribution_across_backends() { fn test_hash_distribution_across_backends() {
let backends = vec![ let pool = create_dummy_backends(3);
Arc::new(Backend::new( let backends_ref = pool.backends.clone();
"backend 1".into(),
"127.0.0.1:8081".parse().unwrap(),
Arc::new(RwLock::new(ServerHealth::default())),
)),
Arc::new(Backend::new(
"backend 2".into(),
"127.0.0.1:8082".parse().unwrap(),
Arc::new(RwLock::new(ServerHealth::default())),
)),
Arc::new(Backend::new(
"backend 3".into(),
"127.0.0.1:8083".parse().unwrap(),
Arc::new(RwLock::new(ServerHealth::default())),
)),
];
let mut balancer = SourceIPHash::new(BackendPool::new(backends.clone())); let mut balancer = SourceIPHash::new(pool);
let mut distribution = [0, 0, 0]; let mut distribution = [0, 0, 0];
// Test 30 different IPs to see if they distribute across backends // Test 30 different IPs
for i in 0..30 { for i in 0..30 {
let client_ip = format!("192.168.1.{}:54321", 100 + i).parse().unwrap(); let client_ip: IpAddr = format!("192.168.1.{}", 100 + i).parse().unwrap();
CURRENT_CONNECTION_INFO.with(|info| {
*info.borrow_mut() = Some(ConnectionInfo { client_ip });
});
if let Some(backend) = balancer.choose_backend() { if let Some(backend) = balancer.choose_backend(ConnectionInfo { client_ip }) {
for (idx, b) in backends.iter().enumerate() { for (idx, b) in backends_ref.iter().enumerate() {
if backend.id == b.id && backend.address == b.address { if backend.id == b.id && backend.address == b.address {
distribution[idx] += 1; distribution[idx] += 1;
break; break;
@@ -182,12 +101,8 @@ mod tests {
} }
} }
assert!(distribution[0] > 0); assert!(distribution[0] > 0, "Backend 0 received no traffic");
assert!(distribution[1] > 0); assert!(distribution[1] > 0, "Backend 1 received no traffic");
assert!(distribution[2] > 0); assert!(distribution[2] > 0, "Backend 2 received no traffic");
CURRENT_CONNECTION_INFO.with(|info| {
*info.borrow_mut() = None;
});
} }
} }

View File

@@ -1,24 +1,18 @@
pub mod round_robin;
pub mod adaptive_weight; pub mod adaptive_weight;
pub mod least_connections;
pub mod ip_hashing; pub mod ip_hashing;
pub mod least_connections;
pub mod round_robin;
use std::fmt::Debug;
use std::sync::Arc;
use crate::backend::Backend; use crate::backend::Backend;
use std::cell::RefCell; use std::fmt::Debug;
use std::net::{SocketAddr}; use std::net::IpAddr;
use std::sync::Arc;
thread_local! {
pub static CURRENT_CONNECTION_INFO: RefCell<Option<ConnectionInfo>> = RefCell::new(None);
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct ConnectionInfo { pub struct ConnectionInfo {
pub client_ip : SocketAddr, pub client_ip: IpAddr,
} }
pub trait Balancer: Debug + Send + Sync + 'static { pub trait Balancer: Debug + Send + Sync + 'static {
fn choose_backend(&mut self) -> Option<Arc<Backend>>; fn choose_backend(&mut self, ctx: ConnectionInfo) -> Option<Arc<Backend>>;
} }

View File

@@ -1,7 +1,7 @@
use std::sync::{Arc, RwLock};
use std::fmt::Debug;
use crate::backend::{Backend, BackendPool}; use crate::backend::{Backend, BackendPool};
use crate::balancer::Balancer; use crate::balancer::{Balancer, ConnectionInfo};
use std::fmt::Debug;
use std::sync::{Arc, RwLock};
// only the main thread for receiving connections should be // only the main thread for receiving connections should be
// doing the load balancing. alternatively, each thread // doing the load balancing. alternatively, each thread
@@ -14,17 +14,16 @@ pub struct RoundRobinBalancer {
impl RoundRobinBalancer { impl RoundRobinBalancer {
pub fn new(pool: BackendPool) -> RoundRobinBalancer { pub fn new(pool: BackendPool) -> RoundRobinBalancer {
Self { Self { pool, index: 0 }
pool,
index: 0,
}
} }
} }
impl Balancer for RoundRobinBalancer { impl Balancer for RoundRobinBalancer {
fn choose_backend(&mut self) -> Option<Arc<Backend>> { fn choose_backend(&mut self, ctx: ConnectionInfo) -> Option<Arc<Backend>> {
let backends = self.pool.backends.clone(); let backends = self.pool.backends.clone();
if backends.is_empty() { return None; } if backends.is_empty() {
return None;
}
let backend = backends[self.index % backends.len()].clone(); let backend = backends[self.index % backends.len()].clone();
self.index = self.index.wrapping_add(1); self.index = self.index.wrapping_add(1);

View File

@@ -1,12 +1,12 @@
use cidr::IpCidr;
use std::collections::HashMap; use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr}; use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use cidr::IpCidr;
use crate::backend::*; use crate::backend::*;
use crate::balancer::Balancer; use crate::balancer::Balancer;
use crate::balancer::round_robin::RoundRobinBalancer;
use crate::balancer::adaptive_weight::AdaptiveWeightBalancer; use crate::balancer::adaptive_weight::AdaptiveWeightBalancer;
use crate::balancer::round_robin::RoundRobinBalancer;
use crate::config::*; use crate::config::*;
pub struct RoutingTable { pub struct RoutingTable {
@@ -27,7 +27,7 @@ fn parse_client(s: &str) -> (IpCidr, u16) {
pub type PortListeners = HashMap<u16, RoutingTable>; pub type PortListeners = HashMap<u16, RoutingTable>;
pub fn build_lb(config: AppConfig) -> PortListeners { pub fn build_lb(config: AppConfig) -> PortListeners {
let mut healths: HashMap<IpAddr, Arc<RwLock<ServerHealth>>> = HashMap::new(); let mut healths: HashMap<IpAddr, Arc<RwLock<ServerMetrics>>> = HashMap::new();
let mut backends: HashMap<String, Arc<Backend>> = HashMap::new(); let mut backends: HashMap<String, Arc<Backend>> = HashMap::new();
for backend_cfg in config.backends { for backend_cfg in config.backends {
@@ -36,14 +36,10 @@ pub fn build_lb(config: AppConfig) -> PortListeners {
let health = healths let health = healths
.entry(ip) .entry(ip)
.or_insert_with(|| Arc::new(RwLock::new(ServerHealth::default()))) .or_insert_with(|| Arc::new(RwLock::new(ServerMetrics::default())))
.clone(); .clone();
let backend = Arc::new(Backend::new( let backend = Arc::new(Backend::new(backend_cfg.id.clone(), addr, health));
backend_cfg.id.clone(),
addr,
health,
));
backends.insert(backend_cfg.id, backend); backends.insert(backend_cfg.id, backend);
} }
@@ -60,8 +56,7 @@ pub fn build_lb(config: AppConfig) -> PortListeners {
target_backends.push(backend.clone()); target_backends.push(backend.clone());
} }
} }
} } else if let Some(backend) = backends.get(target_name) {
else if let Some(backend) = backends.get(target_name) {
target_backends.push(backend.clone()); target_backends.push(backend.clone());
} else { } else {
eprintln!("warning: target {} not found", target_name); eprintln!("warning: target {} not found", target_name);
@@ -98,12 +93,11 @@ pub fn build_lb(config: AppConfig) -> PortListeners {
let pool = BackendPool::new(target_backends.clone()); let pool = BackendPool::new(target_backends.clone());
let balancer: Box<dyn Balancer + Send> = match &rule.strategy { let balancer: Box<dyn Balancer + Send> = match &rule.strategy {
LoadBalancerStrategy::RoundRobin => { LoadBalancerStrategy::RoundRobin => Box::new(RoundRobinBalancer::new(pool)),
Box::new(RoundRobinBalancer::new(pool)) LoadBalancerStrategy::Adaptive {
}, coefficients,
LoadBalancerStrategy::Adaptive { coefficients, alpha } => { alpha,
Box::new(AdaptiveWeightBalancer::new(pool, *coefficients, *alpha)) } => Box::new(AdaptiveWeightBalancer::new(pool, *coefficients, *alpha)),
}
}; };
let balancer_idx = table.balancers.len(); let balancer_idx = table.balancers.len();
@@ -117,7 +111,9 @@ pub fn build_lb(config: AppConfig) -> PortListeners {
// sort to make most specific first, so that first match == longest prefix match // sort to make most specific first, so that first match == longest prefix match
for table in listeners.values_mut() { for table in listeners.values_mut() {
table.entries.sort_by(|(a, _), (b, _)| a.network_length().cmp(&b.network_length())); table
.entries
.sort_by(|(a, _), (b, _)| a.network_length().cmp(&b.network_length()));
} }
listeners listeners

View File

@@ -13,8 +13,8 @@
// specify which algorithm to use. // specify which algorithm to use.
pub mod loader; pub mod loader;
use std::collections::HashMap;
use serde::Deserialize; use serde::Deserialize;
use std::collections::HashMap;
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
pub struct AppConfig { pub struct AppConfig {
@@ -42,8 +42,5 @@ pub struct RuleConfig {
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum LoadBalancerStrategy { pub enum LoadBalancerStrategy {
RoundRobin, RoundRobin,
Adaptive { Adaptive { coefficients: [f64; 4], alpha: f64 },
coefficients: [f64; 4],
alpha: f64,
},
} }

View File

@@ -1,16 +1,13 @@
mod backend;
mod balancer; mod balancer;
mod config; mod config;
mod backend;
mod proxy; mod proxy;
use crate::balancer::{Balancer, ConnectionInfo};
use crate::proxy::tcp::proxy_tcp_connection;
use std::fs::File; use std::fs::File;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use tokio::net::TcpListener; use tokio::net::TcpListener;
use crate::backend::{Backend, BackendPool, ServerHealth};
use crate::balancer::{Balancer, CURRENT_CONNECTION_INFO, ConnectionInfo};
use crate::balancer::round_robin::RoundRobinBalancer;
use crate::balancer::ip_hashing::SourceIPHash;
use crate::proxy::tcp::proxy_tcp_connection;
static NEXT_CONN_ID: AtomicU64 = AtomicU64::new(1); static NEXT_CONN_ID: AtomicU64 = AtomicU64::new(1);
@@ -19,7 +16,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let f = File::open("config.yaml").expect("couldn't open config.yaml"); let f = File::open("config.yaml").expect("couldn't open config.yaml");
let app_config: config::AppConfig = serde_saphyr::from_reader(f)?; let app_config: config::AppConfig = serde_saphyr::from_reader(f)?;
println!("Loaded {} backends, {} rules.", println!(
"Loaded {} backends, {} rules.",
app_config.backends.len(), app_config.backends.len(),
app_config.rules.len() app_config.rules.len()
); );
@@ -51,18 +49,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let remote_ip = remote_addr.ip(); let remote_ip = remote_addr.ip();
let conn_id = NEXT_CONN_ID.fetch_add(1, Ordering::Relaxed); let conn_id = NEXT_CONN_ID.fetch_add(1, Ordering::Relaxed);
let client_ip = socket.local_addr()?;
CURRENT_CONNECTION_INFO.with(|info| {
*info.borrow_mut() = Some(ConnectionInfo { client_ip : client_ip });
});
let mut chosen_backend = None; let mut chosen_backend = None;
for (cidr, balancer_idx) in &mut routing_table.entries { for (cidr, balancer_idx) in &mut routing_table.entries {
if cidr.contains(&remote_ip) { if cidr.contains(&remote_ip) {
let balancer = &mut routing_table.balancers[*balancer_idx]; let balancer = &mut routing_table.balancers[*balancer_idx];
chosen_backend = balancer.choose_backend(); chosen_backend = balancer.choose_backend(ConnectionInfo {
client_ip: remote_ip,
});
break; break;
} }
} }
@@ -76,11 +71,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
} else { } else {
println!("error: no matching rule for {} on port {}", remote_ip, port); println!("error: no matching rule for {} on port {}", remote_ip, port);
} }
// clear the slot after use to avoid stale data
CURRENT_CONNECTION_INFO.with(|info| {
*info.borrow_mut() = None;
});
} }
})); }));
} }

View File

@@ -1,7 +1,7 @@
use crate::backend::Backend;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use crate::backend::Backend;
pub mod tcp; pub mod tcp;
@@ -32,7 +32,8 @@ impl Drop for ConnectionContext {
self.backend.dec_connections(); self.backend.dec_connections();
let duration = self.start_time.elapsed(); let duration = self.start_time.elapsed();
println!("info: conn_id={} closed. client={} backend={} bytes={} duration={:.2?}", println!(
"info: conn_id={} closed. client={} backend={} bytes={} duration={:.2?}",
self.id, self.id,
self.client_addr, self.client_addr,
self.backend.address, self.backend.address,

View File

@@ -1,24 +1,28 @@
use crate::backend::Backend;
use crate::proxy::ConnectionContext;
use anywho::Error;
use std::sync::Arc; use std::sync::Arc;
use tokio::io; use tokio::io;
use tokio::net::TcpStream; use tokio::net::TcpStream;
use anywho::Error;
use crate::backend::Backend;
use crate::proxy::ConnectionContext;
pub async fn proxy_tcp_connection(connection_id: u64, mut client_stream: TcpStream, backend: Arc<Backend>) -> Result<(), Error> { pub async fn proxy_tcp_connection(
connection_id: u64,
mut client_stream: TcpStream,
backend: Arc<Backend>,
) -> Result<(), Error> {
let client_addr = client_stream.peer_addr()?; let client_addr = client_stream.peer_addr()?;
let mut ctx = ConnectionContext::new(connection_id, client_addr, backend.clone()); let mut ctx = ConnectionContext::new(connection_id, client_addr, backend.clone());
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
println!("info: conn_id={} connecting to {}", connection_id, ctx.backend.id); println!(
"info: conn_id={} connecting to {}",
connection_id, ctx.backend.id
);
let mut backend_stream = TcpStream::connect(&backend.address).await?; let mut backend_stream = TcpStream::connect(&backend.address).await?;
let (tx, rx) = io::copy_bidirectional( let (tx, rx) = io::copy_bidirectional(&mut client_stream, &mut backend_stream).await?;
&mut client_stream,
&mut backend_stream,
).await?;
ctx.bytes_transferred = tx + rx; ctx.bytes_transferred = tx + rx;