implemented better routing system, config parsing from yaml.

This commit is contained in:
psun256
2025-12-10 01:49:45 -05:00
parent 9046a85d84
commit 8170d2a6bf
7 changed files with 498 additions and 45 deletions

View File

@@ -1,6 +0,0 @@
// TODO: "routing" rules
// backends defined as ip + port
// define sets of backends
// allowed set operations for now is just union
// rules are ip + mask and ports, maps to some of the sets
// defined earlier, along with a routing strategy

124
src/config/loader.rs Normal file
View File

@@ -0,0 +1,124 @@
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::sync::{Arc, RwLock};
use cidr::IpCidr;
use crate::backend::*;
use crate::balancer::Balancer;
use crate::balancer::round_robin::RoundRobinBalancer;
use crate::balancer::adaptive_weight::AdaptiveWeightBalancer;
use crate::config::*;
pub struct RoutingTable {
pub balancers: Vec<Box<dyn Balancer + Send>>,
pub entries: Vec<(IpCidr, usize)>,
}
fn parse_client(s: &str) -> (IpCidr, u16) {
// just splits "0.0.0.0/0:80" into ("0.0.0.0/0", 80)
let (ip_part, port_part) = s.rsplit_once(':').expect("badly formatted client");
let port: u16 = port_part.parse().expect("bad port");
let cidr: IpCidr = ip_part.parse().expect("bad ip/mask");
(cidr, port)
}
pub type PortListeners = HashMap<u16, RoutingTable>;
pub fn build_lb(config: AppConfig) -> PortListeners {
let mut healths: HashMap<IpAddr, Arc<RwLock<ServerHealth>>> = HashMap::new();
let mut backends: HashMap<String, Arc<Backend>> = HashMap::new();
for backend_cfg in config.backends {
let ip: IpAddr = backend_cfg.ip.parse().unwrap();
let addr = SocketAddr::new(ip, backend_cfg.port);
let health = healths
.entry(ip)
.or_insert_with(|| Arc::new(RwLock::new(ServerHealth::default())))
.clone();
let backend = Arc::new(Backend::new(
backend_cfg.id.clone(),
addr,
health,
));
backends.insert(backend_cfg.id, backend);
}
let mut listeners: PortListeners = HashMap::new();
for rule in config.rules {
let mut target_backends = Vec::new();
for target_name in &rule.targets {
if let Some(members) = config.clusters.get(target_name) {
for member_id in members {
if let Some(backend) = backends.get(member_id) {
target_backends.push(backend.clone());
}
}
}
else if let Some(backend) = backends.get(target_name) {
target_backends.push(backend.clone());
} else {
eprintln!("warning: target {} not found", target_name);
}
}
// possible for multiple targets of the same rule to have common backends.
target_backends.sort_by(|a, b| a.id.cmp(&b.id));
target_backends.dedup_by(|a, b| a.id == b.id);
if target_backends.is_empty() {
eprintln!("warning: rule has no valid targets, skipping.");
continue;
}
// for each different client port on this rule, we unfortunately need to make a new
// Balancer, since Balancer is not thread safe, requires &mut self for the backend
// selection.
// a good enough compromise to make a new one for each port, avoids using Mutex, at the
// cost of minor penalty to load balancing "quality" when you have several client ports.
let mut port_groups: HashMap<u16, Vec<IpCidr>> = HashMap::new();
for client_def in rule.clients {
let (cidr, port) = parse_client(&client_def);
port_groups.entry(port).or_default().push(cidr);
}
for (port, cidrs) in port_groups {
let table = listeners.entry(port).or_insert_with(|| RoutingTable {
balancers: Vec::new(),
entries: Vec::new(),
});
let pool = BackendPool::new(target_backends.clone());
let balancer: Box<dyn Balancer + Send> = match &rule.strategy {
LoadBalancerStrategy::RoundRobin => {
Box::new(RoundRobinBalancer::new(pool))
},
LoadBalancerStrategy::Adaptive { coefficients, alpha } => {
Box::new(AdaptiveWeightBalancer::new(pool, *coefficients, *alpha))
}
};
let balancer_idx = table.balancers.len();
table.balancers.push(balancer);
for cidr in cidrs {
table.entries.push((cidr, balancer_idx));
}
}
}
// sort to make most specific first, so that first match == longest prefix match
for table in listeners.values_mut() {
table.entries.sort_by(|(a, _), (b, _)| a.network_length().cmp(&b.network_length()));
}
listeners
}

49
src/config/mod.rs Normal file
View File

