diff --git a/examples/repeater.rs b/examples/repeater.rs index ea5b7a7..322d263 100644 --- a/examples/repeater.rs +++ b/examples/repeater.rs @@ -21,7 +21,7 @@ fn main() { client.for_each_incoming(|message| { print!("{}", message); if let Command::PRIVMSG(ref target, ref msg) = message.command { - if msg.starts_with(client.current_nickname()) { + if msg.starts_with(&*client.current_nickname()) { let tokens: Vec<_> = msg.split(' ').collect(); if tokens.len() > 2 { let n = tokens[0].len() + tokens[1].len() + 2; diff --git a/src/client/mod.rs b/src/client/mod.rs index 0a05f56..1e863d9 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -37,7 +37,7 @@ //! # client.identify().unwrap(); //! client.for_each_incoming(|irc_msg| { //! if let Command::PRIVMSG(channel, message) = irc_msg.command { -//! if message.contains(client.current_nickname()) { +//! if message.contains(&*client.current_nickname()) { //! client.send_privmsg(&channel, "beep boop").unwrap(); //! } //! } @@ -49,7 +49,7 @@ use std::ascii::AsciiExt; use std::collections::HashMap; use std::path::Path; -use std::sync::{Arc, Mutex, RwLock}; +use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard}; use std::thread; #[cfg(feature = "ctcp")] @@ -94,7 +94,7 @@ pub mod transport; /// # client.identify().unwrap(); /// client.stream().for_each_incoming(|irc_msg| { /// match irc_msg.command { -/// Command::PRIVMSG(channel, message) => if message.contains(client.current_nickname()) { +/// Command::PRIVMSG(channel, message) => if message.contains(&*client.current_nickname()) { /// client.send_privmsg(&channel, "beep boop").unwrap(); /// } /// _ => () @@ -159,7 +159,7 @@ pub trait Client { /// # client.identify().unwrap(); /// client.for_each_incoming(|irc_msg| { /// if let Command::PRIVMSG(channel, message) = irc_msg.command { - /// if message.contains(client.current_nickname()) { + /// if message.contains(&*client.current_nickname()) { /// client.send_privmsg(&channel, "beep boop").unwrap(); /// } /// } @@ -231,6 +231,9 @@ struct ClientState { chanlists: Mutex>>, /// A thread-safe index to track the current alternative nickname being used. alt_nick_index: RwLock, + /// The current nickname in use by this client, which may differ from the one implied by + /// `alt_nick_index`. This can be the case if, for example, a new `NICK` command is sent. + current_nickname: RwLock, /// A thread-safe internal IRC stream used for the reading API. incoming: Mutex>>, /// A thread-safe copy of the outgoing channel. @@ -289,26 +292,20 @@ impl ClientState { incoming: SplitStream, outgoing: UnboundedSender, config: Config, - ) -> ClientState { - ClientState { - config: config, + ) -> error::Result { + Ok(ClientState { chanlists: Mutex::new(HashMap::new()), alt_nick_index: RwLock::new(0), + current_nickname: RwLock::new(config.nickname()?.to_owned()), incoming: Mutex::new(Some(incoming)), outgoing: outgoing, - } + config: config, + }) } /// Gets the current nickname in use. - fn current_nickname(&self) -> &str { - let alt_nicks = self.config().alternate_nicknames(); - let index = self.alt_nick_index.read().unwrap(); - match *index { - 0 => self.config().nickname().expect( - "current_nickname should not be callable if nickname is not defined." - ), - i => alt_nicks[i - 1], - } + fn current_nickname(&self) -> RwLockReadGuard { + self.current_nickname.read().unwrap() } /// Handles sent messages internally for basic client functionality. @@ -332,6 +329,7 @@ impl ClientState { KICK(ref chan, ref user, _) => self.handle_part(user, chan), QUIT(_) => self.handle_quit(msg.source_nickname().unwrap_or("")), NICK(ref new_nick) => { + self.handle_current_nick_change(msg.source_nickname().unwrap_or(""), new_nick); self.handle_nick_change(msg.source_nickname().unwrap_or(""), new_nick) } ChannelMODE(ref chan, ref modes) => self.handle_mode(chan, modes), @@ -475,6 +473,14 @@ impl ClientState { } } + fn handle_current_nick_change(&self, old_nick: &str, new_nick: &str) { + if old_nick.is_empty() || new_nick.is_empty() || old_nick != &*self.current_nickname() { + return; + } + let mut nick = self.current_nickname.write().unwrap(); + *nick = new_nick.to_owned(); + } + #[cfg(feature = "nochanlists")] fn handle_nick_change(&self, _: &str, _: &str) {} @@ -715,7 +721,7 @@ impl IrcClient { }); Ok(IrcClient { - state: Arc::new(ClientState::new(rx_incoming.wait()?, tx_outgoing, config)), + state: Arc::new(ClientState::new(rx_incoming.wait()?, tx_outgoing, config)?), view: rx_view.wait()?, }) } @@ -774,7 +780,7 @@ impl IrcClient { /// Gets the current nickname in use. This may be the primary username set in the configuration, /// or it could be any of the alternative nicknames listed as well. As a result, this is the /// preferred way to refer to the client's nickname. - pub fn current_nickname(&self) -> &str { + pub fn current_nickname(&self) -> RwLockReadGuard { self.state.current_nickname() } @@ -820,7 +826,7 @@ impl<'a> Future for IrcClientFuture<'a> { let server = IrcClient { state: Arc::new(ClientState::new( stream, self.tx_outgoing.take().unwrap(), self.config.clone() - )), + )?), view: view, }; Ok(Async::Ready(PackedIrcClient(server, Box::new(outgoing_future)))) @@ -1048,6 +1054,22 @@ mod test { } } + #[test] + fn current_nickname_tracking() { + let value = ":test!test@test NICK :t3st\r\n\ + :t3st!test@test NICK :t35t\r\n"; + let client = IrcClient::from_config(Config { + mock_initial_value: Some(value.to_owned()), + ..test_config() + }).unwrap(); + + assert_eq!(&*client.current_nickname(), "test"); + client.for_each_incoming(|message| { + println!("{:?}", message); + }).unwrap(); + assert_eq!(&*client.current_nickname(), "t35t"); + } + #[test] fn send() { let client = IrcClient::from_config(test_config()).unwrap(); diff --git a/src/lib.rs b/src/lib.rs index ec1c275..27c66f0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,7 +29,7 @@ //! client.for_each_incoming(|irc_msg| { //! // irc_msg is a Message //! if let Command::PRIVMSG(channel, message) = irc_msg.command { -//! if message.contains(client.current_nickname()) { +//! if message.contains(&*client.current_nickname()) { //! // send_privmsg comes from ClientExt //! client.send_privmsg(&channel, "beep boop").unwrap(); //! }