From 608ce1d9910cd68ce825838ea313e02c598f908e Mon Sep 17 00:00:00 2001 From: Mica White Date: Mon, 8 Dec 2025 20:08:21 -0500 Subject: Stuff --- src/api/clients.rs | 966 +++++++++++++-------------- src/api/liveops.rs | 22 +- src/api/mod.rs | 26 +- src/api/oauth.rs | 1852 ++++++++++++++++++++++++++-------------------------- src/api/ops.rs | 140 ++-- src/api/users.rs | 544 +++++++-------- 6 files changed, 1775 insertions(+), 1775 deletions(-) (limited to 'src/api') diff --git a/src/api/clients.rs b/src/api/clients.rs index 3f906bb..ded8b81 100644 --- a/src/api/clients.rs +++ b/src/api/clients.rs @@ -1,483 +1,483 @@ -use actix_web::http::{header, StatusCode}; -use actix_web::{get, post, put, web, HttpResponse, ResponseError, Scope}; -use raise::yeet; -use serde::{Deserialize, Serialize}; -use sqlx::MySqlPool; -use thiserror::Error; -use url::Url; -use uuid::Uuid; - -use crate::models::client::{Client, ClientType, CreateClientError}; -use crate::services::crypto::PasswordHash; -use crate::services::db::ClientRow; -use crate::services::{db, id}; - -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "camelCase")] -struct ClientResponse { - client_id: Uuid, - alias: Box, - client_type: ClientType, - allowed_scopes: Box<[Box]>, - default_scopes: Option]>>, - is_trusted: bool, -} - -impl From for ClientResponse { - fn from(value: ClientRow) -> Self { - Self { - client_id: value.id, - alias: value.alias.into_boxed_str(), - client_type: value.client_type, - allowed_scopes: value - .allowed_scopes - .split_whitespace() - .map(Box::from) - .collect(), - default_scopes: value - .default_scopes - .map(|s| s.split_whitespace().map(Box::from).collect()), - is_trusted: value.is_trusted, - } - } -} - -#[derive(Debug, Clone, Copy, Error)] -#[error("No client with the given client ID was found")] -struct ClientNotFound { - id: Uuid, -} - -impl ResponseError for ClientNotFound { - fn status_code(&self) -> StatusCode { - StatusCode::NOT_FOUND - } -} - -impl ClientNotFound { - fn new(id: Uuid) -> Self { - Self { id } - } -} - -#[get("/{client_id}")] -async fn get_client( - client_id: web::Path, - db: web::Data, -) -> Result { - let db = db.as_ref(); - let id = *client_id; - - let Some(client) = db::get_client_response(db, id).await.unwrap() else { - yeet!(ClientNotFound::new(id)) - }; - - let redirect_uris_link = format!("; rel=\"redirect-uris\""); - let response: ClientResponse = client.into(); - let response = HttpResponse::Ok() - .append_header((header::LINK, redirect_uris_link)) - .json(response); - Ok(response) -} - -#[get("/{client_id}/alias")] -async fn get_client_alias( - client_id: web::Path, - db: web::Data, -) -> Result { - let db = db.as_ref(); - let id = *client_id; - - let Some(alias) = db::get_client_alias(db, id).await.unwrap() else { - yeet!(ClientNotFound::new(id)) - }; - - Ok(HttpResponse::Ok().json(alias)) -} - -#[get("/{client_id}/type")] -async fn get_client_type( - client_id: web::Path, - db: web::Data, -) -> Result { - let db = db.as_ref(); - let id = *client_id; - - let Some(client_type) = db::get_client_type(db, id).await.unwrap() else { - yeet!(ClientNotFound::new(id)) - }; - - Ok(HttpResponse::Ok().json(client_type)) -} - -#[get("/{client_id}/redirect-uris")] -async fn get_client_redirect_uris( - client_id: web::Path, - db: web::Data, -) -> Result { - let db = db.as_ref(); - let id = *client_id; - - if !db::client_id_exists(db, id).await.unwrap() { - yeet!(ClientNotFound::new(id)) - }; - - let redirect_uris = db::get_client_redirect_uris(db, id).await.unwrap(); - - Ok(HttpResponse::Ok().json(redirect_uris)) -} - -#[get("/{client_id}/allowed-scopes")] -async fn get_client_allowed_scopes( - client_id: web::Path, - db: web::Data, -) -> Result { - let db = db.as_ref(); - let id = *client_id; - - let Some(allowed_scopes) = db::get_client_allowed_scopes(db, id).await.unwrap() else { - yeet!(ClientNotFound::new(id)) - }; - - let allowed_scopes = allowed_scopes.split_whitespace().collect::>(); - - Ok(HttpResponse::Ok().json(allowed_scopes)) -} - -#[get("/{client_id}/default-scopes")] -async fn get_client_default_scopes( - client_id: web::Path, - db: web::Data, -) -> Result { - let db = db.as_ref(); - let id = *client_id; - - let Some(default_scopes) = db::get_client_default_scopes(db, id).await.unwrap() else { - yeet!(ClientNotFound::new(id)) - }; - - let default_scopes = default_scopes.map(|scopes| { - scopes - .split_whitespace() - .map(Box::from) - .collect::]>>() - }); - - Ok(HttpResponse::Ok().json(default_scopes)) -} - -#[get("/{client_id}/is-trusted")] -async fn get_client_is_trusted( - client_id: web::Path, - db: web::Data, -) -> Result { - let db = db.as_ref(); - let id = *client_id; - - let Some(is_trusted) = db::is_client_trusted(db, id).await.unwrap() else { - yeet!(ClientNotFound::new(id)) - }; - - Ok(HttpResponse::Ok().json(is_trusted)) -} - -#[derive(Clone, Deserialize)] -#[serde(rename_all = "camelCase")] -struct ClientRequest { - alias: Box, - ty: ClientType, - redirect_uris: Box<[Url]>, - secret: Option>, - allowed_scopes: Box<[Box]>, - default_scopes: Option]>>, - trusted: bool, -} - -#[derive(Debug, Clone, Error)] -#[error("The given client alias is already taken")] -struct AliasTakenError { - alias: Box, -} - -impl ResponseError for AliasTakenError { - fn status_code(&self) -> StatusCode { - StatusCode::CONFLICT - } -} - -impl AliasTakenError { - fn new(alias: &str) -> Self { - Self { - alias: Box::from(alias), - } - } -} - -#[post("")] -async fn create_client( - body: web::Json, - db: web::Data, -) -> Result { - let db = db.get_ref(); - let alias = &body.alias; - - if db::client_alias_exists(db, &alias).await.unwrap() { - yeet!(AliasTakenError::new(&alias).into()); - } - - let id = id::new_id(db, db::client_id_exists).await.unwrap(); - let client = Client::new( - id, - &alias, - body.ty, - body.secret.as_deref(), - body.allowed_scopes.clone(), - body.default_scopes.clone(), - &body.redirect_uris, - body.trusted, - ) - .map_err(|e| e.unwrap())?; - - let transaction = db.begin().await.unwrap(); - db::create_client(transaction, &client).await.unwrap(); - - let response = HttpResponse::Created() - .insert_header((header::LOCATION, format!("clients/{id}"))) - .finish(); - Ok(response) -} - -#[derive(Debug, Clone, Error)] -enum UpdateClientError { - #[error(transparent)] - NotFound(#[from] ClientNotFound), - #[error(transparent)] - ClientError(#[from] CreateClientError), - #[error(transparent)] - AliasTaken(#[from] AliasTakenError), -} - -impl ResponseError for UpdateClientError { - fn status_code(&self) -> StatusCode { - match self { - Self::NotFound(e) => e.status_code(), - Self::ClientError(e) => e.status_code(), - Self::AliasTaken(e) => e.status_code(), - } - } -} - -#[put("/{id}")] -async fn update_client( - id: web::Path, - body: web::Json, - db: web::Data, -) -> Result { - let db = db.get_ref(); - let id = *id; - let alias = &body.alias; - - let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else { - yeet!(ClientNotFound::new(id).into()) - }; - if old_alias != alias.clone() && db::client_alias_exists(db, &alias).await.unwrap() { - yeet!(AliasTakenError::new(&alias).into()); - } - - let client = Client::new( - id, - &alias, - body.ty, - body.secret.as_deref(), - body.allowed_scopes.clone(), - body.default_scopes.clone(), - &body.redirect_uris, - body.trusted, - ) - .map_err(|e| e.unwrap())?; - - let transaction = db.begin().await.unwrap(); - db::update_client(transaction, &client).await.unwrap(); - - let response = HttpResponse::NoContent().finish(); - Ok(response) -} - -#[put("/{id}/alias")] -async fn update_client_alias( - id: web::Path, - body: web::Json>, - db: web::Data, -) -> Result { - let db = db.get_ref(); - let id = *id; - let alias = body.0; - - let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else { - yeet!(ClientNotFound::new(id).into()) - }; - if old_alias == alias { - return Ok(HttpResponse::NoContent().finish()); - } - if db::client_alias_exists(db, &alias).await.unwrap() { - yeet!(AliasTakenError::new(&alias).into()); - } - - db::update_client_alias(db, id, &alias).await.unwrap(); - - let response = HttpResponse::NoContent().finish(); - Ok(response) -} - -#[put("/{id}/type")] -async fn update_client_type( - id: web::Path, - body: web::Json, - db: web::Data, -) -> Result { - let db = db.get_ref(); - let id = *id; - let ty = body.0; - - if !db::client_id_exists(db, id).await.unwrap() { - yeet!(ClientNotFound::new(id).into()); - } - - db::update_client_type(db, id, ty).await.unwrap(); - - Ok(HttpResponse::NoContent().finish()) -} - -#[put("/{id}/allowed-scopes")] -async fn update_client_allowed_scopes( - id: web::Path, - body: web::Json]>>, - db: web::Data, -) -> Result { - let db = db.get_ref(); - let id = *id; - let allowed_scopes = body.0.join(" "); - - if !db::client_id_exists(db, id).await.unwrap() { - yeet!(ClientNotFound::new(id).into()); - } - - db::update_client_allowed_scopes(db, id, &allowed_scopes) - .await - .unwrap(); - - Ok(HttpResponse::NoContent().finish()) -} - -#[put("/{id}/default-scopes")] -async fn update_client_default_scopes( - id: web::Path, - body: web::Json]>>>, - db: web::Data, -) -> Result { - let db = db.get_ref(); - let id = *id; - let default_scopes = body.0.map(|s| s.join(" ")); - - if !db::client_id_exists(db, id).await.unwrap() { - yeet!(ClientNotFound::new(id).into()); - } - - db::update_client_default_scopes(db, id, default_scopes) - .await - .unwrap(); - - Ok(HttpResponse::NoContent().finish()) -} - -#[put("/{id}/is-trusted")] -async fn update_client_is_trusted( - id: web::Path, - body: web::Json, - db: web::Data, -) -> Result { - let db = db.get_ref(); - let id = *id; - let is_trusted = *body; - - if !db::client_id_exists(db, id).await.unwrap() { - yeet!(ClientNotFound::new(id).into()); - } - - db::update_client_trusted(db, id, is_trusted).await.unwrap(); - - Ok(HttpResponse::NoContent().finish()) -} - -#[put("/{id}/redirect-uris")] -async fn update_client_redirect_uris( - id: web::Path, - body: web::Json>, - db: web::Data, -) -> Result { - let db = db.get_ref(); - let id = *id; - - for uri in body.0.iter() { - if uri.scheme() != "https" { - yeet!(CreateClientError::NonHttpsUri.into()); - } - - if uri.fragment().is_some() { - yeet!(CreateClientError::UriFragment.into()) - } - } - - if !db::client_id_exists(db, id).await.unwrap() { - yeet!(ClientNotFound::new(id).into()); - } - - let transaction = db.begin().await.unwrap(); - db::update_client_redirect_uris(transaction, id, &body.0) - .await - .unwrap(); - - Ok(HttpResponse::NoContent().finish()) -} - -#[put("{id}/secret")] -async fn update_client_secret( - id: web::Path, - body: web::Json>>, - db: web::Data, -) -> Result { - let db = db.get_ref(); - let id = *id; - - let Some(client_type) = db::get_client_type(db, id).await.unwrap() else { - yeet!(ClientNotFound::new(id).into()) - }; - - if client_type == ClientType::Confidential && body.is_none() { - yeet!(CreateClientError::NoSecret.into()) - } - - let secret = body.0.map(|s| PasswordHash::new(&s).unwrap()); - db::update_client_secret(db, id, secret).await.unwrap(); - - Ok(HttpResponse::NoContent().finish()) -} - -pub fn service() -> Scope { - web::scope("/clients") - .service(get_client) - .service(get_client_alias) - .service(get_client_type) - .service(get_client_allowed_scopes) - .service(get_client_default_scopes) - .service(get_client_redirect_uris) - .service(get_client_is_trusted) - .service(create_client) - .service(update_client) - .service(update_client_alias) - .service(update_client_type) - .service(update_client_allowed_scopes) - .service(update_client_default_scopes) - .service(update_client_redirect_uris) - .service(update_client_secret) -} +use actix_web::http::{header, StatusCode}; +use actix_web::{get, post, put, web, HttpResponse, ResponseError, Scope}; +use raise::yeet; +use serde::{Deserialize, Serialize}; +use sqlx::MySqlPool; +use thiserror::Error; +use url::Url; +use uuid::Uuid; + +use crate::models::client::{Client, ClientType, CreateClientError}; +use crate::services::crypto::PasswordHash; +use crate::services::db::ClientRow; +use crate::services::{db, id}; + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +struct ClientResponse { + client_id: Uuid, + alias: Box, + client_type: ClientType, + allowed_scopes: Box<[Box]>, + default_scopes: Option]>>, + is_trusted: bool, +} + +impl From for ClientResponse { + fn from(value: ClientRow) -> Self { + Self { + client_id: value.id, + alias: value.alias.into_boxed_str(), + client_type: value.client_type, + allowed_scopes: value + .allowed_scopes + .split_whitespace() + .map(Box::from) + .collect(), + default_scopes: value + .default_scopes + .map(|s| s.split_whitespace().map(Box::from).collect()), + is_trusted: value.is_trusted, + } + } +} + +#[derive(Debug, Clone, Copy, Error)] +#[error("No client with the given client ID was found")] +struct ClientNotFound { + id: Uuid, +} + +impl ResponseError for ClientNotFound { + fn status_code(&self) -> StatusCode { + StatusCode::NOT_FOUND + } +} + +impl ClientNotFound { + fn new(id: Uuid) -> Self { + Self { id } + } +} + +#[get("/{client_id}")] +async fn get_client( + client_id: web::Path, + db: web::Data, +) -> Result { + let db = db.as_ref(); + let id = *client_id; + + let Some(client) = db::get_client_response(db, id).await.unwrap() else { + yeet!(ClientNotFound::new(id)) + }; + + let redirect_uris_link = format!("; rel=\"redirect-uris\""); + let response: ClientResponse = client.into(); + let response = HttpResponse::Ok() + .append_header((header::LINK, redirect_uris_link)) + .json(response); + Ok(response) +} + +#[get("/{client_id}/alias")] +async fn get_client_alias( + client_id: web::Path, + db: web::Data, +) -> Result { + let db = db.as_ref(); + let id = *client_id; + + let Some(alias) = db::get_client_alias(db, id).await.unwrap() else { + yeet!(ClientNotFound::new(id)) + }; + + Ok(HttpResponse::Ok().json(alias)) +} + +#[get("/{client_id}/type")] +async fn get_client_type( + client_id: web::Path, + db: web::Data, +) -> Result { + let db = db.as_ref(); + let id = *client_id; + + let Some(client_type) = db::get_client_type(db, id).await.unwrap() else { + yeet!(ClientNotFound::new(id)) + }; + + Ok(HttpResponse::Ok().json(client_type)) +} + +#[get("/{client_id}/redirect-uris")] +async fn get_client_redirect_uris( + client_id: web::Path, + db: web::Data, +) -> Result { + let db = db.as_ref(); + let id = *client_id; + + if !db::client_id_exists(db, id).await.unwrap() { + yeet!(ClientNotFound::new(id)) + }; + + let redirect_uris = db::get_client_redirect_uris(db, id).await.unwrap(); + + Ok(HttpResponse::Ok().json(redirect_uris)) +} + +#[get("/{client_id}/allowed-scopes")] +async fn get_client_allowed_scopes( + client_id: web::Path, + db: web::Data, +) -> Result { + let db = db.as_ref(); + let id = *client_id; + + let Some(allowed_scopes) = db::get_client_allowed_scopes(db, id).await.unwrap() else { + yeet!(ClientNotFound::new(id)) + }; + + let allowed_scopes = allowed_scopes.split_whitespace().collect::>(); + + Ok(HttpResponse::Ok().json(allowed_scopes)) +} + +#[get("/{client_id}/default-scopes")] +async fn get_client_default_scopes( + client_id: web::Path, + db: web::Data, +) -> Result { + let db = db.as_ref(); + let id = *client_id; + + let Some(default_scopes) = db::get_client_default_scopes(db, id).await.unwrap() else { + yeet!(ClientNotFound::new(id)) + }; + + let default_scopes = default_scopes.map(|scopes| { + scopes + .split_whitespace() + .map(Box::from) + .collect::]>>() + }); + + Ok(HttpResponse::Ok().json(default_scopes)) +} + +#[get("/{client_id}/is-trusted")] +async fn get_client_is_trusted( + client_id: web::Path, + db: web::Data, +) -> Result { + let db = db.as_ref(); + let id = *client_id; + + let Some(is_trusted) = db::is_client_trusted(db, id).await.unwrap() else { + yeet!(ClientNotFound::new(id)) + }; + + Ok(HttpResponse::Ok().json(is_trusted)) +} + +#[derive(Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ClientRequest { + alias: Box, + ty: ClientType, + redirect_uris: Box<[Url]>, + secret: Option>, + allowed_scopes: Box<[Box]>, + default_scopes: Option]>>, + trusted: bool, +} + +#[derive(Debug, Clone, Error)] +#[error("The given client alias is already taken")] +struct AliasTakenError { + alias: Box, +} + +impl ResponseError for AliasTakenError { + fn status_code(&self) -> StatusCode { + StatusCode::CONFLICT + } +} + +impl AliasTakenError { + fn new(alias: &str) -> Self { + Self { + alias: Box::from(alias), + } + } +} + +#[post("")] +async fn create_client( + body: web::Json, + db: web::Data, +) -> Result { + let db = db.get_ref(); + let alias = &body.alias; + + if db::client_alias_exists(db, &alias).await.unwrap() { + yeet!(AliasTakenError::new(&alias).into()); + } + + let id = id::new_id(db, db::client_id_exists).await.unwrap(); + let client = Client::new( + id, + &alias, + body.ty, + body.secret.as_deref(), + body.allowed_scopes.clone(), + body.default_scopes.clone(), + &body.redirect_uris, + body.trusted, + ) + .map_err(|e| e.unwrap())?; + + let transaction = db.begin().await.unwrap(); + db::create_client(transaction, &client).await.unwrap(); + + let response = HttpResponse::Created() + .insert_header((header::LOCATION, format!("clients/{id}"))) + .finish(); + Ok(response) +} + +#[derive(Debug, Clone, Error)] +enum UpdateClientError { + #[error(transparent)] + NotFound(#[from] ClientNotFound), + #[error(transparent)] + ClientError(#[from] CreateClientError), + #[error(transparent)] + AliasTaken(#[from] AliasTakenError), +} + +impl ResponseError for UpdateClientError { + fn status_code(&self) -> StatusCode { + match self { + Self::NotFound(e) => e.status_code(), + Self::ClientError(e) => e.status_code(), + Self::AliasTaken(e) => e.status_code(), + } + } +} + +#[put("/{id}")] +async fn update_client( + id: web::Path, + body: web::Json, + db: web::Data, +) -> Result { + let db = db.get_ref(); + let id = *id; + let alias = &body.alias; + + let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else { + yeet!(ClientNotFound::new(id).into()) + }; + if old_alias != alias.clone() && db::client_alias_exists(db, &alias).await.unwrap() { + yeet!(AliasTakenError::new(&alias).into()); + } + + let client = Client::new( + id, + &alias, + body.ty, + body.secret.as_deref(), + body.allowed_scopes.clone(), + body.default_scopes.clone(), + &body.redirect_uris, + body.trusted, + ) + .map_err(|e| e.unwrap())?; + + let transaction = db.begin().await.unwrap(); + db::update_client(transaction, &client).await.unwrap(); + + let response = HttpResponse::NoContent().finish(); + Ok(response) +} + +#[put("/{id}/alias")] +async fn update_client_alias( + id: web::Path, + body: web::Json>, + db: web::Data, +) -> Result { + let db = db.get_ref(); + let id = *id; + let alias = body.0; + + let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else { + yeet!(ClientNotFound::new(id).into()) + }; + if old_alias == alias { + return Ok(HttpResponse::NoContent().finish()); + } + if db::client_alias_exists(db, &alias).await.unwrap() { + yeet!(AliasTakenError::new(&alias).into()); + } + + db::update_client_alias(db, id, &alias).await.unwrap(); + + let response = HttpResponse::NoContent().finish(); + Ok(response) +} + +#[put("/{id}/type")] +async fn update_client_type( + id: web::Path, + body: web::Json, + db: web::Data, +) -> Result { + let db = db.get_ref(); + let id = *id; + let ty = body.0; + + if !db::client_id_exists(db, id).await.unwrap() { + yeet!(ClientNotFound::new(id).into()); + } + + db::update_client_type(db, id, ty).await.unwrap(); + + Ok(HttpResponse::NoContent().finish()) +} + +#[put("/{id}/allowed-scopes")] +async fn update_client_allowed_scopes( + id: web::Path, + body: web::Json]>>, + db: web::Data, +) -> Result { + let db = db.get_ref(); + let id = *id; + let allowed_scopes = body.0.join(" "); + + if !db::client_id_exists(db, id).await.unwrap() { + yeet!(ClientNotFound::new(id).into()); + } + + db::update_client_allowed_scopes(db, id, &allowed_scopes) + .await + .unwrap(); + + Ok(HttpResponse::NoContent().finish()) +} + +#[put("/{id}/default-scopes")] +async fn update_client_default_scopes( + id: web::Path, + body: web::Json]>>>, + db: web::Data, +) -> Result { + let db = db.get_ref(); + let id = *id; + let default_scopes = body.0.map(|s| s.join(" ")); + + if !db::client_id_exists(db, id).await.unwrap() { + yeet!(ClientNotFound::new(id).into()); + } + + db::update_client_default_scopes(db, id, default_scopes) + .await + .unwrap(); + + Ok(HttpResponse::NoContent().finish()) +} + +#[put("/{id}/is-trusted")] +async fn update_client_is_trusted( + id: web::Path, + body: web::Json, + db: web::Data, +) -> Result { + let db = db.get_ref(); + let id = *id; + let is_trusted = *body; + + if !db::client_id_exists(db, id).await.unwrap() { + yeet!(ClientNotFound::new(id).into()); + } + + db::update_client_trusted(db, id, is_trusted).await.unwrap(); + + Ok(HttpResponse::NoContent().finish()) +} + +#[put("/{id}/redirect-uris")] +async fn update_client_redirect_uris( + id: web::Path, + body: web::Json>, + db: web::Data, +) -> Result { + let db = db.get_ref(); + let id = *id; + + for uri in body.0.iter() { + if uri.scheme() != "https" { + yeet!(CreateClientError::NonHttpsUri.into()); + } + + if uri.fragment().is_some() { + yeet!(CreateClientError::UriFragment.into()) + } + } + + if !db::client_id_exists(db, id).await.unwrap() { + yeet!(ClientNotFound::new(id).into()); + } + + let transaction = db.begin().await.unwrap(); + db::update_client_redirect_uris(transaction, id, &body.0) + .await + .unwrap(); + + Ok(HttpResponse::NoContent().finish()) +} + +#[put("{id}/secret")] +async fn update_client_secret( + id: web::Path, + body: web::Json>>, + db: web::Data, +) -> Result { + let db = db.get_ref(); + let id = *id; + + let Some(client_type) = db::get_client_type(db, id).await.unwrap() else { + yeet!(ClientNotFound::new(id).into()) + }; + + if client_type == ClientType::Confidential && body.is_none() { + yeet!(CreateClientError::NoSecret.into()) + } + + let secret = body.0.map(|s| PasswordHash::new(&s).unwrap()); + db::update_client_secret(db, id, secret).await.unwrap(); + + Ok(HttpResponse::NoContent().finish()) +} + +pub fn service() -> Scope { + web::scope("/clients") + .service(get_client) + .service(get_client_alias) + .service(get_client_type) + .service(get_client_allowed_scopes) + .service(get_client_default_scopes) + .service(get_client_redirect_uris) + .service(get_client_is_trusted) + .service(create_client) + .service(update_client) + .service(update_client_alias) + .service(update_client_type) + .service(update_client_allowed_scopes) + .service(update_client_default_scopes) + .service(update_client_redirect_uris) + .service(update_client_secret) +} diff --git a/src/api/liveops.rs b/src/api/liveops.rs index d4bf129..2caf6e3 100644 --- a/src/api/liveops.rs +++ b/src/api/liveops.rs @@ -1,11 +1,11 @@ -use actix_web::{get, web, HttpResponse, Scope}; - -/// Simple ping -#[get("/ping")] -async fn ping() -> HttpResponse { - HttpResponse::Ok().finish() -} - -pub fn service() -> Scope { - web::scope("/liveops").service(ping) -} +use actix_web::{get, web, HttpResponse, Scope}; + +/// Simple ping +#[get("/ping")] +async fn ping() -> HttpResponse { + HttpResponse::Ok().finish() +} + +pub fn service() -> Scope { + web::scope("/liveops").service(ping) +} diff --git a/src/api/mod.rs b/src/api/mod.rs index 0ab4037..9059e71 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,13 +1,13 @@ -mod clients; -mod liveops; -mod oauth; -mod ops; -mod users; - -pub use clients::service as clients; -pub use liveops::service as liveops; -pub use oauth::service as oauth; -pub use ops::service as ops; -pub use users::service as users; - -pub use oauth::AuthorizationParameters; +mod clients; +mod liveops; +mod oauth; +mod ops; +mod users; + +pub use clients::service as clients; +pub use liveops::service as liveops; +pub use oauth::service as oauth; +pub use ops::service as ops; +pub use users::service as users; + +pub use oauth::AuthorizationParameters; diff --git a/src/api/oauth.rs b/src/api/oauth.rs index f1aa012..3422d2f 100644 --- a/src/api/oauth.rs +++ b/src/api/oauth.rs @@ -1,926 +1,926 @@ -use std::ops::Deref; -use std::str::FromStr; - -use actix_web::http::{header, StatusCode}; -use actix_web::{ - get, post, web, HttpRequest, HttpResponse, HttpResponseBuilder, ResponseError, Scope, -}; -use chrono::Duration; -use exun::{Expect, RawUnexpected, ResultErrorExt, UnexpectedError}; -use raise::yeet; -use serde::{Deserialize, Serialize}; -use sqlx::MySqlPool; -use tera::Tera; -use thiserror::Error; -use unic_langid::subtags::Language; -use url::Url; -use uuid::Uuid; - -use crate::models::client::ClientType; -use crate::resources::{languages, templates}; -use crate::scopes; -use crate::services::jwt::VerifyJwtError; -use crate::services::{authorization, config, db, jwt}; - -const REALLY_BAD_ERROR_PAGE: &str = "Internal Server ErrorInternal Server Error"; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -enum ResponseType { - Code, - Token, - #[serde(other)] - Unsupported, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AuthorizationParameters { - response_type: ResponseType, - client_id: Box, - redirect_uri: Option, - scope: Option>, - state: Option>, -} - -#[derive(Clone, Deserialize)] -struct AuthorizeCredentials { - username: Box, - password: Box, -} - -#[derive(Clone, Serialize)] -struct AuthCodeResponse { - code: Box, - state: Option>, -} - -#[derive(Clone, Serialize)] -struct AuthTokenResponse { - access_token: Box, - token_type: &'static str, - expires_in: i64, - scope: Box, - state: Option>, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)] -#[serde(rename_all = "camelCase")] -enum AuthorizeErrorType { - InvalidRequest, - UnauthorizedClient, - AccessDenied, - UnsupportedResponseType, - InvalidScope, - ServerError, - TemporarilyUnavailable, -} - -#[derive(Debug, Clone, Error, Serialize)] -#[error("{error_description}")] -struct AuthorizeError { - error: AuthorizeErrorType, - error_description: Box, - // TODO error uri - state: Option>, - #[serde(skip)] - redirect_uri: Url, -} - -impl AuthorizeError { - fn no_scope(redirect_uri: Url, state: Option>) -> Self { - Self { - error: AuthorizeErrorType::InvalidScope, - error_description: Box::from( - "No scope was provided, and the client does not have a default scope", - ), - state, - redirect_uri, - } - } - - fn unsupported_response_type(redirect_uri: Url, state: Option>) -> Self { - Self { - error: AuthorizeErrorType::UnsupportedResponseType, - error_description: Box::from("The given response type is not supported"), - state, - redirect_uri, - } - } - - fn invalid_scope(redirect_uri: Url, state: Option>) -> Self { - Self { - error: AuthorizeErrorType::InvalidScope, - error_description: Box::from("The given scope exceeds what the client is allowed"), - state, - redirect_uri, - } - } - - fn internal_server_error(redirect_uri: Url, state: Option>) -> Self { - Self { - error: AuthorizeErrorType::ServerError, - error_description: "An unexpected error occurred".into(), - state, - redirect_uri, - } - } -} - -impl ResponseError for AuthorizeError { - fn error_response(&self) -> HttpResponse { - let query = Some(serde_urlencoded::to_string(self).unwrap()); - let query = query.as_deref(); - let mut url = self.redirect_uri.clone(); - url.set_query(query); - - HttpResponse::Found() - .insert_header((header::LOCATION, url.as_str())) - .finish() - } -} - -fn error_page( - tera: &Tera, - translations: &languages::Translations, - error: templates::ErrorPage, -) -> Result { - // TODO find a better way of doing languages - let language = Language::from_str("en").unwrap(); - let translations = translations.clone(); - let page = templates::error_page(&tera, language, translations, error)?; - Ok(page) -} - -async fn get_redirect_uri( - redirect_uri: &Option, - db: &MySqlPool, - client_id: Uuid, -) -> Result> { - if let Some(uri) = &redirect_uri { - let redirect_uri = uri.clone(); - if !db::client_has_redirect_uri(db, client_id, &redirect_uri) - .await - .map_err(|e| UnexpectedError::from(e)) - .unexpect()? - { - yeet!(Expect::Expected(templates::ErrorPage::InvalidRedirectUri)); - } - - Ok(redirect_uri) - } else { - let redirect_uris = db::get_client_redirect_uris(db, client_id) - .await - .map_err(|e| UnexpectedError::from(e)) - .unexpect()?; - if redirect_uris.len() != 1 { - yeet!(Expect::Expected(templates::ErrorPage::MissingRedirectUri)); - } - - Ok(redirect_uris.get(0).unwrap().clone()) - } -} - -async fn get_scope( - scope: &Option>, - db: &MySqlPool, - client_id: Uuid, - redirect_uri: &Url, - state: &Option>, -) -> Result, Expect> { - let scope = if let Some(scope) = &scope { - scope.clone() - } else { - let default_scopes = db::get_client_default_scopes(db, client_id) - .await - .unwrap() - .unwrap(); - let Some(scope) = default_scopes else { - yeet!(AuthorizeError::no_scope(redirect_uri.clone(), state.clone()).into()) - }; - scope - }; - - // verify scope is valid - let allowed_scopes = db::get_client_allowed_scopes(db, client_id) - .await - .unwrap() - .unwrap(); - if !scopes::is_subset_of(&scope, &allowed_scopes) { - yeet!(AuthorizeError::invalid_scope(redirect_uri.clone(), state.clone()).into()); - } - - Ok(scope) -} - -async fn authenticate_user( - db: &MySqlPool, - username: &str, - password: &str, -) -> Result, RawUnexpected> { - let Some(user) = db::get_user_by_username(db, username).await? else { - return Ok(None); - }; - - if user.check_password(password)? { - Ok(Some(user.id)) - } else { - Ok(None) - } -} - -#[post("/authorize")] -async fn authorize( - db: web::Data, - req: web::Query, - credentials: web::Json, - tera: web::Data, - translations: web::Data, -) -> Result { - // TODO protect against brute force attacks - let db = db.get_ref(); - let Ok(client_id) = db::get_client_id_by_alias(db, &req.client_id).await else { - let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); - return Ok(HttpResponse::InternalServerError().content_type("text/html").body(page)); - }; - let Some(client_id) = client_id else { - let page = error_page(&tera, &translations, templates::ErrorPage::ClientNotFound).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); - return Ok(HttpResponse::NotFound().content_type("text/html").body(page)); - }; - let Ok(config) = config::get_config() else { - let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); - return Ok(HttpResponse::InternalServerError().content_type("text/html").body(page)); - }; - - let self_id = config.url; - let state = req.state.clone(); - - // get redirect uri - let mut redirect_uri = match get_redirect_uri(&req.redirect_uri, db, client_id).await { - Ok(uri) => uri, - Err(e) => { - let e = e - .expected() - .unwrap_or(templates::ErrorPage::InternalServerError); - let page = error_page(&tera, &translations, e) - .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); - return Ok(HttpResponse::BadRequest() - .content_type("text/html") - .body(page)); - } - }; - - // authenticate user - let Some(user_id) = authenticate_user(db, &credentials.username, &credentials.password) - .await - .unwrap() else - { - let language = Language::from_str("en").unwrap(); - let translations = translations.get_ref().clone(); - let page = templates::login_error_page(&tera, &req, language, translations).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); - return Ok(HttpResponse::Ok().content_type("text/html").body(page)); - }; - - let internal_server_error = - AuthorizeError::internal_server_error(redirect_uri.clone(), state.clone()); - - // get scope - let scope = match get_scope(&req.scope, db, client_id, &redirect_uri, &state).await { - Ok(scope) => scope, - Err(e) => { - let e = e.expected().unwrap_or(internal_server_error); - return Err(e); - } - }; - - match req.response_type { - ResponseType::Code => { - // create auth code - let code = - jwt::Claims::auth_code(db, self_id, client_id, user_id, &scope, &redirect_uri) - .await - .map_err(|_| internal_server_error.clone())?; - let code = code.to_jwt().map_err(|_| internal_server_error.clone())?; - - let response = AuthCodeResponse { code, state }; - let query = - Some(serde_urlencoded::to_string(response).map_err(|_| internal_server_error)?); - let query = query.as_deref(); - redirect_uri.set_query(query); - - Ok(HttpResponse::Found() - .append_header((header::LOCATION, redirect_uri.as_str())) - .finish()) - } - ResponseType::Token => { - // create access token - let duration = Duration::hours(1); - let access_token = - jwt::Claims::access_token(db, None, self_id, client_id, user_id, duration, &scope) - .await - .map_err(|_| internal_server_error.clone())?; - - let access_token = access_token - .to_jwt() - .map_err(|_| internal_server_error.clone())?; - let expires_in = duration.num_seconds(); - let token_type = "bearer"; - let response = AuthTokenResponse { - access_token, - expires_in, - token_type, - scope, - state, - }; - - let fragment = Some( - serde_urlencoded::to_string(response).map_err(|_| internal_server_error.clone())?, - ); - let fragment = fragment.as_deref(); - redirect_uri.set_fragment(fragment); - - Ok(HttpResponse::Found() - .append_header((header::LOCATION, redirect_uri.as_str())) - .finish()) - } - _ => Err(AuthorizeError::invalid_scope(redirect_uri, state)), - } -} - -#[get("/authorize")] -async fn authorize_page( - db: web::Data, - tera: web::Data, - translations: web::Data, - request: HttpRequest, -) -> Result { - let Ok(language) = Language::from_str("en") else { - let page = String::from(REALLY_BAD_ERROR_PAGE); - return Ok(HttpResponse::InternalServerError() - .content_type("text/html") - .body(page)); - }; - let translations = translations.get_ref().clone(); - - let params = request.query_string(); - let params = serde_urlencoded::from_str::(params); - let Ok(params) = params else { - let page = error_page( - &tera, - &translations, - templates::ErrorPage::InvalidRequest, - ) - .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); - return Ok(HttpResponse::BadRequest() - .content_type("text/html") - .body(page)); - }; - - let db = db.get_ref(); - let Ok(client_id) = db::get_client_id_by_alias(db, ¶ms.client_id).await else { - let page = templates::error_page( - &tera, - language, - translations, - templates::ErrorPage::InternalServerError, - ) - .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); - return Ok(HttpResponse::InternalServerError() - .content_type("text/html") - .body(page)); - }; - let Some(client_id) = client_id else { - let page = templates::error_page( - &tera, - language, - translations, - templates::ErrorPage::ClientNotFound, - ) - .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); - return Ok(HttpResponse::NotFound() - .content_type("text/html") - .body(page)); - }; - - // verify redirect uri - let redirect_uri = match get_redirect_uri(¶ms.redirect_uri, db, client_id).await { - Ok(uri) => uri, - Err(e) => { - let e = e - .expected() - .unwrap_or(templates::ErrorPage::InternalServerError); - let page = error_page(&tera, &translations, e) - .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); - return Ok(HttpResponse::BadRequest() - .content_type("text/html") - .body(page)); - } - }; - - let state = ¶ms.state; - let internal_server_error = - AuthorizeError::internal_server_error(redirect_uri.clone(), state.clone()); - - // verify scope - let _ = match get_scope(¶ms.scope, db, client_id, &redirect_uri, ¶ms.state).await { - Ok(scope) => scope, - Err(e) => { - let e = e.expected().unwrap_or(internal_server_error); - return Err(e); - } - }; - - // verify response type - if params.response_type == ResponseType::Unsupported { - return Err(AuthorizeError::unsupported_response_type( - redirect_uri, - params.state, - )); - } - - // TODO find a better way of doing languages - let language = Language::from_str("en").unwrap(); - let page = templates::login_page(&tera, ¶ms, language, translations).unwrap(); - Ok(HttpResponse::Ok().content_type("text/html").body(page)) -} - -#[derive(Clone, Deserialize)] -#[serde(tag = "grant_type")] -#[serde(rename_all = "snake_case")] -enum GrantType { - AuthorizationCode { - code: Box, - redirect_uri: Url, - #[serde(rename = "client_id")] - client_alias: Box, - }, - Password { - username: Box, - password: Box, - scope: Option>, - }, - ClientCredentials { - scope: Option>, - }, - RefreshToken { - refresh_token: Box, - scope: Option>, - }, - #[serde(other)] - Unsupported, -} - -#[derive(Clone, Deserialize)] -struct TokenRequest { - #[serde(flatten)] - grant_type: GrantType, - // TODO support optional client credentials in here -} - -#[derive(Clone, Serialize)] -struct TokenResponse { - access_token: Box, - token_type: Box, - expires_in: i64, - refresh_token: Option>, - scope: Box, -} - -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "snake_case")] -enum TokenErrorType { - InvalidRequest, - InvalidClient, - InvalidGrant, - UnauthorizedClient, - UnsupportedGrantType, - InvalidScope, -} - -#[derive(Debug, Clone, Error, Serialize)] -#[error("{error_description}")] -struct TokenError { - #[serde(skip)] - status_code: StatusCode, - error: TokenErrorType, - error_description: Box, - // TODO error uri -} - -impl TokenError { - fn invalid_request() -> Self { - // TODO make this description better, and all the other ones while you're at it - Self { - status_code: StatusCode::BAD_REQUEST, - error: TokenErrorType::InvalidRequest, - error_description: "Invalid request".into(), - } - } - - fn unsupported_grant_type() -> Self { - Self { - status_code: StatusCode::BAD_REQUEST, - error: TokenErrorType::UnsupportedGrantType, - error_description: "The given grant type is not supported".into(), - } - } - - fn bad_auth_code(error: VerifyJwtError) -> Self { - Self { - status_code: StatusCode::BAD_REQUEST, - error: TokenErrorType::InvalidGrant, - error_description: error.to_string().into_boxed_str(), - } - } - - fn no_authorization() -> Self { - Self { - status_code: StatusCode::UNAUTHORIZED, - error: TokenErrorType::InvalidClient, - error_description: Box::from( - "Client credentials must be provided in the HTTP Authorization header", - ), - } - } - - fn client_not_found(alias: &str) -> Self { - Self { - status_code: StatusCode::UNAUTHORIZED, - error: TokenErrorType::InvalidClient, - error_description: format!("No client with the client id: {alias} was found") - .into_boxed_str(), - } - } - - fn mismatch_client_id() -> Self { - Self { - status_code: StatusCode::UNAUTHORIZED, - error: TokenErrorType::InvalidClient, - error_description: Box::from("The client ID in the Authorization header is not the same as the client ID in the request body"), - } - } - - fn incorrect_client_secret() -> Self { - Self { - status_code: StatusCode::UNAUTHORIZED, - error: TokenErrorType::InvalidClient, - error_description: "The client secret is incorrect".into(), - } - } - - fn client_not_confidential(alias: &str) -> Self { - Self { - status_code: StatusCode::BAD_REQUEST, - error: TokenErrorType::UnauthorizedClient, - error_description: format!("Only a confidential client may be used with this endpoint. The {alias} client is a public client.") - .into_boxed_str(), - } - } - - fn no_scope() -> Self { - Self { - status_code: StatusCode::BAD_REQUEST, - error: TokenErrorType::InvalidScope, - error_description: Box::from( - "No scope was provided, and the client doesn't have a default scope", - ), - } - } - - fn excessive_scope() -> Self { - Self { - status_code: StatusCode::BAD_REQUEST, - error: TokenErrorType::InvalidScope, - error_description: Box::from( - "The given scope exceeds what the client is allowed to have", - ), - } - } - - fn bad_refresh_token(err: VerifyJwtError) -> Self { - Self { - status_code: StatusCode::BAD_REQUEST, - error: TokenErrorType::InvalidGrant, - error_description: err.to_string().into_boxed_str(), - } - } - - fn untrusted_client() -> Self { - Self { - status_code: StatusCode::UNAUTHORIZED, - error: TokenErrorType::InvalidClient, - error_description: "Only trusted clients may use this grant".into(), - } - } - - fn incorrect_user_credentials() -> Self { - Self { - status_code: StatusCode::BAD_REQUEST, - error: TokenErrorType::InvalidRequest, - error_description: "The given credentials are incorrect".into(), - } - } -} - -impl ResponseError for TokenError { - fn error_response(&self) -> HttpResponse { - let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]); - - let mut builder = HttpResponseBuilder::new(self.status_code); - - if self.status_code.as_u16() == 401 { - builder.insert_header((header::WWW_AUTHENTICATE, "Basic charset=\"UTF-8\"")); - } - - builder - .insert_header(cache_control) - .insert_header((header::PRAGMA, "no-cache")) - .json(self.clone()) - } -} - -#[post("/token")] -async fn token( - db: web::Data, - req: web::Bytes, - authorization: Option>, -) -> HttpResponse { - // TODO protect against brute force attacks - let db = db.get_ref(); - let request = serde_json::from_slice::(&req); - let Ok(request) = request else { - return TokenError::invalid_request().error_response(); - }; - let config = config::get_config().unwrap(); - - let self_id = config.url; - let duration = Duration::hours(1); - let token_type = Box::from("bearer"); - let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]); - - match request.grant_type { - GrantType::AuthorizationCode { - code, - redirect_uri, - client_alias, - } => { - let Some(client_id) = db::get_client_id_by_alias(db, &client_alias).await.unwrap() else { - return TokenError::client_not_found(&client_alias).error_response(); - }; - - // validate auth code - let claims = - match jwt::verify_auth_code(db, &code, &self_id, client_id, redirect_uri).await { - Ok(claims) => claims, - Err(err) => { - let err = err.unwrap(); - return TokenError::bad_auth_code(err).error_response(); - } - }; - - // verify client, if the client has credentials - if let Some(hash) = db::get_client_secret(db, client_id).await.unwrap() { - let Some(authorization) = authorization else { - return TokenError::no_authorization().error_response(); - }; - - if authorization.username() != client_alias.deref() { - return TokenError::mismatch_client_id().error_response(); - } - if !hash.check_password(authorization.password()).unwrap() { - return TokenError::incorrect_client_secret().error_response(); - } - } - - let access_token = jwt::Claims::access_token( - db, - Some(claims.id()), - self_id, - client_id, - claims.subject(), - duration, - claims.scopes(), - ) - .await - .unwrap(); - - let expires_in = access_token.expires_in(); - let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap(); - let scope = access_token.scopes().into(); - - let access_token = access_token.to_jwt().unwrap(); - let refresh_token = Some(refresh_token.to_jwt().unwrap()); - - let response = TokenResponse { - access_token, - token_type, - expires_in, - refresh_token, - scope, - }; - HttpResponse::Ok() - .insert_header(cache_control) - .insert_header((header::PRAGMA, "no-cache")) - .json(response) - } - GrantType::Password { - username, - password, - scope, - } => { - let Some(authorization) = authorization else { - return TokenError::no_authorization().error_response(); - }; - let client_alias = authorization.username(); - let Some(client_id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else { - return TokenError::client_not_found(client_alias).error_response(); - }; - - let trusted = db::is_client_trusted(db, client_id).await.unwrap().unwrap(); - if !trusted { - return TokenError::untrusted_client().error_response(); - } - - // verify client - let hash = db::get_client_secret(db, client_id).await.unwrap().unwrap(); - if !hash.check_password(authorization.password()).unwrap() { - return TokenError::incorrect_client_secret().error_response(); - } - - // authenticate user - let Some(user_id) = authenticate_user(db, &username, &password).await.unwrap() else { - return TokenError::incorrect_user_credentials().error_response(); - }; - - // verify scope - let allowed_scopes = db::get_client_allowed_scopes(db, client_id) - .await - .unwrap() - .unwrap(); - let scope = if let Some(scope) = &scope { - scope.clone() - } else { - let default_scopes = db::get_client_default_scopes(db, client_id) - .await - .unwrap() - .unwrap(); - let Some(scope) = default_scopes else { - return TokenError::no_scope().error_response(); - }; - scope - }; - if !scopes::is_subset_of(&scope, &allowed_scopes) { - return TokenError::excessive_scope().error_response(); - } - - let access_token = - jwt::Claims::access_token(db, None, self_id, client_id, user_id, duration, &scope) - .await - .unwrap(); - let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap(); - - let expires_in = access_token.expires_in(); - let scope = access_token.scopes().into(); - let access_token = access_token.to_jwt().unwrap(); - let refresh_token = Some(refresh_token.to_jwt().unwrap()); - - let response = TokenResponse { - access_token, - token_type, - expires_in, - refresh_token, - scope, - }; - HttpResponse::Ok() - .insert_header(cache_control) - .insert_header((header::PRAGMA, "no-cache")) - .json(response) - } - GrantType::ClientCredentials { scope } => { - let Some(authorization) = authorization else { - return TokenError::no_authorization().error_response(); - }; - let client_alias = authorization.username(); - let Some(client_id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else { - return TokenError::client_not_found(client_alias).error_response(); - }; - - let ty = db::get_client_type(db, client_id).await.unwrap().unwrap(); - if ty != ClientType::Confidential { - return TokenError::client_not_confidential(client_alias).error_response(); - } - - // verify client - let hash = db::get_client_secret(db, client_id).await.unwrap().unwrap(); - if !hash.check_password(authorization.password()).unwrap() { - return TokenError::incorrect_client_secret().error_response(); - } - - // verify scope - let allowed_scopes = db::get_client_allowed_scopes(db, client_id) - .await - .unwrap() - .unwrap(); - let scope = if let Some(scope) = &scope { - scope.clone() - } else { - let default_scopes = db::get_client_default_scopes(db, client_id) - .await - .unwrap() - .unwrap(); - let Some(scope) = default_scopes else { - return TokenError::no_scope().error_response(); - }; - scope - }; - if !scopes::is_subset_of(&scope, &allowed_scopes) { - return TokenError::excessive_scope().error_response(); - } - - let access_token = jwt::Claims::access_token( - db, None, self_id, client_id, client_id, duration, &scope, - ) - .await - .unwrap(); - - let expires_in = access_token.expires_in(); - let scope = access_token.scopes().into(); - let access_token = access_token.to_jwt().unwrap(); - - let response = TokenResponse { - access_token, - token_type, - expires_in, - refresh_token: None, - scope, - }; - HttpResponse::Ok() - .insert_header(cache_control) - .insert_header((header::PRAGMA, "no-cache")) - .json(response) - } - GrantType::RefreshToken { - refresh_token, - scope, - } => { - let client_id: Option; - if let Some(authorization) = authorization { - let client_alias = authorization.username(); - let Some(id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else { - return TokenError::client_not_found(client_alias).error_response(); - }; - client_id = Some(id); - } else { - client_id = None; - } - - let claims = - match jwt::verify_refresh_token(db, &refresh_token, &self_id, client_id).await { - Ok(claims) => claims, - Err(e) => { - let e = e.unwrap(); - return TokenError::bad_refresh_token(e).error_response(); - } - }; - - let scope = if let Some(scope) = scope { - if !scopes::is_subset_of(&scope, claims.scopes()) { - return TokenError::excessive_scope().error_response(); - } - - scope - } else { - claims.scopes().into() - }; - - let exp_time = Duration::hours(1); - let access_token = jwt::Claims::refreshed_access_token(db, &claims, exp_time) - .await - .unwrap(); - let refresh_token = jwt::Claims::refresh_token(db, &claims).await.unwrap(); - - let access_token = access_token.to_jwt().unwrap(); - let refresh_token = Some(refresh_token.to_jwt().unwrap()); - let expires_in = exp_time.num_seconds(); - - let response = TokenResponse { - access_token, - token_type, - expires_in, - refresh_token, - scope, - }; - HttpResponse::Ok() - .insert_header(cache_control) - .insert_header((header::PRAGMA, "no-cache")) - .json(response) - } - _ => TokenError::unsupported_grant_type().error_response(), - } -} - -pub fn service() -> Scope { - web::scope("/oauth") - .service(authorize_page) - .service(authorize) - .service(token) -} +use std::ops::Deref; +use std::str::FromStr; + +use actix_web::http::{header, StatusCode}; +use actix_web::{ + get, post, web, HttpRequest, HttpResponse, HttpResponseBuilder, ResponseError, Scope, +}; +use chrono::Duration; +use exun::{Expect, RawUnexpected, ResultErrorExt, UnexpectedError}; +use raise::yeet; +use serde::{Deserialize, Serialize}; +use sqlx::MySqlPool; +use tera::Tera; +use thiserror::Error; +use unic_langid::subtags::Language; +use url::Url; +use uuid::Uuid; + +use crate::models::client::ClientType; +use crate::resources::{languages, templates}; +use crate::scopes; +use crate::services::jwt::VerifyJwtError; +use crate::services::{authorization, config, db, jwt}; + +const REALLY_BAD_ERROR_PAGE: &str = "Internal Server ErrorInternal Server Error"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +enum ResponseType { + Code, + Token, + #[serde(other)] + Unsupported, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuthorizationParameters { + response_type: ResponseType, + client_id: Box, + redirect_uri: Option, + scope: Option>, + state: Option>, +} + +#[derive(Clone, Deserialize)] +struct AuthorizeCredentials { + username: Box, + password: Box, +} + +#[derive(Clone, Serialize)] +struct AuthCodeResponse { + code: Box, + state: Option>, +} + +#[derive(Clone, Serialize)] +struct AuthTokenResponse { + access_token: Box, + token_type: &'static str, + expires_in: i64, + scope: Box, + state: Option>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)] +#[serde(rename_all = "camelCase")] +enum AuthorizeErrorType { + InvalidRequest, + UnauthorizedClient, + AccessDenied, + UnsupportedResponseType, + InvalidScope, + ServerError, + TemporarilyUnavailable, +} + +#[derive(Debug, Clone, Error, Serialize)] +#[error("{error_description}")] +struct AuthorizeError { + error: AuthorizeErrorType, + error_description: Box, + // TODO error uri + state: Option>, + #[serde(skip)] + redirect_uri: Url, +} + +impl AuthorizeError { + fn no_scope(redirect_uri: Url, state: Option>) -> Self { + Self { + error: AuthorizeErrorType::InvalidScope, + error_description: Box::from( + "No scope was provided, and the client does not have a default scope", + ), + state, + redirect_uri, + } + } + + fn unsupported_response_type(redirect_uri: Url, state: Option>) -> Self { + Self { + error: AuthorizeErrorType::UnsupportedResponseType, + error_description: Box::from("The given response type is not supported"), + state, + redirect_uri, + } + } + + fn invalid_scope(redirect_uri: Url, state: Option>) -> Self { + Self { + error: AuthorizeErrorType::InvalidScope, + error_description: Box::from("The given scope exceeds what the client is allowed"), + state, + redirect_uri, + } + } + + fn internal_server_error(redirect_uri: Url, state: Option>) -> Self { + Self { + error: AuthorizeErrorType::ServerError, + error_description: "An unexpected error occurred".into(), + state, + redirect_uri, + } + } +} + +impl ResponseError for AuthorizeError { + fn error_response(&self) -> HttpResponse { + let query = Some(serde_urlencoded::to_string(self).unwrap()); + let query = query.as_deref(); + let mut url = self.redirect_uri.clone(); + url.set_query(query); + + HttpResponse::Found() + .insert_header((header::LOCATION, url.as_str())) + .finish() + } +} + +fn error_page( + tera: &Tera, + translations: &languages::Translations, + error: templates::ErrorPage, +) -> Result { + // TODO find a better way of doing languages + let language = Language::from_str("en").unwrap(); + let translations = translations.clone(); + let page = templates::error_page(&tera, language, translations, error)?; + Ok(page) +} + +async fn get_redirect_uri( + redirect_uri: &Option, + db: &MySqlPool, + client_id: Uuid, +) -> Result> { + if let Some(uri) = &redirect_uri { + let redirect_uri = uri.clone(); + if !db::client_has_redirect_uri(db, client_id, &redirect_uri) + .await + .map_err(|e| UnexpectedError::from(e)) + .unexpect()? + { + yeet!(Expect::Expected(templates::ErrorPage::InvalidRedirectUri)); + } + + Ok(redirect_uri) + } else { + let redirect_uris = db::get_client_redirect_uris(db, client_id) + .await + .map_err(|e| UnexpectedError::from(e)) + .unexpect()?; + if redirect_uris.len() != 1 { + yeet!(Expect::Expected(templates::ErrorPage::MissingRedirectUri)); + } + + Ok(redirect_uris.get(0).unwrap().clone()) + } +} + +async fn get_scope( + scope: &Option>, + db: &MySqlPool, + client_id: Uuid, + redirect_uri: &Url, + state: &Option>, +) -> Result, Expect> { + let scope = if let Some(scope) = &scope { + scope.clone() + } else { + let default_scopes = db::get_client_default_scopes(db, client_id) + .await + .unwrap() + .unwrap(); + let Some(scope) = default_scopes else { + yeet!(AuthorizeError::no_scope(redirect_uri.clone(), state.clone()).into()) + }; + scope + }; + + // verify scope is valid + let allowed_scopes = db::get_client_allowed_scopes(db, client_id) + .await + .unwrap() + .unwrap(); + if !scopes::is_subset_of(&scope, &allowed_scopes) { + yeet!(AuthorizeError::invalid_scope(redirect_uri.clone(), state.clone()).into()); + } + + Ok(scope) +} + +async fn authenticate_user( + db: &MySqlPool, + username: &str, + password: &str, +) -> Result, RawUnexpected> { + let Some(user) = db::get_user_by_username(db, username).await? else { + return Ok(None); + }; + + if user.check_password(password)? { + Ok(Some(user.id)) + } else { + Ok(None) + } +} + +#[post("/authorize")] +async fn authorize( + db: web::Data, + req: web::Query, + credentials: web::Json, + tera: web::Data, + translations: web::Data, +) -> Result { + // TODO protect against brute force attacks + let db = db.get_ref(); + let Ok(client_id) = db::get_client_id_by_alias(db, &req.client_id).await else { + let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); + return Ok(HttpResponse::InternalServerError().content_type("text/html").body(page)); + }; + let Some(client_id) = client_id else { + let page = error_page(&tera, &translations, templates::ErrorPage::ClientNotFound).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); + return Ok(HttpResponse::NotFound().content_type("text/html").body(page)); + }; + let Ok(config) = config::get_config() else { + let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); + return Ok(HttpResponse::InternalServerError().content_type("text/html").body(page)); + }; + + let self_id = config.url; + let state = req.state.clone(); + + // get redirect uri + let mut redirect_uri = match get_redirect_uri(&req.redirect_uri, db, client_id).await { + Ok(uri) => uri, + Err(e) => { + let e = e + .expected() + .unwrap_or(templates::ErrorPage::InternalServerError); + let page = error_page(&tera, &translations, e) + .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); + return Ok(HttpResponse::BadRequest() + .content_type("text/html") + .body(page)); + } + }; + + // authenticate user + let Some(user_id) = authenticate_user(db, &credentials.username, &credentials.password) + .await + .unwrap() else + { + let language = Language::from_str("en").unwrap(); + let translations = translations.get_ref().clone(); + let page = templates::login_error_page(&tera, &req, language, translations).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); + return Ok(HttpResponse::Ok().content_type("text/html").body(page)); + }; + + let internal_server_error = + AuthorizeError::internal_server_error(redirect_uri.clone(), state.clone()); + + // get scope + let scope = match get_scope(&req.scope, db, client_id, &redirect_uri, &state).await { + Ok(scope) => scope, + Err(e) => { + let e = e.expected().unwrap_or(internal_server_error); + return Err(e); + } + }; + + match req.response_type { + ResponseType::Code => { + // create auth code + let code = + jwt::Claims::auth_code(db, self_id, client_id, user_id, &scope, &redirect_uri) + .await + .map_err(|_| internal_server_error.clone())?; + let code = code.to_jwt().map_err(|_| internal_server_error.clone())?; + + let response = AuthCodeResponse { code, state }; + let query = + Some(serde_urlencoded::to_string(response).map_err(|_| internal_server_error)?); + let query = query.as_deref(); + redirect_uri.set_query(query); + + Ok(HttpResponse::Found() + .append_header((header::LOCATION, redirect_uri.as_str())) + .finish()) + } + ResponseType::Token => { + // create access token + let duration = Duration::hours(1); + let access_token = + jwt::Claims::access_token(db, None, self_id, client_id, user_id, duration, &scope) + .await + .map_err(|_| internal_server_error.clone())?; + + let access_token = access_token + .to_jwt() + .map_err(|_| internal_server_error.clone())?; + let expires_in = duration.num_seconds(); + let token_type = "bearer"; + let response = AuthTokenResponse { + access_token, + expires_in, + token_type, + scope, + state, + }; + + let fragment = Some( + serde_urlencoded::to_string(response).map_err(|_| internal_server_error.clone())?, + ); + let fragment = fragment.as_deref(); + redirect_uri.set_fragment(fragment); + + Ok(HttpResponse::Found() + .append_header((header::LOCATION, redirect_uri.as_str())) + .finish()) + } + _ => Err(AuthorizeError::invalid_scope(redirect_uri, state)), + } +} + +#[get("/authorize")] +async fn authorize_page( + db: web::Data, + tera: web::Data, + translations: web::Data, + request: HttpRequest, +) -> Result { + let Ok(language) = Language::from_str("en") else { + let page = String::from(REALLY_BAD_ERROR_PAGE); + return Ok(HttpResponse::InternalServerError() + .content_type("text/html") + .body(page)); + }; + let translations = translations.get_ref().clone(); + + let params = request.query_string(); + let params = serde_urlencoded::from_str::(params); + let Ok(params) = params else { + let page = error_page( + &tera, + &translations, + templates::ErrorPage::InvalidRequest, + ) + .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); + return Ok(HttpResponse::BadRequest() + .content_type("text/html") + .body(page)); + }; + + let db = db.get_ref(); + let Ok(client_id) = db::get_client_id_by_alias(db, ¶ms.client_id).await else { + let page = templates::error_page( + &tera, + language, + translations, + templates::ErrorPage::InternalServerError, + ) + .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); + return Ok(HttpResponse::InternalServerError() + .content_type("text/html") + .body(page)); + }; + let Some(client_id) = client_id else { + let page = templates::error_page( + &tera, + language, + translations, + templates::ErrorPage::ClientNotFound, + ) + .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); + return Ok(HttpResponse::NotFound() + .content_type("text/html") + .body(page)); + }; + + // verify redirect uri + let redirect_uri = match get_redirect_uri(¶ms.redirect_uri, db, client_id).await { + Ok(uri) => uri, + Err(e) => { + let e = e + .expected() + .unwrap_or(templates::ErrorPage::InternalServerError); + let page = error_page(&tera, &translations, e) + .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); + return Ok(HttpResponse::BadRequest() + .content_type("text/html") + .body(page)); + } + }; + + let state = ¶ms.state; + let internal_server_error = + AuthorizeError::internal_server_error(redirect_uri.clone(), state.clone()); + + // verify scope + let _ = match get_scope(¶ms.scope, db, client_id, &redirect_uri, ¶ms.state).await { + Ok(scope) => scope, + Err(e) => { + let e = e.expected().unwrap_or(internal_server_error); + return Err(e); + } + }; + + // verify response type + if params.response_type == ResponseType::Unsupported { + return Err(AuthorizeError::unsupported_response_type( + redirect_uri, + params.state, + )); + } + + // TODO find a better way of doing languages + let language = Language::from_str("en").unwrap(); + let page = templates::login_page(&tera, ¶ms, language, translations).unwrap(); + Ok(HttpResponse::Ok().content_type("text/html").body(page)) +} + +#[derive(Clone, Deserialize)] +#[serde(tag = "grant_type")] +#[serde(rename_all = "snake_case")] +enum GrantType { + AuthorizationCode { + code: Box, + redirect_uri: Url, + #[serde(rename = "client_id")] + client_alias: Box, + }, + Password { + username: Box, + password: Box, + scope: Option>, + }, + ClientCredentials { + scope: Option>, + }, + RefreshToken { + refresh_token: Box, + scope: Option>, + }, + #[serde(other)] + Unsupported, +} + +#[derive(Clone, Deserialize)] +struct TokenRequest { + #[serde(flatten)] + grant_type: GrantType, + // TODO support optional client credentials in here +} + +#[derive(Clone, Serialize)] +struct TokenResponse { + access_token: Box, + token_type: Box, + expires_in: i64, + refresh_token: Option>, + scope: Box, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "snake_case")] +enum TokenErrorType { + InvalidRequest, + InvalidClient, + InvalidGrant, + UnauthorizedClient, + UnsupportedGrantType, + InvalidScope, +} + +#[derive(Debug, Clone, Error, Serialize)] +#[error("{error_description}")] +struct TokenError { + #[serde(skip)] + status_code: StatusCode, + error: TokenErrorType, + error_description: Box, + // TODO error uri +} + +impl TokenError { + fn invalid_request() -> Self { + // TODO make this description better, and all the other ones while you're at it + Self { + status_code: StatusCode::BAD_REQUEST, + error: TokenErrorType::InvalidRequest, + error_description: "Invalid request".into(), + } + } + + fn unsupported_grant_type() -> Self { + Self { + status_code: StatusCode::BAD_REQUEST, + error: TokenErrorType::UnsupportedGrantType, + error_description: "The given grant type is not supported".into(), + } + } + + fn bad_auth_code(error: VerifyJwtError) -> Self { + Self { + status_code: StatusCode::BAD_REQUEST, + error: TokenErrorType::InvalidGrant, + error_description: error.to_string().into_boxed_str(), + } + } + + fn no_authorization() -> Self { + Self { + status_code: StatusCode::UNAUTHORIZED, + error: TokenErrorType::InvalidClient, + error_description: Box::from( + "Client credentials must be provided in the HTTP Authorization header", + ), + } + } + + fn client_not_found(alias: &str) -> Self { + Self { + status_code: StatusCode::UNAUTHORIZED, + error: TokenErrorType::InvalidClient, + error_description: format!("No client with the client id: {alias} was found") + .into_boxed_str(), + } + } + + fn mismatch_client_id() -> Self { + Self { + status_code: StatusCode::UNAUTHORIZED, + error: TokenErrorType::InvalidClient, + error_description: Box::from("The client ID in the Authorization header is not the same as the client ID in the request body"), + } + } + + fn incorrect_client_secret() -> Self { + Self { + status_code: StatusCode::UNAUTHORIZED, + error: TokenErrorType::InvalidClient, + error_description: "The client secret is incorrect".into(), + } + } + + fn client_not_confidential(alias: &str) -> Self { + Self { + status_code: StatusCode::BAD_REQUEST, + error: TokenErrorType::UnauthorizedClient, + error_description: format!("Only a confidential client may be used with this endpoint. The {alias} client is a public client.") + .into_boxed_str(), + } + } + + fn no_scope() -> Self { + Self { + status_code: StatusCode::BAD_REQUEST, + error: TokenErrorType::InvalidScope, + error_description: Box::from( + "No scope was provided, and the client doesn't have a default scope", + ), + } + } + + fn excessive_scope() -> Self { + Self { + status_code: StatusCode::BAD_REQUEST, + error: TokenErrorType::InvalidScope, + error_description: Box::from( + "The given scope exceeds what the client is allowed to have", + ), + } + } + + fn bad_refresh_token(err: VerifyJwtError) -> Self { + Self { + status_code: StatusCode::BAD_REQUEST, + error: TokenErrorType::InvalidGrant, + error_description: err.to_string().into_boxed_str(), + } + } + + fn untrusted_client() -> Self { + Self { + status_code: StatusCode::UNAUTHORIZED, + error: TokenErrorType::InvalidClient, + error_description: "Only trusted clients may use this grant".into(), + } + } + + fn incorrect_user_credentials() -> Self { + Self { + status_code: StatusCode::BAD_REQUEST, + error: TokenErrorType::InvalidRequest, + error_description: "The given credentials are incorrect".into(), + } + } +} + +impl ResponseError for TokenError { + fn error_response(&self) -> HttpResponse { + let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]); + + let mut builder = HttpResponseBuilder::new(self.status_code); + + if self.status_code.as_u16() == 401 { + builder.insert_header((header::WWW_AUTHENTICATE, "Basic charset=\"UTF-8\"")); + } + + builder + .insert_header(cache_control) + .insert_header((header::PRAGMA, "no-cache")) + .json(self.clone()) + } +} + +#[post("/token")] +async fn token( + db: web::Data, + req: web::Bytes, + authorization: Option>, +) -> HttpResponse { + // TODO protect against brute force attacks + let db = db.get_ref(); + let request = serde_json::from_slice::(&req); + let Ok(request) = request else { + return TokenError::invalid_request().error_response(); + }; + let config = config::get_config().unwrap(); + + let self_id = config.url; + let duration = Duration::hours(1); + let token_type = Box::from("bearer"); + let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]); + + match request.grant_type { + GrantType::AuthorizationCode { + code, + redirect_uri, + client_alias, + } => { + let Some(client_id) = db::get_client_id_by_alias(db, &client_alias).await.unwrap() else { + return TokenError::client_not_found(&client_alias).error_response(); + }; + + // validate auth code + let claims = + match jwt::verify_auth_code(db, &code, &self_id, client_id, redirect_uri).await { + Ok(claims) => claims, + Err(err) => { + let err = err.unwrap(); + return TokenError::bad_auth_code(err).error_response(); + } + }; + + // verify client, if the client has credentials + if let Some(hash) = db::get_client_secret(db, client_id).await.unwrap() { + let Some(authorization) = authorization else { + return TokenError::no_authorization().error_response(); + }; + + if authorization.username() != client_alias.deref() { + return TokenError::mismatch_client_id().error_response(); + } + if !hash.check_password(authorization.password()).unwrap() { + return TokenError::incorrect_client_secret().error_response(); + } + } + + let access_token = jwt::Claims::access_token( + db, + Some(claims.id()), + self_id, + client_id, + claims.subject(), + duration, + claims.scopes(), + ) + .await + .unwrap(); + + let expires_in = access_token.expires_in(); + let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap(); + let scope = access_token.scopes().into(); + + let access_token = access_token.to_jwt().unwrap(); + let refresh_token = Some(refresh_token.to_jwt().unwrap()); + + let response = TokenResponse { + access_token, + token_type, + expires_in, + refresh_token, + scope, + }; + HttpResponse::Ok() + .insert_header(cache_control) + .insert_header((header::PRAGMA, "no-cache")) + .json(response) + } + GrantType::Password { + username, + password, + scope, + } => { + let Some(authorization) = authorization else { + return TokenError::no_authorization().error_response(); + }; + let client_alias = authorization.username(); + let Some(client_id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else { + return TokenError::client_not_found(client_alias).error_response(); + }; + + let trusted = db::is_client_trusted(db, client_id).await.unwrap().unwrap(); + if !trusted { + return TokenError::untrusted_client().error_response(); + } + + // verify client + let hash = db::get_client_secret(db, client_id).await.unwrap().unwrap(); + if !hash.check_password(authorization.password()).unwrap() { + return TokenError::incorrect_client_secret().error_response(); + } + + // authenticate user + let Some(user_id) = authenticate_user(db, &username, &password).await.unwrap() else { + return TokenError::incorrect_user_credentials().error_response(); + }; + + // verify scope + let allowed_scopes = db::get_client_allowed_scopes(db, client_id) + .await + .unwrap() + .unwrap(); + let scope = if let Some(scope) = &scope { + scope.clone() + } else { + let default_scopes = db::get_client_default_scopes(db, client_id) + .await + .unwrap() + .unwrap(); + let Some(scope) = default_scopes else { + return TokenError::no_scope().error_response(); + }; + scope + }; + if !scopes::is_subset_of(&scope, &allowed_scopes) { + return TokenError::excessive_scope().error_response(); + } + + let access_token = + jwt::Claims::access_token(db, None, self_id, client_id, user_id, duration, &scope) + .await + .unwrap(); + let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap(); + + let expires_in = access_token.expires_in(); + let scope = access_token.scopes().into(); + let access_token = access_token.to_jwt().unwrap(); + let refresh_token = Some(refresh_token.to_jwt().unwrap()); + + let response = TokenResponse { + access_token, + token_type, + expires_in, + refresh_token, + scope, + }; + HttpResponse::Ok() + .insert_header(cache_control) + .insert_header((header::PRAGMA, "no-cache")) + .json(response) + } + GrantType::ClientCredentials { scope } => { + let Some(authorization) = authorization else { + return TokenError::no_authorization().error_response(); + }; + let client_alias = authorization.username(); + let Some(client_id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else { + return TokenError::client_not_found(client_alias).error_response(); + }; + + let ty = db::get_client_type(db, client_id).await.unwrap().unwrap(); + if ty != ClientType::Confidential { + return TokenError::client_not_confidential(client_alias).error_response(); + } + + // verify client + let hash = db::get_client_secret(db, client_id).await.unwrap().unwrap(); + if !hash.check_password(authorization.password()).unwrap() { + return TokenError::incorrect_client_secret().error_response(); + } + + // verify scope + let allowed_scopes = db::get_client_allowed_scopes(db, client_id) + .await + .unwrap() + .unwrap(); + let scope = if let Some(scope) = &scope { + scope.clone() + } else { + let default_scopes = db::get_client_default_scopes(db, client_id) + .await + .unwrap() + .unwrap(); + let Some(scope) = default_scopes else { + return TokenError::no_scope().error_response(); + }; + scope + }; + if !scopes::is_subset_of(&scope, &allowed_scopes) { + return TokenError::excessive_scope().error_response(); + } + + let access_token = jwt::Claims::access_token( + db, None, self_id, client_id, client_id, duration, &scope, + ) + .await + .unwrap(); + + let expires_in = access_token.expires_in(); + let scope = access_token.scopes().into(); + let access_token = access_token.to_jwt().unwrap(); + + let response = TokenResponse { + access_token, + token_type, + expires_in, + refresh_token: None, + scope, + }; + HttpResponse::Ok() + .insert_header(cache_control) + .insert_header((header::PRAGMA, "no-cache")) + .json(response) + } + GrantType::RefreshToken { + refresh_token, + scope, + } => { + let client_id: Option; + if let Some(authorization) = authorization { + let client_alias = authorization.username(); + let Some(id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else { + return TokenError::client_not_found(client_alias).error_response(); + }; + client_id = Some(id); + } else { + client_id = None; + } + + let claims = + match jwt::verify_refresh_token(db, &refresh_token, &self_id, client_id).await { + Ok(claims) => claims, + Err(e) => { + let e = e.unwrap(); + return TokenError::bad_refresh_token(e).error_response(); + } + }; + + let scope = if let Some(scope) = scope { + if !scopes::is_subset_of(&scope, claims.scopes()) { + return TokenError::excessive_scope().error_response(); + } + + scope + } else { + claims.scopes().into() + }; + + let exp_time = Duration::hours(1); + let access_token = jwt::Claims::refreshed_access_token(db, &claims, exp_time) + .await + .unwrap(); + let refresh_token = jwt::Claims::refresh_token(db, &claims).await.unwrap(); + + let access_token = access_token.to_jwt().unwrap(); + let refresh_token = Some(refresh_token.to_jwt().unwrap()); + let expires_in = exp_time.num_seconds(); + + let response = TokenResponse { + access_token, + token_type, + expires_in, + refresh_token, + scope, + }; + HttpResponse::Ok() + .insert_header(cache_control) + .insert_header((header::PRAGMA, "no-cache")) + .json(response) + } + _ => TokenError::unsupported_grant_type().error_response(), + } +} + +pub fn service() -> Scope { + web::scope("/oauth") + .service(authorize_page) + .service(authorize) + .service(token) +} diff --git a/src/api/ops.rs b/src/api/ops.rs index 555bb1b..2164f1f 100644 --- a/src/api/ops.rs +++ b/src/api/ops.rs @@ -1,70 +1,70 @@ -use std::str::FromStr; - -use actix_web::{get, http::StatusCode, post, web, HttpResponse, ResponseError, Scope}; -use raise::yeet; -use serde::Deserialize; -use sqlx::MySqlPool; -use tera::Tera; -use thiserror::Error; -use unic_langid::subtags::Language; - -use crate::resources::{languages, templates}; -use crate::services::db; - -/// A request to login -#[derive(Debug, Clone, Deserialize)] -struct LoginRequest { - username: Box, - password: Box, -} - -/// An error occurred when authenticating, because either the username or -/// password was invalid. -#[derive(Debug, Clone, Error)] -enum LoginFailure { - #[error("No user found with the given username")] - UserNotFound { username: Box }, - #[error("The given password is incorrect")] - IncorrectPassword { username: Box }, -} - -impl ResponseError for LoginFailure { - fn status_code(&self) -> actix_web::http::StatusCode { - match self { - Self::UserNotFound { .. } => StatusCode::NOT_FOUND, - Self::IncorrectPassword { .. } => StatusCode::UNAUTHORIZED, - } - } -} - -/// Returns `200` if login was successful. -/// Returns `404` if the username is invalid. -/// Returns `401` if the password was invalid. -#[post("/login")] -async fn login( - body: web::Json, - conn: web::Data, -) -> Result { - let conn = conn.get_ref(); - - let user = db::get_user_by_username(conn, &body.username) - .await - .unwrap(); - let Some(user) = user else { - yeet!(LoginFailure::UserNotFound{ username: body.username.clone() }); - }; - - let good_password = user.check_password(&body.password).unwrap(); - let response = if good_password { - HttpResponse::Ok().finish() - } else { - yeet!(LoginFailure::IncorrectPassword { - username: body.username.clone() - }); - }; - Ok(response) -} - -pub fn service() -> Scope { - web::scope("").service(login) -} +use std::str::FromStr; + +use actix_web::{get, http::StatusCode, post, web, HttpResponse, ResponseError, Scope}; +use raise::yeet; +use serde::Deserialize; +use sqlx::MySqlPool; +use tera::Tera; +use thiserror::Error; +use unic_langid::subtags::Language; + +use crate::resources::{languages, templates}; +use crate::services::db; + +/// A request to login +#[derive(Debug, Clone, Deserialize)] +struct LoginRequest { + username: Box, + password: Box, +} + +/// An error occurred when authenticating, because either the username or +/// password was invalid. +#[derive(Debug, Clone, Error)] +enum LoginFailure { + #[error("No user found with the given username")] + UserNotFound { username: Box }, + #[error("The given password is incorrect")] + IncorrectPassword { username: Box }, +} + +impl ResponseError for LoginFailure { + fn status_code(&self) -> actix_web::http::StatusCode { + match self { + Self::UserNotFound { .. } => StatusCode::NOT_FOUND, + Self::IncorrectPassword { .. } => StatusCode::UNAUTHORIZED, + } + } +} + +/// Returns `200` if login was successful. +/// Returns `404` if the username is invalid. +/// Returns `401` if the password was invalid. +#[post("/login")] +async fn login( + body: web::Json, + conn: web::Data, +) -> Result { + let conn = conn.get_ref(); + + let user = db::get_user_by_username(conn, &body.username) + .await + .unwrap(); + let Some(user) = user else { + yeet!(LoginFailure::UserNotFound{ username: body.username.clone() }); + }; + + let good_password = user.check_password(&body.password).unwrap(); + let response = if good_password { + HttpResponse::Ok().finish() + } else { + yeet!(LoginFailure::IncorrectPassword { + username: body.username.clone() + }); + }; + Ok(response) +} + +pub fn service() -> Scope { + web::scope("").service(login) +} diff --git a/src/api/users.rs b/src/api/users.rs index 391a059..da2a0d0 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -1,272 +1,272 @@ -use actix_web::http::{header, StatusCode}; -use actix_web::{get, post, put, web, HttpResponse, ResponseError, Scope}; -use raise::yeet; -use serde::{Deserialize, Serialize}; -use sqlx::MySqlPool; -use thiserror::Error; -use uuid::Uuid; - -use crate::models::user::User; -use crate::services::crypto::PasswordHash; -use crate::services::{db, id}; - -/// Just a username. No password hash, because that'd be tempting fate. -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "camelCase")] -struct UserResponse { - id: Uuid, - username: Box, -} - -impl From for UserResponse { - fn from(user: User) -> Self { - Self { - id: user.id, - username: user.username, - } - } -} - -#[derive(Debug, Clone, Deserialize)] -#[serde(rename_all = "camelCase")] -struct SearchUsers { - username: Option>, - limit: Option, - offset: Option, -} - -#[get("")] -async fn search_users(params: web::Query, conn: web::Data) -> HttpResponse { - let conn = conn.get_ref(); - - let username = params.username.clone().unwrap_or_default(); - let offset = params.offset.unwrap_or_default(); - - let results: Box<[UserResponse]> = if let Some(limit) = params.limit { - db::search_users_limit(conn, &username, offset, limit) - .await - .unwrap() - .iter() - .cloned() - .map(|u| u.into()) - .collect() - } else { - db::search_users(conn, &username) - .await - .unwrap() - .into_iter() - .skip(offset as usize) - .cloned() - .map(|u| u.into()) - .collect() - }; - - let response = HttpResponse::Ok().json(results); - response -} - -#[derive(Debug, Clone, Error)] -#[error("No user with the given ID exists")] -struct UserNotFoundError { - user_id: Uuid, -} - -impl ResponseError for UserNotFoundError { - fn status_code(&self) -> StatusCode { - StatusCode::NOT_FOUND - } -} - -#[get("/{user_id}")] -async fn get_user( - user_id: web::Path, - conn: web::Data, -) -> Result { - let conn = conn.get_ref(); - - let id = user_id.to_owned(); - let username = db::get_username(conn, id).await.unwrap(); - - let Some(username) = username else { - yeet!(UserNotFoundError { user_id: id }); - }; - - let response = UserResponse { id, username }; - let response = HttpResponse::Ok().json(response); - Ok(response) -} - -#[get("/{user_id}/username")] -async fn get_username( - user_id: web::Path, - conn: web::Data, -) -> Result { - let conn = conn.get_ref(); - - let user_id = user_id.to_owned(); - let username = db::get_username(conn, user_id).await.unwrap(); - - let Some(username) = username else { - yeet!(UserNotFoundError { user_id }); - }; - - let response = HttpResponse::Ok().json(username); - Ok(response) -} - -/// A request to create or update user information -#[derive(Clone, Deserialize)] -#[serde(rename_all = "camelCase")] -struct UserRequest { - username: Box, - password: Box, -} - -#[derive(Debug, Clone, Error)] -#[error("An account with the given username already exists.")] -struct UsernameTakenError { - username: Box, -} - -impl ResponseError for UsernameTakenError { - fn status_code(&self) -> StatusCode { - StatusCode::CONFLICT - } -} - -#[post("")] -async fn create_user( - body: web::Json, - conn: web::Data, -) -> Result { - let conn = conn.get_ref(); - - let user_id = id::new_id(conn, db::user_id_exists).await.unwrap(); - let username = body.username.clone(); - let password = PasswordHash::new(&body.password).unwrap(); - - if db::username_is_used(conn, &body.username).await.unwrap() { - yeet!(UsernameTakenError { username }); - } - - let user = User { - id: user_id, - username, - password, - }; - - db::create_user(conn, &user).await.unwrap(); - - let response = HttpResponse::Created() - .insert_header((header::LOCATION, format!("users/{user_id}"))) - .finish(); - Ok(response) -} - -#[derive(Debug, Clone, Error)] -enum UpdateUserError { - #[error(transparent)] - UsernameTaken(#[from] UsernameTakenError), - #[error(transparent)] - NotFound(#[from] UserNotFoundError), -} - -impl ResponseError for UpdateUserError { - fn status_code(&self) -> StatusCode { - match self { - Self::UsernameTaken(e) => e.status_code(), - Self::NotFound(e) => e.status_code(), - } - } -} - -#[put("/{user_id}")] -async fn update_user( - user_id: web::Path, - body: web::Json, - conn: web::Data, -) -> Result { - let conn = conn.get_ref(); - - let user_id = user_id.to_owned(); - let username = body.username.clone(); - let password = PasswordHash::new(&body.password).unwrap(); - - let old_username = db::get_username(conn, user_id).await.unwrap().unwrap(); - if username != old_username && db::username_is_used(conn, &body.username).await.unwrap() { - yeet!(UsernameTakenError { username }.into()) - } - - if !db::user_id_exists(conn, user_id).await.unwrap() { - yeet!(UserNotFoundError { user_id }.into()) - } - - let user = User { - id: user_id, - username, - password, - }; - - db::update_user(conn, &user).await.unwrap(); - - let response = HttpResponse::NoContent().finish(); - Ok(response) -} - -#[put("/{user_id}/username")] -async fn update_username( - user_id: web::Path, - body: web::Json>, - conn: web::Data, -) -> Result { - let conn = conn.get_ref(); - - let user_id = user_id.to_owned(); - let username = body.clone(); - - let old_username = db::get_username(conn, user_id).await.unwrap().unwrap(); - if username != old_username && db::username_is_used(conn, &body).await.unwrap() { - yeet!(UsernameTakenError { username }.into()) - } - - if !db::user_id_exists(conn, user_id).await.unwrap() { - yeet!(UserNotFoundError { user_id }.into()) - } - - db::update_username(conn, user_id, &body).await.unwrap(); - - let response = HttpResponse::NoContent().finish(); - Ok(response) -} - -#[put("/{user_id}/password")] -async fn update_password( - user_id: web::Path, - body: web::Json>, - conn: web::Data, -) -> Result { - let conn = conn.get_ref(); - - let user_id = user_id.to_owned(); - let password = PasswordHash::new(&body).unwrap(); - - if !db::user_id_exists(conn, user_id).await.unwrap() { - yeet!(UserNotFoundError { user_id }) - } - - db::update_password(conn, user_id, &password).await.unwrap(); - - let response = HttpResponse::NoContent().finish(); - Ok(response) -} - -pub fn service() -> Scope { - web::scope("/users") - .service(search_users) - .service(get_user) - .service(get_username) - .service(create_user) - .service(update_user) - .service(update_username) - .service(update_password) -} +use actix_web::http::{header, StatusCode}; +use actix_web::{get, post, put, web, HttpResponse, ResponseError, Scope}; +use raise::yeet; +use serde::{Deserialize, Serialize}; +use sqlx::MySqlPool; +use thiserror::Error; +use uuid::Uuid; + +use crate::models::user::User; +use crate::services::crypto::PasswordHash; +use crate::services::{db, id}; + +/// Just a username. No password hash, because that'd be tempting fate. +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +struct UserResponse { + id: Uuid, + username: Box, +} + +impl From for UserResponse { + fn from(user: User) -> Self { + Self { + id: user.id, + username: user.username, + } + } +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +struct SearchUsers { + username: Option>, + limit: Option, + offset: Option, +} + +#[get("")] +async fn search_users(params: web::Query, conn: web::Data) -> HttpResponse { + let conn = conn.get_ref(); + + let username = params.username.clone().unwrap_or_default(); + let offset = params.offset.unwrap_or_default(); + + let results: Box<[UserResponse]> = if let Some(limit) = params.limit { + db::search_users_limit(conn, &username, offset, limit) + .await + .unwrap() + .iter() + .cloned() + .map(|u| u.into()) + .collect() + } else { + db::search_users(conn, &username) + .await + .unwrap() + .into_iter() + .skip(offset as usize) + .cloned() + .map(|u| u.into()) + .collect() + }; + + let response = HttpResponse::Ok().json(results); + response +} + +#[derive(Debug, Clone, Error)] +#[error("No user with the given ID exists")] +struct UserNotFoundError { + user_id: Uuid, +} + +impl ResponseError for UserNotFoundError { + fn status_code(&self) -> StatusCode { + StatusCode::NOT_FOUND + } +} + +#[get("/{user_id}")] +async fn get_user( + user_id: web::Path, + conn: web::Data, +) -> Result { + let conn = conn.get_ref(); + + let id = user_id.to_owned(); + let username = db::get_username(conn, id).await.unwrap(); + + let Some(username) = username else { + yeet!(UserNotFoundError { user_id: id }); + }; + + let response = UserResponse { id, username }; + let response = HttpResponse::Ok().json(response); + Ok(response) +} + +#[get("/{user_id}/username")] +async fn get_username( + user_id: web::Path, + conn: web::Data, +) -> Result { + let conn = conn.get_ref(); + + let user_id = user_id.to_owned(); + let username = db::get_username(conn, user_id).await.unwrap(); + + let Some(username) = username else { + yeet!(UserNotFoundError { user_id }); + }; + + let response = HttpResponse::Ok().json(username); + Ok(response) +} + +/// A request to create or update user information +#[derive(Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +struct UserRequest { + username: Box, + password: Box, +} + +#[derive(Debug, Clone, Error)] +#[error("An account with the given username already exists.")] +struct UsernameTakenError { + username: Box, +} + +impl ResponseError for UsernameTakenError { + fn status_code(&self) -> StatusCode { + StatusCode::CONFLICT + } +} + +#[post("")] +async fn create_user( + body: web::Json, + conn: web::Data, +) -> Result { + let conn = conn.get_ref(); + + let user_id = id::new_id(conn, db::user_id_exists).await.unwrap(); + let username = body.username.clone(); + let password = PasswordHash::new(&body.password).unwrap(); + + if db::username_is_used(conn, &body.username).await.unwrap() { + yeet!(UsernameTakenError { username }); + } + + let user = User { + id: user_id, + username, + password, + }; + + db::create_user(conn, &user).await.unwrap(); + + let response = HttpResponse::Created() + .insert_header((header::LOCATION, format!("users/{user_id}"))) + .finish(); + Ok(response) +} + +#[derive(Debug, Clone, Error)] +enum UpdateUserError { + #[error(transparent)] + UsernameTaken(#[from] UsernameTakenError), + #[error(transparent)] + NotFound(#[from] UserNotFoundError), +} + +impl ResponseError for UpdateUserError { + fn status_code(&self) -> StatusCode { + match self { + Self::UsernameTaken(e) => e.status_code(), + Self::NotFound(e) => e.status_code(), + } + } +} + +#[put("/{user_id}")] +async fn update_user( + user_id: web::Path, + body: web::Json, + conn: web::Data, +) -> Result { + let conn = conn.get_ref(); + + let user_id = user_id.to_owned(); + let username = body.username.clone(); + let password = PasswordHash::new(&body.password).unwrap(); + + let old_username = db::get_username(conn, user_id).await.unwrap().unwrap(); + if username != old_username && db::username_is_used(conn, &body.username).await.unwrap() { + yeet!(UsernameTakenError { username }.into()) + } + + if !db::user_id_exists(conn, user_id).await.unwrap() { + yeet!(UserNotFoundError { user_id }.into()) + } + + let user = User { + id: user_id, + username, + password, + }; + + db::update_user(conn, &user).await.unwrap(); + + let response = HttpResponse::NoContent().finish(); + Ok(response) +} + +#[put("/{user_id}/username")] +async fn update_username( + user_id: web::Path, + body: web::Json>, + conn: web::Data, +) -> Result { + let conn = conn.get_ref(); + + let user_id = user_id.to_owned(); + let username = body.clone(); + + let old_username = db::get_username(conn, user_id).await.unwrap().unwrap(); + if username != old_username && db::username_is_used(conn, &body).await.unwrap() { + yeet!(UsernameTakenError { username }.into()) + } + + if !db::user_id_exists(conn, user_id).await.unwrap() { + yeet!(UserNotFoundError { user_id }.into()) + } + + db::update_username(conn, user_id, &body).await.unwrap(); + + let response = HttpResponse::NoContent().finish(); + Ok(response) +} + +#[put("/{user_id}/password")] +async fn update_password( + user_id: web::Path, + body: web::Json>, + conn: web::Data, +) -> Result { + let conn = conn.get_ref(); + + let user_id = user_id.to_owned(); + let password = PasswordHash::new(&body).unwrap(); + + if !db::user_id_exists(conn, user_id).await.unwrap() { + yeet!(UserNotFoundError { user_id }) + } + + db::update_password(conn, user_id, &password).await.unwrap(); + + let response = HttpResponse::NoContent().finish(); + Ok(response) +} + +pub fn service() -> Scope { + web::scope("/users") + .service(search_users) + .service(get_user) + .service(get_username) + .service(create_user) + .service(update_user) + .service(update_username) + .service(update_password) +} -- cgit v1.2.3