2FA + Add/Remove Auth Providers (#652)

* 2FA + Add/Remove Auth Providers

* fix fmt issue
This commit is contained in:
Geometrically
2023-07-11 19:13:07 -07:00
committed by GitHub
parent 7fbb8838e7
commit 4bdf9bff3a
20 changed files with 1483 additions and 667 deletions

View File

@@ -0,0 +1,88 @@
use super::ids::*;
use crate::auth::flows::AuthProvider;
use crate::database::models::DatabaseError;
use chrono::{DateTime, Timelike, Utc};
use rand::distributions::Alphanumeric;
use rand::Rng;
use rand_chacha::rand_core::SeedableRng;
use rand_chacha::ChaCha20Rng;
use redis::cmd;
use serde::{Deserialize, Serialize};
const FLOWS_NAMESPACE: &str = "flows";
#[derive(Deserialize, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Flow {
OAuth {
user_id: Option<UserId>,
url: String,
provider: AuthProvider,
},
Login2FA {
user_id: UserId,
},
Initialize2FA {
user_id: UserId,
secret: String,
},
ForgotPassword {
user_id: UserId,
},
ConfirmEmail {
user_id: UserId,
confirm_email: String,
},
}
impl Flow {
pub async fn insert(
&self,
expires: DateTime<Utc>,
redis: &deadpool_redis::Pool,
) -> Result<String, DatabaseError> {
let mut redis = redis.get().await?;
let flow = ChaCha20Rng::from_entropy()
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
.collect::<String>();
cmd("SET")
.arg(format!("{}:{}", FLOWS_NAMESPACE, flow))
.arg(serde_json::to_string(&self)?)
.arg("EX")
.arg(expires.second())
.query_async::<_, ()>(&mut redis)
.await?;
Ok(flow)
}
pub async fn get(
id: &str,
redis: &deadpool_redis::Pool,
) -> Result<Option<Flow>, DatabaseError> {
let mut redis = redis.get().await?;
let res = cmd("GET")
.arg(format!("{}:{}", FLOWS_NAMESPACE, id))
.query_async::<_, Option<String>>(&mut redis)
.await?;
Ok(res.and_then(|x| serde_json::from_str(&x).ok()))
}
pub async fn remove(
id: &str,
redis: &deadpool_redis::Pool,
) -> Result<Option<()>, DatabaseError> {
let mut redis = redis.get().await?;
let mut cmd = cmd("DEL");
cmd.arg(format!("{}:{}", FLOWS_NAMESPACE, id));
cmd.query_async::<_, ()>(&mut redis).await?;
Ok(Some(()))
}
}

View File

