2FA + Add/Remove Auth Providers (#652)

* 2FA + Add/Remove Auth Providers

* fix fmt issue
This commit is contained in:
Geometrically
2023-07-11 19:13:07 -07:00
committed by GitHub
parent 7fbb8838e7
commit 4bdf9bff3a
20 changed files with 1483 additions and 667 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -38,6 +38,8 @@ pub enum AuthenticationError {
InvalidAuthMethod,
#[error("GitHub Token from incorrect Client ID")]
InvalidClientId,
#[error("User email/account is already registered on Modrinth")]
DuplicateUser,
#[error("Invalid callback URL specified")]
Url,
}
@@ -56,6 +58,7 @@ impl actix_web::ResponseError for AuthenticationError {
AuthenticationError::InvalidClientId => StatusCode::UNAUTHORIZED,
AuthenticationError::Url => StatusCode::BAD_REQUEST,
AuthenticationError::FileHosting(..) => StatusCode::INTERNAL_SERVER_ERROR,
AuthenticationError::DuplicateUser => StatusCode::BAD_REQUEST,
}
}
@@ -73,6 +76,7 @@ impl actix_web::ResponseError for AuthenticationError {
AuthenticationError::InvalidClientId => "invalid_client_id",
AuthenticationError::Url => "url_error",
AuthenticationError::FileHosting(..) => "file_hosting",
AuthenticationError::DuplicateUser => "duplicate_user",
},
description: &self.to_string(),
})

View File