@@ -0,0 +1,49 @@
// config is written as a YAML file, the path will be passed to the program.
//
// the high level structure of the config is that we
// first define the individual backends (ip + port) we are going
// to load balance around.
//
// next we define some clusters, which are really more like a short
// alias for a group of backends.
//
// next we define the rules. these are written as a list of
// "ip/subnet:port" for the clients, and then a list of clusters
// for which backends these are balanced around. and of course
// specify which algorithm to use.
pub mod loader;
use std::collections::HashMap;
use serde::Deserialize;
#[derive(Debug, Deserialize)]
pub struct AppConfig {
pub backends: Vec<BackendConfig>,
#[serde(default)]
pub clusters: HashMap<String, Vec<String>>,
pub rules: Vec<RuleConfig>,
}
#[derive(Debug, Deserialize)]
pub struct BackendConfig {
pub id: String,
pub ip: String,
pub port: u16,
}
#[derive(Debug, Deserialize)]
pub struct RuleConfig {
pub clients: Vec<String>,
pub targets: Vec<String>,
pub strategy: LoadBalancerStrategy,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type")]
pub enum LoadBalancerStrategy {
RoundRobin,
Adaptive {
coefficients: [f64; 4],
alpha: f64,
},
}

View File

@@ -1,55 +1,79 @@
extern crate core;
mod balancer;
mod config;
mod backend;
mod proxy;
use tokio::net::TcpListener;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use std::sync::{Arc, RwLock};
use std::fs::File;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::backend::{Backend, BackendPool, ServerHealth};
use crate::balancer::Balancer;
use crate::balancer::round_robin::RoundRobinBalancer;
use tokio::net::TcpListener;
use crate::proxy::tcp::proxy_tcp_connection;
static NEXT_CONN_ID: AtomicU64 = AtomicU64::new(1);
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut pool: Vec<Arc<Backend>> = Vec::new();
let server_metric = Arc::new(RwLock::new(ServerHealth::default()));
pool.push(Arc::new(Backend::new(
"backend 1".into(),
"127.0.0.1:8081".parse().unwrap(),
server_metric.clone()
)));
let f = File::open("config.yaml").expect("couldn't open config.yaml");
let app_config: config::AppConfig = serde_saphyr::from_reader(f)?;
pool.push(Arc::new(Backend::new(
"backend 2".into(),
"127.0.0.1:8082".parse().unwrap(),
server_metric.clone()
)));
println!("Loaded {} backends, {} rules.",
app_config.backends.len(),
app_config.rules.len()
);
let mut balancer = RoundRobinBalancer::new(BackendPool::new(pool));
let listeners = config::loader::build_lb(app_config);
let listener = TcpListener::bind("127.0.0.1:8080").await?;
loop {
let (socket, _) = listener.accept().await?;
let conn_id = NEXT_CONN_ID.fetch_add(1, Ordering::Relaxed);
if let Some(backend) = balancer.choose_backend() {
tokio::spawn(async move {
if let Err(e) = proxy_tcp_connection(conn_id, socket, backend).await {
eprintln!("error: conn_id={} proxy failed: {}", conn_id, e);
}
});
} else {
eprintln!("error: no backendsd for conn_id={}", conn_id);
}
if listeners.is_empty() {
eprintln!("its a lawless land");
return Ok(());
}
let mut handles = Vec::new();
for (port, mut routing_table) in listeners {
handles.push(tokio::spawn(async move {
let addr = format!("0.0.0.0:{}", port);
println!("Starting tcp listener on {}", addr);
let listener = TcpListener::bind(&addr).await.expect("Failed to bind port");
loop {
let (socket, remote_addr) = match listener.accept().await {
Ok(v) => v,
Err(e) => {
eprintln!("error: listener port {}: {}", port, e);
continue;
}
};
let remote_ip = remote_addr.ip();
let conn_id = NEXT_CONN_ID.fetch_add(1, Ordering::Relaxed);
let mut chosen_backend = None;
for (cidr, balancer_idx) in &mut routing_table.entries {
if cidr.contains(&remote_ip) {
let balancer = &mut routing_table.balancers[*balancer_idx];
chosen_backend = balancer.choose_backend();
break;
}
}
if let Some(backend) = chosen_backend {
tokio::spawn(async move {
if let Err(e) = proxy_tcp_connection(conn_id, socket, backend).await {
eprintln!("error: conn_id={} proxy failed: {}", conn_id, e);
}
});
} else {
println!("error: no matching rule for {} on port {}", remote_ip, port);
}
}
}));
}
for h in handles {
let _ = h.await;
}
Ok(())
}