summaryrefslogtreecommitdiff
path: root/src/api/oauth.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/api/oauth.rs')
-rw-r--r--src/api/oauth.rs278
1 files changed, 180 insertions, 98 deletions
diff --git a/src/api/oauth.rs b/src/api/oauth.rs
index 43ad402..ef40637 100644
--- a/src/api/oauth.rs
+++ b/src/api/oauth.rs
@@ -6,6 +6,8 @@ use actix_web::{
get, post, web, HttpRequest, HttpResponse, HttpResponseBuilder, ResponseError, Scope,
};
use chrono::Duration;
+use exun::{Expect, RawUnexpected, ResultErrorExt, UnexpectedError};
+use raise::yeet;
use serde::{Deserialize, Serialize};
use sqlx::MySqlPool;
use tera::Tera;
@@ -20,6 +22,8 @@ use crate::scopes;
use crate::services::jwt::VerifyJwtError;
use crate::services::{authorization, db, jwt};
+const REALLY_BAD_ERROR_PAGE: &str = "<!DOCTYPE html><html><head><title>Internal Server Error</title></head><body>Internal Server Error</body></html>";
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
enum ResponseType {
@@ -111,6 +115,15 @@ impl AuthorizeError {
redirect_uri,
}
}
+
+ fn internal_server_error(redirect_uri: Url, state: Option<Box<str>>) -> Self {
+ Self {
+ error: AuthorizeErrorType::ServerError,
+ error_description: "An unexpected error occurred".into(),
+ state,
+ redirect_uri,
+ }
+ }
}
impl ResponseError for AuthorizeError {
@@ -126,6 +139,91 @@ impl ResponseError for AuthorizeError {
}
}
+fn error_page(
+ tera: &Tera,
+ translations: &languages::Translations,
+ error: templates::ErrorPage,
+) -> Result<String, RawUnexpected> {
+ // TODO find a better way of doing languages
+ let language = Language::from_str("en").unwrap();
+ let translations = translations.clone();
+ let page = templates::error_page(&tera, language, translations, error)?;
+ Ok(page)
+}
+
+async fn get_redirect_uri(
+ redirect_uri: &Option<Url>,
+ db: &MySqlPool,
+ client_id: Uuid,
+) -> Result<Url, Expect<templates::ErrorPage>> {
+ if let Some(uri) = &redirect_uri {
+ let redirect_uri = uri.clone();
+ if !db::client_has_redirect_uri(db, client_id, &redirect_uri)
+ .await
+ .map_err(|e| UnexpectedError::from(e))
+ .unexpect()?
+ {
+ yeet!(Expect::Expected(templates::ErrorPage::InvalidRedirectUri));
+ }
+
+ Ok(redirect_uri)
+ } else {
+ let redirect_uris = db::get_client_redirect_uris(db, client_id)
+ .await
+ .map_err(|e| UnexpectedError::from(e))
+ .unexpect()?;
+ if redirect_uris.len() != 1 {
+ yeet!(Expect::Expected(templates::ErrorPage::MissingRedirectUri));
+ }
+
+ Ok(redirect_uris.get(0).unwrap().clone())
+ }
+}
+
+async fn get_scope(
+ scope: &Option<Box<str>>,
+ db: &MySqlPool,
+ client_id: Uuid,
+ redirect_uri: &Url,
+ state: &Option<Box<str>>,
+) -> Result<Box<str>, Expect<AuthorizeError>> {
+ let scope = if let Some(scope) = &scope {
+ scope.clone()
+ } else {
+ let default_scopes = db::get_client_default_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let Some(scope) = default_scopes else {
+ yeet!(AuthorizeError::no_scope(redirect_uri.clone(), state.clone()).into())
+ };
+ scope
+ };
+
+ // verify scope is valid
+ let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ if !scopes::is_subset_of(&scope, &allowed_scopes) {
+ yeet!(AuthorizeError::invalid_scope(redirect_uri.clone(), state.clone()).into());
+ }
+
+ Ok(scope)
+}
+
+async fn authenticate_user(
+ db: &MySqlPool,
+ username: &str,
+ password: &str,
+) -> Result<bool, RawUnexpected> {
+ let Some(user) = db::get_user_by_username(db, username).await? else {
+ return Ok(false);
+ };
+
+ Ok(user.check_password(password)?)
+}
+
#[post("/authorize")]
async fn authorize(
db: web::Data<MySqlPool>,
@@ -134,62 +232,53 @@ async fn authorize(
tera: web::Data<Tera>,
translations: web::Data<languages::Translations>,
) -> HttpResponse {
- // TODO use sessions to verify that the request was previously validated
// TODO handle internal server error
+ // TODO protect against brute force attacks
let db = db.get_ref();
- let Some(client_id) = db::get_client_id_by_alias(db, &req.client_id).await.unwrap() else {
- // TODO find a better way of doing languages
- let language = Language::from_str("en").unwrap();
- let translations = translations.get_ref().clone();
- let page = templates::error_page(&tera, language, translations, templates::ErrorPage::ClientNotFound).unwrap();
+ let Ok(client_id) = db::get_client_id_by_alias(db, &req.client_id).await else {
+ let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return HttpResponse::InternalServerError().content_type("text/html").body(page);
+ };
+ let Some(client_id) = client_id else {
+ let page = error_page(&tera, &translations, templates::ErrorPage::ClientNotFound).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
return HttpResponse::NotFound().content_type("text/html").body(page);
};
let self_id = Url::parse("www.google.com").unwrap(); // TODO find the actual value
let state = req.state.clone();
// get redirect uri
- let mut redirect_uri = if let Some(redirect_uri) = &req.redirect_uri {
- redirect_uri.clone()
- } else {
- let redirect_uris = db::get_client_redirect_uris(db, client_id).await.unwrap();
- if redirect_uris.len() != 1 {
- let language = Language::from_str("en").unwrap();
- let translations = translations.get_ref().clone();
- let page = templates::error_page(
- &tera,
- language,
- translations,
- templates::ErrorPage::MissingRedirectUri,
- )
- .unwrap();
- return HttpResponse::NotFound()
+ let mut redirect_uri = match get_redirect_uri(&req.redirect_uri, db, client_id).await {
+ Ok(uri) => uri,
+ Err(e) => {
+ let e = e
+ .expected()
+ .unwrap_or(templates::ErrorPage::InternalServerError);
+ let page = error_page(&tera, &translations, e)
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return HttpResponse::BadRequest()
.content_type("text/html")
.body(page);
}
-
- redirect_uris[0].clone()
};
// authenticate user
- let Some(user) = db::get_user_by_username(db, &credentials.username).await.unwrap() else {
- todo!("bad username")
+ if !authenticate_user(db, &credentials.username, &credentials.password)
+ .await
+ .unwrap()
+ {
+ let language = Language::from_str("en").unwrap();
+ let translations = translations.get_ref().clone();
+ let page = templates::login_error_page(&tera, &req, language, translations).unwrap();
+ return HttpResponse::Ok().content_type("text/html").body(page);
};
- if !user.check_password(&credentials.password).unwrap() {
- todo!("bad password")
- }
// get scope
- let scope = if let Some(scope) = &req.scope {
- scope.clone()
- } else {
- let default_scopes = db::get_client_default_scopes(db, client_id)
- .await
- .unwrap()
- .unwrap();
- let Some(scope) = default_scopes else {
- return AuthorizeError::no_scope(redirect_uri, state).error_response()
- };
- scope
+ let scope = match get_scope(&req.scope, db, client_id, &redirect_uri, &state).await {
+ Ok(scope) => scope,
+ Err(e) => {
+ let e = e.unwrap();
+ return e.error_response();
+ }
};
match req.response_type {
@@ -248,97 +337,77 @@ async fn authorize_page(
request: HttpRequest,
) -> HttpResponse {
// TODO handle internal server error
- let language = Language::from_str("en").unwrap();
+ let Ok(language) = Language::from_str("en") else {
+ let page = String::from(REALLY_BAD_ERROR_PAGE);
+ return HttpResponse::InternalServerError()
+ .content_type("text/html")
+ .body(page);
+ };
let translations = translations.get_ref().clone();
let params = request.query_string();
let params = serde_urlencoded::from_str::<AuthorizationParameters>(params);
let Ok(params) = params else {
- let page = templates::error_page(
+ let page = error_page(
&tera,
- language,
- translations,
+ &translations,
templates::ErrorPage::InvalidRequest,
)
- .unwrap();
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
return HttpResponse::BadRequest()
.content_type("text/html")
.body(page);
};
let db = db.get_ref();
- let Some(client_id) = db::get_client_id_by_alias(db, &params.client_id).await.unwrap() else {
+ let Ok(client_id) = db::get_client_id_by_alias(db, &params.client_id).await else {
+ let page = templates::error_page(
+ &tera,
+ language,
+ translations,
+ templates::ErrorPage::InternalServerError,
+ )
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return HttpResponse::InternalServerError()
+ .content_type("text/html")
+ .body(page);
+ };
+ let Some(client_id) = client_id else {
let page = templates::error_page(
&tera,
language,
translations,
templates::ErrorPage::ClientNotFound,
)
- .unwrap();
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
return HttpResponse::NotFound()
.content_type("text/html")
.body(page);
};
- // verify scope
- let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
- .await
- .unwrap()
- .unwrap();
-
// verify redirect uri
- let redirect_uri: Url;
- if let Some(uri) = &params.redirect_uri {
- redirect_uri = uri.clone();
- if !db::client_has_redirect_uri(db, client_id, &redirect_uri)
- .await
- .unwrap()
- {
- let page = templates::error_page(
- &tera,
- language,
- translations,
- templates::ErrorPage::InvalidRedirectUri,
- )
- .unwrap();
+ let redirect_uri = match get_redirect_uri(&params.redirect_uri, db, client_id).await {
+ Ok(uri) => uri,
+ Err(e) => {
+ let e = e
+ .expected()
+ .unwrap_or(templates::ErrorPage::InternalServerError);
+ let page = error_page(&tera, &translations, e)
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
return HttpResponse::BadRequest()
.content_type("text/html")
.body(page);
}
- } else {
- let redirect_uris = db::get_client_redirect_uris(db, client_id).await.unwrap();
- if redirect_uris.len() != 1 {
- let page = templates::error_page(
- &tera,
- language,
- translations,
- templates::ErrorPage::MissingRedirectUri,
- )
- .unwrap();
- return HttpResponse::NotFound()
- .content_type("text/html")
- .body(page);
- }
-
- redirect_uri = redirect_uris.get(0).unwrap().clone();
- }
-
- let scope = if let Some(scope) = &params.scope {
- scope.clone()
- } else {
- let default_scopes = db::get_client_default_scopes(db, client_id)
- .await
- .unwrap()
- .unwrap();
- let Some(scope) = default_scopes else {
- return AuthorizeError::no_scope(redirect_uri, params.state).error_response();
- };
- scope
};
- if !scopes::is_subset_of(&scope, &allowed_scopes) {
- return AuthorizeError::invalid_scope(redirect_uri, params.state).error_response();
- }
+ // verify scope
+ let _ = match get_scope(&params.scope, db, client_id, &redirect_uri, &params.state).await {
+ Ok(scope) => scope,
+ Err(e) => {
+ let e = e.unwrap();
+ return e.error_response();
+ }
+ };
// verify response type
if params.response_type == ResponseType::Unsupported {
@@ -520,6 +589,14 @@ impl TokenError {
error_description: "Only trusted clients may use this grant".into(),
}
}
+
+ fn incorrect_user_credentials() -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidRequest,
+ error_description: "The given credentials are incorrect".into(),
+ }
+ }
}
impl ResponseError for TokenError {
@@ -647,6 +724,11 @@ async fn token(
return TokenError::incorrect_client_secret().error_response();
}
+ // authenticate user
+ if !authenticate_user(db, &username, &password).await.unwrap() {
+ return TokenError::incorrect_user_credentials().error_response();
+ };
+
// verify scope
let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
.await