From a3f50c1f0a7d153805b2c9cee2f19c55eae324dd Mon Sep 17 00:00:00 2001 From: psun256 Date: Mon, 8 Dec 2025 14:31:59 -0500 Subject: [PATCH] should be good to extend functionality now --- src/backend/mod.rs | 89 ++++++++++++++++--------------------- src/balancer/mod.rs | 9 ++++ src/balancer/round_robin.rs | 33 ++++++++++++++ src/config.rs | 6 +++ src/main.rs | 53 ++++++++++++++-------- src/proxy/mod.rs | 40 ++++++++++++----- src/proxy/tcp.rs | 26 +++++++++++ src/proxy/tcp_proxy.rs | 2 - 8 files changed, 174 insertions(+), 84 deletions(-) create mode 100644 src/balancer/round_robin.rs create mode 100644 src/proxy/tcp.rs delete mode 100644 src/proxy/tcp_proxy.rs diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 4bfd49c..9efcfc2 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -1,71 +1,58 @@ -use std::collections::HashMap; +use core::fmt; use std::net::SocketAddr; use std::sync::RwLock; use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; - -pub struct BackendPool { - pub backends: Arc>>>, -} +use std::sync::atomic::{AtomicUsize, Ordering}; #[derive(Debug)] pub struct Backend { pub id: String, pub address: SocketAddr, - pub is_healthy: AtomicBool, // no clue how this should work, for now - pub current_load: AtomicUsize, // no clue how this should work, for now -} - -impl BackendPool { - pub fn new(initial_backends: Vec>) -> Self { - let mut map = HashMap::new(); - for backend in initial_backends { - map.insert(backend.id.clone(), backend); - } - - Self { - backends: Arc::new(RwLock::new(map)), - } - } - - pub fn add_backend(&self, backend: Arc) { - let mut backends_guard = self.backends - .write() - .expect("BackendPool lock poisoned"); - // let backends_guard = self.backends.read().unwrap_or_else(|poisoned| poisoned.into_inner()); - backends_guard.insert(backend.id.clone(), backend); - } - - pub fn get_backend(&self, id: &str) -> Option> { - let backends_guard = self.backends - .read() - .expect("BackendPool lock poisoned"); - // let backends_guard = self.backends.read().unwrap_or_else(|poisoned| poisoned.into_inner()); - backends_guard.get(id).cloned() - } - - pub fn bruh_amogus_sus(&self) { - for k in self.backends.read().unwrap().keys() { - self.backends.write().unwrap().get(k).unwrap().increment_current_load(); - } - } + pub active_connections: AtomicUsize, } impl Backend { pub fn new(id: String, address: SocketAddr) -> Self { Self { - id: id, - address: address, - is_healthy: AtomicBool::new(false), - current_load: AtomicUsize::new(0), + id: id.to_string(), + address, + active_connections: AtomicUsize::new(0), } } - pub fn increment_current_load(&self) { - self.current_load.fetch_add(1, Ordering::SeqCst); + // Ordering::Relaxed means the ops could be in any order, but since this + // is just a metric, and we assume the underlying system is sane + // enough not to behave poorly, so SeqCst is probably overkill. + pub fn inc_connections(&self) { + self.active_connections.fetch_add(1, Ordering::Relaxed); + println!("{} has {} connections open", self.id, self.active_connections.load(Ordering::Relaxed)); } - pub fn decrement_current_load(&self) { - self.current_load.fetch_sub(1, Ordering::SeqCst); + pub fn dec_connections(&self) { + self.active_connections.fetch_sub(1, Ordering::Relaxed); + println!("{} has {} connections open", self.id, self.active_connections.load(Ordering::Relaxed)); + } +} + +impl fmt::Display for Backend { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{} ({})", self.address, self.id) + } +} + +#[derive(Clone, Debug)] +pub struct BackendPool { + pub backends: Arc>>>, +} + +impl BackendPool { + pub fn new() -> Self { + BackendPool { + backends: Arc::new(RwLock::new(Vec::new())), + } + } + + pub fn add(&self, backend: Backend) { + self.backends.write().unwrap().push(Arc::new(backend)); } } \ No newline at end of file diff --git a/src/balancer/mod.rs b/src/balancer/mod.rs index e69de29..555f6c2 100644 --- a/src/balancer/mod.rs +++ b/src/balancer/mod.rs @@ -0,0 +1,9 @@ +pub mod round_robin; + +use std::fmt::Debug; +use std::sync::Arc; +use crate::backend::Backend; + +pub trait Balancer: Debug + Send + Sync + 'static { + fn choose_backend(&mut self) -> Option>; +} \ No newline at end of file diff --git a/src/balancer/round_robin.rs b/src/balancer/round_robin.rs new file mode 100644 index 0000000..3269a69 --- /dev/null +++ b/src/balancer/round_robin.rs @@ -0,0 +1,33 @@ +use std::sync::{Arc, RwLock}; +use std::fmt::Debug; +use crate::backend::{Backend, BackendPool}; +use crate::balancer::Balancer; + +// only the main thread for receiving connections should be +// doing the load balancing. alternatively, each thread +// that handles load balancing should get their own instance. +#[derive(Debug)] +pub struct RoundRobinBalancer { + pool: BackendPool, + index: usize, +} + +impl RoundRobinBalancer { + pub fn new(pool: BackendPool) -> RoundRobinBalancer { + Self { + pool, + index: 0, + } + } +} + +impl Balancer for RoundRobinBalancer { + fn choose_backend(&mut self) -> Option> { + let backends = self.pool.backends.read().unwrap(); + if backends.is_empty() { return None; } + + let backend = backends[self.index % backends.len()].clone(); + self.index = self.index.wrapping_add(1); + Some(backend) + } +} \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index e69de29..6c59f64 100644 --- a/src/config.rs +++ b/src/config.rs @@ -0,0 +1,6 @@ +// TODO: "routing" rules +// backends defined as ip + port +// define sets of backends +// allowed set operations for now is just union +// rules are ip + mask and ports, maps to some of the sets +// defined earlier, along with a routing strategy \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index ca2015f..2b4a373 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ +extern crate core; + mod balancer; mod config; mod backend; @@ -6,32 +8,45 @@ mod proxy; use tokio::net::TcpListener; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use crate::backend::{Backend, BackendPool}; +use crate::balancer::Balancer; +use crate::balancer::round_robin::RoundRobinBalancer; +use crate::proxy::tcp::proxy_tcp_connection; + +static NEXT_CONN_ID: AtomicU64 = AtomicU64::new(1); #[tokio::main] async fn main() -> Result<(), Box> { - let listener = TcpListener::bind("0.0.0.0:8080").await?; + let pool = BackendPool::new(); + + pool.add(Backend::new( + "backend 1".into(), + "127.0.0.1:8081".parse().unwrap(), + )); + + pool.add(Backend::new( + "backend 2".into(), + "127.0.0.1:8082".parse().unwrap(), + )); + + let mut balancer = RoundRobinBalancer::new(pool.clone()); + + let listener = TcpListener::bind("127.0.0.1:8080").await?; loop { - let (mut socket, _) = listener.accept().await?; + let (socket, _) = listener.accept().await?; - tokio::spawn(async move { - let mut buf = [0; 1024]; + let conn_id = NEXT_CONN_ID.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - loop { - let n = match socket.read(&mut buf).await { - Ok(0) => return, - Ok(n) => n, - Err(e) => { - eprintln!("failed to read from socket; err = {:?}", e); - return; - } - }; - - if let Err(e) = socket.write_all(&buf[0..n]).await { - eprintln!("failed to write to socket; err = {:?}", e); - return; + if let Some(backend) = balancer.choose_backend() { + tokio::spawn(async move { + if let Err(e) = proxy_tcp_connection(conn_id, socket, backend).await { + eprintln!("error: conn_id={} proxy failed: {}", conn_id, e); } - } - }); + }); + } else { + eprintln!("error: no backendsd for conn_id={}", conn_id); + } } } diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 9da0885..7b8e6df 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -1,27 +1,43 @@ -mod tcp_proxy; - use std::net::SocketAddr; +use std::sync::Arc; use std::time::Instant; +use crate::backend::Backend; + +pub mod tcp; -// owned and accessed by only one thread. pub struct ConnectionContext { - pub connection_id: u64, + pub id: u64, pub client_addr: SocketAddr, pub start_time: Instant, - pub backend_addr: Option, - pub bytes_transferred: usize, - // pub protocol: String, - // pub sticky_id: Option, + pub backend: Arc, + pub bytes_transferred: u64, } impl ConnectionContext { - pub fn new(connection_id: u64, client_addr: SocketAddr) -> Self { + pub fn new(id: u64, client_addr: SocketAddr, backend: Arc) -> Self { + backend.inc_connections(); + Self { - connection_id: connection_id, - client_addr: client_addr, + id, + client_addr, start_time: Instant::now(), - backend_addr: Default::default(), + backend, bytes_transferred: 0, } } +} + +impl Drop for ConnectionContext { + fn drop(&mut self) { + self.backend.dec_connections(); + let duration = self.start_time.elapsed(); + + println!("info: conn_id={} closed. client={} backend={} bytes={} duration={:.2?}", + self.id, + self.client_addr, + self.backend.address, + self.bytes_transferred, + duration.as_secs_f64() + ); + } } \ No newline at end of file diff --git a/src/proxy/tcp.rs b/src/proxy/tcp.rs new file mode 100644 index 0000000..c03bfb6 --- /dev/null +++ b/src/proxy/tcp.rs @@ -0,0 +1,26 @@ +use std::sync::Arc; +use tokio::io; +use tokio::net::TcpStream; +use anywho::Error; +use crate::backend::Backend; +use crate::proxy::ConnectionContext; + +pub async fn proxy_tcp_connection(connection_id: u64, mut client_stream: TcpStream, backend: Arc) -> Result<(), Error> { + let client_addr = client_stream.peer_addr()?; + + let mut ctx = ConnectionContext::new(connection_id, client_addr, backend.clone()); + + #[cfg(debug_assertions)] + println!("info: conn_id={} connecting to {}", connection_id, ctx.backend.id); + + let mut backend_stream = TcpStream::connect(&backend.address).await?; + + let (tx, rx) = io::copy_bidirectional( + &mut client_stream, + &mut backend_stream, + ).await?; + + ctx.bytes_transferred = tx + rx; + + Ok(()) +} \ No newline at end of file diff --git a/src/proxy/tcp_proxy.rs b/src/proxy/tcp_proxy.rs deleted file mode 100644 index c5d584a..0000000 --- a/src/proxy/tcp_proxy.rs +++ /dev/null @@ -1,2 +0,0 @@ -use super::*; -