diff options
Diffstat (limited to 'src/api')
| -rw-r--r-- | src/api/oauth.rs | 223 |
1 files changed, 210 insertions, 13 deletions
diff --git a/src/api/oauth.rs b/src/api/oauth.rs index 920f488..48c3210 100644 --- a/src/api/oauth.rs +++ b/src/api/oauth.rs @@ -1,8 +1,10 @@ use std::ops::Deref; use std::str::FromStr; -use actix_web::http::header; -use actix_web::{get, post, web, HttpRequest, HttpResponse, ResponseError, Scope}; +use actix_web::http::{header, StatusCode}; +use actix_web::{ + get, post, web, HttpRequest, HttpResponse, HttpResponseBuilder, ResponseError, Scope, +}; use chrono::Duration; use serde::{Deserialize, Serialize}; use sqlx::MySqlPool; @@ -11,8 +13,10 @@ use thiserror::Error; use unic_langid::subtags::Language; use url::Url; +use crate::models::client::ClientType; use crate::resources::{languages, templates}; use crate::scopes; +use crate::services::jwt::VerifyJwtError; use crate::services::{authorization, db, jwt}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] @@ -247,6 +251,8 @@ enum GrantType { ClientCredentials { scope: Option<Box<str>>, }, + #[serde(other)] + Unsupported, } #[derive(Clone, Deserialize)] @@ -261,10 +267,131 @@ struct TokenResponse { access_token: Box<str>, token_type: Box<str>, expires_in: i64, - refresh_token: Box<str>, + refresh_token: Option<Box<str>>, scope: Box<str>, } +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "snake_case")] +enum TokenErrorType { + InvalidRequest, + InvalidClient, + InvalidGrant, + UnauthorizedClient, + UnsupportedGrantType, + InvalidScope, +} + +#[derive(Debug, Clone, Error, Serialize)] +#[error("{error_description}")] +struct TokenError { + #[serde(skip)] + status_code: StatusCode, + error: TokenErrorType, + error_description: Box<str>, + // TODO error uri +} + +impl TokenError { + fn invalid_request() -> Self { + // TODO make this description better, and all the other ones while you're at it + Self { + status_code: StatusCode::BAD_REQUEST, + error: TokenErrorType::InvalidRequest, + error_description: "Invalid request".into(), + } + } + + fn unsupported_grant_type() -> Self { + Self { + status_code: StatusCode::BAD_REQUEST, + error: TokenErrorType::UnsupportedGrantType, + error_description: "The given grant type is not supported".into(), + } + } + + fn bad_auth_code(error: VerifyJwtError) -> Self { + Self { + status_code: StatusCode::BAD_REQUEST, + error: TokenErrorType::InvalidGrant, + error_description: error.to_string().into_boxed_str(), + } + } + + fn no_authorization() -> Self { + Self { + status_code: StatusCode::UNAUTHORIZED, + error: TokenErrorType::InvalidClient, + error_description: Box::from( + "Client credentials must be provided in the HTTP Authorization header", + ), + } + } + + fn client_not_found(alias: &str) -> Self { + Self { + status_code: StatusCode::UNAUTHORIZED, + error: TokenErrorType::InvalidClient, + error_description: format!("No client with the client id: {alias} was found") + .into_boxed_str(), + } + } + + fn incorrect_client_secret() -> Self { + Self { + status_code: StatusCode::UNAUTHORIZED, + error: TokenErrorType::InvalidClient, + error_description: "The client secret is incorrect".into(), + } + } + + fn client_not_confidential(alias: &str) -> Self { + Self { + status_code: StatusCode::BAD_REQUEST, + error: TokenErrorType::UnauthorizedClient, + error_description: format!("Only a confidential client may be used with this endpoint. The {alias} client is a public client.") + .into_boxed_str(), + } + } + + fn no_scope() -> Self { + Self { + status_code: StatusCode::BAD_REQUEST, + error: TokenErrorType::InvalidScope, + error_description: Box::from( + "No scope was provided, and the client doesn't have a default scope", + ), + } + } + + fn excessive_scope() -> Self { + Self { + status_code: StatusCode::BAD_REQUEST, + error: TokenErrorType::InvalidScope, + error_description: Box::from( + "The given scope exceeds what the client is allowed to have", + ), + } + } +} + +impl ResponseError for TokenError { + fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> { + let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]); + + let mut builder = HttpResponseBuilder::new(self.status_code); + + if self.status_code.as_u16() == 401 { + builder.insert_header((header::WWW_AUTHENTICATE, "Basic charset=\"UTF-8\"")); + } + + builder + .insert_header(cache_control) + .insert_header((header::PRAGMA, "no-cache")) + .json(self.clone()) + } +} + #[post("/token")] async fn token( db: web::Data<MySqlPool>, @@ -275,7 +402,7 @@ async fn token( let db = db.get_ref(); let request = serde_json::from_slice::<TokenRequest>(&req); let Ok(request) = request else { - todo!("invalid request") + return TokenError::invalid_request().error_response(); }; let self_id = Url::parse("www.google.com").unwrap(); // TODO find the actual value @@ -290,30 +417,38 @@ async fn token( client_alias, } => { let Some(client_id) = db::get_client_id_by_alias(db, &client_alias).await.unwrap() else { - todo!("client not found") + return TokenError::client_not_found(&client_alias).error_response(); }; - let Ok(claims) = jwt::verify_auth_code(db, &code, self_id.clone(), client_id, redirect_uri).await else { - todo!("invalid code"); - }; + // validate auth code + let claims = + match jwt::verify_auth_code(db, &code, self_id.clone(), client_id, redirect_uri) + .await + { + Ok(claims) => claims, + Err(err) => { + let err = err.unwrap(); + return TokenError::bad_auth_code(err).error_response(); + } + }; // verify client, if the client has credentials if let Some(hash) = db::get_client_secret(db, client_id).await.unwrap() { let Some(authorization) = authorization else { - todo!("no client credentials") + return TokenError::no_authorization().error_response(); }; if authorization.username() != client_alias.deref() { todo!("bad username") } if !hash.check_password(authorization.password()).unwrap() { - todo!("bad password") + return TokenError::incorrect_client_secret().error_response(); } } let access_token = jwt::Claims::access_token( db, - claims.id(), + Some(claims.id()), self_id, client_id, duration, @@ -327,7 +462,7 @@ async fn token( let scope = access_token.scopes().into(); let access_token = access_token.to_jwt().unwrap(); - let refresh_token = refresh_token.to_jwt().unwrap(); + let refresh_token = Some(refresh_token.to_jwt().unwrap()); let response = TokenResponse { access_token, @@ -346,7 +481,69 @@ async fn token( password, scope, } => todo!(), - GrantType::ClientCredentials { scope } => todo!(), + GrantType::ClientCredentials { scope } => { + let Some(authorization) = authorization else { + return TokenError::no_authorization().error_response(); + }; + let client_alias = authorization.username(); + let Some(client_id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else { + return TokenError::client_not_found(client_alias).error_response(); + }; + + let ty = db::get_client_type(db, client_id).await.unwrap().unwrap(); + if ty != ClientType::Confidential { + return TokenError::client_not_confidential(client_alias).error_response(); + } + + // verify client + let hash = db::get_client_secret(db, client_id).await.unwrap().unwrap(); + if !hash.check_password(authorization.password()).unwrap() { + return TokenError::incorrect_client_secret().error_response(); + } + + // verify scope + let allowed_scopes = db::get_client_allowed_scopes(db, client_id) + .await + .unwrap() + .unwrap(); + let scope = if let Some(scope) = &scope { + scope.clone() + } else { + let default_scopes = db::get_client_default_scopes(db, client_id) + .await + .unwrap() + .unwrap(); + let Some(scope) = default_scopes else { + return TokenError::no_scope().error_response(); + }; + scope + }; + if !scopes::is_subset_of(&scope, &allowed_scopes) { + return TokenError::excessive_scope().error_response(); + } + + let access_token = + jwt::Claims::access_token(db, None, self_id, client_id, duration, &scope) + .await + .unwrap(); + + let expires_in = access_token.expires_in(); + let scope = access_token.scopes().into(); + let access_token = access_token.to_jwt().unwrap(); + + let response = TokenResponse { + access_token, + token_type, + expires_in, + refresh_token: None, + scope, + }; + HttpResponse::Ok() + .insert_header(cache_control) + .insert_header((header::PRAGMA, "no-cache")) + .json(response) + } + _ => TokenError::unsupported_grant_type().error_response(), } } |
