//! A module providing IRC connections for use by `IrcServer`s. use futures_util::{sink::Sink, stream::Stream}; use pin_project::pin_project; use std::{ fmt, pin::Pin, task::{Context, Poll}, }; use tokio::net::TcpStream; use tokio::sync::mpsc::UnboundedSender; use tokio_util::codec::Framed; #[cfg(feature = "proxy")] use tokio_socks::tcp::Socks5Stream; #[cfg(feature = "proxy")] use crate::client::data::ProxyType; #[cfg(feature = "tls-native")] use std::{fs::File, io::Read}; #[cfg(feature = "tls-native")] use native_tls::{Certificate, Identity, TlsConnector}; #[cfg(feature = "tls-native")] use tokio_native_tls::{self, TlsStream}; #[cfg(feature = "tls-rust")] use std::{ fs::File, io::{BufReader, Error, ErrorKind}, sync::Arc, }; #[cfg(feature = "tls-rust")] use webpki_roots::TLS_SERVER_ROOTS; #[cfg(feature = "tls-rust")] use tokio_rustls::{ client::TlsStream, rustls::{internal::pemfile::certs, ClientConfig, PrivateKey}, webpki::DNSNameRef, TlsConnector, }; use crate::{ client::{ data::Config, mock::MockStream, transport::{LogView, Logged, Transport}, }, error, proto::{IrcCodec, Message}, }; /// An IRC connection used internally by `IrcServer`. #[pin_project(project = ConnectionProj)] pub enum Connection { #[doc(hidden)] Unsecured(#[pin] Transport), #[doc(hidden)] #[cfg(any(feature = "tls-native", feature = "tls-rust"))] Secured(#[pin] Transport>), #[doc(hidden)] Mock(#[pin] Logged), } impl fmt::Debug for Connection { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!( f, "{}", match *self { Connection::Unsecured(_) => "Connection::Unsecured(...)", #[cfg(any(feature = "tls-native", feature = "tls-rust"))] Connection::Secured(_) => "Connection::Secured(...)", Connection::Mock(_) => "Connection::Mock(...)", } ) } } impl Connection { /// Creates a new `Connection` using the specified `Config` pub(crate) async fn new( config: &Config, tx: UnboundedSender, ) -> error::Result { if config.use_mock_connection() { log::info!("Connecting via mock to {}.", config.server()?); return Ok(Connection::Mock(Logged::wrap( Self::new_mocked_transport(config, tx).await?, ))); } #[cfg(any(feature = "tls-native", feature = "tls-rust"))] { if config.use_tls() { log::info!("Connecting via TLS to {}.", config.server()?); return Ok(Connection::Secured( Self::new_secured_transport(config, tx).await?, )); } } log::info!("Connecting to {}.", config.server()?); Ok(Connection::Unsecured( Self::new_unsecured_transport(config, tx).await?, )) } #[cfg(not(feature = "proxy"))] async fn new_stream(config: &Config) -> error::Result { Ok(TcpStream::connect((config.server()?, config.port())).await?) } #[cfg(feature = "proxy")] async fn new_stream(config: &Config) -> error::Result { let server = config.server()?; let port = config.port(); let address = (server, port); match config.proxy_type() { ProxyType::None => Ok(TcpStream::connect(address).await?), ProxyType::Socks5 => { let proxy_server = config.proxy_server(); let proxy_port = config.proxy_port(); let proxy = (proxy_server, proxy_port); log::info!("Setup proxy {:?}.", proxy); let proxy_username = config.proxy_username(); let proxy_password = config.proxy_password(); if !proxy_username.is_empty() || !proxy_password.is_empty() { return Ok(Socks5Stream::connect_with_password( proxy, address, proxy_username, proxy_password, ) .await? .into_inner()); } Ok(Socks5Stream::connect(proxy, address).await?.into_inner()) } } } async fn new_unsecured_transport( config: &Config, tx: UnboundedSender, ) -> error::Result> { let stream = Self::new_stream(config).await?; let framed = Framed::new(stream, IrcCodec::new(config.encoding())?); Ok(Transport::new(&config, framed, tx)) } #[cfg(feature = "tls-native")] async fn new_secured_transport( config: &Config, tx: UnboundedSender, ) -> error::Result>> { let mut builder = TlsConnector::builder(); if let Some(cert_path) = config.cert_path() { if let Ok(mut file) = File::open(cert_path) { let mut cert_data = vec![]; file.read_to_end(&mut cert_data)?; let cert = Certificate::from_der(&cert_data)?; builder.add_root_certificate(cert); log::info!("Added {} to trusted certificates.", cert_path); } else { return Err(error::Error::InvalidConfig { path: config.path(), cause: error::ConfigError::FileMissing { file: cert_path.to_string(), }, }); } } if let Some(client_cert_path) = config.client_cert_path() { if let Ok(mut file) = File::open(client_cert_path) { let mut client_cert_data = vec![]; file.read_to_end(&mut client_cert_data)?; let client_cert_pass = config.client_cert_pass(); let pkcs12_archive = Identity::from_pkcs12(&client_cert_data, &client_cert_pass)?; builder.identity(pkcs12_archive); log::info!( "Using {} for client certificate authentication.", client_cert_path ); } else { return Err(error::Error::InvalidConfig { path: config.path(), cause: error::ConfigError::FileMissing { file: client_cert_path.to_string(), }, }); } } let connector: tokio_native_tls::TlsConnector = builder.build()?.into(); let domain = config.server()?; let stream = Self::new_stream(config).await?; let stream = connector.connect(domain, stream).await?; let framed = Framed::new(stream, IrcCodec::new(config.encoding())?); Ok(Transport::new(&config, framed, tx)) } #[cfg(feature = "tls-rust")] async fn new_secured_transport( config: &Config, tx: UnboundedSender, ) -> error::Result>> { let mut builder = ClientConfig::default(); builder .root_store .add_server_trust_anchors(&TLS_SERVER_ROOTS); if let Some(cert_path) = config.cert_path() { if let Ok(mut file) = File::open(cert_path) { let mut cert_data = BufReader::new(file); builder .root_store .add_pem_file(&mut cert_data) .map_err(|_| { error::Error::Io(Error::new(ErrorKind::InvalidInput, "invalid cert")) })?; log::info!("Added {} to trusted certificates.", cert_path); } else { return Err(error::Error::InvalidConfig { path: config.path(), cause: error::ConfigError::FileMissing { file: cert_path.to_string(), }, }); } } if let Some(client_cert_path) = config.client_cert_path() { if let Ok(mut file) = File::open(client_cert_path) { let client_cert_data = certs(&mut BufReader::new(file)).map_err(|_| { error::Error::Io(Error::new(ErrorKind::InvalidInput, "invalid cert")) })?; let client_cert_pass = PrivateKey(Vec::from(config.client_cert_pass())); builder .set_single_client_cert(client_cert_data, client_cert_pass) .map_err(|err| error::Error::Io(Error::new(ErrorKind::InvalidInput, err)))?; log::info!( "Using {} for client certificate authentication.", client_cert_path ); } else { return Err(error::Error::InvalidConfig { path: config.path(), cause: error::ConfigError::FileMissing { file: client_cert_path.to_string(), }, }); } } let connector = TlsConnector::from(Arc::new(builder)); let domain = DNSNameRef::try_from_ascii_str(config.server()?)?; let stream = Self::new_stream(config).await?; let stream = connector.connect(domain, stream).await?; let framed = Framed::new(stream, IrcCodec::new(config.encoding())?); Ok(Transport::new(&config, framed, tx)) } async fn new_mocked_transport( config: &Config, tx: UnboundedSender, ) -> error::Result> { use encoding::{label::encoding_from_whatwg_label, EncoderTrap}; let encoding = encoding_from_whatwg_label(config.encoding()).ok_or_else(|| { error::Error::UnknownCodec { codec: config.encoding().to_owned(), } })?; let init_str = config.mock_initial_value(); let initial = encoding .encode(init_str, EncoderTrap::Replace) .map_err(|data| error::Error::CodecFailed { codec: encoding.name(), data: data.into_owned(), })?; let stream = MockStream::new(&initial); let framed = Framed::new(stream, IrcCodec::new(config.encoding())?); Ok(Transport::new(&config, framed, tx)) } /// Gets a view of the internal logging if and only if this connection is using a mock stream. /// Otherwise, this will always return `None`. This is used for unit testing. pub fn log_view(&self) -> Option { match *self { Connection::Mock(ref inner) => Some(inner.view()), _ => None, } } } impl Stream for Connection { type Item = error::Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { ConnectionProj::Unsecured(inner) => inner.poll_next(cx), #[cfg(any(feature = "tls-native", feature = "tls-rust"))] ConnectionProj::Secured(inner) => inner.poll_next(cx), ConnectionProj::Mock(inner) => inner.poll_next(cx), } } } impl Sink for Connection { type Error = error::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { ConnectionProj::Unsecured(inner) => inner.poll_ready(cx), #[cfg(any(feature = "tls-native", feature = "tls-rust"))] ConnectionProj::Secured(inner) => inner.poll_ready(cx), ConnectionProj::Mock(inner) => inner.poll_ready(cx), } } fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { match self.project() { ConnectionProj::Unsecured(inner) => inner.start_send(item), #[cfg(any(feature = "tls-native", feature = "tls-rust"))] ConnectionProj::Secured(inner) => inner.start_send(item), ConnectionProj::Mock(inner) => inner.start_send(item), } } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { ConnectionProj::Unsecured(inner) => inner.poll_flush(cx), #[cfg(any(feature = "tls-native", feature = "tls-rust"))] ConnectionProj::Secured(inner) => inner.poll_flush(cx), ConnectionProj::Mock(inner) => inner.poll_flush(cx), } } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.project() { ConnectionProj::Unsecured(inner) => inner.poll_close(cx), #[cfg(any(feature = "tls-native", feature = "tls-rust"))] ConnectionProj::Secured(inner) => inner.poll_close(cx), ConnectionProj::Mock(inner) => inner.poll_close(cx), } } }