rustfmt, fixed ip hash issue
This commit is contained in:
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
|
|||||||
@@ -2,27 +2,20 @@ pub mod health;
|
|||||||
|
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::RwLock;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::sync::RwLock;
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
|
||||||
// Physical server information
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Server {
|
|
||||||
pub endpoints: Arc<Vec<Arc<Backend>>>,
|
|
||||||
pub metrics: Arc<RwLock<ServerHealth>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Physical server health statistics, used for certain load balancing algorithms
|
// Physical server health statistics, used for certain load balancing algorithms
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
pub struct ServerHealth {
|
pub struct ServerMetrics {
|
||||||
pub cpu: f64,
|
pub cpu: f64,
|
||||||
pub mem: f64,
|
pub mem: f64,
|
||||||
pub net: f64,
|
pub net: f64,
|
||||||
pub io: f64,
|
pub io: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ServerHealth {
|
impl ServerMetrics {
|
||||||
pub fn update(&mut self, cpu: f64, mem: f64, net: f64, io: f64) {
|
pub fn update(&mut self, cpu: f64, mem: f64, net: f64, io: f64) {
|
||||||
self.cpu = cpu;
|
self.cpu = cpu;
|
||||||
self.mem = mem;
|
self.mem = mem;
|
||||||
@@ -38,11 +31,15 @@ pub struct Backend {
|
|||||||
pub id: String,
|
pub id: String,
|
||||||
pub address: SocketAddr,
|
pub address: SocketAddr,
|
||||||
pub active_connections: AtomicUsize,
|
pub active_connections: AtomicUsize,
|
||||||
pub metrics: Arc<RwLock<ServerHealth>>,
|
pub metrics: Arc<RwLock<ServerMetrics>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Backend {
|
impl Backend {
|
||||||
pub fn new(id: String, address: SocketAddr, server_metrics: Arc<RwLock<ServerHealth>>) -> Self {
|
pub fn new(
|
||||||
|
id: String,
|
||||||
|
address: SocketAddr,
|
||||||
|
server_metrics: Arc<RwLock<ServerMetrics>>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
id: id.to_string(),
|
id: id.to_string(),
|
||||||
address,
|
address,
|
||||||
@@ -56,12 +53,20 @@ impl Backend {
|
|||||||
// enough not to behave poorly, so SeqCst is probably overkill.
|
// enough not to behave poorly, so SeqCst is probably overkill.
|
||||||
pub fn inc_connections(&self) {
|
pub fn inc_connections(&self) {
|
||||||
self.active_connections.fetch_add(1, Ordering::Relaxed);
|
self.active_connections.fetch_add(1, Ordering::Relaxed);
|
||||||
println!("{} has {} connections open", self.id, self.active_connections.load(Ordering::Relaxed));
|
println!(
|
||||||
|
"{} has {} connections open",
|
||||||
|
self.id,
|
||||||
|
self.active_connections.load(Ordering::Relaxed)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dec_connections(&self) {
|
pub fn dec_connections(&self) {
|
||||||
self.active_connections.fetch_sub(1, Ordering::Relaxed);
|
self.active_connections.fetch_sub(1, Ordering::Relaxed);
|
||||||
println!("{} has {} connections open", self.id, self.active_connections.load(Ordering::Relaxed));
|
println!(
|
||||||
|
"{} has {} connections open",
|
||||||
|
self.id,
|
||||||
|
self.active_connections.load(Ordering::Relaxed)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
use std::sync::{Arc, RwLock};
|
use crate::backend::{Backend, BackendPool, ServerMetrics};
|
||||||
use std::fmt::Debug;
|
use crate::balancer::{Balancer, ConnectionInfo};
|
||||||
use std::fs::Metadata;
|
|
||||||
use crate::backend::{Backend, BackendPool, ServerHealth};
|
|
||||||
use crate::balancer::Balancer;
|
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
use rand::rngs::SmallRng;
|
use rand::rngs::SmallRng;
|
||||||
|
use std::fmt::Debug;
|
||||||
|
use std::fs::Metadata;
|
||||||
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct AdaptiveNode {
|
struct AdaptiveNode {
|
||||||
@@ -22,7 +22,8 @@ pub struct AdaptiveWeightBalancer {
|
|||||||
|
|
||||||
impl AdaptiveWeightBalancer {
|
impl AdaptiveWeightBalancer {
|
||||||
pub fn new(pool: BackendPool, coefficients: [f64; 4], alpha: f64) -> Self {
|
pub fn new(pool: BackendPool, coefficients: [f64; 4], alpha: f64) -> Self {
|
||||||
let nodes = pool.backends
|
let nodes = pool
|
||||||
|
.backends
|
||||||
.iter()
|
.iter()
|
||||||
.map(|b| AdaptiveNode {
|
.map(|b| AdaptiveNode {
|
||||||
backend: b.clone(),
|
backend: b.clone(),
|
||||||
@@ -34,20 +35,20 @@ impl AdaptiveWeightBalancer {
|
|||||||
pool: nodes,
|
pool: nodes,
|
||||||
coefficients,
|
coefficients,
|
||||||
alpha,
|
alpha,
|
||||||
rng: SmallRng::from_rng(&mut rand::rng())
|
rng: SmallRng::from_rng(&mut rand::rng()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn metrics_to_weight(&self, metrics: &ServerHealth) -> f64 {
|
pub fn metrics_to_weight(&self, metrics: &ServerMetrics) -> f64 {
|
||||||
self.coefficients[0] * metrics.cpu +
|
self.coefficients[0] * metrics.cpu
|
||||||
self.coefficients[1] * metrics.mem +
|
+ self.coefficients[1] * metrics.mem
|
||||||
self.coefficients[2] * metrics.net +
|
+ self.coefficients[2] * metrics.net
|
||||||
self.coefficients[3] * metrics.io
|
+ self.coefficients[3] * metrics.io
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Balancer for AdaptiveWeightBalancer {
|
impl Balancer for AdaptiveWeightBalancer {
|
||||||
fn choose_backend(&mut self) -> Option<Arc<Backend>> {
|
fn choose_backend(&mut self, ctx: ConnectionInfo) -> Option<Arc<Backend>> {
|
||||||
if self.pool.is_empty() {
|
if self.pool.is_empty() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
@@ -62,7 +63,9 @@ impl Balancer for AdaptiveWeightBalancer {
|
|||||||
r_sum += self.metrics_to_weight(&health);
|
r_sum += self.metrics_to_weight(&health);
|
||||||
}
|
}
|
||||||
w_sum += node.weight;
|
w_sum += node.weight;
|
||||||
l_sum += node.backend.active_connections
|
l_sum += node
|
||||||
|
.backend
|
||||||
|
.active_connections
|
||||||
.load(std::sync::atomic::Ordering::Relaxed);
|
.load(std::sync::atomic::Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,7 +75,9 @@ impl Balancer for AdaptiveWeightBalancer {
|
|||||||
for idx in 0..self.pool.len() {
|
for idx in 0..self.pool.len() {
|
||||||
let node = &self.pool[idx];
|
let node = &self.pool[idx];
|
||||||
|
|
||||||
if node.weight <= 0.001 { continue; }
|
if node.weight <= 0.001 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
let risk = match node.backend.metrics.read() {
|
let risk = match node.backend.metrics.read() {
|
||||||
Ok(h) => self.metrics_to_weight(&h),
|
Ok(h) => self.metrics_to_weight(&h),
|
||||||
@@ -93,7 +98,9 @@ impl Balancer for AdaptiveWeightBalancer {
|
|||||||
let l_sum_f64 = l_sum as f64;
|
let l_sum_f64 = l_sum as f64;
|
||||||
|
|
||||||
for node in &self.pool {
|
for node in &self.pool {
|
||||||
let load = node.backend.active_connections
|
let load = node
|
||||||
|
.backend
|
||||||
|
.active_connections
|
||||||
.load(std::sync::atomic::Ordering::Relaxed) as f64;
|
.load(std::sync::atomic::Ordering::Relaxed) as f64;
|
||||||
let weight = node.weight.max(1e-12);
|
let weight = node.weight.max(1e-12);
|
||||||
let lwi = load * (safe_w_sum / weight) * l_sum_f64;
|
let lwi = load * (safe_w_sum / weight) * l_sum_f64;
|
||||||
@@ -107,7 +114,9 @@ impl Balancer for AdaptiveWeightBalancer {
|
|||||||
let mut min_load = usize::MAX;
|
let mut min_load = usize::MAX;
|
||||||
|
|
||||||
for node in &mut self.pool {
|
for node in &mut self.pool {
|
||||||
let load = node.backend.active_connections
|
let load = node
|
||||||
|
.backend
|
||||||
|
.active_connections
|
||||||
.load(std::sync::atomic::Ordering::Relaxed);
|
.load(std::sync::atomic::Ordering::Relaxed);
|
||||||
let load_f64 = load as f64;
|
let load_f64 = load as f64;
|
||||||
let weight = node.weight.max(1e-12);
|
let weight = node.weight.max(1e-12);
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use crate::backend::{Backend, BackendPool, ServerHealth};
|
use crate::backend::{Backend, BackendPool};
|
||||||
use crate::balancer::{Balancer, CURRENT_CONNECTION_INFO, ConnectionInfo};
|
use crate::balancer::{Balancer, ConnectionInfo};
|
||||||
use std::hash::{Hasher, DefaultHasher, Hash};
|
use std::hash::{DefaultHasher, Hash, Hasher};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -15,165 +15,84 @@ impl SourceIPHash {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Balancer for SourceIPHash {
|
impl Balancer for SourceIPHash {
|
||||||
fn choose_backend(&mut self) -> Option<Arc<Backend>>{
|
fn choose_backend(&mut self, ctx: ConnectionInfo) -> Option<Arc<Backend>> {
|
||||||
let client_ip = CURRENT_CONNECTION_INFO.with(|info| {
|
let client_ip = ctx.client_ip;
|
||||||
info.borrow().as_ref().map(|c| c.client_ip.clone())
|
|
||||||
});
|
|
||||||
|
|
||||||
let client_ip = match client_ip {
|
|
||||||
Some(ip) => ip,
|
|
||||||
None => return None, // no client info available
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut hasher = DefaultHasher::new();
|
let mut hasher = DefaultHasher::new();
|
||||||
client_ip.hash(&mut hasher);
|
client_ip.hash(&mut hasher);
|
||||||
let hash = hasher.finish();
|
let hash = hasher.finish();
|
||||||
let idx = (hash as usize) % self.pool.backends.len();
|
let idx = (hash as usize) % self.pool.backends.len();
|
||||||
|
|
||||||
return Some(self.pool.backends[idx].clone());
|
Some(self.pool.backends[idx].clone())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use std::net::IpAddr;
|
||||||
|
use crate::backend::ServerMetrics;
|
||||||
|
|
||||||
|
fn create_dummy_backends(count: usize) -> BackendPool {
|
||||||
|
let mut backends = Vec::new();
|
||||||
|
for i in 1..=count {
|
||||||
|
backends.push(Arc::new(Backend::new(
|
||||||
|
format!("backend {}", i),
|
||||||
|
format!("127.0.0.1:808{}", i).parse().unwrap(),
|
||||||
|
Arc::new(RwLock::new(ServerMetrics::default())),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
BackendPool::new(backends)
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_same_ip_always_selects_same_backend() {
|
fn test_same_ip_always_selects_same_backend() {
|
||||||
let backends = vec![
|
let backends = create_dummy_backends(3);
|
||||||
Arc::new(Backend::new(
|
let mut balancer = SourceIPHash::new(backends);
|
||||||
"backend 1".into(),
|
|
||||||
"127.0.0.1:8081".parse().unwrap(),
|
|
||||||
Arc::new(RwLock::new(ServerHealth::default())),
|
|
||||||
)),
|
|
||||||
Arc::new(Backend::new(
|
|
||||||
"backend 2".into(),
|
|
||||||
"127.0.0.1:8082".parse().unwrap(),
|
|
||||||
Arc::new(RwLock::new(ServerHealth::default())),
|
|
||||||
)),
|
|
||||||
Arc::new(Backend::new(
|
|
||||||
"backend 3".into(),
|
|
||||||
"127.0.0.1:8083".parse().unwrap(),
|
|
||||||
Arc::new(RwLock::new(ServerHealth::default())),
|
|
||||||
)),
|
|
||||||
];
|
|
||||||
|
|
||||||
let mut balancer = SourceIPHash::new(BackendPool::new(backends));
|
let client_ip: IpAddr = "192.168.1.100".parse().unwrap();
|
||||||
let client_ip = "192.168.1.100:54321".parse().unwrap();
|
|
||||||
|
|
||||||
CURRENT_CONNECTION_INFO.with(|info| {
|
let first_choice = balancer.choose_backend(ConnectionInfo { client_ip });
|
||||||
*info.borrow_mut() = Some(ConnectionInfo { client_ip });
|
let second_choice = balancer.choose_backend(ConnectionInfo { client_ip });
|
||||||
});
|
|
||||||
|
|
||||||
let first_choice = balancer.choose_backend();
|
|
||||||
|
|
||||||
CURRENT_CONNECTION_INFO.with(|info| {
|
|
||||||
*info.borrow_mut() = Some(ConnectionInfo { client_ip });
|
|
||||||
});
|
|
||||||
|
|
||||||
let second_choice = balancer.choose_backend();
|
|
||||||
|
|
||||||
assert!(first_choice.is_some());
|
assert!(first_choice.is_some());
|
||||||
assert!(second_choice.is_some());
|
assert!(second_choice.is_some());
|
||||||
|
|
||||||
let first = first_choice.unwrap();
|
let first = first_choice.unwrap();
|
||||||
let second = second_choice.unwrap();
|
let second = second_choice.unwrap();
|
||||||
assert_eq!(first.id, second.id);
|
|
||||||
|
|
||||||
// Cleanup
|
assert_eq!(first.id, second.id);
|
||||||
CURRENT_CONNECTION_INFO.with(|info| {
|
|
||||||
*info.borrow_mut() = None;
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_different_ips_may_select_different_backends() {
|
fn test_different_ips_may_select_different_backends() {
|
||||||
let backends = vec![
|
let backends = create_dummy_backends(2);
|
||||||
Arc::new(Backend::new(
|
let mut balancer = SourceIPHash::new(backends);
|
||||||
"backend 1".into(),
|
|
||||||
"127.0.0.1:8081".parse().unwrap(),
|
|
||||||
Arc::new(RwLock::new(ServerHealth::default())),
|
|
||||||
)),
|
|
||||||
Arc::new(Backend::new(
|
|
||||||
"backend 2".into(),
|
|
||||||
"127.0.0.1:8082".parse().unwrap(),
|
|
||||||
Arc::new(RwLock::new(ServerHealth::default())),
|
|
||||||
)),
|
|
||||||
];
|
|
||||||
|
|
||||||
let mut balancer = SourceIPHash::new(BackendPool::new(backends));
|
let ip1: IpAddr = "192.168.1.100".parse().unwrap();
|
||||||
|
let choice1 = balancer.choose_backend(ConnectionInfo { client_ip: ip1 });
|
||||||
|
|
||||||
let ip1 = "192.168.1.100:54321".parse().unwrap();
|
let ip2: IpAddr = "192.168.1.101".parse().unwrap();
|
||||||
CURRENT_CONNECTION_INFO.with(|info| {
|
let choice2 = balancer.choose_backend(ConnectionInfo { client_ip: ip2 });
|
||||||
*info.borrow_mut() = Some(ConnectionInfo { client_ip: ip1 });
|
|
||||||
});
|
|
||||||
let choice1 = balancer.choose_backend();
|
|
||||||
|
|
||||||
let ip2 = "192.168.1.101:54322".parse().unwrap();
|
|
||||||
CURRENT_CONNECTION_INFO.with(|info| {
|
|
||||||
*info.borrow_mut() = Some(ConnectionInfo { client_ip: ip2 });
|
|
||||||
});
|
|
||||||
let choice2 = balancer.choose_backend();
|
|
||||||
|
|
||||||
assert!(choice1.is_some());
|
assert!(choice1.is_some());
|
||||||
assert!(choice2.is_some());
|
assert!(choice2.is_some());
|
||||||
// Note: choice1 and choice2 might be equal by chance, but statistically should differ
|
|
||||||
|
|
||||||
CURRENT_CONNECTION_INFO.with(|info| {
|
|
||||||
*info.borrow_mut() = None;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_returns_none_when_no_connection_info() {
|
|
||||||
let backends = vec![Arc::new(Backend::new(
|
|
||||||
"backend 1".into(),
|
|
||||||
"127.0.0.1:8081".parse().unwrap(),
|
|
||||||
Arc::new(RwLock::new(ServerHealth::default())),
|
|
||||||
))];
|
|
||||||
|
|
||||||
let mut balancer = SourceIPHash::new(BackendPool::new(backends));
|
|
||||||
|
|
||||||
// Don't set any connection info
|
|
||||||
CURRENT_CONNECTION_INFO.with(|info| {
|
|
||||||
*info.borrow_mut() = None;
|
|
||||||
});
|
|
||||||
|
|
||||||
let choice = balancer.choose_backend();
|
|
||||||
assert!(choice.is_none());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_hash_distribution_across_backends() {
|
fn test_hash_distribution_across_backends() {
|
||||||
let backends = vec![
|
let pool = create_dummy_backends(3);
|
||||||
Arc::new(Backend::new(
|
let backends_ref = pool.backends.clone();
|
||||||
"backend 1".into(),
|
|
||||||
"127.0.0.1:8081".parse().unwrap(),
|
|
||||||
Arc::new(RwLock::new(ServerHealth::default())),
|
|
||||||
)),
|
|
||||||
Arc::new(Backend::new(
|
|
||||||
"backend 2".into(),
|
|
||||||
"127.0.0.1:8082".parse().unwrap(),
|
|
||||||
Arc::new(RwLock::new(ServerHealth::default())),
|
|
||||||
)),
|
|
||||||
Arc::new(Backend::new(
|
|
||||||
"backend 3".into(),
|
|
||||||
"127.0.0.1:8083".parse().unwrap(),
|
|
||||||
Arc::new(RwLock::new(ServerHealth::default())),
|
|
||||||
)),
|
|
||||||
];
|
|
||||||
|
|
||||||
let mut balancer = SourceIPHash::new(BackendPool::new(backends.clone()));
|
let mut balancer = SourceIPHash::new(pool);
|
||||||
let mut distribution = [0, 0, 0];
|
let mut distribution = [0, 0, 0];
|
||||||
|
|
||||||
// Test 30 different IPs to see if they distribute across backends
|
// Test 30 different IPs
|
||||||
for i in 0..30 {
|
for i in 0..30 {
|
||||||
let client_ip = format!("192.168.1.{}:54321", 100 + i).parse().unwrap();
|
let client_ip: IpAddr = format!("192.168.1.{}", 100 + i).parse().unwrap();
|
||||||
CURRENT_CONNECTION_INFO.with(|info| {
|
|
||||||
*info.borrow_mut() = Some(ConnectionInfo { client_ip });
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Some(backend) = balancer.choose_backend() {
|
if let Some(backend) = balancer.choose_backend(ConnectionInfo { client_ip }) {
|
||||||
for (idx, b) in backends.iter().enumerate() {
|
for (idx, b) in backends_ref.iter().enumerate() {
|
||||||
if backend.id == b.id && backend.address == b.address {
|
if backend.id == b.id && backend.address == b.address {
|
||||||
distribution[idx] += 1;
|
distribution[idx] += 1;
|
||||||
break;
|
break;
|
||||||
@@ -182,12 +101,8 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
assert!(distribution[0] > 0);
|
assert!(distribution[0] > 0, "Backend 0 received no traffic");
|
||||||
assert!(distribution[1] > 0);
|
assert!(distribution[1] > 0, "Backend 1 received no traffic");
|
||||||
assert!(distribution[2] > 0);
|
assert!(distribution[2] > 0, "Backend 2 received no traffic");
|
||||||
|
|
||||||
CURRENT_CONNECTION_INFO.with(|info| {
|
|
||||||
*info.borrow_mut() = None;
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1,24 +1,18 @@
|
|||||||
pub mod round_robin;
|
|
||||||
pub mod adaptive_weight;
|
pub mod adaptive_weight;
|
||||||
pub mod least_connections;
|
|
||||||
pub mod ip_hashing;
|
pub mod ip_hashing;
|
||||||
|
pub mod least_connections;
|
||||||
|
pub mod round_robin;
|
||||||
|
|
||||||
use std::fmt::Debug;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use crate::backend::Backend;
|
use crate::backend::Backend;
|
||||||
use std::cell::RefCell;
|
use std::fmt::Debug;
|
||||||
use std::net::{SocketAddr};
|
use std::net::IpAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
thread_local! {
|
|
||||||
pub static CURRENT_CONNECTION_INFO: RefCell<Option<ConnectionInfo>> = RefCell::new(None);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct ConnectionInfo {
|
pub struct ConnectionInfo {
|
||||||
pub client_ip : SocketAddr,
|
pub client_ip: IpAddr,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait Balancer: Debug + Send + Sync + 'static {
|
pub trait Balancer: Debug + Send + Sync + 'static {
|
||||||
fn choose_backend(&mut self) -> Option<Arc<Backend>>;
|
fn choose_backend(&mut self, ctx: ConnectionInfo) -> Option<Arc<Backend>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use std::sync::{Arc, RwLock};
|
|
||||||
use std::fmt::Debug;
|
|
||||||
use crate::backend::{Backend, BackendPool};
|
use crate::backend::{Backend, BackendPool};
|
||||||
use crate::balancer::Balancer;
|
use crate::balancer::{Balancer, ConnectionInfo};
|
||||||
|
use std::fmt::Debug;
|
||||||
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
// only the main thread for receiving connections should be
|
// only the main thread for receiving connections should be
|
||||||
// doing the load balancing. alternatively, each thread
|
// doing the load balancing. alternatively, each thread
|
||||||
@@ -14,17 +14,16 @@ pub struct RoundRobinBalancer {
|
|||||||
|
|
||||||
impl RoundRobinBalancer {
|
impl RoundRobinBalancer {
|
||||||
pub fn new(pool: BackendPool) -> RoundRobinBalancer {
|
pub fn new(pool: BackendPool) -> RoundRobinBalancer {
|
||||||
Self {
|
Self { pool, index: 0 }
|
||||||
pool,
|
|
||||||
index: 0,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Balancer for RoundRobinBalancer {
|
impl Balancer for RoundRobinBalancer {
|
||||||
fn choose_backend(&mut self) -> Option<Arc<Backend>> {
|
fn choose_backend(&mut self, ctx: ConnectionInfo) -> Option<Arc<Backend>> {
|
||||||
let backends = self.pool.backends.clone();
|
let backends = self.pool.backends.clone();
|
||||||
if backends.is_empty() { return None; }
|
if backends.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
let backend = backends[self.index % backends.len()].clone();
|
let backend = backends[self.index % backends.len()].clone();
|
||||||
self.index = self.index.wrapping_add(1);
|
self.index = self.index.wrapping_add(1);
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
|
use cidr::IpCidr;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::net::{IpAddr, SocketAddr};
|
use std::net::{IpAddr, SocketAddr};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
use cidr::IpCidr;
|
|
||||||
|
|
||||||
use crate::backend::*;
|
use crate::backend::*;
|
||||||
use crate::balancer::Balancer;
|
use crate::balancer::Balancer;
|
||||||
use crate::balancer::round_robin::RoundRobinBalancer;
|
|
||||||
use crate::balancer::adaptive_weight::AdaptiveWeightBalancer;
|
use crate::balancer::adaptive_weight::AdaptiveWeightBalancer;
|
||||||
|
use crate::balancer::round_robin::RoundRobinBalancer;
|
||||||
use crate::config::*;
|
use crate::config::*;
|
||||||
|
|
||||||
pub struct RoutingTable {
|
pub struct RoutingTable {
|
||||||
@@ -27,7 +27,7 @@ fn parse_client(s: &str) -> (IpCidr, u16) {
|
|||||||
pub type PortListeners = HashMap<u16, RoutingTable>;
|
pub type PortListeners = HashMap<u16, RoutingTable>;
|
||||||
|
|
||||||
pub fn build_lb(config: AppConfig) -> PortListeners {
|
pub fn build_lb(config: AppConfig) -> PortListeners {
|
||||||
let mut healths: HashMap<IpAddr, Arc<RwLock<ServerHealth>>> = HashMap::new();
|
let mut healths: HashMap<IpAddr, Arc<RwLock<ServerMetrics>>> = HashMap::new();
|
||||||
let mut backends: HashMap<String, Arc<Backend>> = HashMap::new();
|
let mut backends: HashMap<String, Arc<Backend>> = HashMap::new();
|
||||||
|
|
||||||
for backend_cfg in config.backends {
|
for backend_cfg in config.backends {
|
||||||
@@ -36,14 +36,10 @@ pub fn build_lb(config: AppConfig) -> PortListeners {
|
|||||||
|
|
||||||
let health = healths
|
let health = healths
|
||||||
.entry(ip)
|
.entry(ip)
|
||||||
.or_insert_with(|| Arc::new(RwLock::new(ServerHealth::default())))
|
.or_insert_with(|| Arc::new(RwLock::new(ServerMetrics::default())))
|
||||||
.clone();
|
.clone();
|
||||||
|
|
||||||
let backend = Arc::new(Backend::new(
|
let backend = Arc::new(Backend::new(backend_cfg.id.clone(), addr, health));
|
||||||
backend_cfg.id.clone(),
|
|
||||||
addr,
|
|
||||||
health,
|
|
||||||
));
|
|
||||||
|
|
||||||
backends.insert(backend_cfg.id, backend);
|
backends.insert(backend_cfg.id, backend);
|
||||||
}
|
}
|
||||||
@@ -60,8 +56,7 @@ pub fn build_lb(config: AppConfig) -> PortListeners {
|
|||||||
target_backends.push(backend.clone());
|
target_backends.push(backend.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
} else if let Some(backend) = backends.get(target_name) {
|
||||||
else if let Some(backend) = backends.get(target_name) {
|
|
||||||
target_backends.push(backend.clone());
|
target_backends.push(backend.clone());
|
||||||
} else {
|
} else {
|
||||||
eprintln!("warning: target {} not found", target_name);
|
eprintln!("warning: target {} not found", target_name);
|
||||||
@@ -98,12 +93,11 @@ pub fn build_lb(config: AppConfig) -> PortListeners {
|
|||||||
let pool = BackendPool::new(target_backends.clone());
|
let pool = BackendPool::new(target_backends.clone());
|
||||||
|
|
||||||
let balancer: Box<dyn Balancer + Send> = match &rule.strategy {
|
let balancer: Box<dyn Balancer + Send> = match &rule.strategy {
|
||||||
LoadBalancerStrategy::RoundRobin => {
|
LoadBalancerStrategy::RoundRobin => Box::new(RoundRobinBalancer::new(pool)),
|
||||||
Box::new(RoundRobinBalancer::new(pool))
|
LoadBalancerStrategy::Adaptive {
|
||||||
},
|
coefficients,
|
||||||
LoadBalancerStrategy::Adaptive { coefficients, alpha } => {
|
alpha,
|
||||||
Box::new(AdaptiveWeightBalancer::new(pool, *coefficients, *alpha))
|
} => Box::new(AdaptiveWeightBalancer::new(pool, *coefficients, *alpha)),
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let balancer_idx = table.balancers.len();
|
let balancer_idx = table.balancers.len();
|
||||||
@@ -117,7 +111,9 @@ pub fn build_lb(config: AppConfig) -> PortListeners {
|
|||||||
|
|
||||||
// sort to make most specific first, so that first match == longest prefix match
|
// sort to make most specific first, so that first match == longest prefix match
|
||||||
for table in listeners.values_mut() {
|
for table in listeners.values_mut() {
|
||||||
table.entries.sort_by(|(a, _), (b, _)| a.network_length().cmp(&b.network_length()));
|
table
|
||||||
|
.entries
|
||||||
|
.sort_by(|(a, _), (b, _)| a.network_length().cmp(&b.network_length()));
|
||||||
}
|
}
|
||||||
|
|
||||||
listeners
|
listeners
|
||||||
|
|||||||
@@ -13,8 +13,8 @@
|
|||||||
// specify which algorithm to use.
|
// specify which algorithm to use.
|
||||||
pub mod loader;
|
pub mod loader;
|
||||||
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct AppConfig {
|
pub struct AppConfig {
|
||||||
@@ -42,8 +42,5 @@ pub struct RuleConfig {
|
|||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
pub enum LoadBalancerStrategy {
|
pub enum LoadBalancerStrategy {
|
||||||
RoundRobin,
|
RoundRobin,
|
||||||
Adaptive {
|
Adaptive { coefficients: [f64; 4], alpha: f64 },
|
||||||
coefficients: [f64; 4],
|
|
||||||
alpha: f64,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
26
src/main.rs
26
src/main.rs
@@ -1,16 +1,13 @@
|
|||||||
|
mod backend;
|
||||||
mod balancer;
|
mod balancer;
|
||||||
mod config;
|
mod config;
|
||||||
mod backend;
|
|
||||||
mod proxy;
|
mod proxy;
|
||||||
|
|
||||||
|
use crate::balancer::{Balancer, ConnectionInfo};
|
||||||
|
use crate::proxy::tcp::proxy_tcp_connection;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
use crate::backend::{Backend, BackendPool, ServerHealth};
|
|
||||||
use crate::balancer::{Balancer, CURRENT_CONNECTION_INFO, ConnectionInfo};
|
|
||||||
use crate::balancer::round_robin::RoundRobinBalancer;
|
|
||||||
use crate::balancer::ip_hashing::SourceIPHash;
|
|
||||||
use crate::proxy::tcp::proxy_tcp_connection;
|
|
||||||
|
|
||||||
static NEXT_CONN_ID: AtomicU64 = AtomicU64::new(1);
|
static NEXT_CONN_ID: AtomicU64 = AtomicU64::new(1);
|
||||||
|
|
||||||
@@ -19,7 +16,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let f = File::open("config.yaml").expect("couldn't open config.yaml");
|
let f = File::open("config.yaml").expect("couldn't open config.yaml");
|
||||||
let app_config: config::AppConfig = serde_saphyr::from_reader(f)?;
|
let app_config: config::AppConfig = serde_saphyr::from_reader(f)?;
|
||||||
|
|
||||||
println!("Loaded {} backends, {} rules.",
|
println!(
|
||||||
|
"Loaded {} backends, {} rules.",
|
||||||
app_config.backends.len(),
|
app_config.backends.len(),
|
||||||
app_config.rules.len()
|
app_config.rules.len()
|
||||||
);
|
);
|
||||||
@@ -51,18 +49,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
|
|
||||||
let remote_ip = remote_addr.ip();
|
let remote_ip = remote_addr.ip();
|
||||||
let conn_id = NEXT_CONN_ID.fetch_add(1, Ordering::Relaxed);
|
let conn_id = NEXT_CONN_ID.fetch_add(1, Ordering::Relaxed);
|
||||||
let client_ip = socket.local_addr()?;
|
|
||||||
|
|
||||||
CURRENT_CONNECTION_INFO.with(|info| {
|
|
||||||
*info.borrow_mut() = Some(ConnectionInfo { client_ip : client_ip });
|
|
||||||
});
|
|
||||||
|
|
||||||
let mut chosen_backend = None;
|
let mut chosen_backend = None;
|
||||||
|
|
||||||
for (cidr, balancer_idx) in &mut routing_table.entries {
|
for (cidr, balancer_idx) in &mut routing_table.entries {
|
||||||
if cidr.contains(&remote_ip) {
|
if cidr.contains(&remote_ip) {
|
||||||
let balancer = &mut routing_table.balancers[*balancer_idx];
|
let balancer = &mut routing_table.balancers[*balancer_idx];
|
||||||
chosen_backend = balancer.choose_backend();
|
chosen_backend = balancer.choose_backend(ConnectionInfo {
|
||||||
|
client_ip: remote_ip,
|
||||||
|
});
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -76,11 +71,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
} else {
|
} else {
|
||||||
println!("error: no matching rule for {} on port {}", remote_ip, port);
|
println!("error: no matching rule for {} on port {}", remote_ip, port);
|
||||||
}
|
}
|
||||||
|
|
||||||
// clear the slot after use to avoid stale data
|
|
||||||
CURRENT_CONNECTION_INFO.with(|info| {
|
|
||||||
*info.borrow_mut() = None;
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
|
use crate::backend::Backend;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use crate::backend::Backend;
|
|
||||||
|
|
||||||
pub mod tcp;
|
pub mod tcp;
|
||||||
|
|
||||||
@@ -32,7 +32,8 @@ impl Drop for ConnectionContext {
|
|||||||
self.backend.dec_connections();
|
self.backend.dec_connections();
|
||||||
let duration = self.start_time.elapsed();
|
let duration = self.start_time.elapsed();
|
||||||
|
|
||||||
println!("info: conn_id={} closed. client={} backend={} bytes={} duration={:.2?}",
|
println!(
|
||||||
|
"info: conn_id={} closed. client={} backend={} bytes={} duration={:.2?}",
|
||||||
self.id,
|
self.id,
|
||||||
self.client_addr,
|
self.client_addr,
|
||||||
self.backend.address,
|
self.backend.address,
|
||||||
|
|||||||
@@ -1,24 +1,28 @@
|
|||||||
|
use crate::backend::Backend;
|
||||||
|
use crate::proxy::ConnectionContext;
|
||||||
|
use anywho::Error;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::io;
|
use tokio::io;
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use anywho::Error;
|
|
||||||
use crate::backend::Backend;
|
|
||||||
use crate::proxy::ConnectionContext;
|
|
||||||
|
|
||||||
pub async fn proxy_tcp_connection(connection_id: u64, mut client_stream: TcpStream, backend: Arc<Backend>) -> Result<(), Error> {
|
pub async fn proxy_tcp_connection(
|
||||||
|
connection_id: u64,
|
||||||
|
mut client_stream: TcpStream,
|
||||||
|
backend: Arc<Backend>,
|
||||||
|
) -> Result<(), Error> {
|
||||||
let client_addr = client_stream.peer_addr()?;
|
let client_addr = client_stream.peer_addr()?;
|
||||||
|
|
||||||
let mut ctx = ConnectionContext::new(connection_id, client_addr, backend.clone());
|
let mut ctx = ConnectionContext::new(connection_id, client_addr, backend.clone());
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
println!("info: conn_id={} connecting to {}", connection_id, ctx.backend.id);
|
println!(
|
||||||
|
"info: conn_id={} connecting to {}",
|
||||||
|
connection_id, ctx.backend.id
|
||||||
|
);
|
||||||
|
|
||||||
let mut backend_stream = TcpStream::connect(&backend.address).await?;
|
let mut backend_stream = TcpStream::connect(&backend.address).await?;
|
||||||
|
|
||||||
let (tx, rx) = io::copy_bidirectional(
|
let (tx, rx) = io::copy_bidirectional(&mut client_stream, &mut backend_stream).await?;
|
||||||
&mut client_stream,
|
|
||||||
&mut backend_stream,
|
|
||||||
).await?;
|
|
||||||
|
|
||||||
ctx.bytes_transferred = tx + rx;
|
ctx.bytes_transferred = tx + rx;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user