@@ -76,13 +76,6 @@ generate_ids!(
"SELECT EXISTS(SELECT 1 FROM team_members WHERE id=$1)",
TeamMemberId
);
generate_ids!(
pub generate_state_id,
StateId,
8,
"SELECT EXISTS(SELECT 1 FROM states WHERE id=$1)",
StateId
);
generate_ids!(
pub generate_pat_id,
PatId,
@@ -189,10 +182,6 @@ pub struct ReportTypeId(pub i32);
#[sqlx(transparent)]
pub struct FileId(pub i64);
#[derive(Copy, Clone, Debug, Type)]
#[sqlx(transparent)]
pub struct StateId(pub i64);
#[derive(Copy, Clone, Debug, Type, Deserialize, Serialize)]
#[sqlx(transparent)]
pub struct PatId(pub i64);

View File

@@ -1,6 +1,7 @@
use thiserror::Error;
pub mod categories;
pub mod flow_item;
pub mod ids;
pub mod notification_item;
pub mod pat_item;

View File

@@ -135,7 +135,8 @@ impl Report {
.await?;
if let Some(thread_id) = thread_id {
crate::database::models::Thread::remove_full(ThreadId(thread_id.id), transaction).await?;
crate::database::models::Thread::remove_full(ThreadId(thread_id.id), transaction)
.await?;
}
sqlx::query!(

View File

@@ -24,6 +24,8 @@ pub struct User {
pub microsoft_id: Option<String>,
pub password: Option<String>,
pub totp_secret: Option<String>,
pub username: String,
pub name: Option<String>,
pub email: Option<String>,
@@ -204,7 +206,7 @@ impl User {
created, role, badges,
balance, payout_wallet, payout_wallet_type, payout_address,
github_id, discord_id, gitlab_id, google_id, steam_id, microsoft_id,
email_verified, password
email_verified, password, totp_secret
FROM users
WHERE id = ANY($1) OR LOWER(username) = ANY($2)
",
@@ -240,6 +242,7 @@ impl User {
.map(|x| RecipientType::from_string(&x)),
payout_address: u.payout_address,
password: u.password,
totp_secret: u.totp_secret,
}))
})
.try_collect::<Vec<User>>()
@@ -272,6 +275,23 @@ impl User {
Ok(found_users)
}
pub async fn get_email<'a, E>(email: &str, exec: E) -> Result<Option<UserId>, sqlx::Error>
where
E: sqlx::Executor<'a, Database = sqlx::Postgres> + Copy,
{
let user_pass = sqlx::query!(
"
SELECT id FROM users
WHERE email = $1
",
email
)
.fetch_optional(exec)
.await?;
Ok(user_pass.map(|x| UserId(x.id)))
}
pub async fn get_projects<'a, E>(
user_id: UserId,
exec: E,
@@ -298,6 +318,30 @@ impl User {
Ok(projects)
}
pub async fn get_backup_codes<'a, E>(
user_id: UserId,
exec: E,
) -> Result<Vec<String>, sqlx::Error>
where
E: sqlx::Executor<'a, Database = sqlx::Postgres> + Copy,
{
use futures::stream::TryStreamExt;
let codes = sqlx::query!(
"
SELECT code FROM user_backup_codes
WHERE user_id = $1
",
user_id as UserId,
)
.fetch_many(exec)
.try_filter_map(|e| async { Ok(e.right().map(|m| to_base62(m.code as u64))) })
.try_collect::<Vec<String>>()
.await?;
Ok(codes)
}
pub async fn clear_caches(
user_ids: &[(UserId, Option<String>)],
redis: &deadpool_redis::Pool,
@@ -486,6 +530,36 @@ impl User {
.execute(&mut *transaction)
.await?;
sqlx::query!(
"
DELETE FROM sessions
WHERE user_id = $1
",
id as UserId,
)
.execute(&mut *transaction)
.await?;
sqlx::query!(
"
DELETE FROM pats
WHERE user_id = $1
",
id as UserId,
)
.execute(&mut *transaction)
.await?;
sqlx::query!(
"
DELETE FROM user_backup_codes
WHERE user_id = $1
",
id as UserId,
)
.execute(&mut *transaction)
.await?;
sqlx::query!(
"
DELETE FROM users

View File

@@ -1,6 +1,5 @@
use super::ids::*;
use super::DatabaseError;
use crate::models::ids::base62_impl::parse_base62;
use crate::models::projects::{FileType, VersionStatus};
use chrono::{DateTime, Utc};
use itertools::Itertools;
@@ -751,40 +750,6 @@ impl Version {
Ok(())
}
// TODO: Needs to be cached
pub async fn get_full_from_id_slug<'a, 'b, E>(
project_id_or_slug: &str,
slug: &str,
executor: E,
redis: &deadpool_redis::Pool,
) -> Result<Option<QueryVersion>, DatabaseError>
where
E: sqlx::Executor<'a, Database = sqlx::Postgres> + Copy,
{
let project_id_opt = parse_base62(project_id_or_slug).ok().map(|x| x as i64);
let id_opt = parse_base62(slug).ok().map(|x| x as i64);
let id = sqlx::query!(
"
SELECT v.id FROM versions v
INNER JOIN mods m ON mod_id = m.id
WHERE (m.id = $1 OR m.slug = $2) AND (v.id = $3 OR v.version_number = $4)
ORDER BY date_published ASC
",
project_id_opt,
project_id_or_slug,
id_opt,
slug
)
.fetch_optional(executor)
.await?;
if let Some(version_id) = id {
Ok(Version::get(VersionId(version_id.id), executor, redis).await?)
} else {
Ok(None)
}
}
}
#[derive(Clone, Deserialize, Serialize)]