cleanner admin auth

This commit is contained in:
catvayor 2024-06-17 09:11:01 +02:00
parent 0c1f29ac5c
commit 6dd6e46aab
3 changed files with 54 additions and 55 deletions

View file

@ -1,4 +1,6 @@
use rocket::{ use rocket::{
http::Status,
request::{self, FromRequest, Outcome, Request},
response::stream::{Event, EventStream}, response::stream::{Event, EventStream},
serde::json::Json, serde::json::Json,
tokio::{ tokio::{
@ -11,58 +13,62 @@ use rocket_dyn_templates::{context, Template};
use crate::global::*; use crate::global::*;
#[get("/?<tok>&<dbg>")] struct AdminAuth(String);
fn admin_page(
tok: Option<AdminKey>, #[rocket::async_trait]
dbg: Option<bool>, impl<'r> FromRequest<'r> for AdminAuth {
admin_key: &State<AdminKey>, type Error = ();
) -> Option<Template> {
if tok == Some(admin_key.to_string()) { async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
Some(Template::render( if let Some(query) = req.uri().query() {
"admin", for (key, value) in query.segments() {
context! { tok: tok.unwrap(), dbg: dbg.unwrap_or(false) }, if key == "tok" {
)) return if value == req.rocket().state::<Config>().unwrap().admin_token {
Outcome::Success(AdminAuth(value.to_string()))
} else { } else {
None //498 Token expired/invalid
Outcome::Error((Status::new(498), ()))
};
}
}
}
Outcome::Error((Status::Unauthorized, ()))
} }
} }
#[patch("/<id>?<tok>", data = "<nstate>")] #[get("/?<dbg>")]
fn admin_page(auth: AdminAuth, dbg: Option<bool>) -> Template {
Template::render("admin", context! { tok: auth.0, dbg: dbg.unwrap_or(false) })
}
#[patch("/<id>", data = "<nstate>")]
fn admin_set_state( fn admin_set_state(
tok: Option<AdminKey>, _auth: AdminAuth,
id: &str, id: &str,
admin_key: &State<AdminKey>,
nstate: Json<TrackedState>, nstate: Json<TrackedState>,
tracking: &State<Tracking>, tracking: &State<Tracking>,
evt_queue: &State<TrackingEventQueue>, evt_queue: &State<TrackingEventQueue>,
admin_queue: &State<AdminEventQueue>, admin_queue: &State<AdminEventQueue>,
) -> Option<()> { ) {
if tok == Some(admin_key.to_string()) { let tracked = tracking.get(&id.to_string()).unwrap(); //TODO: clean error
let tracked = tracking.get(&id.to_string()).unwrap();
tracked.write().unwrap().state = nstate.into_inner(); tracked.write().unwrap().state = nstate.into_inner();
state_update(&tracked.read().unwrap(), &evt_queue, &admin_queue); state_update(&tracked.read().unwrap(), &evt_queue, &admin_queue);
Some(())
} else {
None
}
} }
#[get("/events?<tok>")] #[get("/events")]
fn admin_events<'a>( fn admin_events<'a>(
tok: Option<AdminKey>, _auth: AdminAuth,
admin_key: &State<AdminKey>,
admin_queue: &'a State<AdminEventQueue>, admin_queue: &'a State<AdminEventQueue>,
tracking: &State<Tracking>, tracking: &State<Tracking>,
config: &State<Config>, config: &State<Config>,
mut shutdown: Shutdown, mut shutdown: Shutdown,
) -> Option<EventStream![Event + 'a]> { ) -> EventStream![Event + 'a] {
if tok == Some(admin_key.to_string()) {
let full_info: Vec<AdminTrackedInfo> = tracking let full_info: Vec<AdminTrackedInfo> = tracking
.iter() .iter()
.map(|(_, tracked)| admin_view(&tracked.read().unwrap())) .map(|(_, tracked)| admin_view(&tracked.read().unwrap()))
.collect(); .collect();
let timeout = Duration::from_millis(config.event_timeout); let timeout = Duration::from_millis(config.event_timeout);
Some(EventStream! { EventStream! {
yield Event::json(&full_info).event("full_update"); yield Event::json(&full_info).event("full_update");
let mut interval = time::interval(timeout); let mut interval = time::interval(timeout);
loop { loop {
@ -76,9 +82,6 @@ fn admin_events<'a>(
_ = &mut shutdown => break _ = &mut shutdown => break
} }
} }
})
} else {
None
} }
} }

View file

@ -171,7 +171,6 @@ impl From<QueuedEvent> for Event {
pub type Tracking = Arc<HashMap<String, RwLock<Tracked>>>; pub type Tracking = Arc<HashMap<String, RwLock<Tracked>>>;
pub type TrackingEventQueue = Arc<HashMap<String, RwLock<VecDeque<QueuedEvent>>>>; pub type TrackingEventQueue = Arc<HashMap<String, RwLock<VecDeque<QueuedEvent>>>>;
pub type AdminEventQueue = Arc<RwLock<VecDeque<QueuedEvent>>>; pub type AdminEventQueue = Arc<RwLock<VecDeque<QueuedEvent>>>;
pub type AdminKey = String;
#[derive(Serialize)] #[derive(Serialize)]
#[serde(crate = "rocket::serde")] #[serde(crate = "rocket::serde")]

View file

@ -98,15 +98,12 @@ async fn rocket() -> _ {
.collect(), .collect(),
); );
let admin_evt_queue: AdminEventQueue = Arc::new(RwLock::new(VecDeque::new())); let admin_evt_queue: AdminEventQueue = Arc::new(RwLock::new(VecDeque::new()));
let key: AdminKey = config.admin_token.clone();
println!("Admin token: {}", key);
let mut rocket = rocket let mut rocket = rocket
.attach(Template::fairing()) .attach(Template::fairing())
.attach(AdHoc::config::<Config>()) .attach(AdHoc::config::<Config>())
.manage(tracking.clone()) .manage(tracking.clone())
.manage(evt_queue.clone()) .manage(evt_queue.clone())
.manage(admin_evt_queue.clone()) .manage(admin_evt_queue.clone())
.manage(key)
.mount("/", routes![index]) .mount("/", routes![index])
.mount("/track", track::routes()) .mount("/track", track::routes())
.mount("/admin", admin::routes()); .mount("/admin", admin::routes());