diff --git a/sqlx-data.json b/sqlx-data.json index 53c06fcc5..8194f9290 100644 --- a/sqlx-data.json +++ b/sqlx-data.json @@ -1387,6 +1387,27 @@ }, "query": "\n DELETE FROM notifications_actions\n WHERE notification_id = ANY($1)\n " }, + "281e3faffa65b51fadc93108ccc93d3d19934c8f26efb568f4794e4c6f16cefe": { + "describe": { + "columns": [ + { + "name": "thread_id", + "ordinal": 0, + "type_info": "Int8" + } + ], + "nullable": [ + true + ], + "parameters": { + "Left": [ + "Int8Array", + "Int8" + ] + } + }, + "query": "\n SELECT m.thread_id FROM mods m\n INNER JOIN team_members tm ON tm.team_id = m.team_id AND user_id = $2\n WHERE m.thread_id = ANY($1)\n " + }, "28b9d32b6d200f34e86f890ce477be0b8717f7ad92dc9cffa56eda4b12ee0df2": { "describe": { "columns": [ @@ -6670,6 +6691,27 @@ }, "query": "\n SELECT id, version_number, version_type\n FROM versions\n WHERE mod_id = $1 AND status = ANY($2)\n ORDER BY date_published ASC\n " }, + "f44572d8ef6ff10fb27a72233792f48cbf825bc58ecf1bc84dcc0aeeba3c12a0": { + "describe": { + "columns": [ + { + "name": "thread_id", + "ordinal": 0, + "type_info": "Int8" + } + ], + "nullable": [ + true + ], + "parameters": { + "Left": [ + "Int8Array", + "Int8" + ] + } + }, + "query": "\n SELECT thread_id FROM reports\n WHERE thread_id = ANY($1) AND reporter = $2\n " + }, "f453b43772c4d2d9d09dc389eb95482cc75e7f0eaf9dc7ff48cf40f22f1497cc": { "describe": { "columns": [], diff --git a/src/database/models/ids.rs b/src/database/models/ids.rs index dbc52e5d2..66e4e2c99 100644 --- a/src/database/models/ids.rs +++ b/src/database/models/ids.rs @@ -184,7 +184,7 @@ pub struct NotificationId(pub i64); #[sqlx(transparent)] pub struct NotificationActionId(pub i32); -#[derive(Copy, Clone, Debug, Type, Deserialize)] +#[derive(Copy, Clone, Debug, Type, Deserialize, Eq, PartialEq)] #[sqlx(transparent)] pub struct ThreadId(pub i64); #[derive(Copy, Clone, Debug, Type, Deserialize)] diff --git a/src/database/models/thread_item.rs b/src/database/models/thread_item.rs index aea93e3ed..d3bfaf4d4 100644 --- a/src/database/models/thread_item.rs +++ b/src/database/models/thread_item.rs @@ -9,6 +9,7 @@ pub struct ThreadBuilder { pub members: Vec, } +#[derive(Clone)] pub struct Thread { pub id: ThreadId, pub type_: ThreadType, @@ -23,7 +24,7 @@ pub struct ThreadMessageBuilder { pub show_in_mod_inbox: bool, } -#[derive(Deserialize)] +#[derive(Deserialize, Clone)] pub struct ThreadMessage { pub id: ThreadMessageId, pub thread_id: ThreadId, diff --git a/src/models/threads.rs b/src/models/threads.rs index 20943732d..2892a2cdd 100644 --- a/src/models/threads.rs +++ b/src/models/threads.rs @@ -36,6 +36,9 @@ pub struct ThreadMessage { pub enum MessageBody { Text { body: String, + #[serde(default)] + private: bool, + replying_to: Option, }, StatusChange { new_status: ProjectStatus, diff --git a/src/routes/v2/threads.rs b/src/routes/v2/threads.rs index b712d1cc2..47bf4a8a0 100644 --- a/src/routes/v2/threads.rs +++ b/src/routes/v2/threads.rs @@ -22,6 +22,7 @@ pub fn config(cfg: &mut web::ServiceConfig) { .service(thread_send_message), ); cfg.service(web::scope("message").service(message_delete)); + cfg.service(threads_get); } pub async fn is_authorized_thread( @@ -63,6 +64,193 @@ pub async fn is_authorized_thread( }) } +pub async fn filter_authorized_threads( + threads: Vec, + user: &User, + pool: &web::Data, +) -> Result, ApiError> { + let user_id: database::models::UserId = user.id.into(); + + let mut return_threads = Vec::new(); + let mut check_threads = Vec::new(); + + for thread in threads { + if user.role.is_mod() + || (thread.type_ == ThreadType::DirectMessage + && thread.members.contains(&user_id)) + { + return_threads.push(thread); + } else { + check_threads.push(thread); + } + } + + if !check_threads.is_empty() { + use futures::TryStreamExt; + + let project_thread_ids = check_threads + .iter() + .filter(|x| x.type_ == ThreadType::Project) + .map(|x| x.id.0) + .collect::>(); + + if !project_thread_ids.is_empty() { + sqlx::query!( + " + SELECT m.thread_id FROM mods m + INNER JOIN team_members tm ON tm.team_id = m.team_id AND user_id = $2 + WHERE m.thread_id = ANY($1) + ", + &*project_thread_ids, + user_id as database::models::ids::UserId, + ) + .fetch_many(&***pool) + .try_for_each(|e| { + if let Some(row) = e.right() { + check_threads.retain(|x| { + let bool = Some(x.id.0) == row.thread_id; + + if bool { + return_threads.push(x.clone()); + } + + !bool + }); + } + + futures::future::ready(Ok(())) + }) + .await?; + } + + let report_thread_ids = check_threads + .iter() + .filter(|x| x.type_ == ThreadType::Report) + .map(|x| x.id.0) + .collect::>(); + + if !report_thread_ids.is_empty() { + sqlx::query!( + " + SELECT thread_id FROM reports + WHERE thread_id = ANY($1) AND reporter = $2 + ", + &*report_thread_ids, + user_id as database::models::ids::UserId, + ) + .fetch_many(&***pool) + .try_for_each(|e| { + if let Some(row) = e.right() { + check_threads.retain(|x| { + let bool = Some(x.id.0) == row.thread_id; + + if bool { + return_threads.push(x.clone()); + } + + !bool + }); + } + + futures::future::ready(Ok(())) + }) + .await?; + } + } + + let mut user_ids = return_threads + .iter() + .flat_map(|x| x.members.clone()) + .collect::>(); + user_ids.append( + &mut return_threads + .iter() + .flat_map(|x| { + x.messages + .iter() + .filter_map(|x| x.author_id) + .collect::>() + }) + .collect::>(), + ); + + let users: Vec = + database::models::User::get_many(&user_ids, &***pool) + .await? + .into_iter() + .map(From::from) + .collect(); + + let mut final_threads = Vec::new(); + + for thread in return_threads { + let mut authors = thread.members.clone(); + + authors.append( + &mut thread + .messages + .iter() + .filter_map(|x| x.author_id) + .collect::>(), + ); + + final_threads.push(convert_thread( + thread, + users + .iter() + .filter(|x| authors.contains(&x.id.into())) + .cloned() + .collect(), + user, + )); + } + + Ok(final_threads) +} + +fn convert_thread( + data: database::models::Thread, + users: Vec, + user: &User, +) -> Thread { + let thread_type = data.type_; + + Thread { + id: data.id.into(), + type_: thread_type, + messages: data + .messages + .into_iter() + .filter(|x| { + if let MessageBody::Text { private, .. } = x.body { + !private || user.role.is_mod() + } else { + true + } + }) + .map(|x| ThreadMessage { + id: x.id.into(), + author_id: if users + .iter() + .find(|y| x.author_id == Some(y.id.into())) + .map(|x| x.role.is_mod() && !user.role.is_mod()) + .unwrap_or(false) + { + None + } else { + x.author_id.map(|x| x.into()) + }, + body: x.body, + created: x.created, + }) + .collect(), + members: users + .into_iter() + .filter(|x| !x.role.is_mod() || user.role.is_mod()) + .collect(), + } +} + #[get("{id}")] pub async fn thread_get( req: HttpRequest, @@ -75,52 +263,60 @@ pub async fn thread_get( let user = get_user_from_headers(req.headers(), &**pool).await?; - if let Some(data) = thread_data { + if let Some(mut data) = thread_data { if is_authorized_thread(&data, &user, &pool).await? { - let users: Vec = database::models::User::get_many( - &data + let authors = &mut data.members; + + authors.append( + &mut data .messages .iter() .filter_map(|x| x.author_id) .collect::>(), - &**pool, - ) - .await? - .into_iter() - .map(From::from) - .collect(); + ); - let thread_type = data.type_; - - return Ok(HttpResponse::Ok().json(Thread { - id: data.id.into(), - type_: thread_type, - messages: data - .messages + let users: Vec = + database::models::User::get_many(authors, &**pool) + .await? .into_iter() - .map(|x| ThreadMessage { - id: x.id.into(), - author_id: if users - .iter() - .find(|y| x.author_id == Some(y.id.into())) - .map(|x| x.role.is_mod()) - .unwrap_or(false) - { - None - } else { - x.author_id.map(|x| x.into()) - }, - body: x.body, - created: x.created, - }) - .collect(), - members: users, - })); + .map(From::from) + .collect(); + + return Ok( + HttpResponse::Ok().json(convert_thread(data, users, &user)) + ); } } Ok(HttpResponse::NotFound().body("")) } +#[derive(Deserialize)] +pub struct ThreadIds { + pub ids: String, +} + +#[get("threads")] +pub async fn threads_get( + req: HttpRequest, + web::Query(ids): web::Query, + pool: web::Data, +) -> Result { + let user = get_user_from_headers(req.headers(), &**pool).await?; + + let thread_ids: Vec = + serde_json::from_str::>(&ids.ids)? + .into_iter() + .map(|x| x.into()) + .collect(); + + let threads_data = + database::models::Thread::get_many(&thread_ids, &**pool).await?; + + let threads = filter_authorized_threads(threads_data, &user, &pool).await?; + + Ok(HttpResponse::Ok().json(threads)) +} + #[derive(Deserialize)] pub struct NewThreadMessage { pub body: MessageBody, @@ -135,15 +331,52 @@ pub async fn thread_send_message( ) -> Result { let user = get_user_from_headers(req.headers(), &**pool).await?; - if let MessageBody::Text { body } = &new_message.body { + let string: database::models::ThreadId = info.into_inner().0.into(); + + if let MessageBody::Text { + body, + replying_to, + private, + } = &new_message.body + { if body.len() > 65536 { return Err(ApiError::InvalidInput( "Input body is too long!".to_string(), )); } + + if *private && !user.role.is_mod() { + return Err(ApiError::InvalidInput( + "You are not allowed to send private messages!".to_string(), + )); + } + + if let Some(replying_to) = replying_to { + let thread_message = database::models::ThreadMessage::get( + (*replying_to).into(), + &**pool, + ) + .await?; + + if let Some(thread_message) = thread_message { + if thread_message.thread_id != string { + return Err(ApiError::InvalidInput( + "Message replied to is from another thread!" + .to_string(), + )); + } + } else { + return Err(ApiError::InvalidInput( + "Message replied to does not exist!".to_string(), + )); + } + } + } else { + return Err(ApiError::InvalidInput( + "You may only send text messages through this route!".to_string(), + )); } - let string: database::models::ThreadId = info.into_inner().0.into(); let result = database::models::Thread::get(string, &**pool).await?; if let Some(thread) = result { @@ -151,16 +384,6 @@ pub async fn thread_send_message( return Ok(HttpResponse::NotFound().body("")); } - match &new_message.body { - MessageBody::Text { .. } => {} - _ => { - return Err(ApiError::InvalidInput( - "You may only send text messages through this route!" - .to_string(), - )) - } - } - let mod_notif = if thread.type_ == ThreadType::Project { let status = sqlx::query!( "SELECT m.status FROM mods m WHERE thread_id = $1",