use anyhow::{Context, Result};
use hyper::{server::conn::Http, service::service_fn, Body, Method, Request, Response};
use openssl::ssl::Ssl;
use sep2_common::{deserialize, packages::pubsub::Notification, traits::SEResource};
use std::collections::HashMap;
use std::net;
use std::path::Path;
use std::{future::Future, net::SocketAddr, pin::Pin, sync::Arc};
use tokio::net::TcpListener;
use tokio_openssl::SslStream;
use crate::client::SEPResponse;
use crate::tls::{create_server_tls_config, TlsServerConfig};
pub trait RouteCallback<T: SEResource>: Send + Sync + 'static {
fn callback(
&self,
notif: Notification<T>,
) -> Pin<Box<dyn Future<Output = SEPResponse> + Send + 'static>>;
}
impl<F, R, T: SEResource> RouteCallback<T> for F
where
F: Fn(Notification<T>) -> R + Send + Sync + 'static,
R: Future<Output = SEPResponse> + Send + 'static,
{
fn callback(
&self,
notif: Notification<T>,
) -> Pin<Box<dyn Future<Output = SEPResponse> + Send + 'static>> {
Box::pin(self(notif))
}
}
type RouteHandler = Box<
dyn Fn(&str) -> Pin<Box<dyn Future<Output = SEPResponse> + Send + 'static>>
+ Send
+ Sync
+ 'static,
>;
struct Router {
routes: HashMap<String, RouteHandler, ahash::RandomState>,
}
impl Router {
fn new() -> Self {
Router {
routes: HashMap::default(),
}
}
async fn router(&self, req: Request<Body>) -> Result<Response<Body>> {
let path = req.uri().path().to_owned();
match self.routes.get(&path) {
Some(func) => {
let method = req.method();
match method {
&Method::POST => {
let body = req.into_body();
let bytes = hyper::body::to_bytes(body).await?;
let xml = String::from_utf8(bytes.to_vec())?;
Ok(hyper::Response::try_from(func(&xml).await)?)
}
_ => {
hyper::Response::try_from(SEPResponse::MethodNotAllowed("POST".to_owned()))
}
}
}
None => hyper::Response::try_from(SEPResponse::NotFound),
}
}
}
pub struct ClientNotifServer {
addr: SocketAddr,
cfg: Option<TlsServerConfig>,
router: Router,
}
impl ClientNotifServer {
pub fn new(addr: impl net::ToSocketAddrs) -> Result<Self> {
Ok(ClientNotifServer {
addr: addr
.to_socket_addrs()?
.next()
.context("Given server address did not yield a SocketAddr")?,
cfg: None,
router: Router::new(),
})
}
pub fn with_https(
mut self,
cert_path: impl AsRef<Path>,
pk_path: impl AsRef<Path>,
rootca_path: impl AsRef<Path>,
) -> Result<Self> {
self.cfg = Some(create_server_tls_config(cert_path, pk_path, rootca_path)?);
Ok(self)
}
pub fn add<T>(mut self, path: impl Into<String>, callback: impl RouteCallback<T>) -> Self
where
T: SEResource,
{
let path = path.into();
let new: RouteHandler = Box::new({
let log_path = path.clone();
move |e| {
let e = deserialize::<Notification<T>>(e);
match e {
Ok(resource) => {
log::debug!(
"NotifServer: Successfully deserialized a resource on {log_path}"
);
Box::pin(callback.callback(resource))
}
Err(err) => {
log::error!(
"NotifServer: Failed to deserialize resource on {log_path}: {err}"
);
Box::pin(async { SEPResponse::BadRequest(None) })
}
}
}
});
self.router.routes.insert(path, new);
self
}
pub async fn run(self, shutdown: impl Future) -> Result<()> {
tokio::pin!(shutdown);
let acceptor = self.cfg.map(|cfg| cfg.build());
let router = Arc::new(self.router);
let listener = TcpListener::bind(self.addr).await?;
let mut set = tokio::task::JoinSet::new();
log::info!("NotifServer: Listening on {}", self.addr);
loop {
let (stream, addr) = tokio::select! {
_ = &mut shutdown => break,
res = listener.accept() => match res {
Ok((s,a)) => (s,a),
Err(err) => {
log::error!("NotifServer: Failed to accept connection: {err}");
continue;
}
}
};
log::debug!("NotifServer: Remote connecting from {}", addr);
let service = service_fn({
let router = router.clone();
move |req| {
let router = router.clone();
async move { router.router(req).await }
}
});
if let Some(acceptor) = &acceptor {
let ssl = Ssl::new(acceptor.context())?;
let stream = SslStream::new(ssl, stream)?;
let mut stream = Box::pin(stream);
if let Err(e) = stream.as_mut().accept().await {
log::error!("NotifServer: Failed to perform TLS handshake: {e}");
continue;
}
set.spawn(async move {
if let Err(err) = Http::new().serve_connection(stream, service).await {
log::error!("NotifServer: Failed to handle HTTPS connection: {err}");
}
});
} else {
set.spawn(async move {
if let Err(err) = Http::new().serve_connection(stream, service).await {
log::error!("NotifServer: Failed to handle HTTP connection: {err}");
}
});
}
}
log::debug!("NotifServer: Attempting graceful shutdown");
set.shutdown().await;
log::info!("NotifServer: Server has been shutdown.");
Ok(())
}
}