rust-irc/src/client/conn.rs
2017-06-19 14:42:27 -04:00

327 lines
10 KiB
Rust

//! Thread-safe connections on `IrcStreams`.
use std::io::prelude::*;
use std::io::{Cursor, Result};
use std::net::TcpStream;
use std::sync::Mutex;
use bufstream::BufStream;
use encoding::DecoderTrap;
use encoding::label::encoding_from_whatwg_label;
/// A connection.
pub trait Connection {
/// Sends a message over this connection.
fn send(&self, msg: &str, encoding: &str) -> Result<()>;
/// Receives a single line from this connection.
fn recv(&self, encoding: &str) -> Result<String>;
/// Gets the full record of all sent messages if the Connection records this.
/// This is intended for use in writing tests.
fn written(&self, encoding: &str) -> Option<String>;
/// Re-establishes this connection, disconnecting from the existing case if necessary.
fn reconnect(&self) -> Result<()>;
}
/// Useful internal type definitions.
type NetBufStream = BufStream<NetStream>;
/// A thread-safe connection over a buffered `NetStream`.
pub struct NetConnection {
host: Mutex<String>,
port: Mutex<u16>,
stream: Mutex<NetBufStream>,
}
impl NetConnection {
fn new(host: &str, port: u16, stream: NetBufStream) -> NetConnection {
NetConnection {
host: Mutex::new(host.to_owned()),
port: Mutex::new(port),
stream: Mutex::new(stream),
}
}
/// Creates a thread-safe TCP connection to the specified server.
pub fn connect(host: &str, port: u16) -> Result<NetConnection> {
let stream = try!(NetConnection::connect_internal(host, port));
Ok(NetConnection::new(host, port, stream))
}
/// connects to the specified server and returns a reader-writer pair.
fn connect_internal(host: &str, port: u16) -> Result<NetBufStream> {
let socket = try!(TcpStream::connect((host, port)));
Ok(BufStream::new(NetStream::Unsecured(socket)))
}
/// Creates a thread-safe TCP connection to the specified server over SSL.
/// If the library is compiled without SSL support, this method panics.
pub fn connect_ssl(host: &str, port: u16) -> Result<NetConnection> {
let stream = try!(NetConnection::connect_ssl_internal(host, port));
Ok(NetConnection::new(host, port, stream))
}
/// Panics because SSL support is not compiled in.
fn connect_ssl_internal(host: &str, port: u16) -> Result<NetBufStream> {
panic!("Cannot connect to {}:{} over SSL without compiling with SSL support.", host, port)
}
}
impl Connection for NetConnection {
fn send(&self, msg: &str, encoding: &str) -> Result<()> {
imp::send(&self.stream, msg, encoding)
}
fn recv(&self, encoding: &str) -> Result<String> {
imp::recv(&self.stream, encoding)
}
fn written(&self, _: &str) -> Option<String> {
None
}
fn reconnect(&self) -> Result<()> {
let use_ssl = match *self.stream.lock().unwrap().get_ref() {
NetStream::Unsecured(_) => false,
};
let host = self.host.lock().unwrap();
let port = self.port.lock().unwrap();
let stream = if use_ssl {
try!(NetConnection::connect_ssl_internal(&host, *port))
} else {
try!(NetConnection::connect_internal(&host, *port))
};
*self.stream.lock().unwrap() = stream;
Ok(())
}
}
/// A mock connection for testing purposes.
pub struct MockConnection {
reader: Mutex<Cursor<Vec<u8>>>,
writer: Mutex<Vec<u8>>,
}
impl MockConnection {
/// Creates a new mock connection with an empty read buffer.
pub fn empty() -> MockConnection {
MockConnection::from_byte_vec(Vec::new())
}
/// Creates a new mock connection with the specified string in the read buffer.
pub fn new(input: &str) -> MockConnection {
MockConnection::from_byte_vec(input.as_bytes().to_vec())
}
/// Creates a new mock connection with the specified bytes in the read buffer.
pub fn from_byte_vec(input: Vec<u8>) -> MockConnection {
MockConnection {
reader: Mutex::new(Cursor::new(input)),
writer: Mutex::new(Vec::new()),
}
}
}
impl Connection for MockConnection {
fn send(&self, msg: &str, encoding: &str) -> Result<()> {
imp::send(&self.writer, msg, encoding)
}
fn recv(&self, encoding: &str) -> Result<String> {
imp::recv(&self.reader, encoding)
}
fn written(&self, encoding: &str) -> Option<String> {
encoding_from_whatwg_label(encoding).and_then(|enc| {
enc.decode(&self.writer.lock().unwrap(), DecoderTrap::Replace)
.ok()
})
}
fn reconnect(&self) -> Result<()> {
Ok(())
}
}
mod imp {
use std::io::Result;
use std::io::Error;
use std::io::ErrorKind;
use std::sync::Mutex;
use encoding::{DecoderTrap, EncoderTrap};
use encoding::label::encoding_from_whatwg_label;
use client::data::kinds::{IrcRead, IrcWrite};
pub fn send<T: IrcWrite>(writer: &Mutex<T>, msg: &str, encoding: &str) -> Result<()> {
let encoding = match encoding_from_whatwg_label(encoding) {
Some(enc) => enc,
None => {
return Err(Error::new(
ErrorKind::InvalidInput,
&format!("Failed to find encoder. ({})", encoding)[..],
))
}
};
let data = match encoding.encode(msg, EncoderTrap::Replace) {
Ok(data) => data,
Err(data) => {
return Err(Error::new(
ErrorKind::InvalidInput,
&format!(
"Failed to encode {} as {}.",
data,
encoding.name()
)
[..],
))
}
};
let mut writer = writer.lock().unwrap();
try!(writer.write_all(&data));
writer.flush()
}
pub fn recv<T: IrcRead>(reader: &Mutex<T>, encoding: &str) -> Result<String> {
let encoding = match encoding_from_whatwg_label(encoding) {
Some(enc) => enc,
None => {
return Err(Error::new(
ErrorKind::InvalidInput,
&format!("Failed to find decoder. ({})", encoding)[..],
))
}
};
let mut buf = Vec::new();
reader
.lock()
.unwrap()
.read_until(b'\n', &mut buf)
.and_then(|_| match encoding.decode(&buf, DecoderTrap::Replace) {
_ if buf.is_empty() => Err(Error::new(ErrorKind::Other, "EOF")),
Ok(data) => Ok(data),
Err(data) => Err(Error::new(
ErrorKind::InvalidInput,
&format!(
"Failed to decode {} as {}.",
data,
encoding.name()
)
[..],
)),
})
}
}
/// An abstraction over different networked streams.
pub enum NetStream {
/// An unsecured TcpStream.
Unsecured(TcpStream),
}
impl Read for NetStream {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
match *self {
NetStream::Unsecured(ref mut stream) => stream.read(buf),
#[cfg(feature = "ssl")]
NetStream::Ssl(ref mut stream) => stream.read(buf),
}
}
}
impl Write for NetStream {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
match *self {
NetStream::Unsecured(ref mut stream) => stream.write(buf),
}
}
fn flush(&mut self) -> Result<()> {
match *self {
NetStream::Unsecured(ref mut stream) => stream.flush(),
}
}
}
#[cfg(test)]
mod test {
use super::{Connection, MockConnection};
use std::io::Result;
use client::data::Message;
use client::data::Command::PRIVMSG;
fn send_to<C: Connection, M: Into<Message>>(conn: &C, msg: M, encoding: &str) -> Result<()> {
conn.send(&msg.into().to_string(), encoding)
}
#[test]
fn send_utf8() {
let conn = MockConnection::empty();
assert!(
send_to(
&conn,
PRIVMSG("test".to_owned(), "€ŠšŽžŒœŸ".to_owned()),
"UTF-8",
).is_ok()
);
let data = conn.written("UTF-8").unwrap();
assert_eq!(&data[..], "PRIVMSG test :€ŠšŽžŒœŸ\r\n");
}
#[test]
fn send_utf8_str() {
let exp = "PRIVMSG test :€ŠšŽžŒœŸ\r\n";
let conn = MockConnection::empty();
assert!(send_to(&conn, exp, "UTF-8").is_ok());
let data = conn.written("UTF-8").unwrap();
assert_eq!(&data[..], exp);
}
#[test]
fn send_iso885915() {
let conn = MockConnection::empty();
assert!(
send_to(
&conn,
PRIVMSG("test".to_owned(), "€ŠšŽžŒœŸ".to_owned()),
"l9",
).is_ok()
);
let data = conn.written("l9").unwrap();
assert_eq!(&data[..], "PRIVMSG test :€ŠšŽžŒœŸ\r\n");
}
#[test]
fn send_iso885915_str() {
let exp = "PRIVMSG test :€ŠšŽžŒœŸ\r\n";
let conn = MockConnection::empty();
assert!(send_to(&conn, exp, "l9").is_ok());
let data = conn.written("l9").unwrap();
assert_eq!(&data[..], exp);
}
#[test]
fn recv_utf8() {
let conn = MockConnection::new("PRIVMSG test :Testing!\r\n");
assert_eq!(
&conn.recv("UTF-8").unwrap()[..],
"PRIVMSG test :Testing!\r\n"
);
}
#[test]
fn recv_iso885915() {
let data = [0xA4, 0xA6, 0xA8, 0xB4, 0xB8, 0xBC, 0xBD, 0xBE];
let conn = MockConnection::from_byte_vec({
let mut vec = Vec::new();
vec.extend("PRIVMSG test :".as_bytes());
vec.extend(data.iter());
vec.extend("\r\n".as_bytes());
vec.into_iter().collect::<Vec<_>>()
});
assert_eq!(
&conn.recv("l9").unwrap()[..],
"PRIVMSG test :€ŠšŽžŒœŸ\r\n"
);
}
}