OAuth 2.0 Authorization Server [MOD-559] (#733)

* WIP end-of-day push

* Authorize endpoint, accept endpoints, DB stuff for oauth clients, their redirects, and client authorizations

* OAuth Client create route

* Get user clients

* Client delete

* Edit oauth client

* Include redirects in edit client route

* Database stuff for tokens

* Reorg oauth stuff out of auth/flows and into its own module

* Impl OAuth get access token endpoint

* Accept oauth access tokens as auth and update through AuthQueue

* User OAuth authorization management routes

* Forgot to actually add the routes lol

* Bit o cleanup

* Happy path test for OAuth and minor fixes for things it found

* Add dummy data oauth client (and detect/handle dummy data version changes)

* More tests

* Another test

* More tests and reject endpoint

* Test oauth client and authorization management routes

* cargo sqlx prepare

* dead code warning

* Auto clippy fixes

* Uri refactoring

* minor name improvement

* Don't compile-time check the test sqlx queries

* Trying to fix db concurrency problem to get tests to pass

* Try fix from test PR

* Fixes for updated sqlx

* Prevent restricted scopes from being requested or issued

* Get OAuth client(s)

* Remove joined oauth client info from authorization returns

* Add default conversion to OAuthError::error so we can use ?

* Rework routes

* Consolidate scopes into SESSION_ACCESS

* Cargo sqlx prepare

* Parse to OAuthClientId automatically through serde and actix

* Cargo clippy

* Remove validation requiring 1 redirect URI on oauth client creation

* Use serde(flatten) on OAuthClientCreationResult
This commit is contained in:
Jackson Kruger
2023-10-30 11:14:38 -05:00
committed by GitHub
parent 8803e11945
commit 6cfd4637db
54 changed files with 3658 additions and 135 deletions

View File

@@ -8,6 +8,25 @@ use crate::routes::ApiError;
use actix_web::web;
use sqlx::PgPool;
pub trait ValidateAuthorized {
fn validate_authorized(&self, user_option: Option<&User>) -> Result<(), ApiError>;
}
pub trait ValidateAllAuthorized {
fn validate_all_authorized(self, user_option: Option<&User>) -> Result<(), ApiError>;
}
impl<'a, T, A> ValidateAllAuthorized for T
where
T: IntoIterator<Item = &'a A>,
A: ValidateAuthorized + 'a,
{
fn validate_all_authorized(self, user_option: Option<&User>) -> Result<(), ApiError> {
self.into_iter()
.try_for_each(|c| c.validate_authorized(user_option))
}
}
pub async fn is_authorized(
project_data: &Project,
user_option: &Option<User>,
@@ -156,6 +175,23 @@ pub async fn is_authorized_version(
Ok(authorized)
}
impl ValidateAuthorized for crate::database::models::OAuthClient {
fn validate_authorized(&self, user_option: Option<&User>) -> Result<(), ApiError> {
if let Some(user) = user_option {
if user.role.is_mod() || user.id == self.created_by.into() {
return Ok(());
} else {
return Err(crate::routes::ApiError::CustomAuthentication(
"You don't have sufficient permissions to interact with this OAuth application"
.to_string(),
));
}
}
Ok(())
}
}
pub async fn filter_authorized_versions(
versions: Vec<QueryVersion>,
user_option: &Option<User>,

View File

@@ -1,11 +1,11 @@
pub mod checks;
pub mod email;
pub mod flows;
pub mod oauth;
pub mod pats;
pub mod session;
mod templates;
pub mod validate;
pub use checks::{
filter_authorized_projects, filter_authorized_versions, is_authorized, is_authorized_version,
};

176
src/auth/oauth/errors.rs Normal file
View File

@@ -0,0 +1,176 @@
use super::ValidatedRedirectUri;
use crate::auth::AuthenticationError;
use crate::models::error::ApiError;
use crate::models::ids::DecodingError;
use actix_web::http::StatusCode;
use actix_web::HttpResponse;
#[derive(thiserror::Error, Debug)]
#[error("{}", .error_type)]
pub struct OAuthError {
#[source]
pub error_type: OAuthErrorType,
pub state: Option<String>,
pub valid_redirect_uri: Option<ValidatedRedirectUri>,
}
impl<T> From<T> for OAuthError
where
T: Into<OAuthErrorType>,
{
fn from(value: T) -> Self {
OAuthError::error(value.into())
}
}
impl OAuthError {
/// The OAuth request failed either because of an invalid redirection URI
/// or before we could validate the one we were given, so return an error
/// directly to the caller
///
/// See: IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1)
pub fn error(error_type: impl Into<OAuthErrorType>) -> Self {
Self {
error_type: error_type.into(),
valid_redirect_uri: None,
state: None,
}
}
/// The OAuth request failed for a reason other than an invalid redirection URI
/// So send the error in url-encoded form to the redirect URI
///
/// See: IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1)
pub fn redirect(
err: impl Into<OAuthErrorType>,
state: &Option<String>,
valid_redirect_uri: &ValidatedRedirectUri,
) -> Self {
Self {
error_type: err.into(),
state: state.clone(),
valid_redirect_uri: Some(valid_redirect_uri.clone()),
}
}
}
impl actix_web::ResponseError for OAuthError {
fn status_code(&self) -> StatusCode {
match self.error_type {
OAuthErrorType::AuthenticationError(_)
| OAuthErrorType::FailedScopeParse(_)
| OAuthErrorType::ScopesTooBroad
| OAuthErrorType::AccessDenied => {
if self.valid_redirect_uri.is_some() {
StatusCode::FOUND
} else {
StatusCode::INTERNAL_SERVER_ERROR
}
}
OAuthErrorType::RedirectUriNotConfigured(_)
| OAuthErrorType::ClientMissingRedirectURI { client_id: _ }
| OAuthErrorType::InvalidAcceptFlowId
| OAuthErrorType::MalformedId(_)
| OAuthErrorType::InvalidClientId(_)
| OAuthErrorType::InvalidAuthCode
| OAuthErrorType::OnlySupportsAuthorizationCodeGrant(_)
| OAuthErrorType::RedirectUriChanged(_)
| OAuthErrorType::UnauthorizedClient => StatusCode::BAD_REQUEST,
OAuthErrorType::ClientAuthenticationFailed => StatusCode::UNAUTHORIZED,
}
}
fn error_response(&self) -> HttpResponse {
if let Some(ValidatedRedirectUri(mut redirect_uri)) = self.valid_redirect_uri.clone() {
redirect_uri = format!(
"{}?error={}&error_description={}",
redirect_uri,
self.error_type.error_name(),
self.error_type,
);
if let Some(state) = self.state.as_ref() {
redirect_uri = format!("{}&state={}", redirect_uri, state);
}
redirect_uri = urlencoding::encode(&redirect_uri).to_string();
HttpResponse::Found()
.append_header(("Location".to_string(), redirect_uri))
.finish()
} else {
HttpResponse::build(self.status_code()).json(ApiError {
error: &self.error_type.error_name(),
description: &self.error_type.to_string(),
})
}
}
}
#[derive(thiserror::Error, Debug)]
pub enum OAuthErrorType {
#[error(transparent)]
AuthenticationError(#[from] AuthenticationError),
#[error("Client {} has no redirect URIs specified", .client_id.0)]
ClientMissingRedirectURI {
client_id: crate::database::models::OAuthClientId,
},
#[error("The provided redirect URI did not match any configured in the client")]
RedirectUriNotConfigured(String),
#[error("The provided scope was malformed or did not correspond to known scopes ({0})")]
FailedScopeParse(bitflags::parser::ParseError),
#[error(
"The provided scope requested scopes broader than the developer app is configured with"
)]
ScopesTooBroad,
#[error("The provided flow id was invalid")]
InvalidAcceptFlowId,
#[error("The provided client id was invalid")]
InvalidClientId(crate::database::models::OAuthClientId),
#[error("The provided ID could not be decoded: {0}")]
MalformedId(#[from] DecodingError),
#[error("Failed to authenticate client")]
ClientAuthenticationFailed,
#[error("The provided authorization grant code was invalid")]
InvalidAuthCode,
#[error("The provided client id did not match the id this authorization code was granted to")]
UnauthorizedClient,
#[error("The provided redirect URI did not exactly match the uri originally provided when this flow began")]
RedirectUriChanged(Option<String>),
#[error("The provided grant type ({0}) must be \"authorization_code\"")]
OnlySupportsAuthorizationCodeGrant(String),
#[error("The resource owner denied the request")]
AccessDenied,
}
impl From<crate::database::models::DatabaseError> for OAuthErrorType {
fn from(value: crate::database::models::DatabaseError) -> Self {
OAuthErrorType::AuthenticationError(value.into())
}
}
impl From<sqlx::Error> for OAuthErrorType {
fn from(value: sqlx::Error) -> Self {
OAuthErrorType::AuthenticationError(value.into())
}
}
impl OAuthErrorType {
pub fn error_name(&self) -> String {
// IETF RFC 6749 4.1.2.1 (https://datatracker.ietf.org/doc/html/rfc6749#autoid-38)
// And 5.2 (https://datatracker.ietf.org/doc/html/rfc6749#section-5.2)
match self {
Self::RedirectUriNotConfigured(_) | Self::ClientMissingRedirectURI { client_id: _ } => {
"invalid_uri"
}
Self::AuthenticationError(_) | Self::InvalidAcceptFlowId => "server_error",
Self::RedirectUriChanged(_) | Self::MalformedId(_) => "invalid_request",
Self::FailedScopeParse(_) | Self::ScopesTooBroad => "invalid_scope",
Self::InvalidClientId(_) | Self::ClientAuthenticationFailed => "invalid_client",
Self::InvalidAuthCode | Self::OnlySupportsAuthorizationCodeGrant(_) => "invalid_grant",
Self::UnauthorizedClient => "unauthorized_client",
Self::AccessDenied => "access_denied",
}
.to_string()
}
}

430
src/auth/oauth/mod.rs Normal file
View File

@@ -0,0 +1,430 @@
use crate::auth::get_user_from_headers;
use crate::auth::oauth::uris::{OAuthRedirectUris, ValidatedRedirectUri};
use crate::auth::validate::extract_authorization_header;
use crate::database::models::flow_item::Flow;
use crate::database::models::oauth_client_authorization_item::OAuthClientAuthorization;
use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient;
use crate::database::models::oauth_token_item::OAuthAccessToken;
use crate::database::models::{
generate_oauth_access_token_id, generate_oauth_client_authorization_id,
OAuthClientAuthorizationId, OAuthClientId,
};
use crate::database::redis::RedisPool;
use crate::models;
use crate::models::pats::Scopes;
use crate::queue::session::AuthQueue;
use actix_web::web::{scope, Data, Query, ServiceConfig};
use actix_web::{get, post, web, HttpRequest, HttpResponse};
use chrono::Duration;
use rand::distributions::Alphanumeric;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;
use reqwest::header::{CACHE_CONTROL, LOCATION, PRAGMA};
use serde::{Deserialize, Serialize};
use sqlx::postgres::PgPool;
use self::errors::{OAuthError, OAuthErrorType};
use super::AuthenticationError;
pub mod errors;
pub mod uris;
pub fn config(cfg: &mut ServiceConfig) {
cfg.service(
scope("auth/oauth")
.service(init_oauth)
.service(accept_client_scopes)
.service(reject_client_scopes)
.service(request_token),
);
}
#[derive(Serialize, Deserialize)]
pub struct OAuthInit {
pub client_id: OAuthClientId,
pub redirect_uri: Option<String>,
pub scope: Option<String>,
pub state: Option<String>,
}
#[derive(Serialize, Deserialize)]
pub struct OAuthClientAccessRequest {
pub flow_id: String,
pub client_id: OAuthClientId,
pub client_name: String,
pub client_icon: Option<String>,
pub requested_scopes: Scopes,
}
#[get("authorize")]
pub async fn init_oauth(
req: HttpRequest,
Query(oauth_info): Query<OAuthInit>,
pool: Data<PgPool>,
redis: Data<RedisPool>,
session_queue: Data<AuthQueue>,
) -> Result<HttpResponse, OAuthError> {
let user = get_user_from_headers(
&req,
&**pool,
&redis,
&session_queue,
Some(&[Scopes::USER_AUTH_WRITE]),
)
.await?
.1;
let client_id = oauth_info.client_id;
let client = DBOAuthClient::get(client_id, &**pool).await?;
if let Some(client) = client {
let redirect_uri = ValidatedRedirectUri::validate(
&oauth_info.redirect_uri,
client.redirect_uris.iter().map(|r| r.uri.as_ref()),
client.id,
)?;
let requested_scopes = oauth_info
.scope
.as_ref()
.map_or(Ok(client.max_scopes), |s| {
Scopes::parse_from_oauth_scopes(s).map_err(|e| {
OAuthError::redirect(
OAuthErrorType::FailedScopeParse(e),
&oauth_info.state,
&redirect_uri,
)
})
})?;
if !client.max_scopes.contains(requested_scopes) {
return Err(OAuthError::redirect(
OAuthErrorType::ScopesTooBroad,
&oauth_info.state,
&redirect_uri,
));
}
let existing_authorization =
OAuthClientAuthorization::get(client.id, user.id.into(), &**pool)
.await
.map_err(|e| OAuthError::redirect(e, &oauth_info.state, &redirect_uri))?;
let redirect_uris =
OAuthRedirectUris::new(oauth_info.redirect_uri.clone(), redirect_uri.clone());
match existing_authorization {
Some(existing_authorization)
if existing_authorization.scopes.contains(requested_scopes) =>
{
init_oauth_code_flow(
user.id.into(),
client.id,
existing_authorization.id,
requested_scopes,
redirect_uris,
oauth_info.state,
&redis,
)
.await
}
_ => {
let flow_id = Flow::InitOAuthAppApproval {
user_id: user.id.into(),
client_id: client.id,
existing_authorization_id: existing_authorization.map(|a| a.id),
scopes: requested_scopes,
redirect_uris,
state: oauth_info.state.clone(),
}
.insert(Duration::minutes(30), &redis)
.await
.map_err(|e| OAuthError::redirect(e, &oauth_info.state, &redirect_uri))?;
let access_request = OAuthClientAccessRequest {
client_id: client.id,
client_name: client.name,
client_icon: client.icon_url,
flow_id,
requested_scopes,
};
Ok(HttpResponse::Ok().json(access_request))
}
}
} else {
Err(OAuthError::error(OAuthErrorType::InvalidClientId(
client_id,
)))
}
}
#[derive(Serialize, Deserialize)]
pub struct RespondToOAuthClientScopes {
pub flow: String,
}
#[post("accept")]
pub async fn accept_client_scopes(
req: HttpRequest,
accept_body: web::Json<RespondToOAuthClientScopes>,
pool: Data<PgPool>,
redis: Data<RedisPool>,
session_queue: Data<AuthQueue>,
) -> Result<HttpResponse, OAuthError> {
accept_or_reject_client_scopes(true, req, accept_body, pool, redis, session_queue).await
}
#[post("reject")]
pub async fn reject_client_scopes(
req: HttpRequest,
body: web::Json<RespondToOAuthClientScopes>,
pool: Data<PgPool>,
redis: Data<RedisPool>,
session_queue: Data<AuthQueue>,
) -> Result<HttpResponse, OAuthError> {
accept_or_reject_client_scopes(false, req, body, pool, redis, session_queue).await
}
#[derive(Serialize, Deserialize)]
pub struct TokenRequest {
pub grant_type: String,
pub code: String,
pub redirect_uri: Option<String>,
pub client_id: models::ids::OAuthClientId,
}
#[derive(Serialize, Deserialize)]
pub struct TokenResponse {
pub access_token: String,
pub token_type: String,
pub expires_in: i64,
}
#[post("token")]
/// Params should be in the urlencoded request body
/// And client secret should be in the HTTP basic authorization header
/// Per IETF RFC6749 Section 4.1.3 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3)
pub async fn request_token(
req: HttpRequest,
req_params: web::Form<TokenRequest>,
pool: Data<PgPool>,
redis: Data<RedisPool>,
) -> Result<HttpResponse, OAuthError> {
let req_client_id = req_params.client_id;
let client = DBOAuthClient::get(req_client_id.into(), &**pool).await?;
if let Some(client) = client {
authenticate_client_token_request(&req, &client)?;
// Ensure auth code is single use
// per IETF RFC6749 Section 10.5 (https://datatracker.ietf.org/doc/html/rfc6749#section-10.5)
let flow = Flow::take_if(
&req_params.code,
|f| matches!(f, Flow::OAuthAuthorizationCodeSupplied { .. }),
&redis,
)
.await?;
if let Some(Flow::OAuthAuthorizationCodeSupplied {
user_id,
client_id,
authorization_id,
scopes,
original_redirect_uri,
}) = flow
{
// https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
if req_client_id != client_id.into() {
return Err(OAuthError::error(OAuthErrorType::UnauthorizedClient));
}
if original_redirect_uri != req_params.redirect_uri {
return Err(OAuthError::error(OAuthErrorType::RedirectUriChanged(
req_params.redirect_uri.clone(),
)));
}
if req_params.grant_type != "authorization_code" {
return Err(OAuthError::error(
OAuthErrorType::OnlySupportsAuthorizationCodeGrant(
req_params.grant_type.clone(),
),
));
}
let scopes = scopes - Scopes::restricted();
let mut transaction = pool.begin().await?;
let token_id = generate_oauth_access_token_id(&mut transaction).await?;
let token = generate_access_token();
let token_hash = OAuthAccessToken::hash_token(&token);
let time_until_expiration = OAuthAccessToken {
id: token_id,
authorization_id,
token_hash,
scopes,
created: Default::default(),
expires: Default::default(),
last_used: None,
client_id,
user_id,
}
.insert(&mut *transaction)
.await?;
transaction.commit().await?;
// IETF RFC6749 Section 5.1 (https://datatracker.ietf.org/doc/html/rfc6749#section-5.1)
Ok(HttpResponse::Ok()
.append_header((CACHE_CONTROL, "no-store"))
.append_header((PRAGMA, "no-cache"))
.json(TokenResponse {
access_token: token,
token_type: "Bearer".to_string(),
expires_in: time_until_expiration.num_seconds(),
}))
} else {
Err(OAuthError::error(OAuthErrorType::InvalidAuthCode))
}
} else {
Err(OAuthError::error(OAuthErrorType::InvalidClientId(
req_client_id.into(),
)))
}
}
pub async fn accept_or_reject_client_scopes(
accept: bool,
req: HttpRequest,
body: web::Json<RespondToOAuthClientScopes>,
pool: Data<PgPool>,
redis: Data<RedisPool>,
session_queue: Data<AuthQueue>,
) -> Result<HttpResponse, OAuthError> {
let current_user = get_user_from_headers(
&req,
&**pool,
&redis,
&session_queue,
Some(&[Scopes::SESSION_ACCESS]),
)
.await?
.1;
let flow = Flow::take_if(
&body.flow,
|f| matches!(f, Flow::InitOAuthAppApproval { .. }),
&redis,
)
.await?;
if let Some(Flow::InitOAuthAppApproval {
user_id,
client_id,
existing_authorization_id,
scopes,
redirect_uris,
state,
}) = flow
{
if current_user.id != user_id.into() {
return Err(OAuthError::error(AuthenticationError::InvalidCredentials));
}
if accept {
let mut transaction = pool.begin().await?;
let auth_id = match existing_authorization_id {
Some(id) => id,
None => generate_oauth_client_authorization_id(&mut transaction).await?,
};
OAuthClientAuthorization::upsert(auth_id, client_id, user_id, scopes, &mut transaction)
.await?;
transaction.commit().await?;
init_oauth_code_flow(
user_id,
client_id,
auth_id,
scopes,
redirect_uris,
state,
&redis,
)
.await
} else {
Err(OAuthError::redirect(
OAuthErrorType::AccessDenied,
&state,
&redirect_uris.validated,
))
}
} else {
Err(OAuthError::error(OAuthErrorType::InvalidAcceptFlowId))
}
}
fn authenticate_client_token_request(
req: &HttpRequest,
client: &DBOAuthClient,
) -> Result<(), OAuthError> {
let client_secret = extract_authorization_header(req)?;
let hashed_client_secret = DBOAuthClient::hash_secret(client_secret);
if client.secret_hash != hashed_client_secret {
Err(OAuthError::error(
OAuthErrorType::ClientAuthenticationFailed,
))
} else {
Ok(())
}
}
fn generate_access_token() -> String {
let random = ChaCha20Rng::from_entropy()
.sample_iter(&Alphanumeric)
.take(60)
.map(char::from)
.collect::<String>();
format!("mro_{}", random)
}
async fn init_oauth_code_flow(
user_id: crate::database::models::UserId,
client_id: OAuthClientId,
authorization_id: OAuthClientAuthorizationId,
scopes: Scopes,
redirect_uris: OAuthRedirectUris,
state: Option<String>,
redis: &RedisPool,
) -> Result<HttpResponse, OAuthError> {
let code = Flow::OAuthAuthorizationCodeSupplied {
user_id,
client_id,
authorization_id,
scopes,
original_redirect_uri: redirect_uris.original.clone(),
}
.insert(Duration::minutes(10), redis)
.await
.map_err(|e| OAuthError::redirect(e, &state, &redirect_uris.validated.clone()))?;
let mut redirect_params = vec![format!("code={code}")];
if let Some(state) = state {
redirect_params.push(format!("state={state}"));
}
let redirect_uri = append_params_to_uri(&redirect_uris.validated.0, &redirect_params);
// IETF RFC 6749 Section 4.1.2 (https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2)
Ok(HttpResponse::Found()
.append_header((LOCATION, redirect_uri))
.finish())
}
fn append_params_to_uri(uri: &str, params: &[impl AsRef<str>]) -> String {
let mut uri = uri.to_string();
let mut connector = if uri.contains('?') { "&" } else { "?" };
for param in params {
uri.push_str(&format!("{}{}", connector, param.as_ref()));
connector = "&";
}
uri
}

94
src/auth/oauth/uris.rs Normal file
View File

@@ -0,0 +1,94 @@
use super::errors::OAuthError;
use crate::auth::oauth::OAuthErrorType;
use crate::database::models::OAuthClientId;
use serde::{Deserialize, Serialize};
#[derive(derive_new::new, Serialize, Deserialize)]
pub struct OAuthRedirectUris {
pub original: Option<String>,
pub validated: ValidatedRedirectUri,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ValidatedRedirectUri(pub String);
impl ValidatedRedirectUri {
pub fn validate<'a>(
to_validate: &Option<String>,
validate_against: impl IntoIterator<Item = &'a str> + Clone,
client_id: OAuthClientId,
) -> Result<Self, OAuthError> {
if let Some(first_client_redirect_uri) = validate_against.clone().into_iter().next() {
if let Some(to_validate) = to_validate {
if validate_against
.into_iter()
.any(|uri| same_uri_except_query_components(uri, to_validate))
{
Ok(ValidatedRedirectUri(to_validate.clone()))
} else {
Err(OAuthError::error(OAuthErrorType::RedirectUriNotConfigured(
to_validate.clone(),
)))
}
} else {
Ok(ValidatedRedirectUri(first_client_redirect_uri.to_string()))
}
} else {
Err(OAuthError::error(
OAuthErrorType::ClientMissingRedirectURI { client_id },
))
}
}
}
fn same_uri_except_query_components(a: &str, b: &str) -> bool {
let mut a_components = a.split('?');
let mut b_components = b.split('?');
a_components.next() == b_components.next()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validate_for_none_returns_first_valid_uri() {
let validate_against = vec!["https://modrinth.com/a"];
let validated =
ValidatedRedirectUri::validate(&None, validate_against.clone(), OAuthClientId(0))
.unwrap();
assert_eq!(validate_against[0], validated.0);
}
#[test]
fn validate_for_valid_uri_returns_first_matching_uri_ignoring_query_params() {
let validate_against = vec![
"https://modrinth.com/a?q3=p3&q4=p4",
"https://modrinth.com/a/b/c?q1=p1&q2=p2",
];
let to_validate = "https://modrinth.com/a/b/c?query0=param0&query1=param1".to_string();
let validated = ValidatedRedirectUri::validate(
&Some(to_validate.clone()),
validate_against,
OAuthClientId(0),
)
.unwrap();
assert_eq!(to_validate, validated.0);
}
#[test]
fn validate_for_invalid_uri_returns_err() {
let validate_against = vec!["https://modrinth.com/a"];
let to_validate = "https://modrinth.com/a/b".to_string();
let validated =
ValidatedRedirectUri::validate(&Some(to_validate), validate_against, OAuthClientId(0));
assert!(validated
.is_err_and(|e| matches!(e.error_type, OAuthErrorType::RedirectUriNotConfigured(_))));
}
}

View File

@@ -91,12 +91,7 @@ where
let token = if let Some(token) = token {
token
} else {
let headers = req.headers();
let token_val: Option<&HeaderValue> = headers.get(AUTHORIZATION);
token_val
.ok_or_else(|| AuthenticationError::InvalidAuthMethod)?
.to_str()
.map_err(|_| AuthenticationError::InvalidCredentials)?
extract_authorization_header(req)?
};
let possible_user = match token.split_once('_') {
@@ -142,6 +137,25 @@ where
user.map(|x| (Scopes::all(), x))
}
Some(("mro", _)) => {
use crate::database::models::oauth_token_item::OAuthAccessToken;
let hash = OAuthAccessToken::hash_token(token);
let access_token =
crate::database::models::oauth_token_item::OAuthAccessToken::get(hash, executor)
.await?
.ok_or(AuthenticationError::InvalidCredentials)?;
if access_token.expires < Utc::now() {
return Err(AuthenticationError::InvalidCredentials);
}
let user = user_item::User::get_id(access_token.user_id, executor, redis).await?;
session_queue.add_oauth_access_token(access_token.id).await;
user.map(|u| (access_token.scopes, u))
}
Some(("github", _)) | Some(("gho", _)) | Some(("ghp", _)) => {
let user = AuthProvider::GitHub.get_user(token).await?;
let id = AuthProvider::GitHub.get_user_id(&user.id, executor).await?;
@@ -160,6 +174,15 @@ where
Ok(possible_user)
}
pub fn extract_authorization_header(req: &HttpRequest) -> Result<&str, AuthenticationError> {
let headers = req.headers();
let token_val: Option<&HeaderValue> = headers.get(AUTHORIZATION);
token_val
.ok_or_else(|| AuthenticationError::InvalidAuthMethod)?
.to_str()
.map_err(|_| AuthenticationError::InvalidCredentials)
}
pub async fn check_is_moderator_from_headers<'a, 'b, E>(
req: &HttpRequest,
executor: E,

View File

@@ -1,7 +1,8 @@
use super::ids::*;
use crate::auth::flows::AuthProvider;
use crate::auth::oauth::uris::OAuthRedirectUris;
use crate::database::models::DatabaseError;
use crate::database::redis::RedisPool;
use crate::{auth::flows::AuthProvider, models::pats::Scopes};
use chrono::Duration;
use rand::distributions::Alphanumeric;
use rand::Rng;
@@ -34,6 +35,21 @@ pub enum Flow {
confirm_email: String,
},
MinecraftAuth,
InitOAuthAppApproval {
user_id: UserId,
client_id: OAuthClientId,
existing_authorization_id: Option<OAuthClientAuthorizationId>,
scopes: Scopes,
redirect_uris: OAuthRedirectUris,
state: Option<String>,
},
OAuthAuthorizationCodeSupplied {
user_id: UserId,
client_id: OAuthClientId,
authorization_id: OAuthClientAuthorizationId,
scopes: Scopes,
original_redirect_uri: Option<String>, // Needed for https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
},
}
impl Flow {
@@ -58,6 +74,22 @@ impl Flow {
redis.get_deserialized_from_json(FLOWS_NAMESPACE, id).await
}
/// Gets the flow and removes it from the cache, but only removes if the flow was present and the predicate returned true
/// The predicate should validate that the flow being removed is the correct one, as a security measure
pub async fn take_if(
id: &str,
predicate: impl FnOnce(&Flow) -> bool,
redis: &RedisPool,
) -> Result<Option<Flow>, DatabaseError> {
let flow = Self::get(id, redis).await?;
if let Some(flow) = flow.as_ref() {
if predicate(flow) {
Self::remove(id, redis).await?;
}
}
Ok(flow)
}
pub async fn remove(id: &str, redis: &RedisPool) -> Result<Option<()>, DatabaseError> {
redis.delete(FLOWS_NAMESPACE, id).await?;
Ok(Some(()))

View File

@@ -152,6 +152,38 @@ generate_ids!(
ImageId
);
generate_ids!(
pub generate_oauth_client_authorization_id,
OAuthClientAuthorizationId,
8,
"SELECT EXISTS(SELECT 1 FROM oauth_client_authorizations WHERE id=$1)",
OAuthClientAuthorizationId
);
generate_ids!(
pub generate_oauth_client_id,
OAuthClientId,
8,
"SELECT EXISTS(SELECT 1 FROM oauth_clients WHERE id=$1)",
OAuthClientId
);
generate_ids!(
pub generate_oauth_redirect_id,
OAuthRedirectUriId,
8,
"SELECT EXISTS(SELECT 1 FROM oauth_client_redirect_uris WHERE id=$1)",
OAuthRedirectUriId
);
generate_ids!(
pub generate_oauth_access_token_id,
OAuthAccessTokenId,
8,
"SELECT EXISTS(SELECT 1 FROM oauth_access_tokens WHERE id=$1)",
OAuthAccessTokenId
);
#[derive(Copy, Clone, Debug, PartialEq, Eq, Type, Hash, Serialize, Deserialize)]
#[sqlx(transparent)]
pub struct UserId(pub i64);
@@ -238,6 +270,22 @@ pub struct SessionId(pub i64);
#[sqlx(transparent)]
pub struct ImageId(pub i64);
#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)]
#[sqlx(transparent)]
pub struct OAuthClientId(pub i64);
#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)]
#[sqlx(transparent)]
pub struct OAuthClientAuthorizationId(pub i64);
#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)]
#[sqlx(transparent)]
pub struct OAuthRedirectUriId(pub i64);
#[derive(Copy, Clone, Debug, Type, Serialize, Deserialize, Eq, PartialEq, Hash)]
#[sqlx(transparent)]
pub struct OAuthAccessTokenId(pub i64);
use crate::models::ids;
impl From<ids::ProjectId> for ProjectId {
@@ -360,3 +408,23 @@ impl From<PatId> for ids::PatId {
ids::PatId(id.0 as u64)
}
}
impl From<OAuthClientId> for ids::OAuthClientId {
fn from(id: OAuthClientId) -> Self {
ids::OAuthClientId(id.0 as u64)
}
}
impl From<ids::OAuthClientId> for OAuthClientId {
fn from(id: ids::OAuthClientId) -> Self {
Self(id.0 as i64)
}
}
impl From<OAuthRedirectUriId> for ids::OAuthRedirectUriId {
fn from(id: OAuthRedirectUriId) -> Self {
ids::OAuthRedirectUriId(id.0 as u64)
}
}
impl From<OAuthClientAuthorizationId> for ids::OAuthClientAuthorizationId {
fn from(id: OAuthClientAuthorizationId) -> Self {
ids::OAuthClientAuthorizationId(id.0 as u64)
}
}

View File

@@ -6,6 +6,9 @@ pub mod flow_item;
pub mod ids;
pub mod image_item;
pub mod notification_item;
pub mod oauth_client_authorization_item;
pub mod oauth_client_item;
pub mod oauth_token_item;
pub mod organization_item;
pub mod pat_item;
pub mod project_item;
@@ -19,6 +22,7 @@ pub mod version_item;
pub use collection_item::Collection;
pub use ids::*;
pub use image_item::Image;
pub use oauth_client_item::OAuthClient;
pub use organization_item::Organization;
pub use project_item::Project;
pub use team_item::Team;

View File

@@ -0,0 +1,126 @@
use chrono::{DateTime, Utc};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use crate::models::pats::Scopes;
use super::{DatabaseError, OAuthClientAuthorizationId, OAuthClientId, UserId};
#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct OAuthClientAuthorization {
pub id: OAuthClientAuthorizationId,
pub client_id: OAuthClientId,
pub user_id: UserId,
pub scopes: Scopes,
pub created: DateTime<Utc>,
}
struct AuthorizationQueryResult {
id: i64,
client_id: i64,
user_id: i64,
scopes: i64,
created: DateTime<Utc>,
}
impl From<AuthorizationQueryResult> for OAuthClientAuthorization {
fn from(value: AuthorizationQueryResult) -> Self {
OAuthClientAuthorization {
id: OAuthClientAuthorizationId(value.id),
client_id: OAuthClientId(value.client_id),
user_id: UserId(value.user_id),
scopes: Scopes::from_postgres(value.scopes),
created: value.created,
}
}
}
impl OAuthClientAuthorization {
pub async fn get(
client_id: OAuthClientId,
user_id: UserId,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<Option<OAuthClientAuthorization>, DatabaseError> {
let value = sqlx::query_as!(
AuthorizationQueryResult,
"
SELECT id, client_id, user_id, scopes, created
FROM oauth_client_authorizations
WHERE client_id=$1 AND user_id=$2
",
client_id.0,
user_id.0,
)
.fetch_optional(exec)
.await?;
Ok(value.map(|r| r.into()))
}
pub async fn get_all_for_user(
user_id: UserId,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<Vec<OAuthClientAuthorization>, DatabaseError> {
let results = sqlx::query_as!(
AuthorizationQueryResult,
"
SELECT id, client_id, user_id, scopes, created
FROM oauth_client_authorizations
WHERE user_id=$1
",
user_id.0
)
.fetch_all(exec)
.await?;
Ok(results.into_iter().map(|r| r.into()).collect_vec())
}
pub async fn upsert(
id: OAuthClientAuthorizationId,
client_id: OAuthClientId,
user_id: UserId,
scopes: Scopes,
transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
) -> Result<(), DatabaseError> {
sqlx::query!(
"
INSERT INTO oauth_client_authorizations (
id, client_id, user_id, scopes
)
VALUES (
$1, $2, $3, $4
)
ON CONFLICT (id)
DO UPDATE SET scopes = EXCLUDED.scopes
",
id.0,
client_id.0,
user_id.0,
scopes.bits() as i64,
)
.execute(&mut **transaction)
.await?;
Ok(())
}
pub async fn remove(
client_id: OAuthClientId,
user_id: UserId,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<(), DatabaseError> {
sqlx::query!(
"
DELETE FROM oauth_client_authorizations
WHERE client_id=$1 AND user_id=$2
",
client_id.0,
user_id.0
)
.execute(exec)
.await?;
Ok(())
}
}

View File

@@ -0,0 +1,245 @@
use chrono::{DateTime, Utc};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use sha2::Digest;
use super::{DatabaseError, OAuthClientId, OAuthRedirectUriId, UserId};
use crate::models::pats::Scopes;
#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct OAuthRedirectUri {
pub id: OAuthRedirectUriId,
pub client_id: OAuthClientId,
pub uri: String,
}
#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct OAuthClient {
pub id: OAuthClientId,
pub name: String,
pub icon_url: Option<String>,
pub max_scopes: Scopes,
pub secret_hash: String,
pub redirect_uris: Vec<OAuthRedirectUri>,
pub created: DateTime<Utc>,
pub created_by: UserId,
}
struct ClientQueryResult {
id: i64,
name: String,
icon_url: Option<String>,
max_scopes: i64,
secret_hash: String,
created: DateTime<Utc>,
created_by: i64,
uri_ids: Option<Vec<i64>>,
uri_vals: Option<Vec<String>>,
}
macro_rules! select_clients_with_predicate {
($predicate:tt, $param:ident) => {
// The columns in this query have nullability type hints, because for some reason
// the combination of the JOIN and filter using ANY makes sqlx think all columns are nullable
// https://docs.rs/sqlx/latest/sqlx/macro.query.html#force-nullable
sqlx::query_as!(
ClientQueryResult,
r#"
SELECT
clients.id as "id!",
clients.name as "name!",
clients.icon_url as "icon_url?",
clients.max_scopes as "max_scopes!",
clients.secret_hash as "secret_hash!",
clients.created as "created!",
clients.created_by as "created_by!",
uris.uri_ids as "uri_ids?",
uris.uri_vals as "uri_vals?"
FROM oauth_clients clients
LEFT JOIN (
SELECT client_id, array_agg(id) as uri_ids, array_agg(uri) as uri_vals
FROM oauth_client_redirect_uris
GROUP BY client_id
) uris ON clients.id = uris.client_id
"#
+ $predicate,
$param
)
};
}
impl OAuthClient {
pub async fn get(
id: OAuthClientId,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<Option<OAuthClient>, DatabaseError> {
Ok(Self::get_many(&[id], exec).await?.into_iter().next())
}
pub async fn get_many(
ids: &[OAuthClientId],
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<Vec<OAuthClient>, DatabaseError> {
let ids = ids.iter().map(|id| id.0).collect_vec();
let ids_ref: &[i64] = &ids;
let results =
select_clients_with_predicate!("WHERE clients.id = ANY($1::bigint[])", ids_ref)
.fetch_all(exec)
.await?;
Ok(results.into_iter().map(|r| r.into()).collect_vec())
}
pub async fn get_all_user_clients(
user_id: UserId,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<Vec<OAuthClient>, DatabaseError> {
let user_id_param = user_id.0;
let clients = select_clients_with_predicate!("WHERE created_by = $1", user_id_param)
.fetch_all(exec)
.await?;
Ok(clients.into_iter().map(|r| r.into()).collect())
}
pub async fn remove(
id: OAuthClientId,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<(), DatabaseError> {
// Cascades to oauth_client_redirect_uris, oauth_client_authorizations
sqlx::query!(
"
DELETE FROM oauth_clients
WHERE id = $1
",
id.0
)
.execute(exec)
.await?;
Ok(())
}
pub async fn insert(
&self,
transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
) -> Result<(), DatabaseError> {
sqlx::query!(
"
INSERT INTO oauth_clients (
id, name, icon_url, max_scopes, secret_hash, created_by
)
VALUES (
$1, $2, $3, $4, $5, $6
)
",
self.id.0,
self.name,
self.icon_url,
self.max_scopes.to_postgres(),
self.secret_hash,
self.created_by.0
)
.execute(&mut **transaction)
.await?;
Self::insert_redirect_uris(&self.redirect_uris, &mut **transaction).await?;
Ok(())
}
pub async fn update_editable_fields(
&self,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<(), DatabaseError> {
sqlx::query!(
"
UPDATE oauth_clients
SET name = $1, icon_url = $2, max_scopes = $3
WHERE (id = $4)
",
self.name,
self.icon_url,
self.max_scopes.to_postgres(),
self.id.0,
)
.execute(exec)
.await?;
Ok(())
}
pub async fn remove_redirect_uris(
ids: impl IntoIterator<Item = OAuthRedirectUriId>,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<(), DatabaseError> {
let ids = ids.into_iter().map(|id| id.0).collect_vec();
sqlx::query!(
"
DELETE FROM oauth_client_redirect_uris
WHERE id IN
(SELECT * FROM UNNEST($1::bigint[]))
",
&ids[..]
)
.execute(exec)
.await?;
Ok(())
}
pub async fn insert_redirect_uris(
uris: &[OAuthRedirectUri],
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<(), DatabaseError> {
let (ids, client_ids, uris): (Vec<_>, Vec<_>, Vec<_>) = uris
.iter()
.map(|r| (r.id.0, r.client_id.0, r.uri.clone()))
.multiunzip();
sqlx::query!(
"
INSERT INTO oauth_client_redirect_uris (id, client_id, uri)
SELECT * FROM UNNEST($1::bigint[], $2::bigint[], $3::varchar[])
",
&ids[..],
&client_ids[..],
&uris[..],
)
.execute(exec)
.await?;
Ok(())
}
pub fn hash_secret(secret: &str) -> String {
format!("{:x}", sha2::Sha512::digest(secret.as_bytes()))
}
}
impl From<ClientQueryResult> for OAuthClient {
fn from(r: ClientQueryResult) -> Self {
let redirects = if let (Some(ids), Some(uris)) = (r.uri_ids.as_ref(), r.uri_vals.as_ref()) {
ids.iter()
.zip(uris.iter())
.map(|(id, uri)| OAuthRedirectUri {
id: OAuthRedirectUriId(*id),
client_id: OAuthClientId(r.id),
uri: uri.to_string(),
})
.collect()
} else {
vec![]
};
OAuthClient {
id: OAuthClientId(r.id),
name: r.name,
icon_url: r.icon_url,
max_scopes: Scopes::from_postgres(r.max_scopes),
secret_hash: r.secret_hash,
redirect_uris: redirects,
created: r.created,
created_by: UserId(r.created_by),
}
}
}

View File

@@ -0,0 +1,95 @@
use super::{DatabaseError, OAuthAccessTokenId, OAuthClientAuthorizationId, OAuthClientId, UserId};
use crate::models::pats::Scopes;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sha2::Digest;
#[derive(Deserialize, Serialize, Clone, Debug)]
pub struct OAuthAccessToken {
pub id: OAuthAccessTokenId,
pub authorization_id: OAuthClientAuthorizationId,
pub token_hash: String,
pub scopes: Scopes,
pub created: DateTime<Utc>,
pub expires: DateTime<Utc>,
pub last_used: Option<DateTime<Utc>>,
// Stored separately inside oauth_client_authorizations table
pub client_id: OAuthClientId,
pub user_id: UserId,
}
impl OAuthAccessToken {
pub async fn get(
token_hash: String,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<Option<OAuthAccessToken>, DatabaseError> {
let value = sqlx::query!(
"
SELECT
tokens.id,
tokens.authorization_id,
tokens.token_hash,
tokens.scopes,
tokens.created,
tokens.expires,
tokens.last_used,
auths.client_id,
auths.user_id
FROM oauth_access_tokens tokens
JOIN oauth_client_authorizations auths
ON tokens.authorization_id = auths.id
WHERE tokens.token_hash = $1
",
token_hash
)
.fetch_optional(exec)
.await?;
Ok(value.map(|r| OAuthAccessToken {
id: OAuthAccessTokenId(r.id),
authorization_id: OAuthClientAuthorizationId(r.authorization_id),
token_hash: r.token_hash,
scopes: Scopes::from_postgres(r.scopes),
created: r.created,
expires: r.expires,
last_used: r.last_used,
client_id: OAuthClientId(r.client_id),
user_id: UserId(r.user_id),
}))
}
/// Inserts and returns the time until the token expires
pub async fn insert(
&self,
exec: impl sqlx::Executor<'_, Database = sqlx::Postgres>,
) -> Result<chrono::Duration, DatabaseError> {
let r = sqlx::query!(
"
INSERT INTO oauth_access_tokens (
id, authorization_id, token_hash, scopes, last_used
)
VALUES (
$1, $2, $3, $4, $5
)
RETURNING created, expires
",
self.id.0,
self.authorization_id.0,
self.token_hash,
self.scopes.to_postgres(),
Option::<DateTime<Utc>>::None
)
.fetch_one(exec)
.await?;
let (created, expires) = (r.created, r.expires);
let time_until_expiration = expires - created;
Ok(time_until_expiration)
}
pub fn hash_token(token: &str) -> String {
format!("{:x}", sha2::Sha512::digest(token.as_bytes()))
}
}

View File

@@ -3,6 +3,8 @@ use thiserror::Error;
pub use super::collections::CollectionId;
pub use super::images::ImageId;
pub use super::notifications::NotificationId;
pub use super::oauth_clients::OAuthClientAuthorizationId;
pub use super::oauth_clients::{OAuthClientId, OAuthRedirectUriId};
pub use super::organizations::OrganizationId;
pub use super::pats::PatId;
pub use super::projects::{ProjectId, VersionId};
@@ -122,6 +124,9 @@ base62_id_impl!(ThreadMessageId, ThreadMessageId);
base62_id_impl!(SessionId, SessionId);
base62_id_impl!(PatId, PatId);
base62_id_impl!(ImageId, ImageId);
base62_id_impl!(OAuthClientId, OAuthClientId);
base62_id_impl!(OAuthRedirectUriId, OAuthRedirectUriId);
base62_id_impl!(OAuthClientAuthorizationId, OAuthClientAuthorizationId);
pub mod base62_impl {
use serde::de::{self, Deserializer, Visitor};

View File

@@ -4,6 +4,7 @@ pub mod error;
pub mod ids;
pub mod images;
pub mod notifications;
pub mod oauth_clients;
pub mod organizations;
pub mod pack;
pub mod pats;

110
src/models/oauth_clients.rs Normal file
View File

@@ -0,0 +1,110 @@
use super::{
ids::{Base62Id, UserId},
pats::Scopes,
};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::database::models::oauth_client_authorization_item::OAuthClientAuthorization as DBOAuthClientAuthorization;
use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient;
use crate::database::models::oauth_client_item::OAuthRedirectUri as DBOAuthRedirectUri;
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(from = "Base62Id")]
#[serde(into = "Base62Id")]
pub struct OAuthClientId(pub u64);
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(from = "Base62Id")]
#[serde(into = "Base62Id")]
pub struct OAuthClientAuthorizationId(pub u64);
#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(from = "Base62Id")]
#[serde(into = "Base62Id")]
pub struct OAuthRedirectUriId(pub u64);
#[derive(Deserialize, Serialize)]
pub struct OAuthRedirectUri {
pub id: OAuthRedirectUriId,
pub client_id: OAuthClientId,
pub uri: String,
}
#[derive(Serialize, Deserialize)]
pub struct OAuthClientCreationResult {
#[serde(flatten)]
pub client: OAuthClient,
pub client_secret: String,
}
#[derive(Deserialize, Serialize)]
pub struct OAuthClient {
pub id: OAuthClientId,
pub name: String,
pub icon_url: Option<String>,
// The maximum scopes the client can request for OAuth
pub max_scopes: Scopes,
// The valid URIs that can be redirected to during an authorization request
pub redirect_uris: Vec<OAuthRedirectUri>,
// The user that created (and thus controls) this client
pub created_by: UserId,
}
#[derive(Deserialize, Serialize)]
pub struct OAuthClientAuthorization {
pub id: OAuthClientAuthorizationId,
pub app_id: OAuthClientId,
pub user_id: UserId,
pub scopes: Scopes,
pub created: DateTime<Utc>,
}
#[derive(Deserialize, Serialize)]
pub struct GetOAuthClientsRequest {
pub ids: Vec<OAuthClientId>,
}
#[derive(Deserialize, Serialize)]
pub struct DeleteOAuthClientQueryParam {
pub client_id: OAuthClientId,
}
impl From<DBOAuthClient> for OAuthClient {
fn from(value: DBOAuthClient) -> Self {
Self {
id: value.id.into(),
name: value.name,
icon_url: value.icon_url,
max_scopes: value.max_scopes,
redirect_uris: value.redirect_uris.into_iter().map(|r| r.into()).collect(),
created_by: value.created_by.into(),
}
}
}
impl From<DBOAuthRedirectUri> for OAuthRedirectUri {
fn from(value: DBOAuthRedirectUri) -> Self {
Self {
id: value.id.into(),
client_id: value.client_id.into(),
uri: value.uri,
}
}
}
impl From<DBOAuthClientAuthorization> for OAuthClientAuthorization {
fn from(value: DBOAuthClientAuthorization) -> Self {
Self {
id: value.id.into(),
app_id: value.client_id.into(),
user_id: value.user_id.into(),
scopes: value.scopes,
created: value.created,
}
}
}

View File

@@ -103,6 +103,9 @@ bitflags::bitflags! {
// delete an organization
const ORGANIZATION_DELETE = 1 << 38;
// only accessible by modrinth-issued sessions
const SESSION_ACCESS = 1 << 39;
const NONE = 0b0;
}
}
@@ -118,6 +121,7 @@ impl Scopes {
| Scopes::PAT_DELETE
| Scopes::SESSION_READ
| Scopes::SESSION_DELETE
| Scopes::SESSION_ACCESS
| Scopes::USER_AUTH_WRITE
| Scopes::USER_DELETE
| Scopes::PERFORM_ANALYTICS
@@ -126,6 +130,19 @@ impl Scopes {
pub fn is_restricted(&self) -> bool {
self.intersects(Self::restricted())
}
pub fn parse_from_oauth_scopes(scopes: &str) -> Result<Scopes, bitflags::parser::ParseError> {
let scopes = scopes.replace(' ', "|").replace("%20", "|");
bitflags::parser::from_str(&scopes)
}
pub fn to_postgres(&self) -> i64 {
self.bits() as i64
}
pub fn from_postgres(value: i64) -> Self {
Self::from_bits(value as u64).unwrap_or(Scopes::NONE)
}
}
#[derive(Serialize, Deserialize)]
@@ -161,3 +178,64 @@ impl PersonalAccessToken {
}
}
}
#[cfg(test)]
mod test {
use super::*;
use itertools::Itertools;
#[test]
fn test_parse_from_oauth_scopes_well_formed() {
let raw = "USER_READ_EMAIL SESSION_READ ORGANIZATION_CREATE";
let expected = Scopes::USER_READ_EMAIL | Scopes::SESSION_READ | Scopes::ORGANIZATION_CREATE;
let parsed = Scopes::parse_from_oauth_scopes(raw).unwrap();
assert_same_flags(expected, parsed);
}
#[test]
fn test_parse_from_oauth_scopes_empty() {
let raw = "";
let expected = Scopes::empty();
let parsed = Scopes::parse_from_oauth_scopes(raw).unwrap();
assert_same_flags(expected, parsed);
}
#[test]
fn test_parse_from_oauth_scopes_invalid_scopes() {
let raw = "notascope";
let parsed = Scopes::parse_from_oauth_scopes(raw);
assert!(parsed.is_err());
}
#[test]
fn test_parse_from_oauth_scopes_invalid_separator() {
let raw = "USER_READ_EMAIL & SESSION_READ";
let parsed = Scopes::parse_from_oauth_scopes(raw);
assert!(parsed.is_err());
}
#[test]
fn test_parse_from_oauth_scopes_url_encoded() {
let raw = urlencoding::encode("PAT_WRITE COLLECTION_DELETE").to_string();
let expected = Scopes::PAT_WRITE | Scopes::COLLECTION_DELETE;
let parsed = Scopes::parse_from_oauth_scopes(&raw).unwrap();
assert_same_flags(expected, parsed);
}
fn assert_same_flags(expected: Scopes, actual: Scopes) {
assert_eq!(
expected.iter_names().map(|(name, _)| name).collect_vec(),
actual.iter_names().map(|(name, _)| name).collect_vec()
);
}
}

View File

@@ -1,9 +1,10 @@
use crate::auth::session::SessionMetadata;
use crate::database::models::pat_item::PersonalAccessToken;
use crate::database::models::session_item::Session;
use crate::database::models::{DatabaseError, PatId, SessionId, UserId};
use crate::database::models::{DatabaseError, OAuthAccessTokenId, PatId, SessionId, UserId};
use crate::database::redis::RedisPool;
use chrono::Utc;
use itertools::Itertools;
use sqlx::PgPool;
use std::collections::{HashMap, HashSet};
use tokio::sync::Mutex;
@@ -11,6 +12,7 @@ use tokio::sync::Mutex;
pub struct AuthQueue {
session_queue: Mutex<HashMap<SessionId, SessionMetadata>>,
pat_queue: Mutex<HashSet<PatId>>,
oauth_access_token_queue: Mutex<HashSet<OAuthAccessTokenId>>,
}
impl Default for AuthQueue {
@@ -25,6 +27,7 @@ impl AuthQueue {
AuthQueue {
session_queue: Mutex::new(HashMap::with_capacity(1000)),
pat_queue: Mutex::new(HashSet::with_capacity(1000)),
oauth_access_token_queue: Mutex::new(HashSet::with_capacity(1000)),
}
}
pub async fn add_session(&self, id: SessionId, metadata: SessionMetadata) {
@@ -35,6 +38,10 @@ impl AuthQueue {
self.pat_queue.lock().await.insert(id);
}
pub async fn add_oauth_access_token(&self, id: crate::database::models::OAuthAccessTokenId) {
self.oauth_access_token_queue.lock().await.insert(id);
}
pub async fn take_sessions(&self) -> HashMap<SessionId, SessionMetadata> {
let mut queue = self.session_queue.lock().await;
let len = queue.len();
@@ -42,8 +49,8 @@ impl AuthQueue {
std::mem::replace(&mut queue, HashMap::with_capacity(len))
}
pub async fn take_pats(&self) -> HashSet<PatId> {
let mut queue = self.pat_queue.lock().await;
pub async fn take_hashset<T>(queue: &Mutex<HashSet<T>>) -> HashSet<T> {
let mut queue = queue.lock().await;
let len = queue.len();
std::mem::replace(&mut queue, HashSet::with_capacity(len))
@@ -51,9 +58,13 @@ impl AuthQueue {
pub async fn index(&self, pool: &PgPool, redis: &RedisPool) -> Result<(), DatabaseError> {
let session_queue = self.take_sessions().await;
let pat_queue = self.take_pats().await;
let pat_queue = Self::take_hashset(&self.pat_queue).await;
let oauth_access_token_queue = Self::take_hashset(&self.oauth_access_token_queue).await;
if !session_queue.is_empty() || !pat_queue.is_empty() {
if !session_queue.is_empty()
|| !pat_queue.is_empty()
|| !oauth_access_token_queue.is_empty()
{
let mut transaction = pool.begin().await?;
let mut clear_cache_sessions = Vec::new();
@@ -102,29 +113,51 @@ impl AuthQueue {
Session::clear_cache(clear_cache_sessions, redis).await?;
let mut clear_cache_pats = Vec::new();
for id in pat_queue {
clear_cache_pats.push((Some(id), None, None));
sqlx::query!(
"
UPDATE pats
SET last_used = $2
WHERE (id = $1)
",
id as PatId,
Utc::now(),
)
.execute(&mut *transaction)
.await?;
}
let ids = pat_queue.iter().map(|id| id.0).collect_vec();
let clear_cache_pats = pat_queue
.into_iter()
.map(|id| (Some(id), None, None))
.collect_vec();
sqlx::query!(
"
UPDATE pats
SET last_used = $2
WHERE id IN
(SELECT * FROM UNNEST($1::bigint[]))
",
&ids[..],
Utc::now(),
)
.execute(&mut *transaction)
.await?;
PersonalAccessToken::clear_cache(clear_cache_pats, redis).await?;
update_oauth_access_token_last_used(oauth_access_token_queue, &mut transaction).await?;
transaction.commit().await?;
}
Ok(())
}
}
async fn update_oauth_access_token_last_used(
oauth_access_token_queue: HashSet<OAuthAccessTokenId>,
transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
) -> Result<(), DatabaseError> {
let ids = oauth_access_token_queue.iter().map(|id| id.0).collect_vec();
sqlx::query!(
"
UPDATE oauth_access_tokens
SET last_used = $2
WHERE id IN
(SELECT * FROM UNNEST($1::bigint[]))
",
&ids[..],
Utc::now()
)
.execute(&mut **transaction)
.await?;
Ok(())
}

View File

@@ -1,13 +1,17 @@
pub use super::ApiError;
use crate::util::cors::default_cors;
use crate::{auth::oauth, util::cors::default_cors};
use actix_web::{web, HttpResponse};
use serde_json::json;
pub mod oauth_clients;
pub fn config(cfg: &mut web::ServiceConfig) {
cfg.service(
web::scope("v3")
.wrap(default_cors())
.route("", web::get().to(hello_world)),
.route("", web::get().to(hello_world))
.configure(oauth::config)
.configure(oauth_clients::config),
);
}

View File

@@ -0,0 +1,444 @@
use std::{collections::HashSet, fmt::Display};
use actix_web::{
delete, get, patch, post,
web::{self, scope},
HttpRequest, HttpResponse,
};
use chrono::Utc;
use itertools::Itertools;
use rand::{distributions::Alphanumeric, Rng, SeedableRng};
use rand_chacha::ChaCha20Rng;
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use validator::Validate;
use super::ApiError;
use crate::{
auth::checks::ValidateAllAuthorized, models::oauth_clients::DeleteOAuthClientQueryParam,
};
use crate::{
auth::{checks::ValidateAuthorized, get_user_from_headers},
database::{
models::{
generate_oauth_client_id, generate_oauth_redirect_id,
oauth_client_authorization_item::OAuthClientAuthorization,
oauth_client_item::{OAuthClient, OAuthRedirectUri},
DatabaseError, OAuthClientId, User,
},
redis::RedisPool,
},
models::{
self,
oauth_clients::{GetOAuthClientsRequest, OAuthClientCreationResult},
pats::Scopes,
},
queue::session::AuthQueue,
routes::v2::project_creation::CreateError,
util::validate::validation_errors_to_string,
};
use crate::database::models::oauth_client_item::OAuthClient as DBOAuthClient;
use crate::models::ids::OAuthClientId as ApiOAuthClientId;
pub fn config(cfg: &mut web::ServiceConfig) {
cfg.service(get_user_clients);
cfg.service(
scope("oauth")
.service(oauth_client_create)
.service(oauth_client_edit)
.service(oauth_client_delete)
.service(get_client)
.service(get_clients)
.service(get_user_oauth_authorizations)
.service(revoke_oauth_authorization),
);
}
#[get("user/{user_id}/oauth_apps")]
pub async fn get_user_clients(
req: HttpRequest,
info: web::Path<String>,
pool: web::Data<PgPool>,
redis: web::Data<RedisPool>,
session_queue: web::Data<AuthQueue>,
) -> Result<HttpResponse, ApiError> {
let current_user = get_user_from_headers(
&req,
&**pool,
&redis,
&session_queue,
Some(&[Scopes::SESSION_ACCESS]),
)
.await?
.1;
let target_user = User::get(&info.into_inner(), &**pool, &redis).await?;
if let Some(target_user) = target_user {
let clients = OAuthClient::get_all_user_clients(target_user.id, &**pool).await?;
clients
.iter()
.validate_all_authorized(Some(&current_user))?;
let response = clients
.into_iter()
.map(models::oauth_clients::OAuthClient::from)
.collect_vec();
Ok(HttpResponse::Ok().json(response))
} else {
Ok(HttpResponse::NotFound().body(""))
}
}
#[get("app/{id}")]
pub async fn get_client(
req: HttpRequest,
id: web::Path<ApiOAuthClientId>,
pool: web::Data<PgPool>,
redis: web::Data<RedisPool>,
session_queue: web::Data<AuthQueue>,
) -> Result<HttpResponse, ApiError> {
let clients = get_clients_inner(&[id.into_inner()], req, pool, redis, session_queue).await?;
if let Some(client) = clients.into_iter().next() {
Ok(HttpResponse::Ok().json(client))
} else {
Ok(HttpResponse::NotFound().body(""))
}
}
#[get("apps")]
pub async fn get_clients(
req: HttpRequest,
info: web::Json<GetOAuthClientsRequest>,
pool: web::Data<PgPool>,
redis: web::Data<RedisPool>,
session_queue: web::Data<AuthQueue>,
) -> Result<HttpResponse, ApiError> {
let clients =
get_clients_inner(&info.into_inner().ids, req, pool, redis, session_queue).await?;
Ok(HttpResponse::Ok().json(clients))
}
#[derive(Deserialize, Validate)]
pub struct NewOAuthApp {
#[validate(
custom(function = "crate::util::validate::validate_name"),
length(min = 3, max = 255)
)]
pub name: String,
#[validate(
custom(function = "crate::util::validate::validate_url"),
length(max = 255)
)]
pub icon_url: Option<String>,
#[validate(custom(function = "crate::util::validate::validate_no_restricted_scopes"))]
pub max_scopes: Scopes,
pub redirect_uris: Vec<String>,
}
#[post("app")]
pub async fn oauth_client_create<'a>(
req: HttpRequest,
new_oauth_app: web::Json<NewOAuthApp>,
pool: web::Data<PgPool>,
redis: web::Data<RedisPool>,
session_queue: web::Data<AuthQueue>,
) -> Result<HttpResponse, CreateError> {
let current_user = get_user_from_headers(
&req,
&**pool,
&redis,
&session_queue,
Some(&[Scopes::SESSION_ACCESS]),
)
.await?
.1;
new_oauth_app
.validate()
.map_err(|e| CreateError::ValidationError(validation_errors_to_string(e, None)))?;
let mut transaction = pool.begin().await?;
let client_id = generate_oauth_client_id(&mut transaction).await?;
let client_secret = generate_oauth_client_secret();
let client_secret_hash = DBOAuthClient::hash_secret(&client_secret);
let redirect_uris =
create_redirect_uris(&new_oauth_app.redirect_uris, client_id, &mut transaction).await?;
let client = OAuthClient {
id: client_id,
icon_url: new_oauth_app.icon_url.clone(),
max_scopes: new_oauth_app.max_scopes,
name: new_oauth_app.name.clone(),
redirect_uris,
created: Utc::now(),
created_by: current_user.id.into(),
secret_hash: client_secret_hash,
};
client.clone().insert(&mut transaction).await?;
transaction.commit().await?;
let client = models::oauth_clients::OAuthClient::from(client);
Ok(HttpResponse::Ok().json(OAuthClientCreationResult {
client,
client_secret,
}))
}
#[delete("app/{id}")]
pub async fn oauth_client_delete<'a>(
req: HttpRequest,
client_id: web::Path<ApiOAuthClientId>,
pool: web::Data<PgPool>,
redis: web::Data<RedisPool>,
session_queue: web::Data<AuthQueue>,
) -> Result<HttpResponse, ApiError> {
let current_user = get_user_from_headers(
&req,
&**pool,
&redis,
&session_queue,
Some(&[Scopes::SESSION_ACCESS]),
)
.await?
.1;
let client = OAuthClient::get(client_id.into_inner().into(), &**pool).await?;
if let Some(client) = client {
client.validate_authorized(Some(&current_user))?;
OAuthClient::remove(client.id, &**pool).await?;
Ok(HttpResponse::NoContent().body(""))
} else {
Ok(HttpResponse::NotFound().body(""))
}
}
#[derive(Serialize, Deserialize, Validate)]
pub struct OAuthClientEdit {
#[validate(
custom(function = "crate::util::validate::validate_name"),
length(min = 3, max = 255)
)]
pub name: Option<String>,
#[validate(
custom(function = "crate::util::validate::validate_url"),
length(max = 255)
)]
pub icon_url: Option<Option<String>>,
pub max_scopes: Option<Scopes>,
#[validate(length(min = 1))]
pub redirect_uris: Option<Vec<String>>,
}
#[patch("app/{id}")]
pub async fn oauth_client_edit(
req: HttpRequest,
client_id: web::Path<ApiOAuthClientId>,
client_updates: web::Json<OAuthClientEdit>,
pool: web::Data<PgPool>,
redis: web::Data<RedisPool>,
session_queue: web::Data<AuthQueue>,
) -> Result<HttpResponse, ApiError> {
let current_user = get_user_from_headers(
&req,
&**pool,
&redis,
&session_queue,
Some(&[Scopes::SESSION_ACCESS]),
)
.await?
.1;
client_updates
.validate()
.map_err(|e| ApiError::Validation(validation_errors_to_string(e, None)))?;
if client_updates.icon_url.is_none()
&& client_updates.name.is_none()
&& client_updates.max_scopes.is_none()
{
return Err(ApiError::InvalidInput("No changes provided".to_string()));
}
if let Some(existing_client) = OAuthClient::get(client_id.into_inner().into(), &**pool).await? {
existing_client.validate_authorized(Some(&current_user))?;
let mut updated_client = existing_client.clone();
let OAuthClientEdit {
name,
icon_url,
max_scopes,
redirect_uris,
} = client_updates.into_inner();
if let Some(name) = name {
updated_client.name = name;
}
if let Some(icon_url) = icon_url {
updated_client.icon_url = icon_url;
}
if let Some(max_scopes) = max_scopes {
updated_client.max_scopes = max_scopes;
}
let mut transaction = pool.begin().await?;
updated_client
.update_editable_fields(&mut *transaction)
.await?;
if let Some(redirects) = redirect_uris {
edit_redirects(redirects, &existing_client, &mut transaction).await?;
}
transaction.commit().await?;
Ok(HttpResponse::Ok().body(""))
} else {
Ok(HttpResponse::NotFound().body(""))
}
}
#[get("authorizations")]
pub async fn get_user_oauth_authorizations(
req: HttpRequest,
pool: web::Data<PgPool>,
redis: web::Data<RedisPool>,
session_queue: web::Data<AuthQueue>,
) -> Result<HttpResponse, ApiError> {
let current_user = get_user_from_headers(
&req,
&**pool,
&redis,
&session_queue,
Some(&[Scopes::SESSION_ACCESS]),
)
.await?
.1;
let authorizations =
OAuthClientAuthorization::get_all_for_user(current_user.id.into(), &**pool).await?;
let mapped: Vec<models::oauth_clients::OAuthClientAuthorization> =
authorizations.into_iter().map(|a| a.into()).collect_vec();
Ok(HttpResponse::Ok().json(mapped))
}
#[delete("authorizations")]
pub async fn revoke_oauth_authorization(
req: HttpRequest,
info: web::Query<DeleteOAuthClientQueryParam>,
pool: web::Data<PgPool>,
redis: web::Data<RedisPool>,
session_queue: web::Data<AuthQueue>,
) -> Result<HttpResponse, ApiError> {
let current_user = get_user_from_headers(
&req,
&**pool,
&redis,
&session_queue,
Some(&[Scopes::SESSION_ACCESS]),
)
.await?
.1;
OAuthClientAuthorization::remove(info.client_id.into(), current_user.id.into(), &**pool)
.await?;
Ok(HttpResponse::Ok().body(""))
}
fn generate_oauth_client_secret() -> String {
ChaCha20Rng::from_entropy()
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
.collect::<String>()
}
async fn create_redirect_uris(
uri_strings: impl IntoIterator<Item = impl Display>,
client_id: OAuthClientId,
transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
) -> Result<Vec<OAuthRedirectUri>, DatabaseError> {
let mut redirect_uris = vec![];
for uri in uri_strings.into_iter() {
let id = generate_oauth_redirect_id(transaction).await?;
redirect_uris.push(OAuthRedirectUri {
id,
client_id,
uri: uri.to_string(),
});
}
Ok(redirect_uris)
}
async fn edit_redirects(
redirects: Vec<String>,
existing_client: &OAuthClient,
transaction: &mut sqlx::Transaction<'_, sqlx::Postgres>,
) -> Result<(), DatabaseError> {
let updated_redirects: HashSet<String> = redirects.into_iter().collect();
let original_redirects: HashSet<String> = existing_client
.redirect_uris
.iter()
.map(|r| r.uri.to_string())
.collect();
let redirects_to_add = create_redirect_uris(
updated_redirects.difference(&original_redirects),
existing_client.id,
&mut *transaction,
)
.await?;
OAuthClient::insert_redirect_uris(&redirects_to_add, &mut **transaction).await?;
let mut redirects_to_remove = existing_client.redirect_uris.clone();
redirects_to_remove.retain(|r| !updated_redirects.contains(&r.uri));
OAuthClient::remove_redirect_uris(redirects_to_remove.iter().map(|r| r.id), &mut **transaction)
.await?;
Ok(())
}
pub async fn get_clients_inner(
ids: &[ApiOAuthClientId],
req: HttpRequest,
pool: web::Data<PgPool>,
redis: web::Data<RedisPool>,
session_queue: web::Data<AuthQueue>,
) -> Result<Vec<models::oauth_clients::OAuthClient>, ApiError> {
let current_user = get_user_from_headers(
&req,
&**pool,
&redis,
&session_queue,
Some(&[Scopes::SESSION_ACCESS]),
)
.await?
.1;
let ids: Vec<OAuthClientId> = ids.iter().map(|i| (*i).into()).collect();
let clients = OAuthClient::get_many(&ids, &**pool).await?;
clients
.iter()
.validate_all_authorized(Some(&current_user))?;
Ok(clients.into_iter().map(|c| c.into()).collect_vec())
}

View File

@@ -3,6 +3,8 @@ use lazy_static::lazy_static;
use regex::Regex;
use validator::{ValidationErrors, ValidationErrorsKind};
use crate::models::pats::Scopes;
lazy_static! {
pub static ref RE_URL_SAFE: Regex = Regex::new(r#"^[a-zA-Z0-9!@$()`.+,_"-]*$"#).unwrap();
}
@@ -91,6 +93,16 @@ pub fn validate_url(value: &str) -> Result<(), validator::ValidationError> {
Ok(())
}
pub fn validate_no_restricted_scopes(value: &Scopes) -> Result<(), validator::ValidationError> {
if value.is_restricted() {
return Err(validator::ValidationError::new(
"Restricted scopes not allowed",
));
}
Ok(())
}
pub fn validate_name(value: &str) -> Result<(), validator::ValidationError> {
if value.trim().is_empty() {
return Err(validator::ValidationError::new(