OAuth 2.0 Authorization Server [MOD-559] (#733)

* WIP end-of-day push

* Authorize endpoint, accept endpoints, DB stuff for oauth clients, their redirects, and client authorizations

* OAuth Client create route

* Get user clients

* Client delete

* Edit oauth client

* Include redirects in edit client route

* Database stuff for tokens

* Reorg oauth stuff out of auth/flows and into its own module

* Impl OAuth get access token endpoint

* Accept oauth access tokens as auth and update through AuthQueue

* User OAuth authorization management routes

* Forgot to actually add the routes lol

* Bit o cleanup

* Happy path test for OAuth and minor fixes for things it found

* Add dummy data oauth client (and detect/handle dummy data version changes)

* More tests

* Another test

* More tests and reject endpoint

* Test oauth client and authorization management routes

* cargo sqlx prepare

* dead code warning

* Auto clippy fixes

* Uri refactoring

* minor name improvement

* Don't compile-time check the test sqlx queries

* Trying to fix db concurrency problem to get tests to pass

* Try fix from test PR

* Fixes for updated sqlx

* Prevent restricted scopes from being requested or issued

* Get OAuth client(s)

* Remove joined oauth client info from authorization returns

* Add default conversion to OAuthError::error so we can use ?

* Rework routes

* Consolidate scopes into SESSION_ACCESS

* Cargo sqlx prepare

* Parse to OAuthClientId automatically through serde and actix

* Cargo clippy

* Remove validation requiring 1 redirect URI on oauth client creation

* Use serde(flatten) on OAuthClientCreationResult
This commit is contained in:
Jackson Kruger
2023-10-30 11:14:38 -05:00
committed by GitHub
parent 8803e11945
commit 6cfd4637db
54 changed files with 3658 additions and 135 deletions

View File

@@ -1,7 +1,8 @@
use super::ids::*;
use crate::auth::flows::AuthProvider;
use crate::auth::oauth::uris::OAuthRedirectUris;
use crate::database::models::DatabaseError;
use crate::database::redis::RedisPool;
use crate::{auth::flows::AuthProvider, models::pats::Scopes};
use chrono::Duration;
use rand::distributions::Alphanumeric;
use rand::Rng;
@@ -34,6 +35,21 @@ pub enum Flow {
confirm_email: String,
},
MinecraftAuth,
InitOAuthAppApproval {
user_id: UserId,
client_id: OAuthClientId,
existing_authorization_id: Option<OAuthClientAuthorizationId>,
scopes: Scopes,
redirect_uris: OAuthRedirectUris,
state: Option<String>,
},
OAuthAuthorizationCodeSupplied {
user_id: UserId,
client_id: OAuthClientId,
authorization_id: OAuthClientAuthorizationId,
scopes: Scopes,
original_redirect_uri: Option<String>, // Needed for https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
},
}
impl Flow {
@@ -58,6 +74,22 @@ impl Flow {
redis.get_deserialized_from_json(FLOWS_NAMESPACE, id).await
}
/// Gets the flow and removes it from the cache, but only removes if the flow was present and the predicate returned true
/// The predicate should validate that the flow being removed is the correct one, as a security measure
pub async fn take_if(
id: &str,
predicate: impl FnOnce(&Flow) -> bool,
redis: &RedisPool,
) -> Result<Option<Flow>, DatabaseError> {
let flow = Self::get(id, redis).await?;
if let Some(flow) = flow.as_ref() {
if predicate(flow) {
Self::remove(id, redis).await?;
}
}
Ok(flow)
}
pub async fn remove(id: &str, redis: &RedisPool) -> Result<Option<()>, DatabaseError> {
redis.delete(FLOWS_NAMESPACE, id).await?;
Ok(Some(()))

View File

@@ -152,6 +152,38 @@ generate_ids!(
ImageId
);
generate_ids!(
pub generate_oauth_client_authorization_id,
OAuthClientAuthorizationId,
8,
"SELECT EXISTS(SELECT 1 FROM oauth_client_authorizations WHERE id=$1)",
OAuthClientAuthorizationId
);
generate_ids!(
pub generate_oauth_client_id,
OAuthClientId,
8,
"SELECT EXISTS(SELECT 1 FROM oauth_clients WHERE id=$1)",
OAuthClientId
);
generate_ids!(
pub generate_oauth_redirect_id,
OAuthRedirectUriId,
8,
"SELECT EXISTS(SELECT 1 FROM oauth_client_redirect_uris WHERE id=$1)",
OAuthRedirectUriId
);
generate_ids!(
pub generate_oauth_access_token_id,
OAuthAccessTokenId,
8,
"SELECT EXISTS(SELECT 1 FROM oauth_access_tokens WHERE id=$1)",
OAuthAccessTokenId
);
#[derive(Copy, Clone, Debug, PartialEq, Eq, Type, Hash, Serialize, Deserialize)]
#[sqlx(transparent)]
pub struct UserId(pub i64);
@@ -238,6 +270,22 @@ pub struct SessionId(pub i64);
#[sqlx(transparent)]
pub struct ImageId(pub i64);
#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)]
#[sqlx(transparent)]
pub struct OAuthClientId(pub i64);
#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)]
#[sqlx(transparent)]
pub struct OAuthClientAuthorizationId(pub i64);
#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)]
#[sqlx(transparent)]
pub struct OAuthRedirectUriId(pub i64);
#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)]
#[sqlx(transparent)]
pub struct OAuthAccessTokenId(pub i64);
use crate::models::ids;
impl From<ids::ProjectId> for ProjectId {
@@ -360,3 +408,23 @@ impl From<PatId> for ids::PatId {
ids::PatId(id.0 as u64)
}
}
impl From<OAuthClientId> for ids::OAuthClientId {
fn from(id: OAuthClientId) -> Self {
ids::OAuthClientId(id.0 as u64)
}
}
impl From<ids::OAuthClientId> for OAuthClientId {
fn from(id: ids::OAuthClientId) -> Self {
Self(id.0 as i64)
}
}
impl From<OAuthRedirectUriId> for ids::OAuthRedirectUriId {
fn from(id: OAuthRedirectUriId) -> Self {
ids::OAuthRedirectUriId(id.0 as u64)
}
}
impl From<OAuthClientAuthorizationId> for ids::OAuthClientAuthorizationId {
fn from(id: OAuthClientAuthorizationId) -> Self {
ids::OAuthClientAuthorizationId(id.0 as u64)
}
}

View File

@@ -6,6 +6,9 @@ pub mod flow_item;
pub mod ids;
pub mod image_item;
pub mod notification_item;
pub mod oauth_client_authorization_item;
pub mod oauth_client_item;
pub mod oauth_token_item;
pub mod organization_item;
pub mod pat_item;
pub mod project_item;
@@ -19,6 +22,7 @@ pub mod version_item;
pub use collection_item::Collection;
pub use ids::*;
pub use image_item::Image;
pub use oauth_client_item::OAuthClient;
pub use organization_item::Organization;
pub use project_item::Project;
pub use team_item::Team;

View File

@@ -0,0 +1,126 @@
use chrono::{DateTime, Utc};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use crate::models::pats::Scopes;
use super::{DatabaseError, OAuthClientAuthorizationId, OAuthClientId, UserId};
#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct OAuthClientAuthorization {
pub id: OAuthClientAuthorizationId,
pub client_id: OAuthClientId,
pub user_id: UserId,
pub scopes: Scopes,
pub created: DateTime<Utc>,
}
struct AuthorizationQueryResult {
id: i64,
client_id: i64,
user_id: i64,
scopes: i64,
created: DateTime<Utc>,
}
impl From<AuthorizationQueryResult> for OAuthClientAuthorization {
fn from(value: AuthorizationQueryResult) -> Self {
OAuthClientAuthorization {
id: OAuthClientAuthorizationId(value.id),
client_id: OAuthClientId(value.client_id),
user_id: UserId(value.user_id),
scopes: Scopes::from_postgres(value.scopes),
created: value.created,
}
}
}
impl OAuthClientAuthorization {
pub async fn get(
client_id: OAuthClientId,
user_id: UserId,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<Option<OAuthClientAuthorization>, DatabaseError> {
let value = sqlx::query_as!(
AuthorizationQueryResult,
"
SELECT id, client_id, user_id, scopes, created
FROM oauth_client_authorizations
WHERE client_id=$1 AND user_id=$2
",
client_id.0,
user_id.0,
)
.fetch_optional(exec)
.await?;
Ok(value.map(|r| r.into()))
}
pub async fn get_all_for_user(
user_id: UserId,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<Vec<OAuthClientAuthorization>, DatabaseError> {
let results = sqlx::query_as!(
AuthorizationQueryResult,
"
SELECT id, client_id, user_id, scopes, created
FROM oauth_client_authorizations
WHERE user_id=$1
",
user_id.0
)
.fetch_all(exec)
.await?;
Ok(results.into_iter().map(|r| r.into()).collect_vec())
}
pub async fn upsert(
id: OAuthClientAuthorizationId,
client_id: OAuthClientId,
user_id: UserId,
scopes: Scopes,
transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
) -> Result<(), DatabaseError> {
sqlx::query!(
"
INSERT INTO oauth_client_authorizations (
id, client_id, user_id, scopes
)
VALUES (
$1, $2, $3, $4
)
ON CONFLICT (id)
DO UPDATE SET scopes = EXCLUDED.scopes
",
id.0,
client_id.0,
user_id.0,
scopes.bits() as i64,
)
.execute(&mut **transaction)
.await?;
Ok(())
}
pub async fn remove(
client_id: OAuthClientId,
user_id: UserId,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<(), DatabaseError> {
sqlx::query!(
"
DELETE FROM oauth_client_authorizations
WHERE client_id=$1 AND user_id=$2
",
client_id.0,
user_id.0
)
.execute(exec)
.await?;
Ok(())
}
}

View File

@@ -0,0 +1,245 @@
use chrono::{DateTime, Utc};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use sha2::Digest;
use super::{DatabaseError, OAuthClientId, OAuthRedirectUriId, UserId};
use crate::models::pats::Scopes;
#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct OAuthRedirectUri {
pub id: OAuthRedirectUriId,
pub client_id: OAuthClientId,
pub uri: String,
}
#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct OAuthClient {
pub id: OAuthClientId,
pub name: String,
pub icon_url: Option<String>,
pub max_scopes: Scopes,
pub secret_hash: String,
pub redirect_uris: Vec<OAuthRedirectUri>,
pub created: DateTime<Utc>,
pub created_by: UserId,
}
struct ClientQueryResult {
id: i64,
name: String,
icon_url: Option<String>,
max_scopes: i64,
secret_hash: String,
created: DateTime<Utc>,
created_by: i64,
uri_ids: Option<Vec<i64>>,
uri_vals: Option<Vec<String>>,
}
macro_rules! select_clients_with_predicate {
($predicate:tt, $param:ident) => {
// The columns in this query have nullability type hints, because for some reason
// the combination of the JOIN and filter using ANY makes sqlx think all columns are nullable
// https://docs.rs/sqlx/latest/sqlx/macro.query.html#force-nullable
sqlx::query_as!(
ClientQueryResult,
r#"
SELECT
clients.id as "id!",
clients.name as "name!",
clients.icon_url as "icon_url?",
clients.max_scopes as "max_scopes!",
clients.secret_hash as "secret_hash!",
clients.created as "created!",
clients.created_by as "created_by!",
uris.uri_ids as "uri_ids?",
uris.uri_vals as "uri_vals?"
FROM oauth_clients clients
LEFT JOIN (
SELECT client_id, array_agg(id) as uri_ids, array_agg(uri) as uri_vals
FROM oauth_client_redirect_uris
GROUP BY client_id
) uris ON clients.id = uris.client_id
"#
+ $predicate,
$param
)
};
}
impl OAuthClient {
pub async fn get(
id: OAuthClientId,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<Option<OAuthClient>, DatabaseError> {
Ok(Self::get_many(&[id], exec).await?.into_iter().next())
}
pub async fn get_many(
ids: &[OAuthClientId],
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<Vec<OAuthClient>, DatabaseError> {
let ids = ids.iter().map(|id| id.0).collect_vec();
let ids_ref: &[i64] = &ids;
let results =
select_clients_with_predicate!("WHERE clients.id = ANY($1::bigint[])", ids_ref)
.fetch_all(exec)
.await?;
Ok(results.into_iter().map(|r| r.into()).collect_vec())
}
pub async fn get_all_user_clients(
user_id: UserId,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<Vec<OAuthClient>, DatabaseError> {
let user_id_param = user_id.0;
let clients = select_clients_with_predicate!("WHERE created_by = $1", user_id_param)
.fetch_all(exec)
.await?;
Ok(clients.into_iter().map(|r| r.into()).collect())
}
pub async fn remove(
id: OAuthClientId,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<(), DatabaseError> {
// Cascades to oauth_client_redirect_uris, oauth_client_authorizations
sqlx::query!(
"
DELETE FROM oauth_clients
WHERE id = $1
",
id.0
)
.execute(exec)
.await?;
Ok(())
}
pub async fn insert(
&self,
transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
) -> Result<(), DatabaseError> {
sqlx::query!(
"
INSERT INTO oauth_clients (
id, name, icon_url, max_scopes, secret_hash, created_by
)
VALUES (
$1, $2, $3, $4, $5, $6
)
",
self.id.0,
self.name,
self.icon_url,
self.max_scopes.to_postgres(),
self.secret_hash,
self.created_by.0
)
.execute(&mut **transaction)
.await?;
Self::insert_redirect_uris(&self.redirect_uris, &mut **transaction).await?;
Ok(())
}
pub async fn update_editable_fields(
&self,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<(), DatabaseError> {
sqlx::query!(
"
UPDATE oauth_clients
SET name = $1, icon_url = $2, max_scopes = $3
WHERE (id = $4)
",
self.name,
self.icon_url,
self.max_scopes.to_postgres(),
self.id.0,
)
.execute(exec)
.await?;
Ok(())
}
pub async fn remove_redirect_uris(
ids: impl IntoIterator<Item = OAuthRedirectUriId>,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<(), DatabaseError> {
let ids = ids.into_iter().map(|id| id.0).collect_vec();
sqlx::query!(
"
DELETE FROM oauth_client_redirect_uris
WHERE id IN
(SELECT * FROM UNNEST($1::bigint[]))
",
&ids[..]
)
.execute(exec)
.await?;
Ok(())
}
pub async fn insert_redirect_uris(
uris: &[OAuthRedirectUri],
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<(), DatabaseError> {
let (ids, client_ids, uris): (Vec<_>, Vec<_>, Vec<_>) = uris
.iter()
.map(|r| (r.id.0, r.client_id.0, r.uri.clone()))
.multiunzip();
sqlx::query!(
"
INSERT INTO oauth_client_redirect_uris (id, client_id, uri)
SELECT * FROM UNNEST($1::bigint[], $2::bigint[], $3::varchar[])
",
&ids[..],
&client_ids[..],
&uris[..],
)
.execute(exec)
.await?;
Ok(())
}
pub fn hash_secret(secret: &str) -> String {
format!("{:x}", sha2::Sha512::digest(secret.as_bytes()))
}
}
impl From<ClientQueryResult> for OAuthClient {
fn from(r: ClientQueryResult) -> Self {
let redirects = if let (Some(ids), Some(uris)) = (r.uri_ids.as_ref(), r.uri_vals.as_ref()) {
ids.iter()
.zip(uris.iter())
.map(|(id, uri)| OAuthRedirectUri {
id: OAuthRedirectUriId(*id),
client_id: OAuthClientId(r.id),
uri: uri.to_string(),
})
.collect()
} else {
vec![]
};
OAuthClient {
id: OAuthClientId(r.id),
name: r.name,
icon_url: r.icon_url,
max_scopes: Scopes::from_postgres(r.max_scopes),
secret_hash: r.secret_hash,
redirect_uris: redirects,
created: r.created,
created_by: UserId(r.created_by),
}
}
}

View File

@@ -0,0 +1,95 @@
use super::{DatabaseError, OAuthAccessTokenId, OAuthClientAuthorizationId, OAuthClientId, UserId};
use crate::models::pats::Scopes;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sha2::Digest;
#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct OAuthAccessToken {
pub id: OAuthAccessTokenId,
pub authorization_id: OAuthClientAuthorizationId,
pub token_hash: String,
pub scopes: Scopes,
pub created: DateTime<Utc>,
pub expires: DateTime<Utc>,
pub last_used: Option<DateTime<Utc>>,
// Stored separately inside oauth_client_authorizations table
pub client_id: OAuthClientId,
pub user_id: UserId,
}
impl OAuthAccessToken {
pub async fn get(
token_hash: String,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<Option<OAuthAccessToken>, DatabaseError> {
let value = sqlx::query!(
"
SELECT
tokens.id,
tokens.authorization_id,
tokens.token_hash,
tokens.scopes,
tokens.created,
tokens.expires,
tokens.last_used,
auths.client_id,
auths.user_id
FROM oauth_access_tokens tokens
JOIN oauth_client_authorizations auths
ON tokens.authorization_id = auths.id
WHERE tokens.token_hash = $1
",
token_hash
)
.fetch_optional(exec)
.await?;
Ok(value.map(|r| OAuthAccessToken {
id: OAuthAccessTokenId(r.id),
authorization_id: OAuthClientAuthorizationId(r.authorization_id),
token_hash: r.token_hash,
scopes: Scopes::from_postgres(r.scopes),
created: r.created,
expires: r.expires,
last_used: r.last_used,
client_id: OAuthClientId(r.client_id),
user_id: UserId(r.user_id),
}))
}
/// Inserts and returns the time until the token expires
pub async fn insert(
&self,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<chrono::Duration, DatabaseError> {
let r = sqlx::query!(
"
INSERT INTO oauth_access_tokens (
id, authorization_id, token_hash, scopes, last_used
)
VALUES (
$1, $2, $3, $4, $5
)
RETURNING created, expires
",
self.id.0,
self.authorization_id.0,
self.token_hash,
self.scopes.to_postgres(),
Option::<DateTime<Utc>>::None
)
.fetch_one(exec)
.await?;
let (created, expires) = (r.created, r.expires);
let time_until_expiration = expires - created;
Ok(time_until_expiration)
}
pub fn hash_token(token: &str) -> String {
format!("{:x}", sha2::Sha512::digest(token.as_bytes()))
}
}