diff options
| -rw-r--r-- | Cargo.lock | 11 | ||||
| -rw-r--r-- | Cargo.toml | 2 | ||||
| -rw-r--r-- | sqlx-data.json | 270 | ||||
| -rw-r--r-- | src/api/oauth.rs | 304 | ||||
| -rw-r--r-- | src/main.rs | 2 | ||||
| -rw-r--r-- | src/services/db/client.rs | 44 | ||||
| -rw-r--r-- | src/services/jwt.rs | 54 |
7 files changed, 602 insertions, 85 deletions
@@ -1686,7 +1686,9 @@ dependencies = [ "rust-argon2", "rust-ini", "serde", + "serde_json", "serde_urlencoded", + "serde_variant", "sha2", "sqlx", "tera", @@ -1807,6 +1809,15 @@ dependencies = [ ] [[package]] +name = "serde_variant" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47a8ec0b2fd0506290348d9699c0e3eb2e3e8c0498b5a9a6158b3bd4d6970076" +dependencies = [ + "serde", +] + +[[package]] name = "sha1" version = "0.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -16,6 +16,7 @@ path-clean = "1" uuid = { version = "1", features = [ "v4", "fast-rng", "serde" ] } url = { version = "2", features = ["serde"] } raise = "2" +serde_json = "1" exun = "0.1" base64 = "0.21" rust-ini = "0.18" @@ -32,3 +33,4 @@ sqlx = { version = "0.6", features = [ "runtime-actix-rustls", "mysql", "uuid", log = "0.4" chrono = { version = "0.4", features = ["serde"] } hex = "0.4" +serde_variant = "0.1" diff --git a/sqlx-data.json b/sqlx-data.json index 145dccb..cce3a53 100644 --- a/sqlx-data.json +++ b/sqlx-data.json @@ -262,66 +262,75 @@ }, "query": "DELETE FROM refresh_tokens WHERE exp < ?" }, - "4e98a6a157a30d9da7621af79845d653ab29eabed1346cd2be60258d8841929d": { + "4faa455ac38672dd2f3f29287125d772aae6956d7a3c0e67d31597e09778e1ee": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Right": 1 + } + }, + "query": "DELETE FROM auth_codes WHERE exp < ?" + }, + "5ae6b0a1174e5735cb3ea5b073f4d1877f7552ac0a6df54c978fcad9e87d5f9b": { "describe": { "columns": [ { - "name": "id: Uuid", + "name": "allowed_scopes", "ordinal": 0, "type_info": { - "char_set": 63, - "flags": { - "bits": 4231 - }, - "max_size": 16, - "type": "String" - } - }, - { - "name": "alias", - "ordinal": 1, - "type_info": { - "char_set": 224, - "flags": { - "bits": 4101 - }, - "max_size": 1020, - "type": "VarString" - } - }, - { - "name": "client_type: ClientType", - "ordinal": 2, - "type_info": { "char_set": 224, "flags": { - "bits": 4097 + "bits": 4113 }, - "max_size": 180, - "type": "VarString" + "max_size": 67108860, + "type": "Blob" } } ], "nullable": [ - false, - false, false ], "parameters": { "Right": 1 } }, - "query": "SELECT id as `id: Uuid`,\n\t\t alias,\n\t\t\t\t type as `client_type: ClientType`\n\t\t FROM clients WHERE id = ?" + "query": "SELECT allowed_scopes FROM clients WHERE id = ?" }, - "4faa455ac38672dd2f3f29287125d772aae6956d7a3c0e67d31597e09778e1ee": { + "5c1a88c154b6e69bb53aee7d0beafbfe7519592f51579d7880117fa52b7be315": { "describe": { "columns": [], "nullable": [], "parameters": { + "Right": 8 + } + }, + "query": "INSERT INTO clients (id, alias, type, secret_hash, secret_salt, secret_version, allowed_scopes, default_scopes)\n\t\t\t\t\t VALUES ( ?, ?, ?, ?, ?, ?, ?, ?)" + }, + "5f3a2ca5d0f61a806ca58195ebbb051758302ed0d376875c671a0aaddb448224": { + "describe": { + "columns": [ + { + "name": "default_scopes", + "ordinal": 0, + "type_info": { + "char_set": 224, + "flags": { + "bits": 16 + }, + "max_size": 67108860, + "type": "Blob" + } + } + ], + "nullable": [ + true + ], + "parameters": { "Right": 1 } }, - "query": "DELETE FROM auth_codes WHERE exp < ?" + "query": "SELECT default_scopes FROM clients WHERE id = ?" }, "64bd64c1c6b272fdd47d12e928be89f2eb69cc0a9f904402d038616b460c8553": { "describe": { @@ -428,6 +437,16 @@ }, "query": "DELETE FROM auth_codes WHERE jti = ?" }, + "7b6de4c923629669f449f91fe17679c8654a6ce9c1238b07dcec2cdb7fcdf18d": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Right": 2 + } + }, + "query": "UPDATE clients SET allowed_scopes = ? WHERE id = ?" + }, "866d1d42c698528f0195a0c2fc7c971ca1a140802dd205bd9918bdcc08fe377b": { "describe": { "columns": [], @@ -493,17 +512,27 @@ }, "query": "SELECT EXISTS(SELECT jti FROM auth_codes WHERE jti = ?) as `e: bool`" }, - "970643c05b6189e1277cfd695492dd3706e0c30615e64812cbd29246ada36bb7": { + "9710cd5915616165c6d27031b21cc7b3cfbd5aae574eb07797dca57064880ef9": { "describe": { "columns": [], "nullable": [], "parameters": { - "Right": 6 + "Right": 2 } }, - "query": "INSERT INTO clients (id, alias, type, secret_hash, secret_salt, secret_version)\n\t\t\t\t\t VALUES ( ?, ?, ?, ?, ?, ?)" + "query": "UPDATE users SET username = ? WHERE id = ?" }, - "9710cd5915616165c6d27031b21cc7b3cfbd5aae574eb07797dca57064880ef9": { + "981d6ca67138bfa4377025ff560f53fd77edcb9bed0d7f0cfb3468357ea5f1fe": { + "describe": { + "columns": [], + "nullable": [], + "parameters": { + "Right": 8 + } + }, + "query": "UPDATE clients SET\n\t\talias = ?,\n\t\ttype = ?,\n\t\tsecret_hash = ?,\n\t\tsecret_salt = ?,\n\t\tsecret_version = ?,\n\t\tallowed_scopes = ?,\n\t\tdefault_scopes = ?\n\t\tWHERE id = ?" + }, + "983348e316c3c8c11f9f5cf0479170d4d7246696010302a472267caeb5d2b62d": { "describe": { "columns": [], "nullable": [], @@ -511,7 +540,7 @@ "Right": 2 } }, - "query": "UPDATE users SET username = ? WHERE id = ?" + "query": "UPDATE clients SET default_scopes = ? WHERE id = ?" }, "a5d7e7e4a36cb1bb0675ccde12dadd013ae2c847648b3274494e206b14cc1370": { "describe": { @@ -625,6 +654,57 @@ }, "query": "SELECT EXISTS(SELECT id FROM users WHERE username = ?) as \"e: bool\"" }, + "b765470e11aa3a02586b0ea0a65f1bb93f104afde56fb2d77b2c72a8742fb9e0": { + "describe": { + "columns": [ + { + "name": "secret_hash", + "ordinal": 0, + "type_info": { + "char_set": 63, + "flags": { + "bits": 144 + }, + "max_size": 255, + "type": "Blob" + } + }, + { + "name": "secret_salt", + "ordinal": 1, + "type_info": { + "char_set": 63, + "flags": { + "bits": 144 + }, + "max_size": 255, + "type": "Blob" + } + }, + { + "name": "secret_version", + "ordinal": 2, + "type_info": { + "char_set": 63, + "flags": { + "bits": 32 + }, + "max_size": 10, + "type": "Long" + } + } + ], + "nullable": [ + true, + true, + true + ], + "parameters": { + "Right": 1 + } + }, + "query": "SELECT secret_hash, secret_salt, secret_version\n\t\tFROM clients WHERE id = ?" + }, "c61516c0c3d51f322a8207581802c2c9723a65beeaeae558d997590dc9e88ef2": { "describe": { "columns": [ @@ -710,6 +790,108 @@ }, "query": "UPDATE users SET\n\t\tpassword_hash = ?,\n\t\tpassword_salt = ?,\n\t\tpassword_version = ?\n\t\tWHERE id = ?" }, + "e757406f5b996a1204700cd4840ac2c5d1e09b82e13aa98d6dc017da81c059e0": { + "describe": { + "columns": [ + { + "name": "id: Uuid", + "ordinal": 0, + "type_info": { + "char_set": 63, + "flags": { + "bits": 4231 + }, + "max_size": 16, + "type": "String" + } + }, + { + "name": "alias", + "ordinal": 1, + "type_info": { + "char_set": 224, + "flags": { + "bits": 4101 + }, + "max_size": 1020, + "type": "VarString" + } + }, + { + "name": "client_type: ClientType", + "ordinal": 2, + "type_info": { + "char_set": 224, + "flags": { + "bits": 4097 + }, + "max_size": 180, + "type": "VarString" + } + }, + { + "name": "allowed_scopes", + "ordinal": 3, + "type_info": { + "char_set": 224, + "flags": { + "bits": 4113 + }, + "max_size": 67108860, + "type": "Blob" + } + }, + { + "name": "default_scopes", + "ordinal": 4, + "type_info": { + "char_set": 224, + "flags": { + "bits": 16 + }, + "max_size": 67108860, + "type": "Blob" + } + } + ], + "nullable": [ + false, + false, + false, + false, + true + ], + "parameters": { + "Right": 1 + } + }, + "query": "SELECT id as `id: Uuid`,\n\t\t alias,\n\t\t\t\t type as `client_type: ClientType`,\n\t\t\t\t allowed_scopes,\n\t\t\t\t default_scopes\n\t\t FROM clients WHERE id = ?" + }, + "f39c1d0c05c8cba9f31aa7365b36eff3c258eb6f554be456600f79b925a808d6": { + "describe": { + "columns": [ + { + "name": "id: Uuid", + "ordinal": 0, + "type_info": { + "char_set": 63, + "flags": { + "bits": 4231 + }, + "max_size": 16, + "type": "String" + } + } + ], + "nullable": [ + false + ], + "parameters": { + "Right": 1 + } + }, + "query": "SELECT id as `id: Uuid` FROM clients WHERE alias = ?" + }, "f488b319d6f387db08fb49920ddb381b2b1496605914275cd1ccd81c9420b23c": { "describe": { "columns": [ @@ -797,16 +979,6 @@ }, "query": "UPDATE clients SET secret_hash = ?, secret_salt = ?, secret_version = ? WHERE id = ?" }, - "f88f4fead2c0aeba318dd45546d5321271cdc1cb4f3b39576087b73a1024b78a": { - "describe": { - "columns": [], - "nullable": [], - "parameters": { - "Right": 6 - } - }, - "query": "UPDATE clients SET\n\t\talias = ?,\n\t\ttype = ?,\n\t\tsecret_hash = ?,\n\t\tsecret_salt = ?,\n\t\tsecret_version = ?\n\t\tWHERE id = ?" - }, "f9d2c85bdcc3b7d0d1fca4e2f0bb37df6dee23bc50af97d8e4112baacd6eb7c9": { "describe": { "columns": [], diff --git a/src/api/oauth.rs b/src/api/oauth.rs index d77695e..920f488 100644 --- a/src/api/oauth.rs +++ b/src/api/oauth.rs @@ -1,29 +1,35 @@ +use std::ops::Deref; use std::str::FromStr; -use actix_web::{get, post, web, HttpResponse, Scope}; +use actix_web::http::header; +use actix_web::{get, post, web, HttpRequest, HttpResponse, ResponseError, Scope}; +use chrono::Duration; use serde::{Deserialize, Serialize}; use sqlx::MySqlPool; use tera::Tera; +use thiserror::Error; use unic_langid::subtags::Language; use url::Url; -use uuid::Uuid; use crate::resources::{languages, templates}; -use crate::services::{authorization, db}; +use crate::scopes; +use crate::services::{authorization, db, jwt}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] enum ResponseType { Code, Token, + #[serde(other)] + Unsupported, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AuthorizationParameters { response_type: ResponseType, - client_id: Uuid, + client_id: Box<str>, redirect_uri: Option<Url>, - scope: String, // TODO lol no + scope: Option<Box<str>>, state: Option<Box<str>>, } @@ -33,14 +39,127 @@ struct AuthorizeCredentials { password: Box<str>, } +#[derive(Clone, Serialize)] +struct CodeResponse { + code: Box<str>, + state: Option<Box<str>>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)] +#[serde(rename_all = "camelCase")] +enum AuthorizeErrorType { + InvalidRequest, + UnauthorizedClient, + AccessDenied, + UnsupportedResponseType, + InvalidScope, + ServerError, + TemporarilyUnavailable, +} + +#[derive(Debug, Clone, Error)] +#[error("{error_description}")] +struct AuthorizeError { + error: AuthorizeErrorType, + error_description: Box<str>, + // TODO error uri + state: Option<Box<str>>, + redirect_uri: Url, +} + +impl AuthorizeError { + fn no_scope(redirect_uri: Url, state: Option<Box<str>>) -> Self { + Self { + error: AuthorizeErrorType::InvalidScope, + error_description: Box::from( + "No scope was provided, and the client does not have a default scope", + ), + state, + redirect_uri, + } + } +} + +impl ResponseError for AuthorizeError { + fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> { + let error = serde_variant::to_variant_name(&self.error).unwrap_or_default(); + let mut url = self.redirect_uri.clone(); + url.query_pairs_mut() + .append_pair("error", error) + .append_pair("error_description", &self.error_description); + + if let Some(state) = &self.state { + url.query_pairs_mut().append_pair("state", &state); + } + + HttpResponse::Found() + .insert_header((header::LOCATION, url.as_str())) + .finish() + } +} + #[post("/authorize")] async fn authorize( db: web::Data<MySqlPool>, - query: web::Query<AuthorizationParameters>, - credentials: web::Form<AuthorizeCredentials>, + req: web::Query<AuthorizationParameters>, + credentials: web::Json<AuthorizeCredentials>, ) -> HttpResponse { - // TODO check that the URI is valid - todo!() + // TODO use sessions to verify that the request was previously validated + let db = db.get_ref(); + let Some(client_id) = db::get_client_id_by_alias(db, &req.client_id).await.unwrap() else { + todo!("client not found") + }; + let self_id = Url::parse("www.google.com").unwrap(); // TODO find the actual value + let state = req.state.clone(); + + // get redirect uri + let redirect_uri = if let Some(redirect_uri) = &req.redirect_uri { + redirect_uri.clone() + } else { + let redirect_uris = db::get_client_redirect_uris(db, client_id).await.unwrap(); + if redirect_uris.len() != 1 { + todo!("no redirect uri"); + } + + redirect_uris[0].clone() + }; + + // authenticate user + let Some(user) = db::get_user_by_username(db, &credentials.username).await.unwrap() else { + todo!("bad username") + }; + if !user.check_password(&credentials.password).unwrap() { + todo!("bad password") + } + + // get scope + let scope = if let Some(scope) = &req.scope { + scope.clone() + } else { + let default_scopes = db::get_client_default_scopes(db, client_id) + .await + .unwrap() + .unwrap(); + let Some(scope) = default_scopes else { + return AuthorizeError::no_scope(redirect_uri, state).error_response() + }; + scope + }; + + match req.response_type { + ResponseType::Code => { + // create auth code + let code = jwt::Claims::auth_code(db, self_id, client_id, &scope, &redirect_uri) + .await + .unwrap(); + let code = code.to_jwt().unwrap(); + let response = CodeResponse { code, state }; + + HttpResponse::Ok().json(response) + } + ResponseType::Token => todo!(), + _ => todo!("unsupported response type"), + } } #[get("/authorize")] @@ -48,36 +167,187 @@ async fn authorize_page( db: web::Data<MySqlPool>, tera: web::Data<Tera>, translations: web::Data<languages::Translations>, - query: web::Query<AuthorizationParameters>, + request: HttpRequest, ) -> HttpResponse { + let params = request.query_string(); + let params = serde_urlencoded::from_str::<AuthorizationParameters>(params); + let Ok(params) = params else { + todo!("invalid request") + }; + + let db = db.get_ref(); + let Some(client_id) = db::get_client_id_by_alias(db, ¶ms.client_id).await.unwrap() else { + todo!("client not found") + }; + + // verify scope + let Some(allowed_scopes) = db::get_client_allowed_scopes(db, client_id).await.unwrap() else { + todo!("client not found") + }; + + let scope = if let Some(scope) = ¶ms.scope { + scope.clone() + } else { + let default_scopes = db::get_client_default_scopes(db, client_id) + .await + .unwrap() + .unwrap(); + let Some(scope) = default_scopes else { + todo!("invalid request") + }; + scope + }; + + if !scopes::is_subset_of(&scope, &allowed_scopes) { + todo!("access_denied") + } + + // verify redirect uri + if let Some(redirect_uri) = ¶ms.redirect_uri { + if !db::client_has_redirect_uri(db, client_id, redirect_uri) + .await + .unwrap() + { + todo!("access denied") + } + } else { + let redirect_uris = db::get_client_redirect_uris(db, client_id).await.unwrap(); + if redirect_uris.len() != 1 { + todo!("must have redirect uri") + } + } + + // verify response type + if params.response_type == ResponseType::Unsupported { + todo!("unsupported response type") + } + // TODO find a better way of doing languages - // TODO check that the URI is valid let language = Language::from_str("en").unwrap(); let page = - templates::login_page(&tera, &query, language, translations.get_ref().clone()).unwrap(); + templates::login_page(&tera, ¶ms, language, translations.get_ref().clone()).unwrap(); HttpResponse::Ok().content_type("text/html").body(page) } #[derive(Clone, Deserialize)] #[serde(tag = "grant_type")] -enum GrantType {} +#[serde(rename_all = "snake_case")] +enum GrantType { + AuthorizationCode { + code: Box<str>, + redirect_uri: Url, + #[serde(rename = "client_id")] + client_alias: Box<str>, + }, + Password { + username: Box<str>, + password: Box<str>, + scope: Option<Box<str>>, + }, + ClientCredentials { + scope: Option<Box<str>>, + }, +} #[derive(Clone, Deserialize)] struct TokenRequest { #[serde(flatten)] grant_type: GrantType, - scope: String, // TODO lol no - // TODO support optional client credentials in here + // TODO support optional client credentials in here +} + +#[derive(Clone, Serialize)] +struct TokenResponse { + access_token: Box<str>, + token_type: Box<str>, + expires_in: i64, + refresh_token: Box<str>, + scope: Box<str>, } #[post("/token")] async fn token( db: web::Data<MySqlPool>, - req: web::Form<TokenRequest>, + req: web::Bytes, authorization: Option<web::Header<authorization::BasicAuthorization>>, ) -> HttpResponse { // TODO protect against brute force attacks - todo!() + let db = db.get_ref(); + let request = serde_json::from_slice::<TokenRequest>(&req); + let Ok(request) = request else { + todo!("invalid request") + }; + + let self_id = Url::parse("www.google.com").unwrap(); // TODO find the actual value + let duration = Duration::hours(1); + let token_type = Box::from("bearer"); + let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]); + + match request.grant_type { + GrantType::AuthorizationCode { + code, + redirect_uri, + client_alias, + } => { + let Some(client_id) = db::get_client_id_by_alias(db, &client_alias).await.unwrap() else { + todo!("client not found") + }; + + let Ok(claims) = jwt::verify_auth_code(db, &code, self_id.clone(), client_id, redirect_uri).await else { + todo!("invalid code"); + }; + + // verify client, if the client has credentials + if let Some(hash) = db::get_client_secret(db, client_id).await.unwrap() { + let Some(authorization) = authorization else { + todo!("no client credentials") + }; + + if authorization.username() != client_alias.deref() { + todo!("bad username") + } + if !hash.check_password(authorization.password()).unwrap() { + todo!("bad password") + } + } + + let access_token = jwt::Claims::access_token( + db, + claims.id(), + self_id, + client_id, + duration, + claims.scopes(), + ) + .await + .unwrap(); + + let expires_in = access_token.expires_in(); + let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap(); + let scope = access_token.scopes().into(); + + let access_token = access_token.to_jwt().unwrap(); + let refresh_token = refresh_token.to_jwt().unwrap(); + + let response = TokenResponse { + access_token, + token_type, + expires_in, + refresh_token, + scope, + }; + HttpResponse::Ok() + .insert_header(cache_control) + .insert_header((header::PRAGMA, "no-cache")) + .json(response) + } + GrantType::Password { + username, + password, + scope, + } => todo!(), + GrantType::ClientCredentials { scope } => todo!(), + } } pub fn service() -> Scope { diff --git a/src/main.rs b/src/main.rs index 1106dc0..da740be 100644 --- a/src/main.rs +++ b/src/main.rs @@ -29,7 +29,7 @@ fn error_content_language<B>( async fn delete_expired_tokens(db: MySqlPool) { let db = db.clone(); - let mut interval = actix_rt::time::interval(Duration::from_secs(60 * 10)); + let mut interval = actix_rt::time::interval(Duration::from_secs(60 * 20)); loop { interval.tick().await; if let Err(e) = db::delete_expired_auth_codes(&db).await { diff --git a/src/services/db/client.rs b/src/services/db/client.rs index c25ad0d..70701d7 100644 --- a/src/services/db/client.rs +++ b/src/services/db/client.rs @@ -21,6 +21,13 @@ pub struct ClientRow { pub default_scopes: Option<String>, } +#[derive(Clone, FromRow)] +struct HashRow { + secret_hash: Option<Vec<u8>>, + secret_salt: Option<Vec<u8>>, + secret_version: Option<u32>, +} + pub async fn client_id_exists<'c>( executor: impl Executor<'c, Database = MySql>, id: Uuid, @@ -47,6 +54,19 @@ pub async fn client_alias_exists<'c>( .unexpect() } +pub async fn get_client_id_by_alias<'c>( + executor: impl Executor<'c, Database = MySql>, + alias: &str, +) -> Result<Option<Uuid>, RawUnexpected> { + query_scalar!( + "SELECT id as `id: Uuid` FROM clients WHERE alias = ?", + alias + ) + .fetch_optional(executor) + .await + .unexpect() +} + pub async fn get_client_response<'c>( executor: impl Executor<'c, Database = MySql>, id: Uuid, @@ -116,6 +136,28 @@ pub async fn get_client_default_scopes<'c>( Ok(scopes.map(|s| s.map(Box::from))) } +pub async fn get_client_secret<'c>( + executor: impl Executor<'c, Database = MySql>, + id: Uuid, +) -> Result<Option<PasswordHash>, RawUnexpected> { + let hash = query_as!( + HashRow, + r"SELECT secret_hash, secret_salt, secret_version + FROM clients WHERE id = ?", + id + ) + .fetch_optional(executor) + .await?; + + let Some(hash) = hash else { return Ok(None) }; + let Some(version) = hash.secret_version else { return Ok(None) }; + let Some(salt) = hash.secret_hash else { return Ok(None) }; + let Some(hash) = hash.secret_salt else { return Ok(None) }; + + let hash = PasswordHash::from_fields(&hash, &salt, version as u8); + Ok(Some(hash)) +} + pub async fn get_client_redirect_uris<'c>( executor: impl Executor<'c, Database = MySql>, id: Uuid, @@ -136,7 +178,7 @@ pub async fn get_client_redirect_uris<'c>( pub async fn client_has_redirect_uri<'c>( executor: impl Executor<'c, Database = MySql>, id: Uuid, - url: Url, + url: &Url, ) -> Result<bool, RawUnexpected> { query_scalar!( r"SELECT EXISTS( diff --git a/src/services/jwt.rs b/src/services/jwt.rs index 7841afb..822101f 100644 --- a/src/services/jwt.rs +++ b/src/services/jwt.rs @@ -32,6 +32,7 @@ pub struct Claims { client_id: Uuid, auth_code_id: Uuid, token_type: TokenType, + redirect_uri: Option<Url>, } #[derive(Debug, Clone, Copy, sqlx::Type)] @@ -43,18 +44,19 @@ pub enum RevokedRefreshTokenReason { impl Claims { pub async fn auth_code<'c>( - db: MySqlPool, + db: &MySqlPool, self_id: Url, client_id: Uuid, scopes: &str, + redirect_uri: &Url, ) -> Result<Self, RawUnexpected> { let five_minutes = Duration::minutes(5); - let id = new_id(&db, db::auth_code_exists).await?; + let id = new_id(db, db::auth_code_exists).await?; let time = Utc::now(); let exp = time + five_minutes; - db::create_auth_code(&db, id, exp).await?; + db::create_auth_code(db, id, exp).await?; Ok(Self { iss: self_id, @@ -67,22 +69,23 @@ impl Claims { client_id, auth_code_id: id, token_type: TokenType::Authorization, + redirect_uri: Some(redirect_uri.clone()), }) } pub async fn access_token<'c>( - db: MySqlPool, + db: &MySqlPool, auth_code_id: Uuid, self_id: Url, client_id: Uuid, duration: Duration, scopes: &str, ) -> Result<Self, RawUnexpected> { - let id = new_id(&db, db::access_token_exists).await?; + let id = new_id(db, db::access_token_exists).await?; let time = Utc::now(); let exp = time + duration; - db::create_access_token(&db, id, auth_code_id, exp) + db::create_access_token(db, id, auth_code_id, exp) .await .unexpect()?; @@ -97,19 +100,23 @@ impl Claims { client_id, auth_code_id, token_type: TokenType::Access, + redirect_uri: None, }) } - pub async fn refresh_token(db: MySqlPool, other_token: Claims) -> Result<Self, RawUnexpected> { + pub async fn refresh_token( + db: &MySqlPool, + other_token: &Claims, + ) -> Result<Self, RawUnexpected> { let one_day = Duration::days(1); - let id = new_id(&db, db::refresh_token_exists).await?; + let id = new_id(db, db::refresh_token_exists).await?; let time = Utc::now(); let exp = other_token.exp + one_day; - db::create_refresh_token(&db, id, other_token.auth_code_id, exp).await?; + db::create_refresh_token(db, id, other_token.auth_code_id, exp).await?; - let mut claims = other_token; + let mut claims = other_token.clone(); claims.exp = exp; claims.iat = Some(time); claims.jti = id; @@ -119,15 +126,15 @@ impl Claims { } pub async fn refreshed_access_token( - db: MySqlPool, + db: &MySqlPool, refresh_token: Claims, exp_time: Duration, ) -> Result<Self, RawUnexpected> { - let id = new_id(&db, db::access_token_exists).await?; + let id = new_id(db, db::access_token_exists).await?; let time = Utc::now(); let exp = time + exp_time; - db::create_access_token(&db, id, refresh_token.auth_code_id, exp).await?; + db::create_access_token(db, id, refresh_token.auth_code_id, exp).await?; let mut claims = refresh_token; claims.exp = exp; @@ -142,6 +149,10 @@ impl Claims { self.jti } + pub fn expires_in(&self) -> i64 { + (self.exp - Utc::now()).num_seconds() + } + pub fn scopes(&self) -> &str { &self.scope } @@ -163,6 +174,8 @@ pub enum VerifyJwtError { WrongClient, #[error("The given audience parameter does not contain this issuer")] BadAudience, + #[error("The redirect URI doesn't match what's in the token")] + IncorrectRedirectUri, #[error("The token is expired")] ExpiredToken, #[error("The token cannot be used yet")] @@ -211,16 +224,23 @@ fn verify_jwt( } pub async fn verify_auth_code<'c>( - db: MySqlPool, + db: &MySqlPool, token: &str, self_id: Url, client_id: Uuid, + redirect_uri: Url, ) -> Result<Claims, Expect<VerifyJwtError>> { let claims = verify_jwt(token, self_id, client_id)?; - if db::delete_auth_code(&db, claims.jti).await? { - db::delete_access_tokens_with_auth_code(&db, claims.jti).await?; - db::revoke_refresh_tokens_with_auth_code(&db, claims.jti).await?; + if let Some(claimed_uri) = &claims.redirect_uri { + if claimed_uri.clone() != redirect_uri { + yeet!(VerifyJwtError::IncorrectRedirectUri.into()); + } + } + + if db::delete_auth_code(db, claims.jti).await? { + db::delete_access_tokens_with_auth_code(db, claims.jti).await?; + db::revoke_refresh_tokens_with_auth_code(db, claims.jti).await?; yeet!(VerifyJwtError::JwtRevoked.into()); } |
