//! Thread-safe connections on IrcStreams. #![stable] #[cfg(feature = "ssl")] use std::borrow::ToOwned; #[cfg(feature = "ssl")] use std::error::Error; use std::old_io::{BufferedReader, BufferedWriter, IoResult, TcpStream}; #[cfg(any(feature = "encode", feature = "ssl"))] use std::old_io::{IoError, IoErrorKind}; use std::sync::{Mutex, MutexGuard}; #[cfg(feature = "encode")] use encoding::{DecoderTrap, EncoderTrap, Encoding}; #[cfg(feature = "encode")] use encoding::label::encoding_from_whatwg_label; use client::data::kinds::{IrcReader, IrcWriter}; use client::data::message::ToMessage; #[cfg(feature = "ssl")] use openssl::ssl::{SslContext, SslMethod, SslStream}; #[cfg(feature = "ssl")] use openssl::ssl::error::SslError; /// A thread-safe connection. #[stable] pub struct Connection { reader: Mutex, writer: Mutex, } /// A Connection over a buffered NetStream. #[stable] pub type NetConnection = Connection, BufferedWriter>; /// An internal type type NetReaderWriterPair = (BufferedReader, BufferedWriter); #[stable] impl Connection, BufferedWriter> { /// Creates a thread-safe TCP connection to the specified server. #[stable] pub fn connect(host: &str, port: u16) -> IoResult { let (reader, writer) = try!(Connection::connect_internal(host, port)); Ok(Connection::new(reader, writer)) } /// connects to the specified server and returns a reader-writer pair. fn connect_internal(host: &str, port: u16) -> IoResult { let socket = try!(TcpStream::connect(&format!("{}:{}", host, port)[])); Ok((BufferedReader::new(NetStream::UnsecuredTcpStream(socket.clone())), BufferedWriter::new(NetStream::UnsecuredTcpStream(socket)))) } /// Creates a thread-safe TCP connection to the specified server over SSL. /// If the library is compiled without SSL support, this method panics. #[stable] pub fn connect_ssl(host: &str, port: u16) -> IoResult { let (reader, writer) = try!(Connection::connect_ssl_internal(host, port)); Ok(Connection::new(reader, writer)) } /// Connects over SSL to the specified server and returns a reader-writer pair. #[cfg(feature = "ssl")] fn connect_ssl_internal(host: &str, port: u16) -> IoResult { let socket = try!(TcpStream::connect(&format!("{}:{}", host, port)[])); let ssl = try!(ssl_to_io(SslContext::new(SslMethod::Tlsv1))); let ssl_socket = try!(ssl_to_io(SslStream::new(&ssl, socket))); Ok((BufferedReader::new(NetStream::SslTcpStream(ssl_socket.clone())), BufferedWriter::new(NetStream::SslTcpStream(ssl_socket)))) } /// Panics because SSL support is not compiled in. #[cfg(not(feature = "ssl"))] fn connect_ssl_internal(host: &str, port: u16) -> IoResult { panic!("Cannot connect to {}:{} over SSL without compiling with SSL support.", host, port) } /// Reconnects to the specified server, dropping the current connection. #[unstable = "Feature is relatively new."] pub fn reconnect(&self, host: &str, port: u16) -> IoResult<()> { let use_ssl = match self.reader.lock().unwrap().get_ref() { &NetStream::UnsecuredTcpStream(_) => false, #[cfg(feature = "ssl")] &NetStream::SslTcpStream(_) => true, }; let (reader, writer) = if use_ssl { try!(Connection::connect_ssl_internal(host, port)) } else { try!(Connection::connect_internal(host, port)) }; *self.reader.lock().unwrap() = reader; *self.writer.lock().unwrap() = writer; Ok(()) } /// Sets the keepalive for the network stream. #[unstable = "Feature is relatively new."] pub fn set_keepalive(&self, delay_in_seconds: Option) -> IoResult<()> { self.mod_stream(|tcp| tcp.set_keepalive(delay_in_seconds)) } /// Sets the timeout for the network stream. #[unstable = "Feature is relatively new."] pub fn set_timeout(&self, timeout_ms: Option) { self.mod_stream(|tcp| Ok(tcp.set_timeout(timeout_ms))).unwrap(); // this cannot fail. } /// Modifies the internal TcpStream using a function. fn mod_stream(&self, f: F) -> IoResult<()> where F: FnOnce(&mut TcpStream) -> IoResult<()> { match self.reader.lock().unwrap().get_mut() { &mut NetStream::UnsecuredTcpStream(ref mut tcp) => f(tcp), #[cfg(feature = "ssl")] &mut NetStream::SslTcpStream(ref mut ssl) => f(ssl.get_mut()), } } } #[stable] impl Connection { /// Creates a new connection from an IrcReader and an IrcWriter. #[stable] pub fn new(reader: T, writer: U) -> Connection { Connection { reader: Mutex::new(reader), writer: Mutex::new(writer), } } /// Sends a Message over this connection. #[stable] #[cfg(feature = "encode")] pub fn send(&self, to_msg: M, encoding: &str) -> IoResult<()> { let encoding = match encoding_from_whatwg_label(encoding) { Some(enc) => enc, None => return Err(IoError { kind: IoErrorKind::InvalidInput, desc: "Failed to find decoder.", detail: Some(format!("Invalid decoder: {}", encoding)) }) }; let msg = to_msg.to_message(); let data = match encoding.encode(&msg.into_string()[], EncoderTrap::Replace) { Ok(data) => data, Err(data) => return Err(IoError { kind: IoErrorKind::InvalidInput, desc: "Failed to decode message.", detail: Some(format!("Failed to decode {} as {}.", data, encoding.name())), }) }; let mut writer = self.writer.lock().unwrap(); try!(writer.write_all(&data[])); writer.flush() } /// Sends a message over this connection. #[stable] #[cfg(not(feature = "encode"))] pub fn send(&self, to_msg: M) -> IoResult<()> { let mut writer = self.writer.lock().unwrap(); try!(writer.write_str(&to_msg.to_message().into_string()[])); writer.flush() } /// Receives a single line from this connection. #[stable] #[cfg(feature = "encoding")] pub fn recv(&self, encoding: &str) -> IoResult { let encoding = match encoding_from_whatwg_label(encoding) { Some(enc) => enc, None => return Err(IoError { kind: IoErrorKind::InvalidInput, desc: "Failed to find decoder.", detail: Some(format!("Invalid decoder: {}", encoding)) }) }; self.reader.lock().unwrap().read_until(b'\n').and_then(|line| match encoding.decode(&line[], DecoderTrap::Replace) { Ok(data) => Ok(data), Err(data) => Err(IoError { kind: IoErrorKind::InvalidInput, desc: "Failed to decode message.", detail: Some(format!("Failed to decode {} as {}.", data, encoding.name())), }) } ) } /// Receives a single line from this connection. #[stable] #[cfg(not(feature = "encoding"))] pub fn recv(&self) -> IoResult { self.reader.lock().unwrap().read_line() } /// Acquires the Reader lock. #[stable] pub fn reader<'a>(&'a self) -> MutexGuard<'a, T> { self.reader.lock().unwrap() } /// Acquires the Writer lock. #[stable] pub fn writer<'a>(&'a self) -> MutexGuard<'a, U> { self.writer.lock().unwrap() } } /// Converts a Result into an IoResult. #[cfg(feature = "ssl")] fn ssl_to_io(res: Result) -> IoResult { match res { Ok(x) => Ok(x), Err(e) => Err(IoError { kind: IoErrorKind::OtherIoError, desc: "An SSL error occurred.", detail: Some(e.description().to_owned()), }), } } /// An abstraction over different networked streams. #[stable] pub enum NetStream { /// An unsecured TcpStream. #[stable] UnsecuredTcpStream(TcpStream), /// An SSL-secured TcpStream. /// This is only available when compiled with SSL support. #[cfg(feature = "ssl")] #[stable] SslTcpStream(SslStream), } impl Reader for NetStream { fn read(&mut self, buf: &mut [u8]) -> IoResult { match self { &mut NetStream::UnsecuredTcpStream(ref mut stream) => stream.read(buf), #[cfg(feature = "ssl")] &mut NetStream::SslTcpStream(ref mut stream) => stream.read(buf), } } } impl Writer for NetStream { fn write_all(&mut self, buf: &[u8]) -> IoResult<()> { match self { &mut NetStream::UnsecuredTcpStream(ref mut stream) => stream.write_all(buf), #[cfg(feature = "ssl")] &mut NetStream::SslTcpStream(ref mut stream) => stream.write_all(buf), } } } #[cfg(test)] mod test { use super::Connection; use std::old_io::{MemReader, MemWriter}; use std::old_io::util::{NullReader, NullWriter}; use client::data::message::Message; #[cfg(feature = "encode")] use encoding::{DecoderTrap, Encoding}; #[cfg(feature = "encode")] use encoding::all::{ISO_8859_15, UTF_8}; #[test] #[cfg(not(feature = "encode"))] fn send() { let conn = Connection::new(NullReader, MemWriter::new()); assert!(conn.send( Message::new(None, "PRIVMSG", Some(vec!["test"]), Some("Testing!")) ).is_ok()); let data = String::from_utf8(conn.writer().get_ref().to_vec()).unwrap(); assert_eq!(&data[], "PRIVMSG test :Testing!\r\n"); } #[test] #[cfg(not(feature = "encode"))] fn send_str() { let exp = "PRIVMSG test :Testing!\r\n"; let conn = Connection::new(NullReader, MemWriter::new()); assert!(conn.send(exp).is_ok()); let data = String::from_utf8(conn.writer().get_ref().to_vec()).unwrap(); assert_eq!(&data[], exp); } #[test] #[cfg(feature = "encode")] fn send_utf8() { let conn = Connection::new(NullReader, MemWriter::new()); assert!(conn.send( Message::new(None, "PRIVMSG", Some(vec!["test"]), Some("€ŠšŽžŒœŸ")), "UTF-8" ).is_ok()); let data = UTF_8.decode(conn.writer().get_ref(), DecoderTrap::Strict).unwrap(); assert_eq!(&data[], "PRIVMSG test :€ŠšŽžŒœŸ\r\n"); } #[test] #[cfg(feature = "encode")] fn send_utf8_str() { let exp = "PRIVMSG test :€ŠšŽžŒœŸ\r\n"; let conn = Connection::new(NullReader, MemWriter::new()); assert!(conn.send(exp, "UTF-8").is_ok()); let data = UTF_8.decode(conn.writer().get_ref(), DecoderTrap::Strict).unwrap(); assert_eq!(&data[], exp); } #[test] #[cfg(feature = "encode")] fn send_iso885915() { let conn = Connection::new(NullReader, MemWriter::new()); assert!(conn.send( Message::new(None, "PRIVMSG", Some(vec!["test"]), Some("€ŠšŽžŒœŸ")), "l9" ).is_ok()); let data = ISO_8859_15.decode(conn.writer().get_ref(), DecoderTrap::Strict).unwrap(); assert_eq!(&data[], "PRIVMSG test :€ŠšŽžŒœŸ\r\n"); } #[test] #[cfg(feature = "encode")] fn send_iso885915_str() { let exp = "PRIVMSG test :€ŠšŽžŒœŸ\r\n"; let conn = Connection::new(NullReader, MemWriter::new()); assert!(conn.send(exp, "l9").is_ok()); let data = ISO_8859_15.decode(conn.writer().get_ref(), DecoderTrap::Strict).unwrap(); assert_eq!(&data[], exp); } #[test] #[cfg(not(feature = "encode"))] fn recv() { let conn = Connection::new( MemReader::new("PRIVMSG test :Testing!\r\n".as_bytes().to_vec()), NullWriter ); assert_eq!(&conn.recv().unwrap()[], "PRIVMSG test :Testing!\r\n"); } #[test] #[cfg(feature = "encode")] fn recv_utf8() { let conn = Connection::new( MemReader::new(b"PRIVMSG test :Testing!\r\n".to_vec()), NullWriter ); assert_eq!(&conn.recv("UTF-8").unwrap()[], "PRIVMSG test :Testing!\r\n"); } #[test] #[cfg(feature = "encode")] fn recv_iso885915() { let conn = Connection::new( MemReader::new({ let mut vec = Vec::new(); vec.push_all(b"PRIVMSG test :"); vec.push_all(&[0xA4, 0xA6, 0xA8, 0xB4, 0xB8, 0xBC, 0xBD, 0xBE]); vec.push_all(b"\r\n"); vec }), NullWriter ); assert_eq!(&conn.recv("l9").unwrap()[], "PRIVMSG test :€ŠšŽžŒœŸ\r\n"); } }