@@ -19,34 +19,38 @@ pub async fn get_user_from_headers<'a, E>(
where
E: sqlx::Executor<'a, Database = sqlx::Postgres> + Copy,
{
let headers = req.headers();
let token: Option<&HeaderValue> = headers.get(AUTHORIZATION);
// Fetch DB user record and minos user from headers
let (scopes, db_user) = get_user_record_from_bearer_token(
req,
token
.ok_or_else(|| AuthenticationError::InvalidAuthMethod)?
.to_str()
.map_err(|_| AuthenticationError::InvalidCredentials)?,
executor,
redis,
session_queue,
)
.await?
.ok_or_else(|| AuthenticationError::InvalidCredentials)?;
let (scopes, db_user) =
get_user_record_from_bearer_token(req, None, executor, redis, session_queue)
.await?
.ok_or_else(|| AuthenticationError::InvalidCredentials)?;
let mut auth_providers = Vec::new();
if db_user.github_id.is_some() {
auth_providers.push(AuthProvider::GitHub)
}
if db_user.gitlab_id.is_some() {
auth_providers.push(AuthProvider::GitLab)
}
if db_user.discord_id.is_some() {
auth_providers.push(AuthProvider::Discord)
}
if db_user.google_id.is_some() {
auth_providers.push(AuthProvider::Google)
}
if db_user.microsoft_id.is_some() {
auth_providers.push(AuthProvider::Microsoft)
}
if db_user.steam_id.is_some() {
auth_providers.push(AuthProvider::Steam)
}
let user = User {
id: UserId::from(db_user.id),
github_id: db_user.github_id.map(|x| x as u64),
// discord_id: minos_user.discord_id,
// google_id: minos_user.google_id,
// microsoft_id: minos_user.microsoft_id,
// apple_id: minos_user.apple_id,
// gitlab_id: minos_user.gitlab_id,
username: db_user.username,
name: db_user.name,
email: db_user.email,
email_verified: Some(db_user.email_verified),
avatar_url: db_user.avatar_url,
bio: db_user.bio,
created: db_user.created,
@@ -58,6 +62,10 @@ where
payout_wallet_type: db_user.payout_wallet_type,
payout_address: db_user.payout_address,
}),
auth_providers: Some(auth_providers),
has_password: Some(db_user.password.is_some()),
has_totp: Some(db_user.totp_secret.is_some()),
github_id: None,
};
if let Some(required_scopes) = required_scopes {
@@ -73,7 +81,7 @@ where
pub async fn get_user_record_from_bearer_token<'a, 'b, E>(
req: &HttpRequest,
token: &str,
token: Option<&str>,
executor: E,
redis: &deadpool_redis::Pool,
session_queue: &AuthQueue,
@@ -81,6 +89,17 @@ pub async fn get_user_record_from_bearer_token<'a, 'b, E>(
where
E: sqlx::Executor<'a, Database = sqlx::Postgres> + Copy,
{
let token = if let Some(token) = token {
token
} else {
let headers = req.headers();
let token_val: Option<&HeaderValue> = headers.get(AUTHORIZATION);
token_val
.ok_or_else(|| AuthenticationError::InvalidAuthMethod)?
.to_str()
.map_err(|_| AuthenticationError::InvalidCredentials)?
};
let possible_user = match token.split_once('_') {
Some(("mrp", _)) => {
let pat =

View File

@@ -0,0 +1,88 @@
use super::ids::*;
use crate::auth::flows::AuthProvider;
use crate::database::models::DatabaseError;
use chrono::{DateTime, Timelike, Utc};
use rand::distributions::Alphanumeric;
use rand::Rng;
use rand_chacha::rand_core::SeedableRng;
use rand_chacha::ChaCha20Rng;
use redis::cmd;
use serde::{Deserialize, Serialize};
const FLOWS_NAMESPACE: &str = "flows";
#[derive(Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Flow {
OAuth {
user_id: Option<UserId>,
url: String,
provider: AuthProvider,
},
Login2FA {
user_id: UserId,
},
Initialize2FA {
user_id: UserId,
secret: String,
},
ForgotPassword {
user_id: UserId,
},
ConfirmEmail {
user_id: UserId,
confirm_email: String,
},
}
impl Flow {
pub async fn insert(
&self,
expires: DateTime<Utc>,
redis: &deadpool_redis::Pool,
) -> Result<String, DatabaseError> {
let mut redis = redis.get().await?;
let flow = ChaCha20Rng::from_entropy()
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
.collect::<String>();
cmd("SET")
.arg(format!("{}:{}", FLOWS_NAMESPACE, flow))
.arg(serde_json::to_string(&self)?)
.arg("EX")
.arg(expires.second())
.query_async::<_, ()>(&mut redis)
.await?;
Ok(flow)
}
pub async fn get(
id: &str,
redis: &deadpool_redis::Pool,
) -> Result<Option<Flow>, DatabaseError> {
let mut redis = redis.get().await?;
let res = cmd("GET")
.arg(format!("{}:{}", FLOWS_NAMESPACE, id))
.query_async::<_, Option<String>>(&mut redis)
.await?;
Ok(res.and_then(|x| serde_json::from_str(&x).ok()))
}
pub async fn remove(
id: &str,
redis: &deadpool_redis::Pool,
) -> Result<Option<()>, DatabaseError> {
let mut redis = redis.get().await?;
let mut cmd = cmd("DEL");
cmd.arg(format!("{}:{}", FLOWS_NAMESPACE, id));
cmd.query_async::<_, ()>(&mut redis).await?;
Ok(Some(()))
}
}

View File

@@ -76,13 +76,6 @@ generate_ids!(
"SELECT EXISTS(SELECT 1 FROM team_members WHERE id=$1)",
TeamMemberId
);
generate_ids!(
pub generate_state_id,
StateId,
8,
"SELECT EXISTS(SELECT 1 FROM states WHERE id=$1)",
StateId
);
generate_ids!(
pub generate_pat_id,
PatId,
@@ -189,10 +182,6 @@ pub struct ReportTypeId(pub i32);
#[sqlx(transparent)]
pub struct FileId(pub i64);
#[derive(Copy, Clone, Debug, Type)]
#[sqlx(transparent)]
pub struct StateId(pub i64);
#[derive(Copy, Clone, Debug, Type, Deserialize, Serialize)]
#[sqlx(transparent)]
pub struct PatId(pub i64);

View File

@@ -1,6 +1,7 @@
use thiserror::Error;
pub mod categories;
pub mod flow_item;
pub mod ids;
pub mod notification_item;
pub mod pat_item;

View File

@@ -135,7 +135,8 @@ impl Report {
.await?;
if let Some(thread_id) = thread_id {
crate::database::models::Thread::remove_full(ThreadId(thread_id.id), transaction).await?;
crate::database::models::Thread::remove_full(ThreadId(thread_id.id), transaction)
.await?;
}
sqlx::query!(

View File

@@ -24,6 +24,8 @@ pub struct User {
pub microsoft_id: Option<String>,
pub password: Option<String>,
pub totp_secret: Option<String>,
pub username: String,
pub name: Option<String>,
pub email: Option<String>,
@@ -204,7 +206,7 @@ impl User {
created, role, badges,
balance, payout_wallet, payout_wallet_type, payout_address,
github_id, discord_id, gitlab_id, google_id, steam_id, microsoft_id,
email_verified, password
email_verified, password, totp_secret
FROM users
WHERE id = ANY($1) OR LOWER(username) = ANY($2)
",
@@ -240,6 +242,7 @@ impl User {
.map(|x| RecipientType::from_string(&x)),
payout_address: u.payout_address,
password: u.password,
totp_secret: u.totp_secret,
}))
})
.try_collect::<Vec<User>>()
@@ -272,6 +275,23 @@ impl User {
Ok(found_users)
}
pub async fn get_email<'a, E>(email: &str, exec: E) -> Result<Option<UserId>, sqlx::Error>
where
E: sqlx::Executor<'a, Database = sqlx::Postgres> + Copy,
{
let user_pass = sqlx::query!(
"
SELECT id FROM users
WHERE email = $1
",
email
)
.fetch_optional(exec)
.await?;
Ok(user_pass.map(|x| UserId(x.id)))
}
pub async fn get_projects<'a, E>(
user_id: UserId,
exec: E,
@@ -298,6 +318,30 @@ impl User {
Ok(projects)
}
pub async fn get_backup_codes<'a, E>(
user_id: UserId,
exec: E,
) -> Result<Vec<String>, sqlx::Error>
where
E: sqlx::Executor<'a, Database = sqlx::Postgres> + Copy,
{
use futures::stream::TryStreamExt;
let codes = sqlx::query!(
"
SELECT code FROM user_backup_codes
WHERE user_id = $1
",
user_id as UserId,
)
.fetch_many(exec)
.try_filter_map(|e| async { Ok(e.right().map(|m| to_base62(m.code as u64))) })
.try_collect::<Vec<String>>()
.await?;
Ok(codes)
}
pub async fn clear_caches(
user_ids: &[(UserId, Option<String>)],
redis: &deadpool_redis::Pool,
@@ -486,6 +530,36 @@ impl User {
.execute(&mut *transaction)
.await?;
sqlx::query!(
"
DELETE FROM sessions
WHERE user_id = $1
",
id as UserId,
)
.execute(&mut *transaction)
.await?;
sqlx::query!(
"
DELETE FROM pats
WHERE user_id = $1
",
id as UserId,
)
.execute(&mut *transaction)
.await?;
sqlx::query!(
"
DELETE FROM user_backup_codes
WHERE user_id = $1
",
id as UserId,
)
.execute(&mut *transaction)
.await?;
sqlx::query!(
"
DELETE FROM users

View File

@@ -1,6 +1,5 @@
use super::ids::*;
use super::DatabaseError;
use crate::models::ids::base62_impl::parse_base62;
use crate::models::projects::{FileType, VersionStatus};
use chrono::{DateTime, Utc};
use itertools::Itertools;
@@ -751,40 +750,6 @@ impl Version {
Ok(())
}
// TODO: Needs to be cached
pub async fn get_full_from_id_slug<'a, 'b, E>(
project_id_or_slug: &str,
slug: &str,
executor: E,
redis: &deadpool_redis::Pool,
) -> Result<Option<QueryVersion>, DatabaseError>
where
E: sqlx::Executor<'a, Database = sqlx::Postgres> + Copy,
{
let project_id_opt = parse_base62(project_id_or_slug).ok().map(|x| x as i64);
let id_opt = parse_base62(slug).ok().map(|x| x as i64);
let id = sqlx::query!(
"
SELECT v.id FROM versions v
INNER JOIN mods m ON mod_id = m.id
WHERE (m.id = $1 OR m.slug = $2) AND (v.id = $3 OR v.version_number = $4)
ORDER BY date_published ASC
",
project_id_opt,
project_id_or_slug,
id_opt,
slug
)
.fetch_optional(executor)
.await?;
if let Some(version_id) = id {
Ok(Version::get(VersionId(version_id.id), executor, redis).await?)
} else {
Ok(None)
}
}
}
#[derive(Clone, Deserialize, Serialize)]

View File

@@ -130,34 +130,6 @@ async fn main() -> std::io::Result<()> {
}
});
// Deleting old authentication states from the database every 15 minutes
let pool_ref = pool.clone();
scheduler.run(std::time::Duration::from_secs(15 * 60), move || {
let pool_ref = pool_ref.clone();
// Use sqlx to delete records more than an hour old
info!("Deleting old records from temporary tables");
async move {
let states_result = sqlx::query!(
"
DELETE FROM states
WHERE expires < CURRENT_DATE
"
)
.execute(&pool_ref)
.await;
if let Err(e) = states_result {
warn!(
"Deleting old records from temporary table states failed: {:?}",
e
);
}
info!("Finished deleting old records from temporary tables");
}
});
// Changes statuses of scheduled projects/versions
let pool_ref = pool.clone();
// TODO: Clear cache when these are run

View File

@@ -1,9 +1,9 @@
use super::ids::Base62Id;
use crate::models::ids::{ProjectId, ReportId};
use crate::models::projects::ProjectStatus;
use crate::models::users::{User, UserId};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::models::ids::{ProjectId, ReportId};
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(from = "Base62Id")]

View File

@@ -1,4 +1,5 @@
use super::ids::Base62Id;
use crate::auth::flows::AuthProvider;
use chrono::{DateTime, Utc};
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
@@ -39,19 +40,21 @@ pub struct User {
pub id: UserId,
pub username: String,
pub name: Option<String>,
pub email: Option<String>,
pub avatar_url: Option<String>,
pub bio: Option<String>,
pub created: DateTime<Utc>,
pub role: Role,
pub badges: Badges,
pub payout_data: Option<UserPayoutData>,
pub auth_providers: Option<Vec<AuthProvider>>,
pub email: Option<String>,
pub email_verified: Option<bool>,
pub has_password: Option<bool>,
pub has_totp: Option<bool>,
// DEPRECATED. Always returns None
pub github_id: Option<u64>,
// pub discord_id: Option<u64>,
// pub google_id: Option<u128>,
// pub microsoft_id: Option<u64>,
// pub apple_id: Option<u64>,
// pub gitlab_id: Option<u64>,
}
#[derive(Serialize, Deserialize, Clone)]
@@ -138,18 +141,17 @@ impl From<DBUser> for User {
username: data.username,
name: data.name,
email: None,
email_verified: None,
avatar_url: data.avatar_url,
bio: data.bio,
created: data.created,
role: Role::from_string(&data.role),
badges: data.badges,
payout_data: None,
auth_providers: None,
has_password: None,
has_totp: None,
github_id: None,
// discord_id: None,
// google_id: None,
// microsoft_id: None,
// apple_id: None,
// gitlab_id: None,
}
}
}

View File

@@ -785,8 +785,8 @@ async fn project_create_inner(
project_id: Some(id),
report_id: None,
}
.insert(&mut *transaction)
.await?;
.insert(&mut *transaction)
.await?;
let response = crate::models::projects::Project {
id: project_id,

View File

@@ -584,8 +584,8 @@ pub async fn project_edit(
},
thread_id: project_item.thread_id,
}
.insert(&mut transaction)
.await?;
.insert(&mut transaction)
.await?;
sqlx::query!(
"

View File

@@ -153,8 +153,8 @@ pub async fn report_create(
project_id: None,
report_id: Some(report.id),
}
.insert(&mut transaction)
.await?;
.insert(&mut transaction)
.await?;
transaction.commit().await?;
@@ -395,8 +395,8 @@ pub async fn report_edit(
},
thread_id: report.thread_id,
}
.insert(&mut transaction)
.await?;
.insert(&mut transaction)
.await?;
sqlx::query!(
"

View File

@@ -45,12 +45,14 @@ pub async fn is_authorized_thread(
report_id as database::models::ids::ReportId,
user_id as database::models::ids::UserId,
)
.fetch_one(pool)
.await?
.exists;
.fetch_one(pool)
.await?
.exists;
report_exists.unwrap_or(false)
} else { false }
} else {
false
}
}
ThreadType::Project => {
if let Some(project_id) = thread.project_id {
@@ -379,12 +381,7 @@ pub async fn thread_send_message(
.await?;
let mod_notif = if let Some(project_id) = thread.project_id {
let project = database::models::Project::get_id(
project_id,
&**pool,
&redis,
)
.await?;
let project = database::models::Project::get_id(project_id, &**pool, &redis).await?;
if let Some(project) = project {
if project.inner.status != ProjectStatus::Processing && user.role.is_mod() {
@@ -393,7 +390,7 @@ pub async fn thread_send_message(
&**pool,
&redis,
)
.await?;
.await?;
NotificationBuilder {
body: NotificationBody::ModeratorMessage {
@@ -403,11 +400,11 @@ pub async fn thread_send_message(
report_id: None,
},
}
.insert_many(
members.into_iter().map(|x| x.user_id).collect(),
&mut transaction,
)
.await?;
.insert_many(
members.into_iter().map(|x| x.user_id).collect(),
&mut transaction,
)
.await?;
}
project.inner.status == ProjectStatus::Processing && !user.role.is_mod()
@@ -415,11 +412,7 @@ pub async fn thread_send_message(
!user.role.is_mod()
}
} else if let Some(report_id) = thread.report_id {
let report = database::models::report_item::Report::get(
report_id,
&**pool,
)
.await?;
let report = database::models::report_item::Report::get(report_id, &**pool).await?;
if let Some(report) = report {
if report.closed && !user.role.is_mod() {

View File

@@ -1,3 +1,4 @@
use crate::auth::flows::AuthProvider;
use crate::auth::{get_user_from_headers, AuthenticationError};
use crate::database::models::User;
use crate::file_hosting::FileHost;
@@ -193,6 +194,7 @@ pub struct EditUser {
#[validate]
pub payout_data: Option<Option<EditPayoutData>>,
pub password: Option<(Option<String>, Option<String>)>,
pub remove_auth_providers: Option<Vec<AuthProvider>>,
}
#[derive(Serialize, Deserialize, Validate)]
@@ -412,7 +414,7 @@ pub async fn user_edit(
));
}
if let Some(pass) = actual_user.password {
if let Some(pass) = actual_user.password.as_ref() {
let old_password = old_password.as_ref().ok_or_else(|| {
ApiError::CustomAuthentication(
"You must specify the old password to change your password!"
@@ -421,7 +423,7 @@ pub async fn user_edit(
})?;
let hasher = Argon2::default();
hasher.verify_password(old_password.as_bytes(), &PasswordHash::new(&pass)?)?;
hasher.verify_password(old_password.as_bytes(), &PasswordHash::new(pass)?)?;
}
let update_password = if let Some(new_password) = new_password {
@@ -483,6 +485,116 @@ pub async fn user_edit(
.await?;
}
if let Some(remove_auth_providers) = &new_user.remove_auth_providers {
if !scopes.contains(Scopes::USER_AUTH_WRITE) {
return Err(ApiError::Authentication(
AuthenticationError::InvalidCredentials,
));
}
let mut auth_providers = Vec::new();
if actual_user.github_id.is_some() {
auth_providers.push(AuthProvider::GitHub)
}
if actual_user.gitlab_id.is_some() {
auth_providers.push(AuthProvider::GitLab)
}
if actual_user.discord_id.is_some() {
auth_providers.push(AuthProvider::Discord)
}
if actual_user.google_id.is_some() {
auth_providers.push(AuthProvider::Google)
}
if actual_user.microsoft_id.is_some() {
auth_providers.push(AuthProvider::Microsoft)
}
if actual_user.steam_id.is_some() {
auth_providers.push(AuthProvider::Steam)
}
if auth_providers.len() <= remove_auth_providers.len()
&& actual_user.password.is_none()
{
return Err(ApiError::InvalidInput(
"You must have another authentication method added to this method!"
.to_string(),
));
}
if remove_auth_providers.contains(&AuthProvider::GitHub) {
sqlx::query!(
"
UPDATE users
SET github_id = NULL
WHERE (id = $1)
",
id as crate::database::models::ids::UserId,
)
.execute(&mut *transaction)
.await?;
}
if remove_auth_providers.contains(&AuthProvider::GitLab) {
sqlx::query!(
"
UPDATE users
SET gitlab_id = NULL
WHERE (id = $1)
",
id as crate::database::models::ids::UserId,
)
.execute(&mut *transaction)
.await?;
}
if remove_auth_providers.contains(&AuthProvider::Google) {
sqlx::query!(
"
UPDATE users
SET google_id = NULL
WHERE (id = $1)
",
id as crate::database::models::ids::UserId,
)
.execute(&mut *transaction)
.await?;
}
if remove_auth_providers.contains(&AuthProvider::Steam) {
sqlx::query!(
"
UPDATE users
SET steam_id = NULL
WHERE (id = $1)
",
id as crate::database::models::ids::UserId,
)
.execute(&mut *transaction)
.await?;
}
if remove_auth_providers.contains(&AuthProvider::Discord) {
sqlx::query!(
"
UPDATE users
SET discord_id = NULL
WHERE (id = $1)
",
id as crate::database::models::ids::UserId,
)
.execute(&mut *transaction)
.await?;
}
if remove_auth_providers.contains(&AuthProvider::Microsoft) {
sqlx::query!(
"
UPDATE users
SET microsoft_id = NULL
WHERE (id = $1)
",
id as crate::database::models::ids::UserId,
)
.execute(&mut *transaction)
.await?;
}
}
User::clear_caches(&[(id, Some(actual_user.username))], &redis).await?;
transaction.commit().await?;
Ok(HttpResponse::NoContent().body(""))

View File

@@ -4,6 +4,7 @@ use crate::auth::{
};
use crate::database;
use crate::models;
use crate::models::ids::base62_impl::parse_base62;
use crate::models::pats::Scopes;
use crate::models::projects::{Dependency, FileType, VersionStatus, VersionType};
use crate::models::teams::Permissions;
@@ -165,8 +166,8 @@ pub async fn version_project_get(
session_queue: web::Data<AuthQueue>,
) -> Result<HttpResponse, ApiError> {
let id = info.into_inner();
let version_data =
database::models::Version::get_full_from_id_slug(&id.0, &id.1, &**pool, &redis).await?;
let result = database::models::Project::get(&id.0, &**pool, &redis).await?;
let user_option = get_user_from_headers(
&req,
@@ -179,9 +180,23 @@ pub async fn version_project_get(
.map(|x| x.1)
.ok();
if let Some(data) = version_data {
if is_authorized_version(&data.inner, &user_option, &pool).await? {
return Ok(HttpResponse::Ok().json(models::projects::Version::from(data)));
if let Some(project) = result {
if !is_authorized(&project.inner, &user_option, &pool).await? {
return Ok(HttpResponse::NotFound().body(""));
}
let versions =
database::models::Version::get_many(&project.versions, &**pool, &redis).await?;
let id_opt = parse_base62(&id.1).ok();
let version = versions
.into_iter()
.find(|x| Some(x.inner.id.0 as u64) == id_opt || x.inner.version_number == id.1);
if let Some(version) = version {
if is_authorized_version(&version.inner, &user_option, &pool).await? {
return Ok(HttpResponse::Ok().json(models::projects::Version::from(version)));
}
}
}