Rewrote internal message handling in more modern Rust.

This commit is contained in:
Aaron Weiss 2016-01-18 00:05:03 -05:00
parent 8a2ce65e71
commit 81518bf244

View file

@ -13,7 +13,7 @@ use std::sync::mpsc::{Receiver, Sender, TryRecvError, channel};
use std::thread::{JoinHandle, spawn}; use std::thread::{JoinHandle, spawn};
use client::conn::{Connection, NetStream, Reconnect}; use client::conn::{Connection, NetStream, Reconnect};
use client::data::{Command, Config, Message, Response, User}; use client::data::{Command, Config, Message, Response, User};
use client::data::Command::{JOIN, NICK, NICKSERV, PING, PONG, MODE}; use client::data::Command::{JOIN, NICK, NICKSERV, PART, PING, PRIVMSG, MODE};
use client::data::kinds::{IrcRead, IrcWrite}; use client::data::kinds::{IrcRead, IrcWrite};
use client::server::utils::ServerExt; use client::server::utils::ServerExt;
use time::{Duration, Timespec, Tm, now}; use time::{Duration, Timespec, Tm, now};
@ -285,156 +285,140 @@ impl<T: IrcRead, U: IrcWrite> IrcServer<T, U> where Connection<T, U>: Reconnect
} }
#[cfg(not(feature = "encode"))] #[cfg(not(feature = "encode"))]
fn write<M: Into<Message>>(state: &Arc<ServerState<T, U>>, msg: M) -> Result<()> where Self: Sized { fn write<M: Into<Message>>(state: &Arc<ServerState<T, U>>, msg: M) -> Result<()> {
state.conn.send(msg) state.conn.send(msg)
} }
/// Handles messages internally for basic bot functionality. /// Returns a reference to the server state's channel lists.
fn handle_message(&self, msg: &Message) { fn chanlists(&self) -> &Mutex<HashMap<String, Vec<User>>> {
&self.state.chanlists
}
/// Handles messages internally for basic client functionality.
fn handle_message(&self, msg: &Message) -> Result<()> {
if let Some(resp) = Response::from_message(msg) { if let Some(resp) = Response::from_message(msg) {
if resp == Response::RPL_NAMREPLY { match resp {
if cfg!(not(feature = "nochanlists")) { Response::RPL_NAMREPLY => if cfg!(not(feature = "nochanlists")) {
if let Some(users) = msg.suffix.clone() { if let Some(users) = msg.suffix.clone() {
if msg.args.len() == 3 { if msg.args.len() == 3 {
// TODO: replace with slice pattern matching when/if stable
let ref chan = msg.args[2]; let ref chan = msg.args[2];
for user in users.split(" ") { for user in users.split(" ") {
if match self.state.chanlists.lock().unwrap().get_mut(chan) { let mut chanlists = self.state.chanlists.lock().unwrap();
Some(vec) => { vec.push(User::new(user)); false }, chanlists.entry(chan.clone()).or_insert(Vec::new())
None => true, .push(User::new(user))
} {
self.state.chanlists.lock().unwrap()
.insert(chan.clone(), vec!(User::new(user)));
}
} }
} }
} }
} },
} else if resp == Response::RPL_ENDOFMOTD || resp == Response::ERR_NOMOTD { Response::RPL_ENDOFMOTD | Response::ERR_NOMOTD => { // On connection behavior.
if self.config().nick_password() != "" { if self.config().nick_password() != "" {
self.send(NICKSERV( try!(self.send(NICKSERV(
format!("IDENTIFY {}", self.config().nick_password()) format!("IDENTIFY {}", self.config().nick_password())
)).unwrap(); )))
}
if self.config().umodes() != "" {
self.send(MODE(self.config().nickname().to_owned(),
self.config().umodes().to_owned(), None)).unwrap();
}
for chan in self.config().channels().into_iter() {
self.send(JOIN(chan.to_owned(), None, None)).unwrap();
}
} else if resp == Response::ERR_NICKNAMEINUSE ||
resp == Response::ERR_ERRONEOUSNICKNAME {
let alt_nicks = self.config().get_alternate_nicknames();
let mut index = self.state.alt_nick_index.write().unwrap();
if *index >= alt_nicks.len() {
panic!("All specified nicknames were in use.")
} else {
self.send(NICK(alt_nicks[*index].to_owned())).unwrap();
*index += 1;
}
}
return
}
if &msg.command[..] == "PING" {
self.send(PONG(msg.suffix.as_ref().unwrap().to_owned(), None)).unwrap();
} else if &msg.command[..] == "PONG" {
if let Ok(data) = msg.suffix.as_ref().unwrap().parse() {
if let Some(timespec) = self.state.last_ping_data() {
if timespec.sec == data {
let mut ping_data = self.state.last_ping_data.lock().unwrap();
ping_data.take();
} }
} if self.config().umodes() != "" {
try!(self.send_mode(self.config().nickname(), self.config().umodes(), ""))
}
for chan in self.config().channels().into_iter() {
try!(self.send_join(chan))
}
},
Response::ERR_NICKNAMEINUSE | Response::ERR_ERRONEOUSNICKNAME => {
let alt_nicks = self.config().get_alternate_nicknames();
let mut index = self.state.alt_nick_index.write().unwrap();
if *index >= alt_nicks.len() {
panic!("All specified nicknames were in use.")
} else {
try!(self.send(NICK(alt_nicks[*index].to_owned())));
*index += 1;
}
},
_ => ()
} }
} else if cfg!(not(feature = "nochanlists")) && Ok(())
(&msg.command[..] == "JOIN" || &msg.command[..] == "PART") { } else if let Ok(cmd) = msg.into() {
let chan = match msg.suffix { match cmd {
Some(ref suffix) => &suffix[..], PING(data, _) => try!(self.send_pong(&data)),
None => &msg.args[0][..], JOIN(chan, _, _) => {
}; if let Some(vec) = self.chanlists().lock().unwrap().get_mut(&chan.to_owned()) {
if let Some(vec) = self.state.chanlists.lock().unwrap().get_mut(&chan.to_string()) { if let Some(src) = msg.get_source_nickname() {
if let Some(ref src) = msg.prefix { vec.push(User::new(src))
if let Some(i) = src.find('!') { }
if &msg.command[..] == "JOIN" { }
vec.push(User::new(&src[..i])); },
} else { PART(chan, _) => {
if let Some(n) = vec.iter().position(|x| x.get_nickname() == &src[..i]) { if let Some(vec) = self.chanlists().lock().unwrap().get_mut(&chan.to_owned()) {
if let Some(src) = msg.get_source_nickname() {
if let Some(n) = vec.iter().position(|x| x.get_nickname() == src) {
vec.swap_remove(n); vec.swap_remove(n);
} }
} }
} }
} },
} MODE(chan, mode, Some(user)) => if cfg!(not(feature = "nochanlists")) {
} else if let ("MODE", 3) = (&msg.command[..], msg.args.len()) { if let Some(vec) = self.chanlists().lock().unwrap().get_mut(&chan) {
let ref chan = msg.args[0]; // TODO: replace with slice pattern matching when/if stable if let Some(n) = vec.iter().position(|x| x.get_nickname() == user) {
let ref mode = msg.args[1]; vec[n].update_access_level(&mode)
let ref user = msg.args[2]; }
if cfg!(not(feature = "nochanlists")) {
if let Some(vec) = self.state.chanlists.lock().unwrap().get_mut(chan) {
if let Some(n) = vec.iter().position(|x| &x.get_nickname() == user) {
vec[n].update_access_level(&mode);
} }
} },
PRIVMSG(target, body) => if body.starts_with("\u{001}") {
let tokens: Vec<_> = {
let end = if body.ends_with("\u{001}") {
body.len() - 1
} else {
body.len()
};
body[1..end].split(" ").collect()
};
if target.starts_with("#") {
try!(self.handle_ctcp(&target, tokens))
} else if let Some(user) = msg.get_source_nickname() {
try!(self.handle_ctcp(user, tokens))
}
},
_ => ()
} }
Ok(())
} else { } else {
self.handle_ctcp(msg); Ok(())
} }
} }
/// Handles CTCP requests if the CTCP feature is enabled. /// Handles CTCP requests if the CTCP feature is enabled.
#[cfg(feature = "ctcp")] #[cfg(feature = "ctcp")]
fn handle_ctcp(&self, msg: &Message) { fn handle_ctcp(&self, resp: &str, tokens: Vec<&str>) -> Result<()> {
let source = match msg.prefix { match tokens[0] {
Some(ref source) => source.find('!').map_or(&source[..], |i| &source[..i]), "FINGER" => self.send_ctcp_internal(resp, &format!(
None => "", "FINGER :{} ({})", self.config().real_name(), self.config().username()
}; )),
if let ("PRIVMSG", 1) = (&msg.command[..], msg.args.len()) { "VERSION" => self.send_ctcp_internal(resp, "VERSION irc:git:Rust"),
// TODO: replace with slice pattern matching when/if stable "SOURCE" => {
let ref target = msg.args[0]; try!(self.send_ctcp_internal(resp, "SOURCE https://github.com/aatxe/irc"));
let resp = if target.starts_with("#") { &target[..] } else { source }; self.send_ctcp_internal(resp, "SOURCE")
match msg.suffix { },
Some(ref msg) if msg.starts_with("\u{001}") => { "PING" => self.send_ctcp_internal(resp, &format!("PING {}", tokens[1])),
let tokens: Vec<_> = { "TIME" => self.send_ctcp_internal(resp, &format!(
let end = if msg.ends_with("\u{001}") { "TIME :{}", now().rfc822z()
msg.len() - 1 )),
} else { "USERINFO" => self.send_ctcp_internal(resp, &format!(
msg.len() "USERINFO :{}", self.config().user_info()
}; )),
msg[1..end].split(" ").collect() _ => Ok(())
};
match tokens[0] {
"FINGER" => self.send_ctcp_internal(resp, &format!(
"FINGER :{} ({})", self.config().real_name(), self.config().username()
)),
"VERSION" => self.send_ctcp_internal(resp, "VERSION irc:git:Rust"),
"SOURCE" => {
self.send_ctcp_internal(resp, "SOURCE https://github.com/aatxe/irc");
self.send_ctcp_internal(resp, "SOURCE");
},
"PING" => self.send_ctcp_internal(resp, &format!("PING {}", tokens[1])),
"TIME" => self.send_ctcp_internal(resp, &format!(
"TIME :{}", now().rfc822z()
)),
"USERINFO" => self.send_ctcp_internal(resp, &format!(
"USERINFO :{}", self.config().user_info()
)),
_ => {}
}
},
_ => {}
}
} }
} }
/// Sends a CTCP-escaped message. /// Sends a CTCP-escaped message.
#[cfg(feature = "ctcp")] #[cfg(feature = "ctcp")]
fn send_ctcp_internal(&self, target: &str, msg: &str) { fn send_ctcp_internal(&self, target: &str, msg: &str) -> Result<()> {
self.send(Command::NOTICE(target.to_owned(), format!("\u{001}{}\u{001}", msg))).unwrap(); self.send_notice(target, &format!("\u{001}{}\u{001}", msg))
} }
/// Handles CTCP requests if the CTCP feature is enabled. /// Handles CTCP requests if the CTCP feature is enabled.
#[cfg(not(feature = "ctcp"))] fn handle_ctcp(&self, _: &Message) {} #[cfg(not(feature = "ctcp"))]
fn handle_ctcp(&self, _: &str, _: Vec<&str>) -> Result<()> {
Ok(())
}
} }
impl<T: IrcRead, U: IrcWrite + Clone> IrcServer<T, U> where Connection<T, U>: Reconnect { impl<T: IrcRead, U: IrcWrite + Clone> IrcServer<T, U> where Connection<T, U>: Reconnect {
@ -491,7 +475,10 @@ impl<'a, T: IrcRead + 'a, U: IrcWrite + 'a> Iterator for ServerIterator<'a, T, U
match self.get_next_line() { match self.get_next_line() {
Ok(msg) => match msg.parse() { Ok(msg) => match msg.parse() {
Ok(res) => { Ok(res) => {
self.server.handle_message(&res); match self.server.handle_message(&res) {
Ok(()) => (),
Err(err) => return Some(Err(err))
}
self.server.state.action_taken(); self.server.state.action_taken();
return Some(Ok(res)) return Some(Ok(res))
}, },