Moved Arc abstraction internally for IrcServer.

This commit is contained in:
Aaron Weiss 2016-01-03 05:35:00 -05:00
parent 8104894c28
commit a6cd761e51

View file

@ -6,7 +6,7 @@ use std::collections::HashMap;
use std::error::Error as StdError; use std::error::Error as StdError;
use std::io::{BufReader, BufWriter, Error, ErrorKind, Result}; use std::io::{BufReader, BufWriter, Error, ErrorKind, Result};
use std::path::Path; use std::path::Path;
use std::sync::{Mutex, RwLock}; use std::sync::{Arc, Mutex, RwLock};
use std::iter::Map; use std::iter::Map;
use client::conn::{Connection, NetStream}; use client::conn::{Connection, NetStream};
use client::data::{Command, Config, Message, Response, User}; use client::data::{Command, Config, Message, Response, User};
@ -34,6 +34,10 @@ pub trait Server<'a, T: IrcRead, U: IrcWrite> {
/// A thread-safe implementation of an IRC Server connection. /// A thread-safe implementation of an IRC Server connection.
pub struct IrcServer<T: IrcRead, U: IrcWrite> { pub struct IrcServer<T: IrcRead, U: IrcWrite> {
state: Arc<ServerState<T, U>>,
}
struct ServerState<T: IrcRead, U: IrcWrite> {
/// The thread-safe IRC connection. /// The thread-safe IRC connection.
conn: Connection<T, U>, conn: Connection<T, U>,
/// The configuration used with this connection. /// The configuration used with this connection.
@ -62,29 +66,34 @@ impl IrcServer<BufReader<NetStream>, BufWriter<NetStream>> {
} else { } else {
Connection::connect(config.server(), config.port()) Connection::connect(config.server(), config.port())
}); });
Ok(IrcServer { config: config, conn: conn, chanlists: Mutex::new(HashMap::new()), let state = ServerState {
alt_nick_index: RwLock::new(0) }) config: config,
conn: conn,
chanlists: Mutex::new(HashMap::new()),
alt_nick_index: RwLock::new(0),
};
Ok(IrcServer { state: Arc::new(state) })
} }
/// Reconnects to the IRC server. /// Reconnects to the IRC server.
pub fn reconnect(&self) -> Result<()> { pub fn reconnect(&self) -> Result<()> {
self.conn.reconnect(self.config().server(), self.config.port()) self.state.conn.reconnect(self.config().server(), self.config().port())
} }
} }
impl<'a, T: IrcRead, U: IrcWrite> Server<'a, T, U> for IrcServer<T, U> { impl<'a, T: IrcRead, U: IrcWrite> Server<'a, T, U> for IrcServer<T, U> {
fn config(&self) -> &Config { fn config(&self) -> &Config {
&self.config &self.state.config
} }
#[cfg(feature = "encode")] #[cfg(feature = "encode")]
fn send<M: Into<Message>>(&self, msg: M) -> Result<()> { fn send<M: Into<Message>>(&self, msg: M) -> Result<()> {
self.conn.send(msg, self.config.encoding()) self.state.conn.send(msg, self.config().encoding())
} }
#[cfg(not(feature = "encode"))] #[cfg(not(feature = "encode"))]
fn send<M: Into<Message>>(&self, msg: M) -> Result<()> where Self: Sized { fn send<M: Into<Message>>(&self, msg: M) -> Result<()> where Self: Sized {
self.conn.send(msg) self.state.conn.send(msg)
} }
fn iter(&'a self) -> ServerIterator<'a, T, U> { fn iter(&'a self) -> ServerIterator<'a, T, U> {
@ -97,7 +106,7 @@ impl<'a, T: IrcRead, U: IrcWrite> Server<'a, T, U> for IrcServer<T, U> {
#[cfg(not(feature = "nochanlists"))] #[cfg(not(feature = "nochanlists"))]
fn list_users(&self, chan: &str) -> Option<Vec<User>> { fn list_users(&self, chan: &str) -> Option<Vec<User>> {
self.chanlists.lock().unwrap().get(&chan.to_owned()).cloned() self.state.chanlists.lock().unwrap().get(&chan.to_owned()).cloned()
} }
@ -110,13 +119,18 @@ impl<'a, T: IrcRead, U: IrcWrite> Server<'a, T, U> for IrcServer<T, U> {
impl<T: IrcRead, U: IrcWrite> IrcServer<T, U> { impl<T: IrcRead, U: IrcWrite> IrcServer<T, U> {
/// Creates an IRC server from the specified configuration, and any arbitrary Connection. /// Creates an IRC server from the specified configuration, and any arbitrary Connection.
pub fn from_connection(config: Config, conn: Connection<T, U>) -> IrcServer<T, U> { pub fn from_connection(config: Config, conn: Connection<T, U>) -> IrcServer<T, U> {
IrcServer { conn: conn, config: config, chanlists: Mutex::new(HashMap::new()), let state = ServerState {
alt_nick_index: RwLock::new(0) } conn: conn,
config: config,
chanlists: Mutex::new(HashMap::new()),
alt_nick_index: RwLock::new(0),
};
IrcServer { state: Arc::new(state) }
} }
/// Gets a reference to the IRC server's connection. /// Gets a reference to the IRC server's connection.
pub fn conn(&self) -> &Connection<T, U> { pub fn conn(&self) -> &Connection<T, U> {
&self.conn &self.state.conn
} }
/// Handles messages internally for basic bot functionality. /// Handles messages internally for basic bot functionality.
@ -129,34 +143,34 @@ impl<T: IrcRead, U: IrcWrite> IrcServer<T, U> {
// TODO: replace with slice pattern matching when/if stable // TODO: replace with slice pattern matching when/if stable
let ref chan = msg.args[2]; let ref chan = msg.args[2];
for user in users.split(" ") { for user in users.split(" ") {
if match self.chanlists.lock().unwrap().get_mut(chan) { if match self.state.chanlists.lock().unwrap().get_mut(chan) {
Some(vec) => { vec.push(User::new(user)); false }, Some(vec) => { vec.push(User::new(user)); false },
None => true, None => true,
} { } {
self.chanlists.lock().unwrap().insert(chan.clone(), self.state.chanlists.lock().unwrap()
vec!(User::new(user))); .insert(chan.clone(), vec!(User::new(user)));
} }
} }
} }
} }
} }
} else if resp == Response::RPL_ENDOFMOTD || resp == Response::ERR_NOMOTD { } else if resp == Response::RPL_ENDOFMOTD || resp == Response::ERR_NOMOTD {
if self.config.nick_password() != "" { if self.config().nick_password() != "" {
self.send(NICKSERV( self.send(NICKSERV(
format!("IDENTIFY {}", self.config.nick_password()) format!("IDENTIFY {}", self.config().nick_password())
)).unwrap(); )).unwrap();
} }
if self.config.umodes() != "" { if self.config().umodes() != "" {
self.send(MODE(self.config.nickname().to_owned(), self.send(MODE(self.config().nickname().to_owned(),
self.config.umodes().to_owned(), None)).unwrap(); self.config().umodes().to_owned(), None)).unwrap();
} }
for chan in self.config.channels().into_iter() { for chan in self.config().channels().into_iter() {
self.send(JOIN(chan.to_owned(), None, None)).unwrap(); self.send(JOIN(chan.to_owned(), None, None)).unwrap();
} }
} else if resp == Response::ERR_NICKNAMEINUSE || } else if resp == Response::ERR_NICKNAMEINUSE ||
resp == Response::ERR_ERRONEOUSNICKNAME { resp == Response::ERR_ERRONEOUSNICKNAME {
let alt_nicks = self.config.get_alternate_nicknames(); let alt_nicks = self.config().get_alternate_nicknames();
let mut index = self.alt_nick_index.write().unwrap(); let mut index = self.state.alt_nick_index.write().unwrap();
if *index >= alt_nicks.len() { if *index >= alt_nicks.len() {
panic!("All specified nicknames were in use.") panic!("All specified nicknames were in use.")
} else { } else {
@ -174,7 +188,7 @@ impl<T: IrcRead, U: IrcWrite> IrcServer<T, U> {
Some(ref suffix) => &suffix[..], Some(ref suffix) => &suffix[..],
None => &msg.args[0][..], None => &msg.args[0][..],
}; };
if let Some(vec) = self.chanlists.lock().unwrap().get_mut(&chan.to_string()) { if let Some(vec) = self.state.chanlists.lock().unwrap().get_mut(&chan.to_string()) {
if let Some(ref src) = msg.prefix { if let Some(ref src) = msg.prefix {
if let Some(i) = src.find('!') { if let Some(i) = src.find('!') {
if &msg.command[..] == "JOIN" { if &msg.command[..] == "JOIN" {
@ -192,7 +206,7 @@ impl<T: IrcRead, U: IrcWrite> IrcServer<T, U> {
let ref mode = msg.args[1]; let ref mode = msg.args[1];
let ref user = msg.args[2]; let ref user = msg.args[2];
if cfg!(not(feature = "nochanlists")) { if cfg!(not(feature = "nochanlists")) {
if let Some(vec) = self.chanlists.lock().unwrap().get_mut(chan) { if let Some(vec) = self.state.chanlists.lock().unwrap().get_mut(chan) {
if let Some(n) = vec.iter().position(|x| &x.get_nickname() == user) { if let Some(n) = vec.iter().position(|x| &x.get_nickname() == user) {
vec[n].update_access_level(&mode); vec[n].update_access_level(&mode);
} }
@ -226,8 +240,8 @@ impl<T: IrcRead, U: IrcWrite> IrcServer<T, U> {
}; };
match tokens[0] { match tokens[0] {
"FINGER" => self.send_ctcp_internal(resp, &format!("FINGER :{} ({})", "FINGER" => self.send_ctcp_internal(resp, &format!("FINGER :{} ({})",
self.config.real_name(), self.config().real_name(),
self.config.username())), self.config().username())),
"VERSION" => self.send_ctcp_internal(resp, "VERSION irc:git:Rust"), "VERSION" => self.send_ctcp_internal(resp, "VERSION irc:git:Rust"),
"SOURCE" => { "SOURCE" => {
self.send_ctcp_internal(resp, "SOURCE https://github.com/aatxe/irc"); self.send_ctcp_internal(resp, "SOURCE https://github.com/aatxe/irc");
@ -237,7 +251,7 @@ impl<T: IrcRead, U: IrcWrite> IrcServer<T, U> {
"TIME" => self.send_ctcp_internal(resp, &format!("TIME :{}", "TIME" => self.send_ctcp_internal(resp, &format!("TIME :{}",
now().rfc822z())), now().rfc822z())),
"USERINFO" => self.send_ctcp_internal(resp, &format!("USERINFO :{}", "USERINFO" => self.send_ctcp_internal(resp, &format!("USERINFO :{}",
self.config.user_info())), self.config().user_info())),
_ => {} _ => {}
} }
}, },
@ -274,13 +288,13 @@ impl<'a, T: IrcRead + 'a, U: IrcWrite + 'a> ServerIterator<'a, T, U> {
/// Gets the next line from the connection. /// Gets the next line from the connection.
#[cfg(feature = "encode")] #[cfg(feature = "encode")]
fn get_next_line(&self) -> Result<String> { fn get_next_line(&self) -> Result<String> {
self.server.conn.recv(self.server.config.encoding()) self.server.state.conn.recv(self.server.config().encoding())
} }
/// Gets the next line from the connection. /// Gets the next line from the connection.
#[cfg(not(feature = "encode"))] #[cfg(not(feature = "encode"))]
fn get_next_line(&self) -> Result<String> { fn get_next_line(&self) -> Result<String> {
self.server.conn.recv() self.server.state.conn.recv()
} }
} }