Implemented internal auto-reconnection for servers.

This commit is contained in:
Aaron Weiss 2016-01-13 17:02:39 -05:00
parent 33004abc12
commit 9ac625c091
2 changed files with 121 additions and 53 deletions

View file

@ -1,7 +1,7 @@
//! Thread-safe connections on IrcStreams. //! Thread-safe connections on IrcStreams.
#[cfg(feature = "ssl")] use std::error::Error as StdError; #[cfg(feature = "ssl")] use std::error::Error as StdError;
use std::io::prelude::*; use std::io::prelude::*;
use std::io::{BufReader, BufWriter, Result}; use std::io::{BufReader, BufWriter, Cursor, Empty, Result, Sink};
use std::io::Error; use std::io::Error;
use std::io::ErrorKind; use std::io::ErrorKind;
use std::net::TcpStream; use std::net::TcpStream;
@ -62,24 +62,6 @@ impl Connection<BufReader<NetStream>, BufWriter<NetStream>> {
panic!("Cannot connect to {}:{} over SSL without compiling with SSL support.", host, port) panic!("Cannot connect to {}:{} over SSL without compiling with SSL support.", host, port)
} }
/// Reconnects to the specified server, dropping the current connection.
pub fn reconnect(&self, host: &str, port: u16) -> Result<()> {
let use_ssl = match self.reader.lock().unwrap().get_ref() {
&NetStream::UnsecuredTcpStream(_) => false,
#[cfg(feature = "ssl")]
&NetStream::SslTcpStream(_) => true,
};
let (reader, writer) = if use_ssl {
try!(Connection::connect_ssl_internal(host, port))
} else {
try!(Connection::connect_internal(host, port))
};
*self.reader.lock().unwrap() = reader;
*self.writer.lock().unwrap() = writer;
Ok(())
}
/* /*
FIXME: removed until set_keepalive is stabilized. FIXME: removed until set_keepalive is stabilized.
/// Sets the keepalive for the network stream. /// Sets the keepalive for the network stream.
@ -193,6 +175,46 @@ fn ssl_to_io<T>(res: StdResult<T, SslError>) -> Result<T> {
} }
} }
/// A trait defining the ability to reconnect.
pub trait Reconnect {
/// Reconnects to the specified host and port, dropping the current connection if necessary.
fn reconnect(&self, host: &str, port: u16) -> Result<()>;
}
macro_rules! noop_reconnect {
($T:ty, $U:ty) => {
impl Reconnect for Connection<$T, $U> {
fn reconnect(&self, _: &str, _: u16) -> Result<()> {
Ok(())
}
}
}
}
impl Reconnect for NetConnection {
fn reconnect(&self, host: &str, port: u16) -> Result<()> {
let use_ssl = match self.reader.lock().unwrap().get_ref() {
&NetStream::UnsecuredTcpStream(_) => false,
#[cfg(feature = "ssl")]
&NetStream::SslTcpStream(_) => true,
};
let (reader, writer) = if use_ssl {
try!(Connection::connect_ssl_internal(host, port))
} else {
try!(Connection::connect_internal(host, port))
};
*self.reader.lock().unwrap() = reader;
*self.writer.lock().unwrap() = writer;
Ok(())
}
}
// TODO: replace all this with specialization when possible. :\
noop_reconnect!(Cursor<Vec<u8>>, Vec<u8>);
noop_reconnect!(Cursor<Vec<u8>>, Sink);
noop_reconnect!(BufReader<Empty>, Vec<u8>);
noop_reconnect!(BufReader<Empty>, Sink);
/// An abstraction over different networked streams. /// An abstraction over different networked streams.
pub enum NetStream { pub enum NetStream {
/// An unsecured TcpStream. /// An unsecured TcpStream.

View file

@ -2,15 +2,16 @@
//! //!
//! There are currently two recommended ways to work //! There are currently two recommended ways to work
use std::borrow::ToOwned; use std::borrow::ToOwned;
use std::cell::Cell;
use std::collections::HashMap; use std::collections::HashMap;
use std::error::Error as StdError; use std::error::Error as StdError;
use std::io::{BufReader, BufWriter, Error, ErrorKind, Result}; use std::io::{BufReader, BufWriter, Error, ErrorKind, Result};
use std::iter::Map; use std::iter::Map;
use std::path::Path; use std::path::Path;
use std::sync::{Arc, Mutex, RwLock}; use std::sync::{Arc, Mutex, RwLock};
use std::sync::mpsc::{Sender, channel}; use std::sync::mpsc::{Sender, Receiver, channel};
use std::thread::{JoinHandle, spawn}; use std::thread::{JoinHandle, spawn};
use client::conn::{Connection, NetStream}; use client::conn::{Connection, NetStream, Reconnect};
use client::data::{Command, Config, Message, Response, User}; use client::data::{Command, Config, Message, Response, User};
use client::data::Command::{JOIN, NICK, NICKSERV, PONG, MODE}; use client::data::Command::{JOIN, NICK, NICKSERV, PONG, MODE};
use client::data::kinds::{IrcRead, IrcWrite}; use client::data::kinds::{IrcRead, IrcWrite};
@ -36,8 +37,12 @@ pub trait Server<'a, T: IrcRead, U: IrcWrite> {
/// A thread-safe implementation of an IRC Server connection. /// A thread-safe implementation of an IRC Server connection.
pub struct IrcServer<T: IrcRead, U: IrcWrite> { pub struct IrcServer<T: IrcRead, U: IrcWrite> {
/// The channel for sending messages to write.
tx: Sender<Message>, tx: Sender<Message>,
/// The internal, thread-safe server state.
state: Arc<ServerState<T, U>>, state: Arc<ServerState<T, U>>,
/// A thread-local count of reconnection attempts used for synchronization.
reconnect_count: Cell<u32>,
} }
/// Thread-safe internal state for an IRC server connection. /// Thread-safe internal state for an IRC server connection.
@ -52,6 +57,14 @@ struct ServerState<T: IrcRead, U: IrcWrite> {
chanlists: Mutex<HashMap<String, Vec<User>>>, chanlists: Mutex<HashMap<String, Vec<User>>>,
/// A thread-safe index to track the current alternative nickname being used. /// A thread-safe index to track the current alternative nickname being used.
alt_nick_index: RwLock<usize>, alt_nick_index: RwLock<usize>,
/// A thread-safe count of reconnection attempts used for synchronization.
reconnect_count: Mutex<u32>,
}
impl<T: IrcRead, U: IrcWrite> ServerState<T, U> where Connection<T, U>: Reconnect {
fn reconnect(&self) -> Result<()> {
self.conn.reconnect(self.config.server(), self.config.port())
}
} }
/// An IrcServer over a buffered NetStream. /// An IrcServer over a buffered NetStream.
@ -72,18 +85,37 @@ impl IrcServer<BufReader<NetStream>, BufWriter<NetStream>> {
} else { } else {
Connection::connect(config.server(), config.port()) Connection::connect(config.server(), config.port())
}); });
Ok(IrcServer::from_connection(config, conn)) let (tx, rx): (Sender<Message>, Receiver<Message>) = channel();
} let state = Arc::new(ServerState {
conn: conn,
/// Reconnects to the IRC server. write_handle: Mutex::new(None),
pub fn reconnect(&self) -> Result<()> { config: config,
self.state.conn.reconnect(self.config().server(), self.config().port()) chanlists: Mutex::new(HashMap::new()),
alt_nick_index: RwLock::new(0),
reconnect_count: Mutex::new(0),
});
let weak = Arc::downgrade(&state);
let write_handle = spawn(move || while let Ok(msg) = rx.recv() {
if let Some(strong) = weak.upgrade() {
while let Err(_) = IrcServer::write(&strong, msg.clone()) {
let _ = strong.reconnect();
}
}
});
let state2 = state.clone();
let mut handle = state2.write_handle.lock().unwrap();
*handle = Some(write_handle);
Ok(IrcServer { tx: tx, state: state, reconnect_count: Cell::new(0) })
} }
} }
impl<T: IrcRead, U: IrcWrite> Clone for IrcServer<T, U> { impl<T: IrcRead, U: IrcWrite> Clone for IrcServer<T, U> {
fn clone(&self) -> IrcServer<T, U> { fn clone(&self) -> IrcServer<T, U> {
IrcServer { tx: self.tx.clone(), state: self.state.clone() } IrcServer {
tx: self.tx.clone(),
state: self.state.clone(),
reconnect_count: self.reconnect_count.clone()
}
} }
} }
@ -96,7 +128,7 @@ impl<T: IrcRead, U: IrcWrite> Drop for ServerState<T, U> {
} }
} }
impl<'a, T: IrcRead, U: IrcWrite> Server<'a, T, U> for IrcServer<T, U> { impl<'a, T: IrcRead, U: IrcWrite> Server<'a, T, U> for IrcServer<T, U> where Connection<T, U>: Reconnect {
fn config(&self) -> &Config { fn config(&self) -> &Config {
&self.state.config &self.state.config
} }
@ -131,7 +163,7 @@ impl<'a, T: IrcRead, U: IrcWrite> Server<'a, T, U> for IrcServer<T, U> {
} }
} }
impl<T: IrcRead, U: IrcWrite> IrcServer<T, U> { impl<T: IrcRead, U: IrcWrite> IrcServer<T, U> where Connection<T, U>: Reconnect {
/// Creates an IRC server from the specified configuration, and any arbitrary Connection. /// Creates an IRC server from the specified configuration, and any arbitrary Connection.
pub fn from_connection(config: Config, conn: Connection<T, U>) -> IrcServer<T, U> { pub fn from_connection(config: Config, conn: Connection<T, U>) -> IrcServer<T, U> {
let (tx, rx) = channel(); let (tx, rx) = channel();
@ -141,17 +173,18 @@ impl<T: IrcRead, U: IrcWrite> IrcServer<T, U> {
config: config, config: config,
chanlists: Mutex::new(HashMap::new()), chanlists: Mutex::new(HashMap::new()),
alt_nick_index: RwLock::new(0), alt_nick_index: RwLock::new(0),
reconnect_count: Mutex::new(0),
}); });
let weak = Arc::downgrade(&state); let weak = Arc::downgrade(&state);
let write_handle = spawn(move || while let Ok(msg) = rx.recv() { let write_handle = spawn(move || while let Ok(msg) = rx.recv() {
if let Some(strong) = weak.upgrade() { if let Some(strong) = weak.upgrade() {
IrcServer::write(&strong, msg); let _ = IrcServer::write(&strong, msg);
} }
}); });
let state2 = state.clone(); let state2 = state.clone();
let mut handle = state2.write_handle.lock().unwrap(); let mut handle = state2.write_handle.lock().unwrap();
*handle = Some(write_handle); *handle = Some(write_handle);
IrcServer { tx: tx, state: state } IrcServer { tx: tx, state: state, reconnect_count: Cell::new(0) }
} }
/// Gets a reference to the IRC server's connection. /// Gets a reference to the IRC server's connection.
@ -159,6 +192,19 @@ impl<T: IrcRead, U: IrcWrite> IrcServer<T, U> {
&self.state.conn &self.state.conn
} }
/// Reconnects to the IRC server.
pub fn reconnect(&self) -> Result<()> {
let mut reconnect_count = self.state.reconnect_count.lock().unwrap();
let res = if self.reconnect_count.get() == *reconnect_count {
*reconnect_count += 1;
self.state.reconnect()
} else {
Ok(())
};
self.reconnect_count.set(*reconnect_count);
res
}
#[cfg(feature = "encode")] #[cfg(feature = "encode")]
fn write<M: Into<Message>>(state: &Arc<ServerState<T, U>>, msg: M) -> Result<()> { fn write<M: Into<Message>>(state: &Arc<ServerState<T, U>>, msg: M) -> Result<()> {
state.conn.send(msg, state.config.encoding()) state.conn.send(msg, state.config.encoding())
@ -169,7 +215,6 @@ impl<T: IrcRead, U: IrcWrite> IrcServer<T, U> {
state.conn.send(msg) state.conn.send(msg)
} }
/// Handles messages internally for basic bot functionality. /// Handles messages internally for basic bot functionality.
fn handle_message(&self, msg: &Message) { fn handle_message(&self, msg: &Message) {
if let Some(resp) = Response::from_message(msg) { if let Some(resp) = Response::from_message(msg) {
@ -316,7 +361,7 @@ pub struct ServerIterator<'a, T: IrcRead + 'a, U: IrcWrite + 'a> {
pub type ServerCmdIterator<'a, T, U> = pub type ServerCmdIterator<'a, T, U> =
Map<ServerIterator<'a, T, U>, fn(Result<Message>) -> Result<Command>>; Map<ServerIterator<'a, T, U>, fn(Result<Message>) -> Result<Command>>;
impl<'a, T: IrcRead + 'a, U: IrcWrite + 'a> ServerIterator<'a, T, U> { impl<'a, T: IrcRead + 'a, U: IrcWrite + 'a> ServerIterator<'a, T, U> where Connection<T, U>: Reconnect {
/// Creates a new ServerIterator for the desired IrcServer. /// Creates a new ServerIterator for the desired IrcServer.
pub fn new(server: &IrcServer<T, U>) -> ServerIterator<T, U> { pub fn new(server: &IrcServer<T, U>) -> ServerIterator<T, U> {
ServerIterator { server: server } ServerIterator { server: server }
@ -325,35 +370,35 @@ impl<'a, T: IrcRead + 'a, U: IrcWrite + 'a> ServerIterator<'a, T, U> {
/// Gets the next line from the connection. /// Gets the next line from the connection.
#[cfg(feature = "encode")] #[cfg(feature = "encode")]
fn get_next_line(&self) -> Result<String> { fn get_next_line(&self) -> Result<String> {
self.server.state.conn.recv(self.server.config().encoding()) self.server.conn().recv(self.server.config().encoding())
} }
/// Gets the next line from the connection. /// Gets the next line from the connection.
#[cfg(not(feature = "encode"))] #[cfg(not(feature = "encode"))]
fn get_next_line(&self) -> Result<String> { fn get_next_line(&self) -> Result<String> {
self.server.state.conn.recv() self.server.conn().recv()
} }
} }
impl<'a, T: IrcRead + 'a, U: IrcWrite + 'a> Iterator for ServerIterator<'a, T, U> { impl<'a, T: IrcRead + 'a, U: IrcWrite + 'a> Iterator for ServerIterator<'a, T, U> where Connection<T, U>: Reconnect {
type Item = Result<Message>; type Item = Result<Message>;
fn next(&mut self) -> Option<Result<Message>> { fn next(&mut self) -> Option<Result<Message>> {
let res = self.get_next_line().and_then(|msg| loop {
match msg.parse() { match self.get_next_line() {
Ok(msg) => { Ok(msg) => match msg.parse() {
self.server.handle_message(&msg); Ok(res) => {
Ok(msg) self.server.handle_message(&res);
return Some(Ok(res))
},
Err(_) => return Some(Err(Error::new(ErrorKind::InvalidInput,
&format!("Failed to parse message. (Message: {})", msg)[..]
)))
}, },
Err(_) => Err(Error::new(ErrorKind::InvalidInput, Err(ref err) if err.description() == "EOF" => return None,
&format!("Failed to parse message. (Message: {})", msg)[..] Err(_) => {
)) let _ = self.server.reconnect();
}
} }
);
match res {
Err(ref err) if err.kind() == ErrorKind::ConnectionAborted => None,
Err(ref err) if err.kind() == ErrorKind::ConnectionReset => None,
Err(ref err) if err.description() == "EOF" => None,
_ => Some(res)
} }
} }
} }
@ -363,7 +408,7 @@ mod test {
use super::{IrcServer, Server}; use super::{IrcServer, Server};
use std::default::Default; use std::default::Default;
use std::io::{Cursor, sink}; use std::io::{Cursor, sink};
use client::conn::Connection; use client::conn::{Connection, Reconnect};
use client::data::{Config, Message, User}; use client::data::{Config, Message, User};
use client::data::command::Command::PRIVMSG; use client::data::command::Command::PRIVMSG;
use client::data::kinds::IrcRead; use client::data::kinds::IrcRead;
@ -381,7 +426,8 @@ mod test {
} }
} }
pub fn get_server_value<T: IrcRead>(server: IrcServer<T, Vec<u8>>) -> String { pub fn get_server_value<T: IrcRead>(server: IrcServer<T, Vec<u8>>) -> String
where Connection<T, Vec<u8>>: Reconnect {
let vec = server.conn().writer().clone(); let vec = server.conn().writer().clone();
String::from_utf8(vec).unwrap() String::from_utf8(vec).unwrap()
} }