1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
use std::{
    future::Future,
    net::{self, SocketAddr},
    path::Path,
};

use anyhow::{anyhow, Result};
use hyper::{
    header::LOCATION, server::conn::Http, service::service_fn, Body, Method, Request, Response,
    StatusCode,
};
use openssl::ssl::{Ssl, SslAcceptor, SslAcceptorBuilder, SslFiletype, SslMethod, SslVerifyMode};

use sep2_common::examples::{
    DC_16_04_11, EDL_16_02_08, ED_16_01_08, ED_16_03_06, ER_16_04_06, FSAL_16_03_11, REG_16_01_10,
};
use tokio::net::TcpListener;
use tokio_openssl::SslStream;

type TlsServerConfig = SslAcceptorBuilder;
fn create_server_tls_config(
    cert_path: impl AsRef<Path>,
    pk_path: impl AsRef<Path>,
    rootca_path: impl AsRef<Path>,
) -> Result<TlsServerConfig> {
    let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls_server()).unwrap();
    log::debug!("Setting CipherSuite");
    builder.set_cipher_list("ECDHE-ECDSA-AES128-CCM8")?;
    log::debug!("Loading Certificate File");
    builder.set_certificate_file(cert_path, SslFiletype::PEM)?;
    log::debug!("Loading Private Key File");
    builder.set_private_key_file(pk_path, SslFiletype::PEM)?;
    log::debug!("Loading Certificate Authority File");
    builder.set_ca_file(rootca_path)?;
    log::debug!("Setting verification mode");
    builder.set_verify(SslVerifyMode::FAIL_IF_NO_PEER_CERT | SslVerifyMode::PEER);
    Ok(builder)
}

pub struct TestServer {
    addr: SocketAddr,
    cfg: TlsServerConfig,
}

impl TestServer {
    pub fn new(
        addr: impl net::ToSocketAddrs,
        cert_path: impl AsRef<Path>,
        pk_path: impl AsRef<Path>,
        rootca_path: impl AsRef<Path>,
    ) -> Result<Self> {
        let cfg = create_server_tls_config(cert_path, pk_path, rootca_path)?;
        Ok(TestServer {
            addr: addr
                .to_socket_addrs()?
                .next()
                .ok_or(anyhow!("Given server address did not yield a SocketAddr"))?,
            cfg,
        })
    }

    pub async fn run(self, shutdown: impl Future) -> Result<()> {
        tokio::pin!(shutdown);
        let acceptor = self.cfg.build();
        let listener = TcpListener::bind(self.addr).await?;
        let mut set = tokio::task::JoinSet::new();
        log::info!("TestServer: Listening on {}", self.addr);
        loop {
            // Accept TCP Connection
            let (stream, addr) = tokio::select! {
                _ = &mut shutdown => break,
                res = listener.accept() => match res {
                    Ok((s,a)) => (s,a),
                    Err(err) => {
                        log::error!("TestServer: Failed to accept connection: {err}");
                        continue;
                    }
                }
            };
            log::debug!("TestServer: Remote connecting from {}", addr);

            // Perform TLS handshake
            let ssl = Ssl::new(acceptor.context())?;
            let stream = SslStream::new(ssl, stream)?;
            let mut stream = Box::pin(stream);
            if let Err(e) = stream.as_mut().accept().await {
                log::error!("TestServer: Failed to perform TLS handshake: {e}");
                continue;
            }

            // Bind connection to service
            let service = service_fn(move |req| async move { router(req).await });
            set.spawn(async move {
                if let Err(err) = Http::new().serve_connection(stream, service).await {
                    log::error!("TestServer: Failed to handle connection: {err}");
                }
            });
        }
        // Wait for all connection handlers to finish
        log::debug!("TestServer: Attempting graceful shutdown");
        set.shutdown().await;
        log::info!("TestServer: Server has been shutdown.");
        Ok(())
    }
}

async fn router(req: Request<Body>) -> Result<Response<Body>> {
    log::info!("Incoming Request: {:?}", req);
    let mut response = Response::new(Body::empty());
    match (req.method(), req.uri().path()) {
        (&Method::GET, "/dcap") => {
            *response.body_mut() = Body::from(DC_16_04_11);
        }
        (&Method::GET, "/edev") => {
            *response.body_mut() = Body::from(EDL_16_02_08);
        }
        (&Method::POST, "/edev") => {
            *response.status_mut() = StatusCode::CREATED;
            response
                .headers_mut()
                .insert(LOCATION, "/edev/4".parse().unwrap());
        }
        (&Method::GET, "/edev/3") => {
            *response.body_mut() = Body::from(ED_16_01_08);
        }
        (&Method::PUT, "/edev/3") => {
            *response.status_mut() = StatusCode::NO_CONTENT;
        }
        (&Method::DELETE, "/edev/3") => {
            *response.status_mut() = StatusCode::NO_CONTENT;
        }
        (&Method::GET, "/edev/4/fsal") => {
            *response.body_mut() = Body::from(FSAL_16_03_11);
        }
        (&Method::GET, "/edev/4") => {
            *response.body_mut() = Body::from(ED_16_03_06);
        }
        (&Method::GET, "/edev/5") => {
            *response.body_mut() = Body::from(ER_16_04_06);
        }
        (&Method::GET, "/edev/3/reg") => {
            *response.body_mut() = Body::from(REG_16_01_10);
        }
        (&Method::POST, "/rsp") => {
            *response.status_mut() = StatusCode::CREATED;
            // Location header is unset in examples, but is technically always required by spec?
            // Client will handle missing location header regardless.
        }
        _ => {
            *response.status_mut() = StatusCode::NOT_FOUND;
        }
    };
    log::info!("Outgoing Response: {:?}", response);
    Ok(response)
}