diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/api/oauth.rs | 304 | ||||
| -rw-r--r-- | src/main.rs | 2 | ||||
| -rw-r--r-- | src/services/db/client.rs | 44 | ||||
| -rw-r--r-- | src/services/jwt.rs | 54 |
4 files changed, 368 insertions, 36 deletions
diff --git a/src/api/oauth.rs b/src/api/oauth.rs index d77695e..920f488 100644 --- a/src/api/oauth.rs +++ b/src/api/oauth.rs @@ -1,29 +1,35 @@ +use std::ops::Deref; use std::str::FromStr; -use actix_web::{get, post, web, HttpResponse, Scope}; +use actix_web::http::header; +use actix_web::{get, post, web, HttpRequest, HttpResponse, ResponseError, Scope}; +use chrono::Duration; use serde::{Deserialize, Serialize}; use sqlx::MySqlPool; use tera::Tera; +use thiserror::Error; use unic_langid::subtags::Language; use url::Url; -use uuid::Uuid; use crate::resources::{languages, templates}; -use crate::services::{authorization, db}; +use crate::scopes; +use crate::services::{authorization, db, jwt}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] enum ResponseType { Code, Token, + #[serde(other)] + Unsupported, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AuthorizationParameters { response_type: ResponseType, - client_id: Uuid, + client_id: Box<str>, redirect_uri: Option<Url>, - scope: String, // TODO lol no + scope: Option<Box<str>>, state: Option<Box<str>>, } @@ -33,14 +39,127 @@ struct AuthorizeCredentials { password: Box<str>, } +#[derive(Clone, Serialize)] +struct CodeResponse { + code: Box<str>, + state: Option<Box<str>>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)] +#[serde(rename_all = "camelCase")] +enum AuthorizeErrorType { + InvalidRequest, + UnauthorizedClient, + AccessDenied, + UnsupportedResponseType, + InvalidScope, + ServerError, + TemporarilyUnavailable, +} + +#[derive(Debug, Clone, Error)] +#[error("{error_description}")] +struct AuthorizeError { + error: AuthorizeErrorType, + error_description: Box<str>, + // TODO error uri + state: Option<Box<str>>, + redirect_uri: Url, +} + +impl AuthorizeError { + fn no_scope(redirect_uri: Url, state: Option<Box<str>>) -> Self { + Self { + error: AuthorizeErrorType::InvalidScope, + error_description: Box::from( + "No scope was provided, and the client does not have a default scope", + ), + state, + redirect_uri, + } + } +} + +impl ResponseError for AuthorizeError { + fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> { + let error = serde_variant::to_variant_name(&self.error).unwrap_or_default(); + let mut url = self.redirect_uri.clone(); + url.query_pairs_mut() + .append_pair("error", error) + .append_pair("error_description", &self.error_description); + + if let Some(state) = &self.state { + url.query_pairs_mut().append_pair("state", &state); + } + + HttpResponse::Found() + .insert_header((header::LOCATION, url.as_str())) + .finish() + } +} + #[post("/authorize")] async fn authorize( db: web::Data<MySqlPool>, - query: web::Query<AuthorizationParameters>, - credentials: web::Form<AuthorizeCredentials>, + req: web::Query<AuthorizationParameters>, + credentials: web::Json<AuthorizeCredentials>, ) -> HttpResponse { - // TODO check that the URI is valid - todo!() + // TODO use sessions to verify that the request was previously validated + let db = db.get_ref(); + let Some(client_id) = db::get_client_id_by_alias(db, &req.client_id).await.unwrap() else { + todo!("client not found") + }; + let self_id = Url::parse("www.google.com").unwrap(); // TODO find the actual value + let state = req.state.clone(); + + // get redirect uri + let redirect_uri = if let Some(redirect_uri) = &req.redirect_uri { + redirect_uri.clone() + } else { + let redirect_uris = db::get_client_redirect_uris(db, client_id).await.unwrap(); + if redirect_uris.len() != 1 { + todo!("no redirect uri"); + } + + redirect_uris[0].clone() + }; + + // authenticate user + let Some(user) = db::get_user_by_username(db, &credentials.username).await.unwrap() else { + todo!("bad username") + }; + if !user.check_password(&credentials.password).unwrap() { + todo!("bad password") + } + + // get scope + let scope = if let Some(scope) = &req.scope { + scope.clone() + } else { + let default_scopes = db::get_client_default_scopes(db, client_id) + .await + .unwrap() + .unwrap(); + let Some(scope) = default_scopes else { + return AuthorizeError::no_scope(redirect_uri, state).error_response() + }; + scope + }; + + match req.response_type { + ResponseType::Code => { + // create auth code + let code = jwt::Claims::auth_code(db, self_id, client_id, &scope, &redirect_uri) + .await + .unwrap(); + let code = code.to_jwt().unwrap(); + let response = CodeResponse { code, state }; + + HttpResponse::Ok().json(response) + } + ResponseType::Token => todo!(), + _ => todo!("unsupported response type"), + } } #[get("/authorize")] @@ -48,36 +167,187 @@ async fn authorize_page( db: web::Data<MySqlPool>, tera: web::Data<Tera>, translations: web::Data<languages::Translations>, - query: web::Query<AuthorizationParameters>, + request: HttpRequest, ) -> HttpResponse { + let params = request.query_string(); + let params = serde_urlencoded::from_str::<AuthorizationParameters>(params); + let Ok(params) = params else { + todo!("invalid request") + }; + + let db = db.get_ref(); + let Some(client_id) = db::get_client_id_by_alias(db, ¶ms.client_id).await.unwrap() else { + todo!("client not found") + }; + + // verify scope + let Some(allowed_scopes) = db::get_client_allowed_scopes(db, client_id).await.unwrap() else { + todo!("client not found") + }; + + let scope = if let Some(scope) = ¶ms.scope { + scope.clone() + } else { + let default_scopes = db::get_client_default_scopes(db, client_id) + .await + .unwrap() + .unwrap(); + let Some(scope) = default_scopes else { + todo!("invalid request") + }; + scope + }; + + if !scopes::is_subset_of(&scope, &allowed_scopes) { + todo!("access_denied") + } + + // verify redirect uri + if let Some(redirect_uri) = ¶ms.redirect_uri { + if !db::client_has_redirect_uri(db, client_id, redirect_uri) + .await + .unwrap() + { + todo!("access denied") + } + } else { + let redirect_uris = db::get_client_redirect_uris(db, client_id).await.unwrap(); + if redirect_uris.len() != 1 { + todo!("must have redirect uri") + } + } + + // verify response type + if params.response_type == ResponseType::Unsupported { + todo!("unsupported response type") + } + // TODO find a better way of doing languages - // TODO check that the URI is valid let language = Language::from_str("en").unwrap(); let page = - templates::login_page(&tera, &query, language, translations.get_ref().clone()).unwrap(); + templates::login_page(&tera, ¶ms, language, translations.get_ref().clone()).unwrap(); HttpResponse::Ok().content_type("text/html").body(page) } #[derive(Clone, Deserialize)] #[serde(tag = "grant_type")] -enum GrantType {} +#[serde(rename_all = "snake_case")] +enum GrantType { + AuthorizationCode { + code: Box<str>, + redirect_uri: Url, + #[serde(rename = "client_id")] + client_alias: Box<str>, + }, + Password { + username: Box<str>, + password: Box<str>, + scope: Option<Box<str>>, + }, + ClientCredentials { + scope: Option<Box<str>>, + }, +} #[derive(Clone, Deserialize)] struct TokenRequest { #[serde(flatten)] grant_type: GrantType, - scope: String, // TODO lol no - // TODO support optional client credentials in here + // TODO support optional client credentials in here +} + +#[derive(Clone, Serialize)] +struct TokenResponse { + access_token: Box<str>, + token_type: Box<str>, + expires_in: i64, + refresh_token: Box<str>, + scope: Box<str>, } #[post("/token")] async fn token( db: web::Data<MySqlPool>, - req: web::Form<TokenRequest>, + req: web::Bytes, authorization: Option<web::Header<authorization::BasicAuthorization>>, ) -> HttpResponse { // TODO protect against brute force attacks - todo!() + let db = db.get_ref(); + let request = serde_json::from_slice::<TokenRequest>(&req); + let Ok(request) = request else { + todo!("invalid request") + }; + + let self_id = Url::parse("www.google.com").unwrap(); // TODO find the actual value + let duration = Duration::hours(1); + let token_type = Box::from("bearer"); + let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]); + + match request.grant_type { + GrantType::AuthorizationCode { + code, + redirect_uri, + client_alias, + } => { + let Some(client_id) = db::get_client_id_by_alias(db, &client_alias).await.unwrap() else { + todo!("client not found") + }; + + let Ok(claims) = jwt::verify_auth_code(db, &code, self_id.clone(), client_id, redirect_uri).await else { + todo!("invalid code"); + }; + + // verify client, if the client has credentials + if let Some(hash) = db::get_client_secret(db, client_id).await.unwrap() { + let Some(authorization) = authorization else { + todo!("no client credentials") + }; + + if authorization.username() != client_alias.deref() { + todo!("bad username") + } + if !hash.check_password(authorization.password()).unwrap() { + todo!("bad password") + } + } + + let access_token = jwt::Claims::access_token( + db, + claims.id(), + self_id, + client_id, + duration, + claims.scopes(), + ) + .await + .unwrap(); + + let expires_in = access_token.expires_in(); + let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap(); + let scope = access_token.scopes().into(); + + let access_token = access_token.to_jwt().unwrap(); + let refresh_token = refresh_token.to_jwt().unwrap(); + + let response = TokenResponse { + access_token, + token_type, + expires_in, + refresh_token, + scope, + }; + HttpResponse::Ok() + .insert_header(cache_control) + .insert_header((header::PRAGMA, "no-cache")) + .json(response) + } + GrantType::Password { + username, + password, + scope, + } => todo!(), + GrantType::ClientCredentials { scope } => todo!(), + } } pub fn service() -> Scope { diff --git a/src/main.rs b/src/main.rs index 1106dc0..da740be 100644 --- a/src/main.rs +++ b/src/main.rs @@ -29,7 +29,7 @@ fn error_content_language<B>( async fn delete_expired_tokens(db: MySqlPool) { let db = db.clone(); - let mut interval = actix_rt::time::interval(Duration::from_secs(60 * 10)); + let mut interval = actix_rt::time::interval(Duration::from_secs(60 * 20)); loop { interval.tick().await; if let Err(e) = db::delete_expired_auth_codes(&db).await { diff --git a/src/services/db/client.rs b/src/services/db/client.rs index c25ad0d..70701d7 100644 --- a/src/services/db/client.rs +++ b/src/services/db/client.rs @@ -21,6 +21,13 @@ pub struct ClientRow { pub default_scopes: Option<String>, } +#[derive(Clone, FromRow)] +struct HashRow { + secret_hash: Option<Vec<u8>>, + secret_salt: Option<Vec<u8>>, + secret_version: Option<u32>, +} + pub async fn client_id_exists<'c>( executor: impl Executor<'c, Database = MySql>, id: Uuid, @@ -47,6 +54,19 @@ pub async fn client_alias_exists<'c>( .unexpect() } +pub async fn get_client_id_by_alias<'c>( + executor: impl Executor<'c, Database = MySql>, + alias: &str, +) -> Result<Option<Uuid>, RawUnexpected> { + query_scalar!( + "SELECT id as `id: Uuid` FROM clients WHERE alias = ?", + alias + ) + .fetch_optional(executor) + .await + .unexpect() +} + pub async fn get_client_response<'c>( executor: impl Executor<'c, Database = MySql>, id: Uuid, @@ -116,6 +136,28 @@ pub async fn get_client_default_scopes<'c>( Ok(scopes.map(|s| s.map(Box::from))) } +pub async fn get_client_secret<'c>( + executor: impl Executor<'c, Database = MySql>, + id: Uuid, +) -> Result<Option<PasswordHash>, RawUnexpected> { + let hash = query_as!( + HashRow, + r"SELECT secret_hash, secret_salt, secret_version + FROM clients WHERE id = ?", + id + ) + .fetch_optional(executor) + .await?; + + let Some(hash) = hash else { return Ok(None) }; + let Some(version) = hash.secret_version else { return Ok(None) }; + let Some(salt) = hash.secret_hash else { return Ok(None) }; + let Some(hash) = hash.secret_salt else { return Ok(None) }; + + let hash = PasswordHash::from_fields(&hash, &salt, version as u8); + Ok(Some(hash)) +} + pub async fn get_client_redirect_uris<'c>( executor: impl Executor<'c, Database = MySql>, id: Uuid, @@ -136,7 +178,7 @@ pub async fn get_client_redirect_uris<'c>( pub async fn client_has_redirect_uri<'c>( executor: impl Executor<'c, Database = MySql>, id: Uuid, - url: Url, + url: &Url, ) -> Result<bool, RawUnexpected> { query_scalar!( r"SELECT EXISTS( diff --git a/src/services/jwt.rs b/src/services/jwt.rs index 7841afb..822101f 100644 --- a/src/services/jwt.rs +++ b/src/services/jwt.rs @@ -32,6 +32,7 @@ pub struct Claims { client_id: Uuid, auth_code_id: Uuid, token_type: TokenType, + redirect_uri: Option<Url>, } #[derive(Debug, Clone, Copy, sqlx::Type)] @@ -43,18 +44,19 @@ pub enum RevokedRefreshTokenReason { impl Claims { pub async fn auth_code<'c>( - db: MySqlPool, + db: &MySqlPool, self_id: Url, client_id: Uuid, scopes: &str, + redirect_uri: &Url, ) -> Result<Self, RawUnexpected> { let five_minutes = Duration::minutes(5); - let id = new_id(&db, db::auth_code_exists).await?; + let id = new_id(db, db::auth_code_exists).await?; let time = Utc::now(); let exp = time + five_minutes; - db::create_auth_code(&db, id, exp).await?; + db::create_auth_code(db, id, exp).await?; Ok(Self { iss: self_id, @@ -67,22 +69,23 @@ impl Claims { client_id, auth_code_id: id, token_type: TokenType::Authorization, + redirect_uri: Some(redirect_uri.clone()), }) } pub async fn access_token<'c>( - db: MySqlPool, + db: &MySqlPool, auth_code_id: Uuid, self_id: Url, client_id: Uuid, duration: Duration, scopes: &str, ) -> Result<Self, RawUnexpected> { - let id = new_id(&db, db::access_token_exists).await?; + let id = new_id(db, db::access_token_exists).await?; let time = Utc::now(); let exp = time + duration; - db::create_access_token(&db, id, auth_code_id, exp) + db::create_access_token(db, id, auth_code_id, exp) .await .unexpect()?; @@ -97,19 +100,23 @@ impl Claims { client_id, auth_code_id, token_type: TokenType::Access, + redirect_uri: None, }) } - pub async fn refresh_token(db: MySqlPool, other_token: Claims) -> Result<Self, RawUnexpected> { + pub async fn refresh_token( + db: &MySqlPool, + other_token: &Claims, + ) -> Result<Self, RawUnexpected> { let one_day = Duration::days(1); - let id = new_id(&db, db::refresh_token_exists).await?; + let id = new_id(db, db::refresh_token_exists).await?; let time = Utc::now(); let exp = other_token.exp + one_day; - db::create_refresh_token(&db, id, other_token.auth_code_id, exp).await?; + db::create_refresh_token(db, id, other_token.auth_code_id, exp).await?; - let mut claims = other_token; + let mut claims = other_token.clone(); claims.exp = exp; claims.iat = Some(time); claims.jti = id; @@ -119,15 +126,15 @@ impl Claims { } pub async fn refreshed_access_token( - db: MySqlPool, + db: &MySqlPool, refresh_token: Claims, exp_time: Duration, ) -> Result<Self, RawUnexpected> { - let id = new_id(&db, db::access_token_exists).await?; + let id = new_id(db, db::access_token_exists).await?; let time = Utc::now(); let exp = time + exp_time; - db::create_access_token(&db, id, refresh_token.auth_code_id, exp).await?; + db::create_access_token(db, id, refresh_token.auth_code_id, exp).await?; let mut claims = refresh_token; claims.exp = exp; @@ -142,6 +149,10 @@ impl Claims { self.jti } + pub fn expires_in(&self) -> i64 { + (self.exp - Utc::now()).num_seconds() + } + pub fn scopes(&self) -> &str { &self.scope } @@ -163,6 +174,8 @@ pub enum VerifyJwtError { WrongClient, #[error("The given audience parameter does not contain this issuer")] BadAudience, + #[error("The redirect URI doesn't match what's in the token")] + IncorrectRedirectUri, #[error("The token is expired")] ExpiredToken, #[error("The token cannot be used yet")] @@ -211,16 +224,23 @@ fn verify_jwt( } pub async fn verify_auth_code<'c>( - db: MySqlPool, + db: &MySqlPool, token: &str, self_id: Url, client_id: Uuid, + redirect_uri: Url, ) -> Result<Claims, Expect<VerifyJwtError>> { let claims = verify_jwt(token, self_id, client_id)?; - if db::delete_auth_code(&db, claims.jti).await? { - db::delete_access_tokens_with_auth_code(&db, claims.jti).await?; - db::revoke_refresh_tokens_with_auth_code(&db, claims.jti).await?; + if let Some(claimed_uri) = &claims.redirect_uri { + if claimed_uri.clone() != redirect_uri { + yeet!(VerifyJwtError::IncorrectRedirectUri.into()); + } + } + + if db::delete_auth_code(db, claims.jti).await? { + db::delete_access_tokens_with_auth_code(db, claims.jti).await?; + db::revoke_refresh_tokens_with_auth_code(db, claims.jti).await?; yeet!(VerifyJwtError::JwtRevoked.into()); } |
