Implemented internal auto-reconnection for servers.
This commit is contained in:
parent
33004abc12
commit
9ac625c091
2 changed files with 121 additions and 53 deletions
|
@ -1,7 +1,7 @@
|
|||
//! Thread-safe connections on IrcStreams.
|
||||
#[cfg(feature = "ssl")] use std::error::Error as StdError;
|
||||
use std::io::prelude::*;
|
||||
use std::io::{BufReader, BufWriter, Result};
|
||||
use std::io::{BufReader, BufWriter, Cursor, Empty, Result, Sink};
|
||||
use std::io::Error;
|
||||
use std::io::ErrorKind;
|
||||
use std::net::TcpStream;
|
||||
|
@ -62,24 +62,6 @@ impl Connection<BufReader<NetStream>, BufWriter<NetStream>> {
|
|||
panic!("Cannot connect to {}:{} over SSL without compiling with SSL support.", host, port)
|
||||
}
|
||||
|
||||
/// Reconnects to the specified server, dropping the current connection.
|
||||
pub fn reconnect(&self, host: &str, port: u16) -> Result<()> {
|
||||
let use_ssl = match self.reader.lock().unwrap().get_ref() {
|
||||
&NetStream::UnsecuredTcpStream(_) => false,
|
||||
#[cfg(feature = "ssl")]
|
||||
&NetStream::SslTcpStream(_) => true,
|
||||
};
|
||||
let (reader, writer) = if use_ssl {
|
||||
try!(Connection::connect_ssl_internal(host, port))
|
||||
} else {
|
||||
try!(Connection::connect_internal(host, port))
|
||||
};
|
||||
*self.reader.lock().unwrap() = reader;
|
||||
*self.writer.lock().unwrap() = writer;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
FIXME: removed until set_keepalive is stabilized.
|
||||
/// Sets the keepalive for the network stream.
|
||||
|
@ -193,6 +175,46 @@ fn ssl_to_io<T>(res: StdResult<T, SslError>) -> Result<T> {
|
|||
}
|
||||
}
|
||||
|
||||
/// A trait defining the ability to reconnect.
|
||||
pub trait Reconnect {
|
||||
/// Reconnects to the specified host and port, dropping the current connection if necessary.
|
||||
fn reconnect(&self, host: &str, port: u16) -> Result<()>;
|
||||
}
|
||||
|
||||
macro_rules! noop_reconnect {
|
||||
($T:ty, $U:ty) => {
|
||||
impl Reconnect for Connection<$T, $U> {
|
||||
fn reconnect(&self, _: &str, _: u16) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Reconnect for NetConnection {
|
||||
fn reconnect(&self, host: &str, port: u16) -> Result<()> {
|
||||
let use_ssl = match self.reader.lock().unwrap().get_ref() {
|
||||
&NetStream::UnsecuredTcpStream(_) => false,
|
||||
#[cfg(feature = "ssl")]
|
||||
&NetStream::SslTcpStream(_) => true,
|
||||
};
|
||||
let (reader, writer) = if use_ssl {
|
||||
try!(Connection::connect_ssl_internal(host, port))
|
||||
} else {
|
||||
try!(Connection::connect_internal(host, port))
|
||||
};
|
||||
*self.reader.lock().unwrap() = reader;
|
||||
*self.writer.lock().unwrap() = writer;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: replace all this with specialization when possible. :\
|
||||
noop_reconnect!(Cursor<Vec<u8>>, Vec<u8>);
|
||||
noop_reconnect!(Cursor<Vec<u8>>, Sink);
|
||||
noop_reconnect!(BufReader<Empty>, Vec<u8>);
|
||||
noop_reconnect!(BufReader<Empty>, Sink);
|
||||
|
||||
/// An abstraction over different networked streams.
|
||||
pub enum NetStream {
|
||||
/// An unsecured TcpStream.
|
||||
|
|
|
@ -2,15 +2,16 @@
|
|||
//!
|
||||
//! There are currently two recommended ways to work
|
||||
use std::borrow::ToOwned;
|
||||
use std::cell::Cell;
|
||||
use std::collections::HashMap;
|
||||
use std::error::Error as StdError;
|
||||
use std::io::{BufReader, BufWriter, Error, ErrorKind, Result};
|
||||
use std::iter::Map;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Mutex, RwLock};
|
||||
use std::sync::mpsc::{Sender, channel};
|
||||
use std::sync::mpsc::{Sender, Receiver, channel};
|
||||
use std::thread::{JoinHandle, spawn};
|
||||
use client::conn::{Connection, NetStream};
|
||||
use client::conn::{Connection, NetStream, Reconnect};
|
||||
use client::data::{Command, Config, Message, Response, User};
|
||||
use client::data::Command::{JOIN, NICK, NICKSERV, PONG, MODE};
|
||||
use client::data::kinds::{IrcRead, IrcWrite};
|
||||
|
@ -36,8 +37,12 @@ pub trait Server<'a, T: IrcRead, U: IrcWrite> {
|
|||
|
||||
/// A thread-safe implementation of an IRC Server connection.
|
||||
pub struct IrcServer<T: IrcRead, U: IrcWrite> {
|
||||
/// The channel for sending messages to write.
|
||||
tx: Sender<Message>,
|
||||
/// The internal, thread-safe server state.
|
||||
state: Arc<ServerState<T, U>>,
|
||||
/// A thread-local count of reconnection attempts used for synchronization.
|
||||
reconnect_count: Cell<u32>,
|
||||
}
|
||||
|
||||
/// Thread-safe internal state for an IRC server connection.
|
||||
|
@ -52,6 +57,14 @@ struct ServerState<T: IrcRead, U: IrcWrite> {
|
|||
chanlists: Mutex<HashMap<String, Vec<User>>>,
|
||||
/// A thread-safe index to track the current alternative nickname being used.
|
||||
alt_nick_index: RwLock<usize>,
|
||||
/// A thread-safe count of reconnection attempts used for synchronization.
|
||||
reconnect_count: Mutex<u32>,
|
||||
}
|
||||
|
||||
impl<T: IrcRead, U: IrcWrite> ServerState<T, U> where Connection<T, U>: Reconnect {
|
||||
fn reconnect(&self) -> Result<()> {
|
||||
self.conn.reconnect(self.config.server(), self.config.port())
|
||||
}
|
||||
}
|
||||
|
||||
/// An IrcServer over a buffered NetStream.
|
||||
|
@ -72,18 +85,37 @@ impl IrcServer<BufReader<NetStream>, BufWriter<NetStream>> {
|
|||
} else {
|
||||
Connection::connect(config.server(), config.port())
|
||||
});
|
||||
Ok(IrcServer::from_connection(config, conn))
|
||||
}
|
||||
|
||||
/// Reconnects to the IRC server.
|
||||
pub fn reconnect(&self) -> Result<()> {
|
||||
self.state.conn.reconnect(self.config().server(), self.config().port())
|
||||
let (tx, rx): (Sender<Message>, Receiver<Message>) = channel();
|
||||
let state = Arc::new(ServerState {
|
||||
conn: conn,
|
||||
write_handle: Mutex::new(None),
|
||||
config: config,
|
||||
chanlists: Mutex::new(HashMap::new()),
|
||||
alt_nick_index: RwLock::new(0),
|
||||
reconnect_count: Mutex::new(0),
|
||||
});
|
||||
let weak = Arc::downgrade(&state);
|
||||
let write_handle = spawn(move || while let Ok(msg) = rx.recv() {
|
||||
if let Some(strong) = weak.upgrade() {
|
||||
while let Err(_) = IrcServer::write(&strong, msg.clone()) {
|
||||
let _ = strong.reconnect();
|
||||
}
|
||||
}
|
||||
});
|
||||
let state2 = state.clone();
|
||||
let mut handle = state2.write_handle.lock().unwrap();
|
||||
*handle = Some(write_handle);
|
||||
Ok(IrcServer { tx: tx, state: state, reconnect_count: Cell::new(0) })
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: IrcRead, U: IrcWrite> Clone for IrcServer<T, U> {
|
||||
fn clone(&self) -> IrcServer<T, U> {
|
||||
IrcServer { tx: self.tx.clone(), state: self.state.clone() }
|
||||
IrcServer {
|
||||
tx: self.tx.clone(),
|
||||
state: self.state.clone(),
|
||||
reconnect_count: self.reconnect_count.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -96,7 +128,7 @@ impl<T: IrcRead, U: IrcWrite> Drop for ServerState<T, U> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a, T: IrcRead, U: IrcWrite> Server<'a, T, U> for IrcServer<T, U> {
|
||||
impl<'a, T: IrcRead, U: IrcWrite> Server<'a, T, U> for IrcServer<T, U> where Connection<T, U>: Reconnect {
|
||||
fn config(&self) -> &Config {
|
||||
&self.state.config
|
||||
}
|
||||
|
@ -131,7 +163,7 @@ impl<'a, T: IrcRead, U: IrcWrite> Server<'a, T, U> for IrcServer<T, U> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: IrcRead, U: IrcWrite> IrcServer<T, U> {
|
||||
impl<T: IrcRead, U: IrcWrite> IrcServer<T, U> where Connection<T, U>: Reconnect {
|
||||
/// Creates an IRC server from the specified configuration, and any arbitrary Connection.
|
||||
pub fn from_connection(config: Config, conn: Connection<T, U>) -> IrcServer<T, U> {
|
||||
let (tx, rx) = channel();
|
||||
|
@ -141,17 +173,18 @@ impl<T: IrcRead, U: IrcWrite> IrcServer<T, U> {
|
|||
config: config,
|
||||
chanlists: Mutex::new(HashMap::new()),
|
||||
alt_nick_index: RwLock::new(0),
|
||||
reconnect_count: Mutex::new(0),
|
||||
});
|
||||
let weak = Arc::downgrade(&state);
|
||||
let write_handle = spawn(move || while let Ok(msg) = rx.recv() {
|
||||
if let Some(strong) = weak.upgrade() {
|
||||
IrcServer::write(&strong, msg);
|
||||
let _ = IrcServer::write(&strong, msg);
|
||||
}
|
||||
});
|
||||
let state2 = state.clone();
|
||||
let mut handle = state2.write_handle.lock().unwrap();
|
||||
*handle = Some(write_handle);
|
||||
IrcServer { tx: tx, state: state }
|
||||
IrcServer { tx: tx, state: state, reconnect_count: Cell::new(0) }
|
||||
}
|
||||
|
||||
/// Gets a reference to the IRC server's connection.
|
||||
|
@ -159,6 +192,19 @@ impl<T: IrcRead, U: IrcWrite> IrcServer<T, U> {
|
|||
&self.state.conn
|
||||
}
|
||||
|
||||
/// Reconnects to the IRC server.
|
||||
pub fn reconnect(&self) -> Result<()> {
|
||||
let mut reconnect_count = self.state.reconnect_count.lock().unwrap();
|
||||
let res = if self.reconnect_count.get() == *reconnect_count {
|
||||
*reconnect_count += 1;
|
||||
self.state.reconnect()
|
||||
} else {
|
||||
Ok(())
|
||||
};
|
||||
self.reconnect_count.set(*reconnect_count);
|
||||
res
|
||||
}
|
||||
|
||||
#[cfg(feature = "encode")]
|
||||
fn write<M: Into<Message>>(state: &Arc<ServerState<T, U>>, msg: M) -> Result<()> {
|
||||
state.conn.send(msg, state.config.encoding())
|
||||
|
@ -169,7 +215,6 @@ impl<T: IrcRead, U: IrcWrite> IrcServer<T, U> {
|
|||
state.conn.send(msg)
|
||||
}
|
||||
|
||||
|
||||
/// Handles messages internally for basic bot functionality.
|
||||
fn handle_message(&self, msg: &Message) {
|
||||
if let Some(resp) = Response::from_message(msg) {
|
||||
|
@ -316,7 +361,7 @@ pub struct ServerIterator<'a, T: IrcRead + 'a, U: IrcWrite + 'a> {
|
|||
pub type ServerCmdIterator<'a, T, U> =
|
||||
Map<ServerIterator<'a, T, U>, fn(Result<Message>) -> Result<Command>>;
|
||||
|
||||
impl<'a, T: IrcRead + 'a, U: IrcWrite + 'a> ServerIterator<'a, T, U> {
|
||||
impl<'a, T: IrcRead + 'a, U: IrcWrite + 'a> ServerIterator<'a, T, U> where Connection<T, U>: Reconnect {
|
||||
/// Creates a new ServerIterator for the desired IrcServer.
|
||||
pub fn new(server: &IrcServer<T, U>) -> ServerIterator<T, U> {
|
||||
ServerIterator { server: server }
|
||||
|
@ -325,35 +370,35 @@ impl<'a, T: IrcRead + 'a, U: IrcWrite + 'a> ServerIterator<'a, T, U> {
|
|||
/// Gets the next line from the connection.
|
||||
#[cfg(feature = "encode")]
|
||||
fn get_next_line(&self) -> Result<String> {
|
||||
self.server.state.conn.recv(self.server.config().encoding())
|
||||
self.server.conn().recv(self.server.config().encoding())
|
||||
}
|
||||
|
||||
/// Gets the next line from the connection.
|
||||
#[cfg(not(feature = "encode"))]
|
||||
fn get_next_line(&self) -> Result<String> {
|
||||
self.server.state.conn.recv()
|
||||
self.server.conn().recv()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: IrcRead + 'a, U: IrcWrite + 'a> Iterator for ServerIterator<'a, T, U> {
|
||||
impl<'a, T: IrcRead + 'a, U: IrcWrite + 'a> Iterator for ServerIterator<'a, T, U> where Connection<T, U>: Reconnect {
|
||||
type Item = Result<Message>;
|
||||
fn next(&mut self) -> Option<Result<Message>> {
|
||||
let res = self.get_next_line().and_then(|msg|
|
||||
match msg.parse() {
|
||||
Ok(msg) => {
|
||||
self.server.handle_message(&msg);
|
||||
Ok(msg)
|
||||
loop {
|
||||
match self.get_next_line() {
|
||||
Ok(msg) => match msg.parse() {
|
||||
Ok(res) => {
|
||||
self.server.handle_message(&res);
|
||||
return Some(Ok(res))
|
||||
},
|
||||
Err(_) => return Some(Err(Error::new(ErrorKind::InvalidInput,
|
||||
&format!("Failed to parse message. (Message: {})", msg)[..]
|
||||
)))
|
||||
},
|
||||
Err(_) => Err(Error::new(ErrorKind::InvalidInput,
|
||||
&format!("Failed to parse message. (Message: {})", msg)[..]
|
||||
))
|
||||
Err(ref err) if err.description() == "EOF" => return None,
|
||||
Err(_) => {
|
||||
let _ = self.server.reconnect();
|
||||
}
|
||||
}
|
||||
);
|
||||
match res {
|
||||
Err(ref err) if err.kind() == ErrorKind::ConnectionAborted => None,
|
||||
Err(ref err) if err.kind() == ErrorKind::ConnectionReset => None,
|
||||
Err(ref err) if err.description() == "EOF" => None,
|
||||
_ => Some(res)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -363,7 +408,7 @@ mod test {
|
|||
use super::{IrcServer, Server};
|
||||
use std::default::Default;
|
||||
use std::io::{Cursor, sink};
|
||||
use client::conn::Connection;
|
||||
use client::conn::{Connection, Reconnect};
|
||||
use client::data::{Config, Message, User};
|
||||
use client::data::command::Command::PRIVMSG;
|
||||
use client::data::kinds::IrcRead;
|
||||
|
@ -381,7 +426,8 @@ mod test {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn get_server_value<T: IrcRead>(server: IrcServer<T, Vec<u8>>) -> String {
|
||||
pub fn get_server_value<T: IrcRead>(server: IrcServer<T, Vec<u8>>) -> String
|
||||
where Connection<T, Vec<u8>>: Reconnect {
|
||||
let vec = server.conn().writer().clone();
|
||||
String::from_utf8(vec).unwrap()
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue