From cfc6b9f35f49636a50ef9d170d01740439dfdbe4 Mon Sep 17 00:00:00 2001 From: mrw1593 Date: Sun, 18 Jun 2023 18:24:58 -0400 Subject: Implement refresh token grant --- src/api/oauth.rs | 69 +++++++++++++++++++++++++++++++++++++++++++++++++++++ src/services/jwt.rs | 18 +++++++------- 2 files changed, 79 insertions(+), 8 deletions(-) (limited to 'src') diff --git a/src/api/oauth.rs b/src/api/oauth.rs index bc9f5a2..de98e80 100644 --- a/src/api/oauth.rs +++ b/src/api/oauth.rs @@ -12,6 +12,7 @@ use tera::Tera; use thiserror::Error; use unic_langid::subtags::Language; use url::Url; +use uuid::Uuid; use crate::models::client::ClientType; use crate::resources::{languages, templates}; @@ -312,6 +313,10 @@ enum GrantType { ClientCredentials { scope: Option>, }, + RefreshToken { + refresh_token: Box, + scope: Option>, + }, #[serde(other)] Unsupported, } @@ -434,6 +439,14 @@ impl TokenError { ), } } + + fn bad_refresh_token(err: VerifyJwtError) -> Self { + Self { + status_code: StatusCode::BAD_REQUEST, + error: TokenErrorType::InvalidGrant, + error_description: err.to_string().into_boxed_str(), + } + } } impl ResponseError for TokenError { @@ -604,6 +617,62 @@ async fn token( .insert_header((header::PRAGMA, "no-cache")) .json(response) } + GrantType::RefreshToken { + refresh_token, + scope, + } => { + let client_id: Option; + if let Some(authorization) = authorization { + let client_alias = authorization.username(); + let Some(id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else { + return TokenError::client_not_found(client_alias).error_response(); + }; + client_id = Some(id); + } else { + client_id = None; + } + + let claims = + match jwt::verify_refresh_token(db, &refresh_token, self_id, client_id).await { + Ok(claims) => claims, + Err(e) => { + let e = e.unwrap(); + return TokenError::bad_refresh_token(e).error_response(); + } + }; + + let scope = if let Some(scope) = scope { + if !scopes::is_subset_of(&scope, claims.scopes()) { + return TokenError::excessive_scope().error_response(); + } + + scope + } else { + claims.scopes().into() + }; + + let exp_time = Duration::hours(1); + let access_token = jwt::Claims::refreshed_access_token(db, &claims, exp_time) + .await + .unwrap(); + let refresh_token = jwt::Claims::refresh_token(db, &claims).await.unwrap(); + + let access_token = access_token.to_jwt().unwrap(); + let refresh_token = Some(refresh_token.to_jwt().unwrap()); + let expires_in = exp_time.num_seconds(); + + let response = TokenResponse { + access_token, + token_type, + expires_in, + refresh_token, + scope, + }; + HttpResponse::Ok() + .insert_header(cache_control) + .insert_header((header::PRAGMA, "no-cache")) + .json(response) + } _ => TokenError::unsupported_grant_type().error_response(), } } diff --git a/src/services/jwt.rs b/src/services/jwt.rs index c86fb01..488e0ac 100644 --- a/src/services/jwt.rs +++ b/src/services/jwt.rs @@ -127,7 +127,7 @@ impl Claims { pub async fn refreshed_access_token( db: &MySqlPool, - refresh_token: Claims, + refresh_token: &Claims, exp_time: Duration, ) -> Result { let id = new_id(db, db::access_token_exists).await?; @@ -136,7 +136,7 @@ impl Claims { db::create_access_token(db, id, refresh_token.auth_code_id, exp).await?; - let mut claims = refresh_token; + let mut claims = refresh_token.clone(); claims.exp = exp; claims.iat = Some(time); claims.jti = id; @@ -187,7 +187,7 @@ pub enum VerifyJwtError { fn verify_jwt( token: &str, self_id: Url, - client_id: Uuid, + client_id: Option, ) -> Result> { let key = secrets::signing_key()?; let claims: Claims = token @@ -198,8 +198,10 @@ fn verify_jwt( yeet!(VerifyJwtError::IncorrectIssuer.into()) } - if claims.client_id != client_id { - yeet!(VerifyJwtError::WrongClient.into()) + if let Some(client_id) = client_id { + if claims.client_id != client_id { + yeet!(VerifyJwtError::WrongClient.into()) + } } if let Some(aud) = claims.aud.clone() { @@ -230,7 +232,7 @@ pub async fn verify_auth_code<'c>( client_id: Uuid, redirect_uri: Url, ) -> Result> { - let claims = verify_jwt(token, self_id, client_id)?; + let claims = verify_jwt(token, self_id, Some(client_id))?; if let Some(claimed_uri) = &claims.redirect_uri { if claimed_uri.clone() != redirect_uri { @@ -253,7 +255,7 @@ pub async fn verify_access_token<'c>( self_id: Url, client_id: Uuid, ) -> Result> { - let claims = verify_jwt(token, self_id, client_id)?; + let claims = verify_jwt(token, self_id, Some(client_id))?; if !db::access_token_exists(db, claims.jti).await? { yeet!(VerifyJwtError::JwtRevoked.into()) @@ -266,7 +268,7 @@ pub async fn verify_refresh_token<'c>( db: impl Executor<'c, Database = MySql>, token: &str, self_id: Url, - client_id: Uuid, + client_id: Option, ) -> Result> { let claims = verify_jwt(token, self_id, client_id)?; -- cgit v1.2.3