diff options
| author | Mica White <botahamec@outlook.com> | 2025-12-08 20:08:21 -0500 |
|---|---|---|
| committer | Mica White <botahamec@outlook.com> | 2025-12-08 20:08:21 -0500 |
| commit | 608ce1d9910cd68ce825838ea313e02c598f908e (patch) | |
| tree | 0bd4ad26f86e5c873f97308983112b0ffe593df3 /src | |
| parent | 93fd2e82e8fdc5ee62739053385f8ccffc660f02 (diff) | |
Diffstat (limited to 'src')
| -rw-r--r-- | src/api/clients.rs | 966 | ||||
| -rw-r--r-- | src/api/liveops.rs | 22 | ||||
| -rw-r--r-- | src/api/mod.rs | 26 | ||||
| -rw-r--r-- | src/api/oauth.rs | 1852 | ||||
| -rw-r--r-- | src/api/ops.rs | 140 | ||||
| -rw-r--r-- | src/api/users.rs | 544 | ||||
| -rw-r--r-- | src/main.rs | 216 | ||||
| -rw-r--r-- | src/models/client.rs | 330 | ||||
| -rw-r--r-- | src/models/mod.rs | 4 | ||||
| -rw-r--r-- | src/models/user.rs | 98 | ||||
| -rw-r--r-- | src/resources/languages.rs | 134 | ||||
| -rw-r--r-- | src/resources/mod.rs | 8 | ||||
| -rw-r--r-- | src/resources/scripts.rs | 76 | ||||
| -rw-r--r-- | src/resources/style.rs | 108 | ||||
| -rw-r--r-- | src/resources/templates.rs | 202 | ||||
| -rw-r--r-- | src/scopes/admin.rs | 56 | ||||
| -rw-r--r-- | src/scopes/mod.rs | 256 | ||||
| -rw-r--r-- | src/services/authorization.rs | 164 | ||||
| -rw-r--r-- | src/services/config.rs | 148 | ||||
| -rw-r--r-- | src/services/crypto.rs | 194 | ||||
| -rw-r--r-- | src/services/db.rs | 30 | ||||
| -rw-r--r-- | src/services/db/client.rs | 784 | ||||
| -rw-r--r-- | src/services/db/jwt.rs | 398 | ||||
| -rw-r--r-- | src/services/db/user.rs | 472 | ||||
| -rw-r--r-- | src/services/id.rs | 54 | ||||
| -rw-r--r-- | src/services/jwt.rs | 582 | ||||
| -rw-r--r-- | src/services/mod.rs | 14 | ||||
| -rw-r--r-- | src/services/secrets.rs | 48 |
28 files changed, 3963 insertions, 3963 deletions
diff --git a/src/api/clients.rs b/src/api/clients.rs index 3f906bb..ded8b81 100644 --- a/src/api/clients.rs +++ b/src/api/clients.rs @@ -1,483 +1,483 @@ -use actix_web::http::{header, StatusCode}; -use actix_web::{get, post, put, web, HttpResponse, ResponseError, Scope}; -use raise::yeet; -use serde::{Deserialize, Serialize}; -use sqlx::MySqlPool; -use thiserror::Error; -use url::Url; -use uuid::Uuid; - -use crate::models::client::{Client, ClientType, CreateClientError}; -use crate::services::crypto::PasswordHash; -use crate::services::db::ClientRow; -use crate::services::{db, id}; - -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "camelCase")] -struct ClientResponse { - client_id: Uuid, - alias: Box<str>, - client_type: ClientType, - allowed_scopes: Box<[Box<str>]>, - default_scopes: Option<Box<[Box<str>]>>, - is_trusted: bool, -} - -impl From<ClientRow> for ClientResponse { - fn from(value: ClientRow) -> Self { - Self { - client_id: value.id, - alias: value.alias.into_boxed_str(), - client_type: value.client_type, - allowed_scopes: value - .allowed_scopes - .split_whitespace() - .map(Box::from) - .collect(), - default_scopes: value - .default_scopes - .map(|s| s.split_whitespace().map(Box::from).collect()), - is_trusted: value.is_trusted, - } - } -} - -#[derive(Debug, Clone, Copy, Error)] -#[error("No client with the given client ID was found")] -struct ClientNotFound { - id: Uuid, -} - -impl ResponseError for ClientNotFound { - fn status_code(&self) -> StatusCode { - StatusCode::NOT_FOUND - } -} - -impl ClientNotFound { - fn new(id: Uuid) -> Self { - Self { id } - } -} - -#[get("/{client_id}")] -async fn get_client( - client_id: web::Path<Uuid>, - db: web::Data<MySqlPool>, -) -> Result<HttpResponse, ClientNotFound> { - let db = db.as_ref(); - let id = *client_id; - - let Some(client) = db::get_client_response(db, id).await.unwrap() else { - yeet!(ClientNotFound::new(id)) - }; - - let redirect_uris_link = format!("</clients/{client_id}/redirect-uris>; rel=\"redirect-uris\""); - let response: ClientResponse = client.into(); - let response = HttpResponse::Ok() - .append_header((header::LINK, redirect_uris_link)) - .json(response); - Ok(response) -} - -#[get("/{client_id}/alias")] -async fn get_client_alias( - client_id: web::Path<Uuid>, - db: web::Data<MySqlPool>, -) -> Result<HttpResponse, ClientNotFound> { - let db = db.as_ref(); - let id = *client_id; - - let Some(alias) = db::get_client_alias(db, id).await.unwrap() else { - yeet!(ClientNotFound::new(id)) - }; - - Ok(HttpResponse::Ok().json(alias)) -} - -#[get("/{client_id}/type")] -async fn get_client_type( - client_id: web::Path<Uuid>, - db: web::Data<MySqlPool>, -) -> Result<HttpResponse, ClientNotFound> { - let db = db.as_ref(); - let id = *client_id; - - let Some(client_type) = db::get_client_type(db, id).await.unwrap() else { - yeet!(ClientNotFound::new(id)) - }; - - Ok(HttpResponse::Ok().json(client_type)) -} - -#[get("/{client_id}/redirect-uris")] -async fn get_client_redirect_uris( - client_id: web::Path<Uuid>, - db: web::Data<MySqlPool>, -) -> Result<HttpResponse, ClientNotFound> { - let db = db.as_ref(); - let id = *client_id; - - if !db::client_id_exists(db, id).await.unwrap() { - yeet!(ClientNotFound::new(id)) - }; - - let redirect_uris = db::get_client_redirect_uris(db, id).await.unwrap(); - - Ok(HttpResponse::Ok().json(redirect_uris)) -} - -#[get("/{client_id}/allowed-scopes")] -async fn get_client_allowed_scopes( - client_id: web::Path<Uuid>, - db: web::Data<MySqlPool>, -) -> Result<HttpResponse, ClientNotFound> { - let db = db.as_ref(); - let id = *client_id; - - let Some(allowed_scopes) = db::get_client_allowed_scopes(db, id).await.unwrap() else { - yeet!(ClientNotFound::new(id)) - }; - - let allowed_scopes = allowed_scopes.split_whitespace().collect::<Box<[&str]>>(); - - Ok(HttpResponse::Ok().json(allowed_scopes)) -} - -#[get("/{client_id}/default-scopes")] -async fn get_client_default_scopes( - client_id: web::Path<Uuid>, - db: web::Data<MySqlPool>, -) -> Result<HttpResponse, ClientNotFound> { - let db = db.as_ref(); - let id = *client_id; - - let Some(default_scopes) = db::get_client_default_scopes(db, id).await.unwrap() else { - yeet!(ClientNotFound::new(id)) - }; - - let default_scopes = default_scopes.map(|scopes| { - scopes - .split_whitespace() - .map(Box::from) - .collect::<Box<[Box<str>]>>() - }); - - Ok(HttpResponse::Ok().json(default_scopes)) -} - -#[get("/{client_id}/is-trusted")] -async fn get_client_is_trusted( - client_id: web::Path<Uuid>, - db: web::Data<MySqlPool>, -) -> Result<HttpResponse, ClientNotFound> { - let db = db.as_ref(); - let id = *client_id; - - let Some(is_trusted) = db::is_client_trusted(db, id).await.unwrap() else { - yeet!(ClientNotFound::new(id)) - }; - - Ok(HttpResponse::Ok().json(is_trusted)) -} - -#[derive(Clone, Deserialize)] -#[serde(rename_all = "camelCase")] -struct ClientRequest { - alias: Box<str>, - ty: ClientType, - redirect_uris: Box<[Url]>, - secret: Option<Box<str>>, - allowed_scopes: Box<[Box<str>]>, - default_scopes: Option<Box<[Box<str>]>>, - trusted: bool, -} - -#[derive(Debug, Clone, Error)] -#[error("The given client alias is already taken")] -struct AliasTakenError { - alias: Box<str>, -} - -impl ResponseError for AliasTakenError { - fn status_code(&self) -> StatusCode { - StatusCode::CONFLICT - } -} - -impl AliasTakenError { - fn new(alias: &str) -> Self { - Self { - alias: Box::from(alias), - } - } -} - -#[post("")] -async fn create_client( - body: web::Json<ClientRequest>, - db: web::Data<MySqlPool>, -) -> Result<HttpResponse, UpdateClientError> { - let db = db.get_ref(); - let alias = &body.alias; - - if db::client_alias_exists(db, &alias).await.unwrap() { - yeet!(AliasTakenError::new(&alias).into()); - } - - let id = id::new_id(db, db::client_id_exists).await.unwrap(); - let client = Client::new( - id, - &alias, - body.ty, - body.secret.as_deref(), - body.allowed_scopes.clone(), - body.default_scopes.clone(), - &body.redirect_uris, - body.trusted, - ) - .map_err(|e| e.unwrap())?; - - let transaction = db.begin().await.unwrap(); - db::create_client(transaction, &client).await.unwrap(); - - let response = HttpResponse::Created() - .insert_header((header::LOCATION, format!("clients/{id}"))) - .finish(); - Ok(response) -} - -#[derive(Debug, Clone, Error)] -enum UpdateClientError { - #[error(transparent)] - NotFound(#[from] ClientNotFound), - #[error(transparent)] - ClientError(#[from] CreateClientError), - #[error(transparent)] - AliasTaken(#[from] AliasTakenError), -} - -impl ResponseError for UpdateClientError { - fn status_code(&self) -> StatusCode { - match self { - Self::NotFound(e) => e.status_code(), - Self::ClientError(e) => e.status_code(), - Self::AliasTaken(e) => e.status_code(), - } - } -} - -#[put("/{id}")] -async fn update_client( - id: web::Path<Uuid>, - body: web::Json<ClientRequest>, - db: web::Data<MySqlPool>, -) -> Result<HttpResponse, UpdateClientError> { - let db = db.get_ref(); - let id = *id; - let alias = &body.alias; - - let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else { - yeet!(ClientNotFound::new(id).into()) - }; - if old_alias != alias.clone() && db::client_alias_exists(db, &alias).await.unwrap() { - yeet!(AliasTakenError::new(&alias).into()); - } - - let client = Client::new( - id, - &alias, - body.ty, - body.secret.as_deref(), - body.allowed_scopes.clone(), - body.default_scopes.clone(), - &body.redirect_uris, - body.trusted, - ) - .map_err(|e| e.unwrap())?; - - let transaction = db.begin().await.unwrap(); - db::update_client(transaction, &client).await.unwrap(); - - let response = HttpResponse::NoContent().finish(); - Ok(response) -} - -#[put("/{id}/alias")] -async fn update_client_alias( - id: web::Path<Uuid>, - body: web::Json<Box<str>>, - db: web::Data<MySqlPool>, -) -> Result<HttpResponse, UpdateClientError> { - let db = db.get_ref(); - let id = *id; - let alias = body.0; - - let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else { - yeet!(ClientNotFound::new(id).into()) - }; - if old_alias == alias { - return Ok(HttpResponse::NoContent().finish()); - } - if db::client_alias_exists(db, &alias).await.unwrap() { - yeet!(AliasTakenError::new(&alias).into()); - } - - db::update_client_alias(db, id, &alias).await.unwrap(); - - let response = HttpResponse::NoContent().finish(); - Ok(response) -} - -#[put("/{id}/type")] -async fn update_client_type( - id: web::Path<Uuid>, - body: web::Json<ClientType>, - db: web::Data<MySqlPool>, -) -> Result<HttpResponse, UpdateClientError> { - let db = db.get_ref(); - let id = *id; - let ty = body.0; - - if !db::client_id_exists(db, id).await.unwrap() { - yeet!(ClientNotFound::new(id).into()); - } - - db::update_client_type(db, id, ty).await.unwrap(); - - Ok(HttpResponse::NoContent().finish()) -} - -#[put("/{id}/allowed-scopes")] -async fn update_client_allowed_scopes( - id: web::Path<Uuid>, - body: web::Json<Box<[Box<str>]>>, - db: web::Data<MySqlPool>, -) -> Result<HttpResponse, UpdateClientError> { - let db = db.get_ref(); - let id = *id; - let allowed_scopes = body.0.join(" "); - - if !db::client_id_exists(db, id).await.unwrap() { - yeet!(ClientNotFound::new(id).into()); - } - - db::update_client_allowed_scopes(db, id, &allowed_scopes) - .await - .unwrap(); - - Ok(HttpResponse::NoContent().finish()) -} - -#[put("/{id}/default-scopes")] -async fn update_client_default_scopes( - id: web::Path<Uuid>, - body: web::Json<Option<Box<[Box<str>]>>>, - db: web::Data<MySqlPool>, -) -> Result<HttpResponse, UpdateClientError> { - let db = db.get_ref(); - let id = *id; - let default_scopes = body.0.map(|s| s.join(" ")); - - if !db::client_id_exists(db, id).await.unwrap() { - yeet!(ClientNotFound::new(id).into()); - } - - db::update_client_default_scopes(db, id, default_scopes) - .await - .unwrap(); - - Ok(HttpResponse::NoContent().finish()) -} - -#[put("/{id}/is-trusted")] -async fn update_client_is_trusted( - id: web::Path<Uuid>, - body: web::Json<bool>, - db: web::Data<MySqlPool>, -) -> Result<HttpResponse, UpdateClientError> { - let db = db.get_ref(); - let id = *id; - let is_trusted = *body; - - if !db::client_id_exists(db, id).await.unwrap() { - yeet!(ClientNotFound::new(id).into()); - } - - db::update_client_trusted(db, id, is_trusted).await.unwrap(); - - Ok(HttpResponse::NoContent().finish()) -} - -#[put("/{id}/redirect-uris")] -async fn update_client_redirect_uris( - id: web::Path<Uuid>, - body: web::Json<Box<[Url]>>, - db: web::Data<MySqlPool>, -) -> Result<HttpResponse, UpdateClientError> { - let db = db.get_ref(); - let id = *id; - - for uri in body.0.iter() { - if uri.scheme() != "https" { - yeet!(CreateClientError::NonHttpsUri.into()); - } - - if uri.fragment().is_some() { - yeet!(CreateClientError::UriFragment.into()) - } - } - - if !db::client_id_exists(db, id).await.unwrap() { - yeet!(ClientNotFound::new(id).into()); - } - - let transaction = db.begin().await.unwrap(); - db::update_client_redirect_uris(transaction, id, &body.0) - .await - .unwrap(); - - Ok(HttpResponse::NoContent().finish()) -} - -#[put("{id}/secret")] -async fn update_client_secret( - id: web::Path<Uuid>, - body: web::Json<Option<Box<str>>>, - db: web::Data<MySqlPool>, -) -> Result<HttpResponse, UpdateClientError> { - let db = db.get_ref(); - let id = *id; - - let Some(client_type) = db::get_client_type(db, id).await.unwrap() else { - yeet!(ClientNotFound::new(id).into()) - }; - - if client_type == ClientType::Confidential && body.is_none() { - yeet!(CreateClientError::NoSecret.into()) - } - - let secret = body.0.map(|s| PasswordHash::new(&s).unwrap()); - db::update_client_secret(db, id, secret).await.unwrap(); - - Ok(HttpResponse::NoContent().finish()) -} - -pub fn service() -> Scope { - web::scope("/clients") - .service(get_client) - .service(get_client_alias) - .service(get_client_type) - .service(get_client_allowed_scopes) - .service(get_client_default_scopes) - .service(get_client_redirect_uris) - .service(get_client_is_trusted) - .service(create_client) - .service(update_client) - .service(update_client_alias) - .service(update_client_type) - .service(update_client_allowed_scopes) - .service(update_client_default_scopes) - .service(update_client_redirect_uris) - .service(update_client_secret) -} +use actix_web::http::{header, StatusCode};
+use actix_web::{get, post, put, web, HttpResponse, ResponseError, Scope};
+use raise::yeet;
+use serde::{Deserialize, Serialize};
+use sqlx::MySqlPool;
+use thiserror::Error;
+use url::Url;
+use uuid::Uuid;
+
+use crate::models::client::{Client, ClientType, CreateClientError};
+use crate::services::crypto::PasswordHash;
+use crate::services::db::ClientRow;
+use crate::services::{db, id};
+
+#[derive(Debug, Clone, Serialize)]
+#[serde(rename_all = "camelCase")]
+struct ClientResponse {
+ client_id: Uuid,
+ alias: Box<str>,
+ client_type: ClientType,
+ allowed_scopes: Box<[Box<str>]>,
+ default_scopes: Option<Box<[Box<str>]>>,
+ is_trusted: bool,
+}
+
+impl From<ClientRow> for ClientResponse {
+ fn from(value: ClientRow) -> Self {
+ Self {
+ client_id: value.id,
+ alias: value.alias.into_boxed_str(),
+ client_type: value.client_type,
+ allowed_scopes: value
+ .allowed_scopes
+ .split_whitespace()
+ .map(Box::from)
+ .collect(),
+ default_scopes: value
+ .default_scopes
+ .map(|s| s.split_whitespace().map(Box::from).collect()),
+ is_trusted: value.is_trusted,
+ }
+ }
+}
+
+#[derive(Debug, Clone, Copy, Error)]
+#[error("No client with the given client ID was found")]
+struct ClientNotFound {
+ id: Uuid,
+}
+
+impl ResponseError for ClientNotFound {
+ fn status_code(&self) -> StatusCode {
+ StatusCode::NOT_FOUND
+ }
+}
+
+impl ClientNotFound {
+ fn new(id: Uuid) -> Self {
+ Self { id }
+ }
+}
+
+#[get("/{client_id}")]
+async fn get_client(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(client) = db::get_client_response(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ let redirect_uris_link = format!("</clients/{client_id}/redirect-uris>; rel=\"redirect-uris\"");
+ let response: ClientResponse = client.into();
+ let response = HttpResponse::Ok()
+ .append_header((header::LINK, redirect_uris_link))
+ .json(response);
+ Ok(response)
+}
+
+#[get("/{client_id}/alias")]
+async fn get_client_alias(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(alias) = db::get_client_alias(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ Ok(HttpResponse::Ok().json(alias))
+}
+
+#[get("/{client_id}/type")]
+async fn get_client_type(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(client_type) = db::get_client_type(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ Ok(HttpResponse::Ok().json(client_type))
+}
+
+#[get("/{client_id}/redirect-uris")]
+async fn get_client_redirect_uris(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ let redirect_uris = db::get_client_redirect_uris(db, id).await.unwrap();
+
+ Ok(HttpResponse::Ok().json(redirect_uris))
+}
+
+#[get("/{client_id}/allowed-scopes")]
+async fn get_client_allowed_scopes(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(allowed_scopes) = db::get_client_allowed_scopes(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ let allowed_scopes = allowed_scopes.split_whitespace().collect::<Box<[&str]>>();
+
+ Ok(HttpResponse::Ok().json(allowed_scopes))
+}
+
+#[get("/{client_id}/default-scopes")]
+async fn get_client_default_scopes(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(default_scopes) = db::get_client_default_scopes(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ let default_scopes = default_scopes.map(|scopes| {
+ scopes
+ .split_whitespace()
+ .map(Box::from)
+ .collect::<Box<[Box<str>]>>()
+ });
+
+ Ok(HttpResponse::Ok().json(default_scopes))
+}
+
+#[get("/{client_id}/is-trusted")]
+async fn get_client_is_trusted(
+ client_id: web::Path<Uuid>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, ClientNotFound> {
+ let db = db.as_ref();
+ let id = *client_id;
+
+ let Some(is_trusted) = db::is_client_trusted(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id))
+ };
+
+ Ok(HttpResponse::Ok().json(is_trusted))
+}
+
+#[derive(Clone, Deserialize)]
+#[serde(rename_all = "camelCase")]
+struct ClientRequest {
+ alias: Box<str>,
+ ty: ClientType,
+ redirect_uris: Box<[Url]>,
+ secret: Option<Box<str>>,
+ allowed_scopes: Box<[Box<str>]>,
+ default_scopes: Option<Box<[Box<str>]>>,
+ trusted: bool,
+}
+
+#[derive(Debug, Clone, Error)]
+#[error("The given client alias is already taken")]
+struct AliasTakenError {
+ alias: Box<str>,
+}
+
+impl ResponseError for AliasTakenError {
+ fn status_code(&self) -> StatusCode {
+ StatusCode::CONFLICT
+ }
+}
+
+impl AliasTakenError {
+ fn new(alias: &str) -> Self {
+ Self {
+ alias: Box::from(alias),
+ }
+ }
+}
+
+#[post("")]
+async fn create_client(
+ body: web::Json<ClientRequest>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let alias = &body.alias;
+
+ if db::client_alias_exists(db, &alias).await.unwrap() {
+ yeet!(AliasTakenError::new(&alias).into());
+ }
+
+ let id = id::new_id(db, db::client_id_exists).await.unwrap();
+ let client = Client::new(
+ id,
+ &alias,
+ body.ty,
+ body.secret.as_deref(),
+ body.allowed_scopes.clone(),
+ body.default_scopes.clone(),
+ &body.redirect_uris,
+ body.trusted,
+ )
+ .map_err(|e| e.unwrap())?;
+
+ let transaction = db.begin().await.unwrap();
+ db::create_client(transaction, &client).await.unwrap();
+
+ let response = HttpResponse::Created()
+ .insert_header((header::LOCATION, format!("clients/{id}")))
+ .finish();
+ Ok(response)
+}
+
+#[derive(Debug, Clone, Error)]
+enum UpdateClientError {
+ #[error(transparent)]
+ NotFound(#[from] ClientNotFound),
+ #[error(transparent)]
+ ClientError(#[from] CreateClientError),
+ #[error(transparent)]
+ AliasTaken(#[from] AliasTakenError),
+}
+
+impl ResponseError for UpdateClientError {
+ fn status_code(&self) -> StatusCode {
+ match self {
+ Self::NotFound(e) => e.status_code(),
+ Self::ClientError(e) => e.status_code(),
+ Self::AliasTaken(e) => e.status_code(),
+ }
+ }
+}
+
+#[put("/{id}")]
+async fn update_client(
+ id: web::Path<Uuid>,
+ body: web::Json<ClientRequest>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let alias = &body.alias;
+
+ let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id).into())
+ };
+ if old_alias != alias.clone() && db::client_alias_exists(db, &alias).await.unwrap() {
+ yeet!(AliasTakenError::new(&alias).into());
+ }
+
+ let client = Client::new(
+ id,
+ &alias,
+ body.ty,
+ body.secret.as_deref(),
+ body.allowed_scopes.clone(),
+ body.default_scopes.clone(),
+ &body.redirect_uris,
+ body.trusted,
+ )
+ .map_err(|e| e.unwrap())?;
+
+ let transaction = db.begin().await.unwrap();
+ db::update_client(transaction, &client).await.unwrap();
+
+ let response = HttpResponse::NoContent().finish();
+ Ok(response)
+}
+
+#[put("/{id}/alias")]
+async fn update_client_alias(
+ id: web::Path<Uuid>,
+ body: web::Json<Box<str>>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let alias = body.0;
+
+ let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id).into())
+ };
+ if old_alias == alias {
+ return Ok(HttpResponse::NoContent().finish());
+ }
+ if db::client_alias_exists(db, &alias).await.unwrap() {
+ yeet!(AliasTakenError::new(&alias).into());
+ }
+
+ db::update_client_alias(db, id, &alias).await.unwrap();
+
+ let response = HttpResponse::NoContent().finish();
+ Ok(response)
+}
+
+#[put("/{id}/type")]
+async fn update_client_type(
+ id: web::Path<Uuid>,
+ body: web::Json<ClientType>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let ty = body.0;
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id).into());
+ }
+
+ db::update_client_type(db, id, ty).await.unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+#[put("/{id}/allowed-scopes")]
+async fn update_client_allowed_scopes(
+ id: web::Path<Uuid>,
+ body: web::Json<Box<[Box<str>]>>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let allowed_scopes = body.0.join(" ");
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id).into());
+ }
+
+ db::update_client_allowed_scopes(db, id, &allowed_scopes)
+ .await
+ .unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+#[put("/{id}/default-scopes")]
+async fn update_client_default_scopes(
+ id: web::Path<Uuid>,
+ body: web::Json<Option<Box<[Box<str>]>>>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let default_scopes = body.0.map(|s| s.join(" "));
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id).into());
+ }
+
+ db::update_client_default_scopes(db, id, default_scopes)
+ .await
+ .unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+#[put("/{id}/is-trusted")]
+async fn update_client_is_trusted(
+ id: web::Path<Uuid>,
+ body: web::Json<bool>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+ let is_trusted = *body;
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id).into());
+ }
+
+ db::update_client_trusted(db, id, is_trusted).await.unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+#[put("/{id}/redirect-uris")]
+async fn update_client_redirect_uris(
+ id: web::Path<Uuid>,
+ body: web::Json<Box<[Url]>>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+
+ for uri in body.0.iter() {
+ if uri.scheme() != "https" {
+ yeet!(CreateClientError::NonHttpsUri.into());
+ }
+
+ if uri.fragment().is_some() {
+ yeet!(CreateClientError::UriFragment.into())
+ }
+ }
+
+ if !db::client_id_exists(db, id).await.unwrap() {
+ yeet!(ClientNotFound::new(id).into());
+ }
+
+ let transaction = db.begin().await.unwrap();
+ db::update_client_redirect_uris(transaction, id, &body.0)
+ .await
+ .unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+#[put("{id}/secret")]
+async fn update_client_secret(
+ id: web::Path<Uuid>,
+ body: web::Json<Option<Box<str>>>,
+ db: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateClientError> {
+ let db = db.get_ref();
+ let id = *id;
+
+ let Some(client_type) = db::get_client_type(db, id).await.unwrap() else {
+ yeet!(ClientNotFound::new(id).into())
+ };
+
+ if client_type == ClientType::Confidential && body.is_none() {
+ yeet!(CreateClientError::NoSecret.into())
+ }
+
+ let secret = body.0.map(|s| PasswordHash::new(&s).unwrap());
+ db::update_client_secret(db, id, secret).await.unwrap();
+
+ Ok(HttpResponse::NoContent().finish())
+}
+
+pub fn service() -> Scope {
+ web::scope("/clients")
+ .service(get_client)
+ .service(get_client_alias)
+ .service(get_client_type)
+ .service(get_client_allowed_scopes)
+ .service(get_client_default_scopes)
+ .service(get_client_redirect_uris)
+ .service(get_client_is_trusted)
+ .service(create_client)
+ .service(update_client)
+ .service(update_client_alias)
+ .service(update_client_type)
+ .service(update_client_allowed_scopes)
+ .service(update_client_default_scopes)
+ .service(update_client_redirect_uris)
+ .service(update_client_secret)
+}
diff --git a/src/api/liveops.rs b/src/api/liveops.rs index d4bf129..2caf6e3 100644 --- a/src/api/liveops.rs +++ b/src/api/liveops.rs @@ -1,11 +1,11 @@ -use actix_web::{get, web, HttpResponse, Scope}; - -/// Simple ping -#[get("/ping")] -async fn ping() -> HttpResponse { - HttpResponse::Ok().finish() -} - -pub fn service() -> Scope { - web::scope("/liveops").service(ping) -} +use actix_web::{get, web, HttpResponse, Scope};
+
+/// Simple ping
+#[get("/ping")]
+async fn ping() -> HttpResponse {
+ HttpResponse::Ok().finish()
+}
+
+pub fn service() -> Scope {
+ web::scope("/liveops").service(ping)
+}
diff --git a/src/api/mod.rs b/src/api/mod.rs index 0ab4037..9059e71 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,13 +1,13 @@ -mod clients; -mod liveops; -mod oauth; -mod ops; -mod users; - -pub use clients::service as clients; -pub use liveops::service as liveops; -pub use oauth::service as oauth; -pub use ops::service as ops; -pub use users::service as users; - -pub use oauth::AuthorizationParameters; +mod clients;
+mod liveops;
+mod oauth;
+mod ops;
+mod users;
+
+pub use clients::service as clients;
+pub use liveops::service as liveops;
+pub use oauth::service as oauth;
+pub use ops::service as ops;
+pub use users::service as users;
+
+pub use oauth::AuthorizationParameters;
diff --git a/src/api/oauth.rs b/src/api/oauth.rs index f1aa012..3422d2f 100644 --- a/src/api/oauth.rs +++ b/src/api/oauth.rs @@ -1,926 +1,926 @@ -use std::ops::Deref; -use std::str::FromStr; - -use actix_web::http::{header, StatusCode}; -use actix_web::{ - get, post, web, HttpRequest, HttpResponse, HttpResponseBuilder, ResponseError, Scope, -}; -use chrono::Duration; -use exun::{Expect, RawUnexpected, ResultErrorExt, UnexpectedError}; -use raise::yeet; -use serde::{Deserialize, Serialize}; -use sqlx::MySqlPool; -use tera::Tera; -use thiserror::Error; -use unic_langid::subtags::Language; -use url::Url; -use uuid::Uuid; - -use crate::models::client::ClientType; -use crate::resources::{languages, templates}; -use crate::scopes; -use crate::services::jwt::VerifyJwtError; -use crate::services::{authorization, config, db, jwt}; - -const REALLY_BAD_ERROR_PAGE: &str = "<!DOCTYPE html><html><head><title>Internal Server Error</title></head><body>Internal Server Error</body></html>"; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] -#[serde(rename_all = "snake_case")] -enum ResponseType { - Code, - Token, - #[serde(other)] - Unsupported, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AuthorizationParameters { - response_type: ResponseType, - client_id: Box<str>, - redirect_uri: Option<Url>, - scope: Option<Box<str>>, - state: Option<Box<str>>, -} - -#[derive(Clone, Deserialize)] -struct AuthorizeCredentials { - username: Box<str>, - password: Box<str>, -} - -#[derive(Clone, Serialize)] -struct AuthCodeResponse { - code: Box<str>, - state: Option<Box<str>>, -} - -#[derive(Clone, Serialize)] -struct AuthTokenResponse { - access_token: Box<str>, - token_type: &'static str, - expires_in: i64, - scope: Box<str>, - state: Option<Box<str>>, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)] -#[serde(rename_all = "camelCase")] -enum AuthorizeErrorType { - InvalidRequest, - UnauthorizedClient, - AccessDenied, - UnsupportedResponseType, - InvalidScope, - ServerError, - TemporarilyUnavailable, -} - -#[derive(Debug, Clone, Error, Serialize)] -#[error("{error_description}")] -struct AuthorizeError { - error: AuthorizeErrorType, - error_description: Box<str>, - // TODO error uri - state: Option<Box<str>>, - #[serde(skip)] - redirect_uri: Url, -} - -impl AuthorizeError { - fn no_scope(redirect_uri: Url, state: Option<Box<str>>) -> Self { - Self { - error: AuthorizeErrorType::InvalidScope, - error_description: Box::from( - "No scope was provided, and the client does not have a default scope", - ), - state, - redirect_uri, - } - } - - fn unsupported_response_type(redirect_uri: Url, state: Option<Box<str>>) -> Self { - Self { - error: AuthorizeErrorType::UnsupportedResponseType, - error_description: Box::from("The given response type is not supported"), - state, - redirect_uri, - } - } - - fn invalid_scope(redirect_uri: Url, state: Option<Box<str>>) -> Self { - Self { - error: AuthorizeErrorType::InvalidScope, - error_description: Box::from("The given scope exceeds what the client is allowed"), - state, - redirect_uri, - } - } - - fn internal_server_error(redirect_uri: Url, state: Option<Box<str>>) -> Self { - Self { - error: AuthorizeErrorType::ServerError, - error_description: "An unexpected error occurred".into(), - state, - redirect_uri, - } - } -} - -impl ResponseError for AuthorizeError { - fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> { - let query = Some(serde_urlencoded::to_string(self).unwrap()); - let query = query.as_deref(); - let mut url = self.redirect_uri.clone(); - url.set_query(query); - - HttpResponse::Found() - .insert_header((header::LOCATION, url.as_str())) - .finish() - } -} - -fn error_page( - tera: &Tera, - translations: &languages::Translations, - error: templates::ErrorPage, -) -> Result<String, RawUnexpected> { - // TODO find a better way of doing languages - let language = Language::from_str("en").unwrap(); - let translations = translations.clone(); - let page = templates::error_page(&tera, language, translations, error)?; - Ok(page) -} - -async fn get_redirect_uri( - redirect_uri: &Option<Url>, - db: &MySqlPool, - client_id: Uuid, -) -> Result<Url, Expect<templates::ErrorPage>> { - if let Some(uri) = &redirect_uri { - let redirect_uri = uri.clone(); - if !db::client_has_redirect_uri(db, client_id, &redirect_uri) - .await - .map_err(|e| UnexpectedError::from(e)) - .unexpect()? - { - yeet!(Expect::Expected(templates::ErrorPage::InvalidRedirectUri)); - } - - Ok(redirect_uri) - } else { - let redirect_uris = db::get_client_redirect_uris(db, client_id) - .await - .map_err(|e| UnexpectedError::from(e)) - .unexpect()?; - if redirect_uris.len() != 1 { - yeet!(Expect::Expected(templates::ErrorPage::MissingRedirectUri)); - } - - Ok(redirect_uris.get(0).unwrap().clone()) - } -} - -async fn get_scope( - scope: &Option<Box<str>>, - db: &MySqlPool, - client_id: Uuid, - redirect_uri: &Url, - state: &Option<Box<str>>, -) -> Result<Box<str>, Expect<AuthorizeError>> { - let scope = if let Some(scope) = &scope { - scope.clone() - } else { - let default_scopes = db::get_client_default_scopes(db, client_id) - .await - .unwrap() - .unwrap(); - let Some(scope) = default_scopes else { - yeet!(AuthorizeError::no_scope(redirect_uri.clone(), state.clone()).into()) - }; - scope - }; - - // verify scope is valid - let allowed_scopes = db::get_client_allowed_scopes(db, client_id) - .await - .unwrap() - .unwrap(); - if !scopes::is_subset_of(&scope, &allowed_scopes) { - yeet!(AuthorizeError::invalid_scope(redirect_uri.clone(), state.clone()).into()); - } - - Ok(scope) -} - -async fn authenticate_user( - db: &MySqlPool, - username: &str, - password: &str, -) -> Result<Option<Uuid>, RawUnexpected> { - let Some(user) = db::get_user_by_username(db, username).await? else { - return Ok(None); - }; - - if user.check_password(password)? { - Ok(Some(user.id)) - } else { - Ok(None) - } -} - -#[post("/authorize")] -async fn authorize( - db: web::Data<MySqlPool>, - req: web::Query<AuthorizationParameters>, - credentials: web::Json<AuthorizeCredentials>, - tera: web::Data<Tera>, - translations: web::Data<languages::Translations>, -) -> Result<HttpResponse, AuthorizeError> { - // TODO protect against brute force attacks - let db = db.get_ref(); - let Ok(client_id) = db::get_client_id_by_alias(db, &req.client_id).await else { - let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); - return Ok(HttpResponse::InternalServerError().content_type("text/html").body(page)); - }; - let Some(client_id) = client_id else { - let page = error_page(&tera, &translations, templates::ErrorPage::ClientNotFound).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); - return Ok(HttpResponse::NotFound().content_type("text/html").body(page)); - }; - let Ok(config) = config::get_config() else { - let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); - return Ok(HttpResponse::InternalServerError().content_type("text/html").body(page)); - }; - - let self_id = config.url; - let state = req.state.clone(); - - // get redirect uri - let mut redirect_uri = match get_redirect_uri(&req.redirect_uri, db, client_id).await { - Ok(uri) => uri, - Err(e) => { - let e = e - .expected() - .unwrap_or(templates::ErrorPage::InternalServerError); - let page = error_page(&tera, &translations, e) - .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); - return Ok(HttpResponse::BadRequest() - .content_type("text/html") - .body(page)); - } - }; - - // authenticate user - let Some(user_id) = authenticate_user(db, &credentials.username, &credentials.password) - .await - .unwrap() else - { - let language = Language::from_str("en").unwrap(); - let translations = translations.get_ref().clone(); - let page = templates::login_error_page(&tera, &req, language, translations).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); - return Ok(HttpResponse::Ok().content_type("text/html").body(page)); - }; - - let internal_server_error = - AuthorizeError::internal_server_error(redirect_uri.clone(), state.clone()); - - // get scope - let scope = match get_scope(&req.scope, db, client_id, &redirect_uri, &state).await { - Ok(scope) => scope, - Err(e) => { - let e = e.expected().unwrap_or(internal_server_error); - return Err(e); - } - }; - - match req.response_type { - ResponseType::Code => { - // create auth code - let code = - jwt::Claims::auth_code(db, self_id, client_id, user_id, &scope, &redirect_uri) - .await - .map_err(|_| internal_server_error.clone())?; - let code = code.to_jwt().map_err(|_| internal_server_error.clone())?; - - let response = AuthCodeResponse { code, state }; - let query = - Some(serde_urlencoded::to_string(response).map_err(|_| internal_server_error)?); - let query = query.as_deref(); - redirect_uri.set_query(query); - - Ok(HttpResponse::Found() - .append_header((header::LOCATION, redirect_uri.as_str())) - .finish()) - } - ResponseType::Token => { - // create access token - let duration = Duration::hours(1); - let access_token = - jwt::Claims::access_token(db, None, self_id, client_id, user_id, duration, &scope) - .await - .map_err(|_| internal_server_error.clone())?; - - let access_token = access_token - .to_jwt() - .map_err(|_| internal_server_error.clone())?; - let expires_in = duration.num_seconds(); - let token_type = "bearer"; - let response = AuthTokenResponse { - access_token, - expires_in, - token_type, - scope, - state, - }; - - let fragment = Some( - serde_urlencoded::to_string(response).map_err(|_| internal_server_error.clone())?, - ); - let fragment = fragment.as_deref(); - redirect_uri.set_fragment(fragment); - - Ok(HttpResponse::Found() - .append_header((header::LOCATION, redirect_uri.as_str())) - .finish()) - } - _ => Err(AuthorizeError::invalid_scope(redirect_uri, state)), - } -} - -#[get("/authorize")] -async fn authorize_page( - db: web::Data<MySqlPool>, - tera: web::Data<Tera>, - translations: web::Data<languages::Translations>, - request: HttpRequest, -) -> Result<HttpResponse, AuthorizeError> { - let Ok(language) = Language::from_str("en") else { - let page = String::from(REALLY_BAD_ERROR_PAGE); - return Ok(HttpResponse::InternalServerError() - .content_type("text/html") - .body(page)); - }; - let translations = translations.get_ref().clone(); - - let params = request.query_string(); - let params = serde_urlencoded::from_str::<AuthorizationParameters>(params); - let Ok(params) = params else { - let page = error_page( - &tera, - &translations, - templates::ErrorPage::InvalidRequest, - ) - .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); - return Ok(HttpResponse::BadRequest() - .content_type("text/html") - .body(page)); - }; - - let db = db.get_ref(); - let Ok(client_id) = db::get_client_id_by_alias(db, ¶ms.client_id).await else { - let page = templates::error_page( - &tera, - language, - translations, - templates::ErrorPage::InternalServerError, - ) - .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); - return Ok(HttpResponse::InternalServerError() - .content_type("text/html") - .body(page)); - }; - let Some(client_id) = client_id else { - let page = templates::error_page( - &tera, - language, - translations, - templates::ErrorPage::ClientNotFound, - ) - .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); - return Ok(HttpResponse::NotFound() - .content_type("text/html") - .body(page)); - }; - - // verify redirect uri - let redirect_uri = match get_redirect_uri(¶ms.redirect_uri, db, client_id).await { - Ok(uri) => uri, - Err(e) => { - let e = e - .expected() - .unwrap_or(templates::ErrorPage::InternalServerError); - let page = error_page(&tera, &translations, e) - .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE)); - return Ok(HttpResponse::BadRequest() - .content_type("text/html") - .body(page)); - } - }; - - let state = ¶ms.state; - let internal_server_error = - AuthorizeError::internal_server_error(redirect_uri.clone(), state.clone()); - - // verify scope - let _ = match get_scope(¶ms.scope, db, client_id, &redirect_uri, ¶ms.state).await { - Ok(scope) => scope, - Err(e) => { - let e = e.expected().unwrap_or(internal_server_error); - return Err(e); - } - }; - - // verify response type - if params.response_type == ResponseType::Unsupported { - return Err(AuthorizeError::unsupported_response_type( - redirect_uri, - params.state, - )); - } - - // TODO find a better way of doing languages - let language = Language::from_str("en").unwrap(); - let page = templates::login_page(&tera, ¶ms, language, translations).unwrap(); - Ok(HttpResponse::Ok().content_type("text/html").body(page)) -} - -#[derive(Clone, Deserialize)] -#[serde(tag = "grant_type")] -#[serde(rename_all = "snake_case")] -enum GrantType { - AuthorizationCode { - code: Box<str>, - redirect_uri: Url, - #[serde(rename = "client_id")] - client_alias: Box<str>, - }, - Password { - username: Box<str>, - password: Box<str>, - scope: Option<Box<str>>, - }, - ClientCredentials { - scope: Option<Box<str>>, - }, - RefreshToken { - refresh_token: Box<str>, - scope: Option<Box<str>>, - }, - #[serde(other)] - Unsupported, -} - -#[derive(Clone, Deserialize)] -struct TokenRequest { - #[serde(flatten)] - grant_type: GrantType, - // TODO support optional client credentials in here -} - -#[derive(Clone, Serialize)] -struct TokenResponse { - access_token: Box<str>, - token_type: Box<str>, - expires_in: i64, - refresh_token: Option<Box<str>>, - scope: Box<str>, -} - -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "snake_case")] -enum TokenErrorType { - InvalidRequest, - InvalidClient, - InvalidGrant, - UnauthorizedClient, - UnsupportedGrantType, - InvalidScope, -} - -#[derive(Debug, Clone, Error, Serialize)] -#[error("{error_description}")] -struct TokenError { - #[serde(skip)] - status_code: StatusCode, - error: TokenErrorType, - error_description: Box<str>, - // TODO error uri -} - -impl TokenError { - fn invalid_request() -> Self { - // TODO make this description better, and all the other ones while you're at it - Self { - status_code: StatusCode::BAD_REQUEST, - error: TokenErrorType::InvalidRequest, - error_description: "Invalid request".into(), - } - } - - fn unsupported_grant_type() -> Self { - Self { - status_code: StatusCode::BAD_REQUEST, - error: TokenErrorType::UnsupportedGrantType, - error_description: "The given grant type is not supported".into(), - } - } - - fn bad_auth_code(error: VerifyJwtError) -> Self { - Self { - status_code: StatusCode::BAD_REQUEST, - error: TokenErrorType::InvalidGrant, - error_description: error.to_string().into_boxed_str(), - } - } - - fn no_authorization() -> Self { - Self { - status_code: StatusCode::UNAUTHORIZED, - error: TokenErrorType::InvalidClient, - error_description: Box::from( - "Client credentials must be provided in the HTTP Authorization header", - ), - } - } - - fn client_not_found(alias: &str) -> Self { - Self { - status_code: StatusCode::UNAUTHORIZED, - error: TokenErrorType::InvalidClient, - error_description: format!("No client with the client id: {alias} was found") - .into_boxed_str(), - } - } - - fn mismatch_client_id() -> Self { - Self { - status_code: StatusCode::UNAUTHORIZED, - error: TokenErrorType::InvalidClient, - error_description: Box::from("The client ID in the Authorization header is not the same as the client ID in the request body"), - } - } - - fn incorrect_client_secret() -> Self { - Self { - status_code: StatusCode::UNAUTHORIZED, - error: TokenErrorType::InvalidClient, - error_description: "The client secret is incorrect".into(), - } - } - - fn client_not_confidential(alias: &str) -> Self { - Self { - status_code: StatusCode::BAD_REQUEST, - error: TokenErrorType::UnauthorizedClient, - error_description: format!("Only a confidential client may be used with this endpoint. The {alias} client is a public client.") - .into_boxed_str(), - } - } - - fn no_scope() -> Self { - Self { - status_code: StatusCode::BAD_REQUEST, - error: TokenErrorType::InvalidScope, - error_description: Box::from( - "No scope was provided, and the client doesn't have a default scope", - ), - } - } - - fn excessive_scope() -> Self { - Self { - status_code: StatusCode::BAD_REQUEST, - error: TokenErrorType::InvalidScope, - error_description: Box::from( - "The given scope exceeds what the client is allowed to have", - ), - } - } - - fn bad_refresh_token(err: VerifyJwtError) -> Self { - Self { - status_code: StatusCode::BAD_REQUEST, - error: TokenErrorType::InvalidGrant, - error_description: err.to_string().into_boxed_str(), - } - } - - fn untrusted_client() -> Self { - Self { - status_code: StatusCode::UNAUTHORIZED, - error: TokenErrorType::InvalidClient, - error_description: "Only trusted clients may use this grant".into(), - } - } - - fn incorrect_user_credentials() -> Self { - Self { - status_code: StatusCode::BAD_REQUEST, - error: TokenErrorType::InvalidRequest, - error_description: "The given credentials are incorrect".into(), - } - } -} - -impl ResponseError for TokenError { - fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> { - let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]); - - let mut builder = HttpResponseBuilder::new(self.status_code); - - if self.status_code.as_u16() == 401 { - builder.insert_header((header::WWW_AUTHENTICATE, "Basic charset=\"UTF-8\"")); - } - - builder - .insert_header(cache_control) - .insert_header((header::PRAGMA, "no-cache")) - .json(self.clone()) - } -} - -#[post("/token")] -async fn token( - db: web::Data<MySqlPool>, - req: web::Bytes, - authorization: Option<web::Header<authorization::BasicAuthorization>>, -) -> HttpResponse { - // TODO protect against brute force attacks - let db = db.get_ref(); - let request = serde_json::from_slice::<TokenRequest>(&req); - let Ok(request) = request else { - return TokenError::invalid_request().error_response(); - }; - let config = config::get_config().unwrap(); - - let self_id = config.url; - let duration = Duration::hours(1); - let token_type = Box::from("bearer"); - let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]); - - match request.grant_type { - GrantType::AuthorizationCode { - code, - redirect_uri, - client_alias, - } => { - let Some(client_id) = db::get_client_id_by_alias(db, &client_alias).await.unwrap() else { - return TokenError::client_not_found(&client_alias).error_response(); - }; - - // validate auth code - let claims = - match jwt::verify_auth_code(db, &code, &self_id, client_id, redirect_uri).await { - Ok(claims) => claims, - Err(err) => { - let err = err.unwrap(); - return TokenError::bad_auth_code(err).error_response(); - } - }; - - // verify client, if the client has credentials - if let Some(hash) = db::get_client_secret(db, client_id).await.unwrap() { - let Some(authorization) = authorization else { - return TokenError::no_authorization().error_response(); - }; - - if authorization.username() != client_alias.deref() { - return TokenError::mismatch_client_id().error_response(); - } - if !hash.check_password(authorization.password()).unwrap() { - return TokenError::incorrect_client_secret().error_response(); - } - } - - let access_token = jwt::Claims::access_token( - db, - Some(claims.id()), - self_id, - client_id, - claims.subject(), - duration, - claims.scopes(), - ) - .await - .unwrap(); - - let expires_in = access_token.expires_in(); - let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap(); - let scope = access_token.scopes().into(); - - let access_token = access_token.to_jwt().unwrap(); - let refresh_token = Some(refresh_token.to_jwt().unwrap()); - - let response = TokenResponse { - access_token, - token_type, - expires_in, - refresh_token, - scope, - }; - HttpResponse::Ok() - .insert_header(cache_control) - .insert_header((header::PRAGMA, "no-cache")) - .json(response) - } - GrantType::Password { - username, - password, - scope, - } => { - let Some(authorization) = authorization else { - return TokenError::no_authorization().error_response(); - }; - let client_alias = authorization.username(); - let Some(client_id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else { - return TokenError::client_not_found(client_alias).error_response(); - }; - - let trusted = db::is_client_trusted(db, client_id).await.unwrap().unwrap(); - if !trusted { - return TokenError::untrusted_client().error_response(); - } - - // verify client - let hash = db::get_client_secret(db, client_id).await.unwrap().unwrap(); - if !hash.check_password(authorization.password()).unwrap() { - return TokenError::incorrect_client_secret().error_response(); - } - - // authenticate user - let Some(user_id) = authenticate_user(db, &username, &password).await.unwrap() else { - return TokenError::incorrect_user_credentials().error_response(); - }; - - // verify scope - let allowed_scopes = db::get_client_allowed_scopes(db, client_id) - .await - .unwrap() - .unwrap(); - let scope = if let Some(scope) = &scope { - scope.clone() - } else { - let default_scopes = db::get_client_default_scopes(db, client_id) - .await - .unwrap() - .unwrap(); - let Some(scope) = default_scopes else { - return TokenError::no_scope().error_response(); - }; - scope - }; - if !scopes::is_subset_of(&scope, &allowed_scopes) { - return TokenError::excessive_scope().error_response(); - } - - let access_token = - jwt::Claims::access_token(db, None, self_id, client_id, user_id, duration, &scope) - .await - .unwrap(); - let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap(); - - let expires_in = access_token.expires_in(); - let scope = access_token.scopes().into(); - let access_token = access_token.to_jwt().unwrap(); - let refresh_token = Some(refresh_token.to_jwt().unwrap()); - - let response = TokenResponse { - access_token, - token_type, - expires_in, - refresh_token, - scope, - }; - HttpResponse::Ok() - .insert_header(cache_control) - .insert_header((header::PRAGMA, "no-cache")) - .json(response) - } - GrantType::ClientCredentials { scope } => { - let Some(authorization) = authorization else { - return TokenError::no_authorization().error_response(); - }; - let client_alias = authorization.username(); - let Some(client_id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else { - return TokenError::client_not_found(client_alias).error_response(); - }; - - let ty = db::get_client_type(db, client_id).await.unwrap().unwrap(); - if ty != ClientType::Confidential { - return TokenError::client_not_confidential(client_alias).error_response(); - } - - // verify client - let hash = db::get_client_secret(db, client_id).await.unwrap().unwrap(); - if !hash.check_password(authorization.password()).unwrap() { - return TokenError::incorrect_client_secret().error_response(); - } - - // verify scope - let allowed_scopes = db::get_client_allowed_scopes(db, client_id) - .await - .unwrap() - .unwrap(); - let scope = if let Some(scope) = &scope { - scope.clone() - } else { - let default_scopes = db::get_client_default_scopes(db, client_id) - .await - .unwrap() - .unwrap(); - let Some(scope) = default_scopes else { - return TokenError::no_scope().error_response(); - }; - scope - }; - if !scopes::is_subset_of(&scope, &allowed_scopes) { - return TokenError::excessive_scope().error_response(); - } - - let access_token = jwt::Claims::access_token( - db, None, self_id, client_id, client_id, duration, &scope, - ) - .await - .unwrap(); - - let expires_in = access_token.expires_in(); - let scope = access_token.scopes().into(); - let access_token = access_token.to_jwt().unwrap(); - - let response = TokenResponse { - access_token, - token_type, - expires_in, - refresh_token: None, - scope, - }; - HttpResponse::Ok() - .insert_header(cache_control) - .insert_header((header::PRAGMA, "no-cache")) - .json(response) - } - GrantType::RefreshToken { - refresh_token, - scope, - } => { - let client_id: Option<Uuid>; - if let Some(authorization) = authorization { - let client_alias = authorization.username(); - let Some(id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else { - return TokenError::client_not_found(client_alias).error_response(); - }; - client_id = Some(id); - } else { - client_id = None; - } - - let claims = - match jwt::verify_refresh_token(db, &refresh_token, &self_id, client_id).await { - Ok(claims) => claims, - Err(e) => { - let e = e.unwrap(); - return TokenError::bad_refresh_token(e).error_response(); - } - }; - - let scope = if let Some(scope) = scope { - if !scopes::is_subset_of(&scope, claims.scopes()) { - return TokenError::excessive_scope().error_response(); - } - - scope - } else { - claims.scopes().into() - }; - - let exp_time = Duration::hours(1); - let access_token = jwt::Claims::refreshed_access_token(db, &claims, exp_time) - .await - .unwrap(); - let refresh_token = jwt::Claims::refresh_token(db, &claims).await.unwrap(); - - let access_token = access_token.to_jwt().unwrap(); - let refresh_token = Some(refresh_token.to_jwt().unwrap()); - let expires_in = exp_time.num_seconds(); - - let response = TokenResponse { - access_token, - token_type, - expires_in, - refresh_token, - scope, - }; - HttpResponse::Ok() - .insert_header(cache_control) - .insert_header((header::PRAGMA, "no-cache")) - .json(response) - } - _ => TokenError::unsupported_grant_type().error_response(), - } -} - -pub fn service() -> Scope { - web::scope("/oauth") - .service(authorize_page) - .service(authorize) - .service(token) -} +use std::ops::Deref;
+use std::str::FromStr;
+
+use actix_web::http::{header, StatusCode};
+use actix_web::{
+ get, post, web, HttpRequest, HttpResponse, HttpResponseBuilder, ResponseError, Scope,
+};
+use chrono::Duration;
+use exun::{Expect, RawUnexpected, ResultErrorExt, UnexpectedError};
+use raise::yeet;
+use serde::{Deserialize, Serialize};
+use sqlx::MySqlPool;
+use tera::Tera;
+use thiserror::Error;
+use unic_langid::subtags::Language;
+use url::Url;
+use uuid::Uuid;
+
+use crate::models::client::ClientType;
+use crate::resources::{languages, templates};
+use crate::scopes;
+use crate::services::jwt::VerifyJwtError;
+use crate::services::{authorization, config, db, jwt};
+
+const REALLY_BAD_ERROR_PAGE: &str = "<!DOCTYPE html><html><head><title>Internal Server Error</title></head><body>Internal Server Error</body></html>";
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+enum ResponseType {
+ Code,
+ Token,
+ #[serde(other)]
+ Unsupported,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct AuthorizationParameters {
+ response_type: ResponseType,
+ client_id: Box<str>,
+ redirect_uri: Option<Url>,
+ scope: Option<Box<str>>,
+ state: Option<Box<str>>,
+}
+
+#[derive(Clone, Deserialize)]
+struct AuthorizeCredentials {
+ username: Box<str>,
+ password: Box<str>,
+}
+
+#[derive(Clone, Serialize)]
+struct AuthCodeResponse {
+ code: Box<str>,
+ state: Option<Box<str>>,
+}
+
+#[derive(Clone, Serialize)]
+struct AuthTokenResponse {
+ access_token: Box<str>,
+ token_type: &'static str,
+ expires_in: i64,
+ scope: Box<str>,
+ state: Option<Box<str>>,
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)]
+#[serde(rename_all = "camelCase")]
+enum AuthorizeErrorType {
+ InvalidRequest,
+ UnauthorizedClient,
+ AccessDenied,
+ UnsupportedResponseType,
+ InvalidScope,
+ ServerError,
+ TemporarilyUnavailable,
+}
+
+#[derive(Debug, Clone, Error, Serialize)]
+#[error("{error_description}")]
+struct AuthorizeError {
+ error: AuthorizeErrorType,
+ error_description: Box<str>,
+ // TODO error uri
+ state: Option<Box<str>>,
+ #[serde(skip)]
+ redirect_uri: Url,
+}
+
+impl AuthorizeError {
+ fn no_scope(redirect_uri: Url, state: Option<Box<str>>) -> Self {
+ Self {
+ error: AuthorizeErrorType::InvalidScope,
+ error_description: Box::from(
+ "No scope was provided, and the client does not have a default scope",
+ ),
+ state,
+ redirect_uri,
+ }
+ }
+
+ fn unsupported_response_type(redirect_uri: Url, state: Option<Box<str>>) -> Self {
+ Self {
+ error: AuthorizeErrorType::UnsupportedResponseType,
+ error_description: Box::from("The given response type is not supported"),
+ state,
+ redirect_uri,
+ }
+ }
+
+ fn invalid_scope(redirect_uri: Url, state: Option<Box<str>>) -> Self {
+ Self {
+ error: AuthorizeErrorType::InvalidScope,
+ error_description: Box::from("The given scope exceeds what the client is allowed"),
+ state,
+ redirect_uri,
+ }
+ }
+
+ fn internal_server_error(redirect_uri: Url, state: Option<Box<str>>) -> Self {
+ Self {
+ error: AuthorizeErrorType::ServerError,
+ error_description: "An unexpected error occurred".into(),
+ state,
+ redirect_uri,
+ }
+ }
+}
+
+impl ResponseError for AuthorizeError {
+ fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
+ let query = Some(serde_urlencoded::to_string(self).unwrap());
+ let query = query.as_deref();
+ let mut url = self.redirect_uri.clone();
+ url.set_query(query);
+
+ HttpResponse::Found()
+ .insert_header((header::LOCATION, url.as_str()))
+ .finish()
+ }
+}
+
+fn error_page(
+ tera: &Tera,
+ translations: &languages::Translations,
+ error: templates::ErrorPage,
+) -> Result<String, RawUnexpected> {
+ // TODO find a better way of doing languages
+ let language = Language::from_str("en").unwrap();
+ let translations = translations.clone();
+ let page = templates::error_page(&tera, language, translations, error)?;
+ Ok(page)
+}
+
+async fn get_redirect_uri(
+ redirect_uri: &Option<Url>,
+ db: &MySqlPool,
+ client_id: Uuid,
+) -> Result<Url, Expect<templates::ErrorPage>> {
+ if let Some(uri) = &redirect_uri {
+ let redirect_uri = uri.clone();
+ if !db::client_has_redirect_uri(db, client_id, &redirect_uri)
+ .await
+ .map_err(|e| UnexpectedError::from(e))
+ .unexpect()?
+ {
+ yeet!(Expect::Expected(templates::ErrorPage::InvalidRedirectUri));
+ }
+
+ Ok(redirect_uri)
+ } else {
+ let redirect_uris = db::get_client_redirect_uris(db, client_id)
+ .await
+ .map_err(|e| UnexpectedError::from(e))
+ .unexpect()?;
+ if redirect_uris.len() != 1 {
+ yeet!(Expect::Expected(templates::ErrorPage::MissingRedirectUri));
+ }
+
+ Ok(redirect_uris.get(0).unwrap().clone())
+ }
+}
+
+async fn get_scope(
+ scope: &Option<Box<str>>,
+ db: &MySqlPool,
+ client_id: Uuid,
+ redirect_uri: &Url,
+ state: &Option<Box<str>>,
+) -> Result<Box<str>, Expect<AuthorizeError>> {
+ let scope = if let Some(scope) = &scope {
+ scope.clone()
+ } else {
+ let default_scopes = db::get_client_default_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let Some(scope) = default_scopes else {
+ yeet!(AuthorizeError::no_scope(redirect_uri.clone(), state.clone()).into())
+ };
+ scope
+ };
+
+ // verify scope is valid
+ let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ if !scopes::is_subset_of(&scope, &allowed_scopes) {
+ yeet!(AuthorizeError::invalid_scope(redirect_uri.clone(), state.clone()).into());
+ }
+
+ Ok(scope)
+}
+
+async fn authenticate_user(
+ db: &MySqlPool,
+ username: &str,
+ password: &str,
+) -> Result<Option<Uuid>, RawUnexpected> {
+ let Some(user) = db::get_user_by_username(db, username).await? else {
+ return Ok(None);
+ };
+
+ if user.check_password(password)? {
+ Ok(Some(user.id))
+ } else {
+ Ok(None)
+ }
+}
+
+#[post("/authorize")]
+async fn authorize(
+ db: web::Data<MySqlPool>,
+ req: web::Query<AuthorizationParameters>,
+ credentials: web::Json<AuthorizeCredentials>,
+ tera: web::Data<Tera>,
+ translations: web::Data<languages::Translations>,
+) -> Result<HttpResponse, AuthorizeError> {
+ // TODO protect against brute force attacks
+ let db = db.get_ref();
+ let Ok(client_id) = db::get_client_id_by_alias(db, &req.client_id).await else {
+ let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::InternalServerError().content_type("text/html").body(page));
+ };
+ let Some(client_id) = client_id else {
+ let page = error_page(&tera, &translations, templates::ErrorPage::ClientNotFound).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::NotFound().content_type("text/html").body(page));
+ };
+ let Ok(config) = config::get_config() else {
+ let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::InternalServerError().content_type("text/html").body(page));
+ };
+
+ let self_id = config.url;
+ let state = req.state.clone();
+
+ // get redirect uri
+ let mut redirect_uri = match get_redirect_uri(&req.redirect_uri, db, client_id).await {
+ Ok(uri) => uri,
+ Err(e) => {
+ let e = e
+ .expected()
+ .unwrap_or(templates::ErrorPage::InternalServerError);
+ let page = error_page(&tera, &translations, e)
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::BadRequest()
+ .content_type("text/html")
+ .body(page));
+ }
+ };
+
+ // authenticate user
+ let Some(user_id) = authenticate_user(db, &credentials.username, &credentials.password)
+ .await
+ .unwrap() else
+ {
+ let language = Language::from_str("en").unwrap();
+ let translations = translations.get_ref().clone();
+ let page = templates::login_error_page(&tera, &req, language, translations).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::Ok().content_type("text/html").body(page));
+ };
+
+ let internal_server_error =
+ AuthorizeError::internal_server_error(redirect_uri.clone(), state.clone());
+
+ // get scope
+ let scope = match get_scope(&req.scope, db, client_id, &redirect_uri, &state).await {
+ Ok(scope) => scope,
+ Err(e) => {
+ let e = e.expected().unwrap_or(internal_server_error);
+ return Err(e);
+ }
+ };
+
+ match req.response_type {
+ ResponseType::Code => {
+ // create auth code
+ let code =
+ jwt::Claims::auth_code(db, self_id, client_id, user_id, &scope, &redirect_uri)
+ .await
+ .map_err(|_| internal_server_error.clone())?;
+ let code = code.to_jwt().map_err(|_| internal_server_error.clone())?;
+
+ let response = AuthCodeResponse { code, state };
+ let query =
+ Some(serde_urlencoded::to_string(response).map_err(|_| internal_server_error)?);
+ let query = query.as_deref();
+ redirect_uri.set_query(query);
+
+ Ok(HttpResponse::Found()
+ .append_header((header::LOCATION, redirect_uri.as_str()))
+ .finish())
+ }
+ ResponseType::Token => {
+ // create access token
+ let duration = Duration::hours(1);
+ let access_token =
+ jwt::Claims::access_token(db, None, self_id, client_id, user_id, duration, &scope)
+ .await
+ .map_err(|_| internal_server_error.clone())?;
+
+ let access_token = access_token
+ .to_jwt()
+ .map_err(|_| internal_server_error.clone())?;
+ let expires_in = duration.num_seconds();
+ let token_type = "bearer";
+ let response = AuthTokenResponse {
+ access_token,
+ expires_in,
+ token_type,
+ scope,
+ state,
+ };
+
+ let fragment = Some(
+ serde_urlencoded::to_string(response).map_err(|_| internal_server_error.clone())?,
+ );
+ let fragment = fragment.as_deref();
+ redirect_uri.set_fragment(fragment);
+
+ Ok(HttpResponse::Found()
+ .append_header((header::LOCATION, redirect_uri.as_str()))
+ .finish())
+ }
+ _ => Err(AuthorizeError::invalid_scope(redirect_uri, state)),
+ }
+}
+
+#[get("/authorize")]
+async fn authorize_page(
+ db: web::Data<MySqlPool>,
+ tera: web::Data<Tera>,
+ translations: web::Data<languages::Translations>,
+ request: HttpRequest,
+) -> Result<HttpResponse, AuthorizeError> {
+ let Ok(language) = Language::from_str("en") else {
+ let page = String::from(REALLY_BAD_ERROR_PAGE);
+ return Ok(HttpResponse::InternalServerError()
+ .content_type("text/html")
+ .body(page));
+ };
+ let translations = translations.get_ref().clone();
+
+ let params = request.query_string();
+ let params = serde_urlencoded::from_str::<AuthorizationParameters>(params);
+ let Ok(params) = params else {
+ let page = error_page(
+ &tera,
+ &translations,
+ templates::ErrorPage::InvalidRequest,
+ )
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::BadRequest()
+ .content_type("text/html")
+ .body(page));
+ };
+
+ let db = db.get_ref();
+ let Ok(client_id) = db::get_client_id_by_alias(db, ¶ms.client_id).await else {
+ let page = templates::error_page(
+ &tera,
+ language,
+ translations,
+ templates::ErrorPage::InternalServerError,
+ )
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::InternalServerError()
+ .content_type("text/html")
+ .body(page));
+ };
+ let Some(client_id) = client_id else {
+ let page = templates::error_page(
+ &tera,
+ language,
+ translations,
+ templates::ErrorPage::ClientNotFound,
+ )
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::NotFound()
+ .content_type("text/html")
+ .body(page));
+ };
+
+ // verify redirect uri
+ let redirect_uri = match get_redirect_uri(¶ms.redirect_uri, db, client_id).await {
+ Ok(uri) => uri,
+ Err(e) => {
+ let e = e
+ .expected()
+ .unwrap_or(templates::ErrorPage::InternalServerError);
+ let page = error_page(&tera, &translations, e)
+ .unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::BadRequest()
+ .content_type("text/html")
+ .body(page));
+ }
+ };
+
+ let state = ¶ms.state;
+ let internal_server_error =
+ AuthorizeError::internal_server_error(redirect_uri.clone(), state.clone());
+
+ // verify scope
+ let _ = match get_scope(¶ms.scope, db, client_id, &redirect_uri, ¶ms.state).await {
+ Ok(scope) => scope,
+ Err(e) => {
+ let e = e.expected().unwrap_or(internal_server_error);
+ return Err(e);
+ }
+ };
+
+ // verify response type
+ if params.response_type == ResponseType::Unsupported {
+ return Err(AuthorizeError::unsupported_response_type(
+ redirect_uri,
+ params.state,
+ ));
+ }
+
+ // TODO find a better way of doing languages
+ let language = Language::from_str("en").unwrap();
+ let page = templates::login_page(&tera, ¶ms, language, translations).unwrap();
+ Ok(HttpResponse::Ok().content_type("text/html").body(page))
+}
+
+#[derive(Clone, Deserialize)]
+#[serde(tag = "grant_type")]
+#[serde(rename_all = "snake_case")]
+enum GrantType {
+ AuthorizationCode {
+ code: Box<str>,
+ redirect_uri: Url,
+ #[serde(rename = "client_id")]
+ client_alias: Box<str>,
+ },
+ Password {
+ username: Box<str>,
+ password: Box<str>,
+ scope: Option<Box<str>>,
+ },
+ ClientCredentials {
+ scope: Option<Box<str>>,
+ },
+ RefreshToken {
+ refresh_token: Box<str>,
+ scope: Option<Box<str>>,
+ },
+ #[serde(other)]
+ Unsupported,
+}
+
+#[derive(Clone, Deserialize)]
+struct TokenRequest {
+ #[serde(flatten)]
+ grant_type: GrantType,
+ // TODO support optional client credentials in here
+}
+
+#[derive(Clone, Serialize)]
+struct TokenResponse {
+ access_token: Box<str>,
+ token_type: Box<str>,
+ expires_in: i64,
+ refresh_token: Option<Box<str>>,
+ scope: Box<str>,
+}
+
+#[derive(Debug, Clone, Serialize)]
+#[serde(rename_all = "snake_case")]
+enum TokenErrorType {
+ InvalidRequest,
+ InvalidClient,
+ InvalidGrant,
+ UnauthorizedClient,
+ UnsupportedGrantType,
+ InvalidScope,
+}
+
+#[derive(Debug, Clone, Error, Serialize)]
+#[error("{error_description}")]
+struct TokenError {
+ #[serde(skip)]
+ status_code: StatusCode,
+ error: TokenErrorType,
+ error_description: Box<str>,
+ // TODO error uri
+}
+
+impl TokenError {
+ fn invalid_request() -> Self {
+ // TODO make this description better, and all the other ones while you're at it
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidRequest,
+ error_description: "Invalid request".into(),
+ }
+ }
+
+ fn unsupported_grant_type() -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::UnsupportedGrantType,
+ error_description: "The given grant type is not supported".into(),
+ }
+ }
+
+ fn bad_auth_code(error: VerifyJwtError) -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidGrant,
+ error_description: error.to_string().into_boxed_str(),
+ }
+ }
+
+ fn no_authorization() -> Self {
+ Self {
+ status_code: StatusCode::UNAUTHORIZED,
+ error: TokenErrorType::InvalidClient,
+ error_description: Box::from(
+ "Client credentials must be provided in the HTTP Authorization header",
+ ),
+ }
+ }
+
+ fn client_not_found(alias: &str) -> Self {
+ Self {
+ status_code: StatusCode::UNAUTHORIZED,
+ error: TokenErrorType::InvalidClient,
+ error_description: format!("No client with the client id: {alias} was found")
+ .into_boxed_str(),
+ }
+ }
+
+ fn mismatch_client_id() -> Self {
+ Self {
+ status_code: StatusCode::UNAUTHORIZED,
+ error: TokenErrorType::InvalidClient,
+ error_description: Box::from("The client ID in the Authorization header is not the same as the client ID in the request body"),
+ }
+ }
+
+ fn incorrect_client_secret() -> Self {
+ Self {
+ status_code: StatusCode::UNAUTHORIZED,
+ error: TokenErrorType::InvalidClient,
+ error_description: "The client secret is incorrect".into(),
+ }
+ }
+
+ fn client_not_confidential(alias: &str) -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::UnauthorizedClient,
+ error_description: format!("Only a confidential client may be used with this endpoint. The {alias} client is a public client.")
+ .into_boxed_str(),
+ }
+ }
+
+ fn no_scope() -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidScope,
+ error_description: Box::from(
+ "No scope was provided, and the client doesn't have a default scope",
+ ),
+ }
+ }
+
+ fn excessive_scope() -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidScope,
+ error_description: Box::from(
+ "The given scope exceeds what the client is allowed to have",
+ ),
+ }
+ }
+
+ fn bad_refresh_token(err: VerifyJwtError) -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidGrant,
+ error_description: err.to_string().into_boxed_str(),
+ }
+ }
+
+ fn untrusted_client() -> Self {
+ Self {
+ status_code: StatusCode::UNAUTHORIZED,
+ error: TokenErrorType::InvalidClient,
+ error_description: "Only trusted clients may use this grant".into(),
+ }
+ }
+
+ fn incorrect_user_credentials() -> Self {
+ Self {
+ status_code: StatusCode::BAD_REQUEST,
+ error: TokenErrorType::InvalidRequest,
+ error_description: "The given credentials are incorrect".into(),
+ }
+ }
+}
+
+impl ResponseError for TokenError {
+ fn error_response(&self) -> HttpResponse<actix_web::body::BoxBody> {
+ let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]);
+
+ let mut builder = HttpResponseBuilder::new(self.status_code);
+
+ if self.status_code.as_u16() == 401 {
+ builder.insert_header((header::WWW_AUTHENTICATE, "Basic charset=\"UTF-8\""));
+ }
+
+ builder
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(self.clone())
+ }
+}
+
+#[post("/token")]
+async fn token(
+ db: web::Data<MySqlPool>,
+ req: web::Bytes,
+ authorization: Option<web::Header<authorization::BasicAuthorization>>,
+) -> HttpResponse {
+ // TODO protect against brute force attacks
+ let db = db.get_ref();
+ let request = serde_json::from_slice::<TokenRequest>(&req);
+ let Ok(request) = request else {
+ return TokenError::invalid_request().error_response();
+ };
+ let config = config::get_config().unwrap();
+
+ let self_id = config.url;
+ let duration = Duration::hours(1);
+ let token_type = Box::from("bearer");
+ let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]);
+
+ match request.grant_type {
+ GrantType::AuthorizationCode {
+ code,
+ redirect_uri,
+ client_alias,
+ } => {
+ let Some(client_id) = db::get_client_id_by_alias(db, &client_alias).await.unwrap() else {
+ return TokenError::client_not_found(&client_alias).error_response();
+ };
+
+ // validate auth code
+ let claims =
+ match jwt::verify_auth_code(db, &code, &self_id, client_id, redirect_uri).await {
+ Ok(claims) => claims,
+ Err(err) => {
+ let err = err.unwrap();
+ return TokenError::bad_auth_code(err).error_response();
+ }
+ };
+
+ // verify client, if the client has credentials
+ if let Some(hash) = db::get_client_secret(db, client_id).await.unwrap() {
+ let Some(authorization) = authorization else {
+ return TokenError::no_authorization().error_response();
+ };
+
+ if authorization.username() != client_alias.deref() {
+ return TokenError::mismatch_client_id().error_response();
+ }
+ if !hash.check_password(authorization.password()).unwrap() {
+ return TokenError::incorrect_client_secret().error_response();
+ }
+ }
+
+ let access_token = jwt::Claims::access_token(
+ db,
+ Some(claims.id()),
+ self_id,
+ client_id,
+ claims.subject(),
+ duration,
+ claims.scopes(),
+ )
+ .await
+ .unwrap();
+
+ let expires_in = access_token.expires_in();
+ let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap();
+ let scope = access_token.scopes().into();
+
+ let access_token = access_token.to_jwt().unwrap();
+ let refresh_token = Some(refresh_token.to_jwt().unwrap());
+
+ let response = TokenResponse {
+ access_token,
+ token_type,
+ expires_in,
+ refresh_token,
+ scope,
+ };
+ HttpResponse::Ok()
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(response)
+ }
+ GrantType::Password {
+ username,
+ password,
+ scope,
+ } => {
+ let Some(authorization) = authorization else {
+ return TokenError::no_authorization().error_response();
+ };
+ let client_alias = authorization.username();
+ let Some(client_id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else {
+ return TokenError::client_not_found(client_alias).error_response();
+ };
+
+ let trusted = db::is_client_trusted(db, client_id).await.unwrap().unwrap();
+ if !trusted {
+ return TokenError::untrusted_client().error_response();
+ }
+
+ // verify client
+ let hash = db::get_client_secret(db, client_id).await.unwrap().unwrap();
+ if !hash.check_password(authorization.password()).unwrap() {
+ return TokenError::incorrect_client_secret().error_response();
+ }
+
+ // authenticate user
+ let Some(user_id) = authenticate_user(db, &username, &password).await.unwrap() else {
+ return TokenError::incorrect_user_credentials().error_response();
+ };
+
+ // verify scope
+ let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let scope = if let Some(scope) = &scope {
+ scope.clone()
+ } else {
+ let default_scopes = db::get_client_default_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let Some(scope) = default_scopes else {
+ return TokenError::no_scope().error_response();
+ };
+ scope
+ };
+ if !scopes::is_subset_of(&scope, &allowed_scopes) {
+ return TokenError::excessive_scope().error_response();
+ }
+
+ let access_token =
+ jwt::Claims::access_token(db, None, self_id, client_id, user_id, duration, &scope)
+ .await
+ .unwrap();
+ let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap();
+
+ let expires_in = access_token.expires_in();
+ let scope = access_token.scopes().into();
+ let access_token = access_token.to_jwt().unwrap();
+ let refresh_token = Some(refresh_token.to_jwt().unwrap());
+
+ let response = TokenResponse {
+ access_token,
+ token_type,
+ expires_in,
+ refresh_token,
+ scope,
+ };
+ HttpResponse::Ok()
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(response)
+ }
+ GrantType::ClientCredentials { scope } => {
+ let Some(authorization) = authorization else {
+ return TokenError::no_authorization().error_response();
+ };
+ let client_alias = authorization.username();
+ let Some(client_id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else {
+ return TokenError::client_not_found(client_alias).error_response();
+ };
+
+ let ty = db::get_client_type(db, client_id).await.unwrap().unwrap();
+ if ty != ClientType::Confidential {
+ return TokenError::client_not_confidential(client_alias).error_response();
+ }
+
+ // verify client
+ let hash = db::get_client_secret(db, client_id).await.unwrap().unwrap();
+ if !hash.check_password(authorization.password()).unwrap() {
+ return TokenError::incorrect_client_secret().error_response();
+ }
+
+ // verify scope
+ let allowed_scopes = db::get_client_allowed_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let scope = if let Some(scope) = &scope {
+ scope.clone()
+ } else {
+ let default_scopes = db::get_client_default_scopes(db, client_id)
+ .await
+ .unwrap()
+ .unwrap();
+ let Some(scope) = default_scopes else {
+ return TokenError::no_scope().error_response();
+ };
+ scope
+ };
+ if !scopes::is_subset_of(&scope, &allowed_scopes) {
+ return TokenError::excessive_scope().error_response();
+ }
+
+ let access_token = jwt::Claims::access_token(
+ db, None, self_id, client_id, client_id, duration, &scope,
+ )
+ .await
+ .unwrap();
+
+ let expires_in = access_token.expires_in();
+ let scope = access_token.scopes().into();
+ let access_token = access_token.to_jwt().unwrap();
+
+ let response = TokenResponse {
+ access_token,
+ token_type,
+ expires_in,
+ refresh_token: None,
+ scope,
+ };
+ HttpResponse::Ok()
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(response)
+ }
+ GrantType::RefreshToken {
+ refresh_token,
+ scope,
+ } => {
+ let client_id: Option<Uuid>;
+ if let Some(authorization) = authorization {
+ let client_alias = authorization.username();
+ let Some(id) = db::get_client_id_by_alias(db, client_alias).await.unwrap() else {
+ return TokenError::client_not_found(client_alias).error_response();
+ };
+ client_id = Some(id);
+ } else {
+ client_id = None;
+ }
+
+ let claims =
+ match jwt::verify_refresh_token(db, &refresh_token, &self_id, client_id).await {
+ Ok(claims) => claims,
+ Err(e) => {
+ let e = e.unwrap();
+ return TokenError::bad_refresh_token(e).error_response();
+ }
+ };
+
+ let scope = if let Some(scope) = scope {
+ if !scopes::is_subset_of(&scope, claims.scopes()) {
+ return TokenError::excessive_scope().error_response();
+ }
+
+ scope
+ } else {
+ claims.scopes().into()
+ };
+
+ let exp_time = Duration::hours(1);
+ let access_token = jwt::Claims::refreshed_access_token(db, &claims, exp_time)
+ .await
+ .unwrap();
+ let refresh_token = jwt::Claims::refresh_token(db, &claims).await.unwrap();
+
+ let access_token = access_token.to_jwt().unwrap();
+ let refresh_token = Some(refresh_token.to_jwt().unwrap());
+ let expires_in = exp_time.num_seconds();
+
+ let response = TokenResponse {
+ access_token,
+ token_type,
+ expires_in,
+ refresh_token,
+ scope,
+ };
+ HttpResponse::Ok()
+ .insert_header(cache_control)
+ .insert_header((header::PRAGMA, "no-cache"))
+ .json(response)
+ }
+ _ => TokenError::unsupported_grant_type().error_response(),
+ }
+}
+
+pub fn service() -> Scope {
+ web::scope("/oauth")
+ .service(authorize_page)
+ .service(authorize)
+ .service(token)
+}
diff --git a/src/api/ops.rs b/src/api/ops.rs index 555bb1b..2164f1f 100644 --- a/src/api/ops.rs +++ b/src/api/ops.rs @@ -1,70 +1,70 @@ -use std::str::FromStr; - -use actix_web::{get, http::StatusCode, post, web, HttpResponse, ResponseError, Scope}; -use raise::yeet; -use serde::Deserialize; -use sqlx::MySqlPool; -use tera::Tera; -use thiserror::Error; -use unic_langid::subtags::Language; - -use crate::resources::{languages, templates}; -use crate::services::db; - -/// A request to login -#[derive(Debug, Clone, Deserialize)] -struct LoginRequest { - username: Box<str>, - password: Box<str>, -} - -/// An error occurred when authenticating, because either the username or -/// password was invalid. -#[derive(Debug, Clone, Error)] -enum LoginFailure { - #[error("No user found with the given username")] - UserNotFound { username: Box<str> }, - #[error("The given password is incorrect")] - IncorrectPassword { username: Box<str> }, -} - -impl ResponseError for LoginFailure { - fn status_code(&self) -> actix_web::http::StatusCode { - match self { - Self::UserNotFound { .. } => StatusCode::NOT_FOUND, - Self::IncorrectPassword { .. } => StatusCode::UNAUTHORIZED, - } - } -} - -/// Returns `200` if login was successful. -/// Returns `404` if the username is invalid. -/// Returns `401` if the password was invalid. -#[post("/login")] -async fn login( - body: web::Json<LoginRequest>, - conn: web::Data<MySqlPool>, -) -> Result<HttpResponse, LoginFailure> { - let conn = conn.get_ref(); - - let user = db::get_user_by_username(conn, &body.username) - .await - .unwrap(); - let Some(user) = user else { - yeet!(LoginFailure::UserNotFound{ username: body.username.clone() }); - }; - - let good_password = user.check_password(&body.password).unwrap(); - let response = if good_password { - HttpResponse::Ok().finish() - } else { - yeet!(LoginFailure::IncorrectPassword { - username: body.username.clone() - }); - }; - Ok(response) -} - -pub fn service() -> Scope { - web::scope("").service(login) -} +use std::str::FromStr;
+
+use actix_web::{get, http::StatusCode, post, web, HttpResponse, ResponseError, Scope};
+use raise::yeet;
+use serde::Deserialize;
+use sqlx::MySqlPool;
+use tera::Tera;
+use thiserror::Error;
+use unic_langid::subtags::Language;
+
+use crate::resources::{languages, templates};
+use crate::services::db;
+
+/// A request to login
+#[derive(Debug, Clone, Deserialize)]
+struct LoginRequest {
+ username: Box<str>,
+ password: Box<str>,
+}
+
+/// An error occurred when authenticating, because either the username or
+/// password was invalid.
+#[derive(Debug, Clone, Error)]
+enum LoginFailure {
+ #[error("No user found with the given username")]
+ UserNotFound { username: Box<str> },
+ #[error("The given password is incorrect")]
+ IncorrectPassword { username: Box<str> },
+}
+
+impl ResponseError for LoginFailure {
+ fn status_code(&self) -> actix_web::http::StatusCode {
+ match self {
+ Self::UserNotFound { .. } => StatusCode::NOT_FOUND,
+ Self::IncorrectPassword { .. } => StatusCode::UNAUTHORIZED,
+ }
+ }
+}
+
+/// Returns `200` if login was successful.
+/// Returns `404` if the username is invalid.
+/// Returns `401` if the password was invalid.
+#[post("/login")]
+async fn login(
+ body: web::Json<LoginRequest>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, LoginFailure> {
+ let conn = conn.get_ref();
+
+ let user = db::get_user_by_username(conn, &body.username)
+ .await
+ .unwrap();
+ let Some(user) = user else {
+ yeet!(LoginFailure::UserNotFound{ username: body.username.clone() });
+ };
+
+ let good_password = user.check_password(&body.password).unwrap();
+ let response = if good_password {
+ HttpResponse::Ok().finish()
+ } else {
+ yeet!(LoginFailure::IncorrectPassword {
+ username: body.username.clone()
+ });
+ };
+ Ok(response)
+}
+
+pub fn service() -> Scope {
+ web::scope("").service(login)
+}
diff --git a/src/api/users.rs b/src/api/users.rs index 391a059..da2a0d0 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -1,272 +1,272 @@ -use actix_web::http::{header, StatusCode}; -use actix_web::{get, post, put, web, HttpResponse, ResponseError, Scope}; -use raise::yeet; -use serde::{Deserialize, Serialize}; -use sqlx::MySqlPool; -use thiserror::Error; -use uuid::Uuid; - -use crate::models::user::User; -use crate::services::crypto::PasswordHash; -use crate::services::{db, id}; - -/// Just a username. No password hash, because that'd be tempting fate. -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "camelCase")] -struct UserResponse { - id: Uuid, - username: Box<str>, -} - -impl From<User> for UserResponse { - fn from(user: User) -> Self { - Self { - id: user.id, - username: user.username, - } - } -} - -#[derive(Debug, Clone, Deserialize)] -#[serde(rename_all = "camelCase")] -struct SearchUsers { - username: Option<Box<str>>, - limit: Option<u32>, - offset: Option<u32>, -} - -#[get("")] -async fn search_users(params: web::Query<SearchUsers>, conn: web::Data<MySqlPool>) -> HttpResponse { - let conn = conn.get_ref(); - - let username = params.username.clone().unwrap_or_default(); - let offset = params.offset.unwrap_or_default(); - - let results: Box<[UserResponse]> = if let Some(limit) = params.limit { - db::search_users_limit(conn, &username, offset, limit) - .await - .unwrap() - .iter() - .cloned() - .map(|u| u.into()) - .collect() - } else { - db::search_users(conn, &username) - .await - .unwrap() - .into_iter() - .skip(offset as usize) - .cloned() - .map(|u| u.into()) - .collect() - }; - - let response = HttpResponse::Ok().json(results); - response -} - -#[derive(Debug, Clone, Error)] -#[error("No user with the given ID exists")] -struct UserNotFoundError { - user_id: Uuid, -} - -impl ResponseError for UserNotFoundError { - fn status_code(&self) -> StatusCode { - StatusCode::NOT_FOUND - } -} - -#[get("/{user_id}")] -async fn get_user( - user_id: web::Path<Uuid>, - conn: web::Data<MySqlPool>, -) -> Result<HttpResponse, UserNotFoundError> { - let conn = conn.get_ref(); - - let id = user_id.to_owned(); - let username = db::get_username(conn, id).await.unwrap(); - - let Some(username) = username else { - yeet!(UserNotFoundError { user_id: id }); - }; - - let response = UserResponse { id, username }; - let response = HttpResponse::Ok().json(response); - Ok(response) -} - -#[get("/{user_id}/username")] -async fn get_username( - user_id: web::Path<Uuid>, - conn: web::Data<MySqlPool>, -) -> Result<HttpResponse, UserNotFoundError> { - let conn = conn.get_ref(); - - let user_id = user_id.to_owned(); - let username = db::get_username(conn, user_id).await.unwrap(); - - let Some(username) = username else { - yeet!(UserNotFoundError { user_id }); - }; - - let response = HttpResponse::Ok().json(username); - Ok(response) -} - -/// A request to create or update user information -#[derive(Clone, Deserialize)] -#[serde(rename_all = "camelCase")] -struct UserRequest { - username: Box<str>, - password: Box<str>, -} - -#[derive(Debug, Clone, Error)] -#[error("An account with the given username already exists.")] -struct UsernameTakenError { - username: Box<str>, -} - -impl ResponseError for UsernameTakenError { - fn status_code(&self) -> StatusCode { - StatusCode::CONFLICT - } -} - -#[post("")] -async fn create_user( - body: web::Json<UserRequest>, - conn: web::Data<MySqlPool>, -) -> Result<HttpResponse, UsernameTakenError> { - let conn = conn.get_ref(); - - let user_id = id::new_id(conn, db::user_id_exists).await.unwrap(); - let username = body.username.clone(); - let password = PasswordHash::new(&body.password).unwrap(); - - if db::username_is_used(conn, &body.username).await.unwrap() { - yeet!(UsernameTakenError { username }); - } - - let user = User { - id: user_id, - username, - password, - }; - - db::create_user(conn, &user).await.unwrap(); - - let response = HttpResponse::Created() - .insert_header((header::LOCATION, format!("users/{user_id}"))) - .finish(); - Ok(response) -} - -#[derive(Debug, Clone, Error)] -enum UpdateUserError { - #[error(transparent)] - UsernameTaken(#[from] UsernameTakenError), - #[error(transparent)] - NotFound(#[from] UserNotFoundError), -} - -impl ResponseError for UpdateUserError { - fn status_code(&self) -> StatusCode { - match self { - Self::UsernameTaken(e) => e.status_code(), - Self::NotFound(e) => e.status_code(), - } - } -} - -#[put("/{user_id}")] -async fn update_user( - user_id: web::Path<Uuid>, - body: web::Json<UserRequest>, - conn: web::Data<MySqlPool>, -) -> Result<HttpResponse, UpdateUserError> { - let conn = conn.get_ref(); - - let user_id = user_id.to_owned(); - let username = body.username.clone(); - let password = PasswordHash::new(&body.password).unwrap(); - - let old_username = db::get_username(conn, user_id).await.unwrap().unwrap(); - if username != old_username && db::username_is_used(conn, &body.username).await.unwrap() { - yeet!(UsernameTakenError { username }.into()) - } - - if !db::user_id_exists(conn, user_id).await.unwrap() { - yeet!(UserNotFoundError { user_id }.into()) - } - - let user = User { - id: user_id, - username, - password, - }; - - db::update_user(conn, &user).await.unwrap(); - - let response = HttpResponse::NoContent().finish(); - Ok(response) -} - -#[put("/{user_id}/username")] -async fn update_username( - user_id: web::Path<Uuid>, - body: web::Json<Box<str>>, - conn: web::Data<MySqlPool>, -) -> Result<HttpResponse, UpdateUserError> { - let conn = conn.get_ref(); - - let user_id = user_id.to_owned(); - let username = body.clone(); - - let old_username = db::get_username(conn, user_id).await.unwrap().unwrap(); - if username != old_username && db::username_is_used(conn, &body).await.unwrap() { - yeet!(UsernameTakenError { username }.into()) - } - - if !db::user_id_exists(conn, user_id).await.unwrap() { - yeet!(UserNotFoundError { user_id }.into()) - } - - db::update_username(conn, user_id, &body).await.unwrap(); - - let response = HttpResponse::NoContent().finish(); - Ok(response) -} - -#[put("/{user_id}/password")] -async fn update_password( - user_id: web::Path<Uuid>, - body: web::Json<Box<str>>, - conn: web::Data<MySqlPool>, -) -> Result<HttpResponse, UserNotFoundError> { - let conn = conn.get_ref(); - - let user_id = user_id.to_owned(); - let password = PasswordHash::new(&body).unwrap(); - - if !db::user_id_exists(conn, user_id).await.unwrap() { - yeet!(UserNotFoundError { user_id }) - } - - db::update_password(conn, user_id, &password).await.unwrap(); - - let response = HttpResponse::NoContent().finish(); - Ok(response) -} - -pub fn service() -> Scope { - web::scope("/users") - .service(search_users) - .service(get_user) - .service(get_username) - .service(create_user) - .service(update_user) - .service(update_username) - .service(update_password) -} +use actix_web::http::{header, StatusCode};
+use actix_web::{get, post, put, web, HttpResponse, ResponseError, Scope};
+use raise::yeet;
+use serde::{Deserialize, Serialize};
+use sqlx::MySqlPool;
+use thiserror::Error;
+use uuid::Uuid;
+
+use crate::models::user::User;
+use crate::services::crypto::PasswordHash;
+use crate::services::{db, id};
+
+/// Just a username. No password hash, because that'd be tempting fate.
+#[derive(Debug, Clone, Serialize)]
+#[serde(rename_all = "camelCase")]
+struct UserResponse {
+ id: Uuid,
+ username: Box<str>,
+}
+
+impl From<User> for UserResponse {
+ fn from(user: User) -> Self {
+ Self {
+ id: user.id,
+ username: user.username,
+ }
+ }
+}
+
+#[derive(Debug, Clone, Deserialize)]
+#[serde(rename_all = "camelCase")]
+struct SearchUsers {
+ username: Option<Box<str>>,
+ limit: Option<u32>,
+ offset: Option<u32>,
+}
+
+#[get("")]
+async fn search_users(params: web::Query<SearchUsers>, conn: web::Data<MySqlPool>) -> HttpResponse {
+ let conn = conn.get_ref();
+
+ let username = params.username.clone().unwrap_or_default();
+ let offset = params.offset.unwrap_or_default();
+
+ let results: Box<[UserResponse]> = if let Some(limit) = params.limit {
+ db::search_users_limit(conn, &username, offset, limit)
+ .await
+ .unwrap()
+ .iter()
+ .cloned()
+ .map(|u| u.into())
+ .collect()
+ } else {
+ db::search_users(conn, &username)
+ .await
+ .unwrap()
+ .into_iter()
+ .skip(offset as usize)
+ .cloned()
+ .map(|u| u.into())
+ .collect()
+ };
+
+ let response = HttpResponse::Ok().json(results);
+ response
+}
+
+#[derive(Debug, Clone, Error)]
+#[error("No user with the given ID exists")]
+struct UserNotFoundError {
+ user_id: Uuid,
+}
+
+impl ResponseError for UserNotFoundError {
+ fn status_code(&self) -> StatusCode {
+ StatusCode::NOT_FOUND
+ }
+}
+
+#[get("/{user_id}")]
+async fn get_user(
+ user_id: web::Path<Uuid>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UserNotFoundError> {
+ let conn = conn.get_ref();
+
+ let id = user_id.to_owned();
+ let username = db::get_username(conn, id).await.unwrap();
+
+ let Some(username) = username else {
+ yeet!(UserNotFoundError { user_id: id });
+ };
+
+ let response = UserResponse { id, username };
+ let response = HttpResponse::Ok().json(response);
+ Ok(response)
+}
+
+#[get("/{user_id}/username")]
+async fn get_username(
+ user_id: web::Path<Uuid>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UserNotFoundError> {
+ let conn = conn.get_ref();
+
+ let user_id = user_id.to_owned();
+ let username = db::get_username(conn, user_id).await.unwrap();
+
+ let Some(username) = username else {
+ yeet!(UserNotFoundError { user_id });
+ };
+
+ let response = HttpResponse::Ok().json(username);
+ Ok(response)
+}
+
+/// A request to create or update user information
+#[derive(Clone, Deserialize)]
+#[serde(rename_all = "camelCase")]
+struct UserRequest {
+ username: Box<str>,
+ password: Box<str>,
+}
+
+#[derive(Debug, Clone, Error)]
+#[error("An account with the given username already exists.")]
+struct UsernameTakenError {
+ username: Box<str>,
+}
+
+impl ResponseError for UsernameTakenError {
+ fn status_code(&self) -> StatusCode {
+ StatusCode::CONFLICT
+ }
+}
+
+#[post("")]
+async fn create_user(
+ body: web::Json<UserRequest>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UsernameTakenError> {
+ let conn = conn.get_ref();
+
+ let user_id = id::new_id(conn, db::user_id_exists).await.unwrap();
+ let username = body.username.clone();
+ let password = PasswordHash::new(&body.password).unwrap();
+
+ if db::username_is_used(conn, &body.username).await.unwrap() {
+ yeet!(UsernameTakenError { username });
+ }
+
+ let user = User {
+ id: user_id,
+ username,
+ password,
+ };
+
+ db::create_user(conn, &user).await.unwrap();
+
+ let response = HttpResponse::Created()
+ .insert_header((header::LOCATION, format!("users/{user_id}")))
+ .finish();
+ Ok(response)
+}
+
+#[derive(Debug, Clone, Error)]
+enum UpdateUserError {
+ #[error(transparent)]
+ UsernameTaken(#[from] UsernameTakenError),
+ #[error(transparent)]
+ NotFound(#[from] UserNotFoundError),
+}
+
+impl ResponseError for UpdateUserError {
+ fn status_code(&self) -> StatusCode {
+ match self {
+ Self::UsernameTaken(e) => e.status_code(),
+ Self::NotFound(e) => e.status_code(),
+ }
+ }
+}
+
+#[put("/{user_id}")]
+async fn update_user(
+ user_id: web::Path<Uuid>,
+ body: web::Json<UserRequest>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateUserError> {
+ let conn = conn.get_ref();
+
+ let user_id = user_id.to_owned();
+ let username = body.username.clone();
+ let password = PasswordHash::new(&body.password).unwrap();
+
+ let old_username = db::get_username(conn, user_id).await.unwrap().unwrap();
+ if username != old_username && db::username_is_used(conn, &body.username).await.unwrap() {
+ yeet!(UsernameTakenError { username }.into())
+ }
+
+ if !db::user_id_exists(conn, user_id).await.unwrap() {
+ yeet!(UserNotFoundError { user_id }.into())
+ }
+
+ let user = User {
+ id: user_id,
+ username,
+ password,
+ };
+
+ db::update_user(conn, &user).await.unwrap();
+
+ let response = HttpResponse::NoContent().finish();
+ Ok(response)
+}
+
+#[put("/{user_id}/username")]
+async fn update_username(
+ user_id: web::Path<Uuid>,
+ body: web::Json<Box<str>>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UpdateUserError> {
+ let conn = conn.get_ref();
+
+ let user_id = user_id.to_owned();
+ let username = body.clone();
+
+ let old_username = db::get_username(conn, user_id).await.unwrap().unwrap();
+ if username != old_username && db::username_is_used(conn, &body).await.unwrap() {
+ yeet!(UsernameTakenError { username }.into())
+ }
+
+ if !db::user_id_exists(conn, user_id).await.unwrap() {
+ yeet!(UserNotFoundError { user_id }.into())
+ }
+
+ db::update_username(conn, user_id, &body).await.unwrap();
+
+ let response = HttpResponse::NoContent().finish();
+ Ok(response)
+}
+
+#[put("/{user_id}/password")]
+async fn update_password(
+ user_id: web::Path<Uuid>,
+ body: web::Json<Box<str>>,
+ conn: web::Data<MySqlPool>,
+) -> Result<HttpResponse, UserNotFoundError> {
+ let conn = conn.get_ref();
+
+ let user_id = user_id.to_owned();
+ let password = PasswordHash::new(&body).unwrap();
+
+ if !db::user_id_exists(conn, user_id).await.unwrap() {
+ yeet!(UserNotFoundError { user_id })
+ }
+
+ db::update_password(conn, user_id, &password).await.unwrap();
+
+ let response = HttpResponse::NoContent().finish();
+ Ok(response)
+}
+
+pub fn service() -> Scope {
+ web::scope("/users")
+ .service(search_users)
+ .service(get_user)
+ .service(get_username)
+ .service(create_user)
+ .service(update_user)
+ .service(update_username)
+ .service(update_password)
+}
diff --git a/src/main.rs b/src/main.rs index e946161..e403798 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,108 +1,108 @@ -use std::time::Duration; - -use actix_web::http::header::{self, HeaderValue}; -use actix_web::middleware::{ErrorHandlerResponse, ErrorHandlers, Logger, NormalizePath}; -use actix_web::web::Data; -use actix_web::{dev, App, HttpServer}; - -use bpaf::Bpaf; -use exun::*; - -mod api; -mod models; -mod resources; -mod scopes; -mod services; - -use resources::*; -use services::*; -use sqlx::MySqlPool; - -fn error_content_language<B>( - mut res: dev::ServiceResponse, -) -> actix_web::Result<ErrorHandlerResponse<B>> { - res.response_mut() - .headers_mut() - .insert(header::CONTENT_LANGUAGE, HeaderValue::from_static("en")); - - Ok(ErrorHandlerResponse::Response(res.map_into_right_body())) -} - -async fn delete_expired_tokens(db: MySqlPool) { - let db = db.clone(); - let mut interval = actix_rt::time::interval(Duration::from_secs(60 * 20)); - loop { - interval.tick().await; - if let Err(e) = db::delete_expired_auth_codes(&db).await { - log::error!("{}", e); - } - if let Err(e) = db::delete_expired_access_tokens(&db).await { - log::error!("{}", e); - } - if let Err(e) = db::delete_expired_refresh_tokens(&db).await { - log::error!("{}", e); - } - } -} - -#[derive(Debug, Clone, Bpaf)] -#[bpaf(options, version)] -struct Opts { - /// The environment that the server is running in. Must be one of: local, - /// dev, staging, prod. - #[bpaf( - env("LOCKDAGGER_ENVIRONMENT"), - fallback(config::Environment::Local), - display_fallback - )] - env: config::Environment, -} - -#[actix_web::main] -async fn main() -> Result<(), RawUnexpected> { - // load the environment file, but only in debug mode - #[cfg(debug_assertions)] - dotenv::dotenv()?; - - let args = opts().run(); - config::set_environment(args.env); - - // initialize the database - let db_url = secrets::database_url()?; - let sql_pool = db::initialize(&db_url).await?; - - let tera = templates::initialize()?; - - let translations = languages::initialize()?; - - actix_rt::spawn(delete_expired_tokens(sql_pool.clone())); - - // start the server - HttpServer::new(move || { - App::new() - // middleware - .wrap(ErrorHandlers::new().default_handler(error_content_language)) - .wrap(NormalizePath::trim()) - .wrap(Logger::new("\"%r\" %s %Dms")) - // app shared state - .app_data(Data::new(sql_pool.clone())) - .app_data(Data::new(tera.clone())) - .app_data(Data::new(translations.clone())) - // frontend services - .service(style::get_css) - .service(scripts::get_js) - .service(languages::languages()) - // api services - .service(api::liveops()) - .service(api::users()) - .service(api::clients()) - .service(api::oauth()) - .service(api::ops()) - }) - .shutdown_timeout(1) - .bind(("127.0.0.1", 8080))? - .run() - .await?; - - Ok(()) -} +use std::time::Duration;
+
+use actix_web::http::header::{self, HeaderValue};
+use actix_web::middleware::{ErrorHandlerResponse, ErrorHandlers, Logger, NormalizePath};
+use actix_web::web::Data;
+use actix_web::{dev, App, HttpServer};
+
+use bpaf::Bpaf;
+use exun::*;
+
+mod api;
+mod models;
+mod resources;
+mod scopes;
+mod services;
+
+use resources::*;
+use services::*;
+use sqlx::MySqlPool;
+
+fn error_content_language<B>(
+ mut res: dev::ServiceResponse,
+) -> actix_web::Result<ErrorHandlerResponse<B>> {
+ res.response_mut()
+ .headers_mut()
+ .insert(header::CONTENT_LANGUAGE, HeaderValue::from_static("en"));
+
+ Ok(ErrorHandlerResponse::Response(res.map_into_right_body()))
+}
+
+async fn delete_expired_tokens(db: MySqlPool) {
+ let db = db.clone();
+ let mut interval = actix_rt::time::interval(Duration::from_secs(60 * 20));
+ loop {
+ interval.tick().await;
+ if let Err(e) = db::delete_expired_auth_codes(&db).await {
+ log::error!("{}", e);
+ }
+ if let Err(e) = db::delete_expired_access_tokens(&db).await {
+ log::error!("{}", e);
+ }
+ if let Err(e) = db::delete_expired_refresh_tokens(&db).await {
+ log::error!("{}", e);
+ }
+ }
+}
+
+#[derive(Debug, Clone, Bpaf)]
+#[bpaf(options, version)]
+struct Opts {
+ /// The environment that the server is running in. Must be one of: local,
+ /// dev, staging, prod.
+ #[bpaf(
+ env("LOCKDAGGER_ENVIRONMENT"),
+ fallback(config::Environment::Local),
+ display_fallback
+ )]
+ env: config::Environment,
+}
+
+#[actix_web::main]
+async fn main() -> Result<(), RawUnexpected> {
+ // load the environment file, but only in debug mode
+ #[cfg(debug_assertions)]
+ dotenv::dotenv()?;
+
+ let args = opts().run();
+ config::set_environment(args.env);
+
+ // initialize the database
+ let db_url = secrets::database_url()?;
+ let sql_pool = db::initialize(&db_url).await?;
+
+ let tera = templates::initialize()?;
+
+ let translations = languages::initialize()?;
+
+ actix_rt::spawn(delete_expired_tokens(sql_pool.clone()));
+
+ // start the server
+ HttpServer::new(move || {
+ App::new()
+ // middleware
+ .wrap(ErrorHandlers::new().default_handler(error_content_language))
+ .wrap(NormalizePath::trim())
+ .wrap(Logger::new("\"%r\" %s %Dms"))
+ // app shared state
+ .app_data(Data::new(sql_pool.clone()))
+ .app_data(Data::new(tera.clone()))
+ .app_data(Data::new(translations.clone()))
+ // frontend services
+ .service(style::get_css)
+ .service(scripts::get_js)
+ .service(languages::languages())
+ // api services
+ .service(api::liveops())
+ .service(api::users())
+ .service(api::clients())
+ .service(api::oauth())
+ .service(api::ops())
+ })
+ .shutdown_timeout(1)
+ .bind(("127.0.0.1", 8080))?
+ .run()
+ .await?;
+
+ Ok(())
+}
diff --git a/src/models/client.rs b/src/models/client.rs index 38be37f..6d0c909 100644 --- a/src/models/client.rs +++ b/src/models/client.rs @@ -1,165 +1,165 @@ -use std::{hash::Hash, marker::PhantomData}; - -use actix_web::{http::StatusCode, ResponseError}; -use exun::{Expect, RawUnexpected}; -use raise::yeet; -use serde::{Deserialize, Serialize}; -use thiserror::Error; -use url::Url; -use uuid::Uuid; - -use crate::services::crypto::PasswordHash; - -/// There are two types of clients, based on their ability to maintain the -/// security of their client credentials. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] -#[sqlx(rename_all = "lowercase")] -#[serde(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum ClientType { - /// A client that is capable of maintaining the confidentiality of their - /// credentials, or capable of secure client authentication using other - /// means. An example would be a secure server with restricted access to - /// the client credentials. - Confidential, - /// A client that is incapable of maintaining the confidentiality of their - /// credentials and cannot authenticate securely by any other means, such - /// as an installed application, or a web-browser based application. - Public, -} - -#[derive(Debug, Clone)] -pub struct Client { - id: Uuid, - ty: ClientType, - alias: Box<str>, - secret: Option<PasswordHash>, - allowed_scopes: Box<[Box<str>]>, - default_scopes: Option<Box<[Box<str>]>>, - redirect_uris: Box<[Url]>, - trusted: bool, -} - -impl PartialEq for Client { - fn eq(&self, other: &Self) -> bool { - self.id == other.id - } -} - -impl Eq for Client {} - -impl Hash for Client { - fn hash<H: std::hash::Hasher>(&self, state: &mut H) { - state.write_u128(self.id.as_u128()) - } -} - -#[derive(Debug, Clone, Copy, Error)] -#[error("Confidential clients must have a secret, but it was not provided")] -pub enum CreateClientError { - #[error("Confidential clients must have a secret, but it was not provided")] - NoSecret, - #[error("Only confidential clients may be trusted")] - TrustedError, - #[error("Redirect URIs must not include a fragment component")] - UriFragment, - #[error("Redirect URIs must use HTTPS")] - NonHttpsUri, -} - -impl ResponseError for CreateClientError { - fn status_code(&self) -> StatusCode { - StatusCode::BAD_REQUEST - } -} - -impl Client { - pub fn new( - id: Uuid, - alias: &str, - ty: ClientType, - secret: Option<&str>, - allowed_scopes: Box<[Box<str>]>, - default_scopes: Option<Box<[Box<str>]>>, - redirect_uris: &[Url], - trusted: bool, - ) -> Result<Self, Expect<CreateClientError>> { - let secret = if let Some(secret) = secret { - Some(PasswordHash::new(secret)?) - } else { - None - }; - - if ty == ClientType::Confidential && secret.is_none() { - yeet!(CreateClientError::NoSecret.into()); - } - - if ty == ClientType::Public && trusted { - yeet!(CreateClientError::TrustedError.into()); - } - - for redirect_uri in redirect_uris { - if redirect_uri.scheme() != "https" { - yeet!(CreateClientError::NonHttpsUri.into()) - } - - if redirect_uri.fragment().is_some() { - yeet!(CreateClientError::UriFragment.into()) - } - } - - Ok(Self { - id, - alias: Box::from(alias), - ty, - secret, - allowed_scopes, - default_scopes, - redirect_uris: redirect_uris.into_iter().cloned().collect(), - trusted, - }) - } - - pub fn id(&self) -> Uuid { - self.id - } - - pub fn alias(&self) -> &str { - &self.alias - } - - pub fn client_type(&self) -> ClientType { - self.ty - } - - pub fn redirect_uris(&self) -> &[Url] { - &self.redirect_uris - } - - pub fn secret_hash(&self) -> Option<&[u8]> { - self.secret.as_ref().map(|s| s.hash()) - } - - pub fn secret_salt(&self) -> Option<&[u8]> { - self.secret.as_ref().map(|s| s.salt()) - } - - pub fn secret_version(&self) -> Option<u8> { - self.secret.as_ref().map(|s| s.version()) - } - - pub fn allowed_scopes(&self) -> String { - self.allowed_scopes.join(" ") - } - - pub fn default_scopes(&self) -> Option<String> { - self.default_scopes.clone().map(|s| s.join(" ")) - } - - pub fn is_trusted(&self) -> bool { - self.trusted - } - - pub fn check_secret(&self, secret: &str) -> Option<Result<bool, RawUnexpected>> { - self.secret.as_ref().map(|s| s.check_password(secret)) - } -} +use std::{hash::Hash, marker::PhantomData};
+
+use actix_web::{http::StatusCode, ResponseError};
+use exun::{Expect, RawUnexpected};
+use raise::yeet;
+use serde::{Deserialize, Serialize};
+use thiserror::Error;
+use url::Url;
+use uuid::Uuid;
+
+use crate::services::crypto::PasswordHash;
+
+/// There are two types of clients, based on their ability to maintain the
+/// security of their client credentials.
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)]
+#[sqlx(rename_all = "lowercase")]
+#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
+pub enum ClientType {
+ /// A client that is capable of maintaining the confidentiality of their
+ /// credentials, or capable of secure client authentication using other
+ /// means. An example would be a secure server with restricted access to
+ /// the client credentials.
+ Confidential,
+ /// A client that is incapable of maintaining the confidentiality of their
+ /// credentials and cannot authenticate securely by any other means, such
+ /// as an installed application, or a web-browser based application.
+ Public,
+}
+
+#[derive(Debug, Clone)]
+pub struct Client {
+ id: Uuid,
+ ty: ClientType,
+ alias: Box<str>,
+ secret: Option<PasswordHash>,
+ allowed_scopes: Box<[Box<str>]>,
+ default_scopes: Option<Box<[Box<str>]>>,
+ redirect_uris: Box<[Url]>,
+ trusted: bool,
+}
+
+impl PartialEq for Client {
+ fn eq(&self, other: &Self) -> bool {
+ self.id == other.id
+ }
+}
+
+impl Eq for Client {}
+
+impl Hash for Client {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ state.write_u128(self.id.as_u128())
+ }
+}
+
+#[derive(Debug, Clone, Copy, Error)]
+#[error("Confidential clients must have a secret, but it was not provided")]
+pub enum CreateClientError {
+ #[error("Confidential clients must have a secret, but it was not provided")]
+ NoSecret,
+ #[error("Only confidential clients may be trusted")]
+ TrustedError,
+ #[error("Redirect URIs must not include a fragment component")]
+ UriFragment,
+ #[error("Redirect URIs must use HTTPS")]
+ NonHttpsUri,
+}
+
+impl ResponseError for CreateClientError {
+ fn status_code(&self) -> StatusCode {
+ StatusCode::BAD_REQUEST
+ }
+}
+
+impl Client {
+ pub fn new(
+ id: Uuid,
+ alias: &str,
+ ty: ClientType,
+ secret: Option<&str>,
+ allowed_scopes: Box<[Box<str>]>,
+ default_scopes: Option<Box<[Box<str>]>>,
+ redirect_uris: &[Url],
+ trusted: bool,
+ ) -> Result<Self, Expect<CreateClientError>> {
+ let secret = if let Some(secret) = secret {
+ Some(PasswordHash::new(secret)?)
+ } else {
+ None
+ };
+
+ if ty == ClientType::Confidential && secret.is_none() {
+ yeet!(CreateClientError::NoSecret.into());
+ }
+
+ if ty == ClientType::Public && trusted {
+ yeet!(CreateClientError::TrustedError.into());
+ }
+
+ for redirect_uri in redirect_uris {
+ if redirect_uri.scheme() != "https" {
+ yeet!(CreateClientError::NonHttpsUri.into())
+ }
+
+ if redirect_uri.fragment().is_some() {
+ yeet!(CreateClientError::UriFragment.into())
+ }
+ }
+
+ Ok(Self {
+ id,
+ alias: Box::from(alias),
+ ty,
+ secret,
+ allowed_scopes,
+ default_scopes,
+ redirect_uris: redirect_uris.into_iter().cloned().collect(),
+ trusted,
+ })
+ }
+
+ pub fn id(&self) -> Uuid {
+ self.id
+ }
+
+ pub fn alias(&self) -> &str {
+ &self.alias
+ }
+
+ pub fn client_type(&self) -> ClientType {
+ self.ty
+ }
+
+ pub fn redirect_uris(&self) -> &[Url] {
+ &self.redirect_uris
+ }
+
+ pub fn secret_hash(&self) -> Option<&[u8]> {
+ self.secret.as_ref().map(|s| s.hash())
+ }
+
+ pub fn secret_salt(&self) -> Option<&[u8]> {
+ self.secret.as_ref().map(|s| s.salt())
+ }
+
+ pub fn secret_version(&self) -> Option<u8> {
+ self.secret.as_ref().map(|s| s.version())
+ }
+
+ pub fn allowed_scopes(&self) -> String {
+ self.allowed_scopes.join(" ")
+ }
+
+ pub fn default_scopes(&self) -> Option<String> {
+ self.default_scopes.clone().map(|s| s.join(" "))
+ }
+
+ pub fn is_trusted(&self) -> bool {
+ self.trusted
+ }
+
+ pub fn check_secret(&self, secret: &str) -> Option<Result<bool, RawUnexpected>> {
+ self.secret.as_ref().map(|s| s.check_password(secret))
+ }
+}
diff --git a/src/models/mod.rs b/src/models/mod.rs index 633f846..1379893 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -1,2 +1,2 @@ -pub mod client; -pub mod user; +pub mod client;
+pub mod user;
diff --git a/src/models/user.rs b/src/models/user.rs index 8555ee2..493a267 100644 --- a/src/models/user.rs +++ b/src/models/user.rs @@ -1,49 +1,49 @@ -use std::hash::Hash; - -use exun::RawUnexpected; -use uuid::Uuid; - -use crate::services::crypto::PasswordHash; - -#[derive(Debug, Clone)] -pub struct User { - pub id: Uuid, - pub username: Box<str>, - pub password: PasswordHash, -} - -impl PartialEq for User { - fn eq(&self, other: &Self) -> bool { - self.id == other.id - } -} - -impl Eq for User {} - -impl Hash for User { - fn hash<H: std::hash::Hasher>(&self, state: &mut H) { - state.write_u128(self.id.as_u128()) - } -} - -impl User { - pub fn username(&self) -> &str { - &self.username - } - - pub fn password_hash(&self) -> &[u8] { - self.password.hash() - } - - pub fn password_salt(&self) -> &[u8] { - self.password.salt() - } - - pub fn password_version(&self) -> u8 { - self.password.version() - } - - pub fn check_password(&self, password: &str) -> Result<bool, RawUnexpected> { - self.password.check_password(password) - } -} +use std::hash::Hash;
+
+use exun::RawUnexpected;
+use uuid::Uuid;
+
+use crate::services::crypto::PasswordHash;
+
+#[derive(Debug, Clone)]
+pub struct User {
+ pub id: Uuid,
+ pub username: Box<str>,
+ pub password: PasswordHash,
+}
+
+impl PartialEq for User {
+ fn eq(&self, other: &Self) -> bool {
+ self.id == other.id
+ }
+}
+
+impl Eq for User {}
+
+impl Hash for User {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ state.write_u128(self.id.as_u128())
+ }
+}
+
+impl User {
+ pub fn username(&self) -> &str {
+ &self.username
+ }
+
+ pub fn password_hash(&self) -> &[u8] {
+ self.password.hash()
+ }
+
+ pub fn password_salt(&self) -> &[u8] {
+ self.password.salt()
+ }
+
+ pub fn password_version(&self) -> u8 {
+ self.password.version()
+ }
+
+ pub fn check_password(&self, password: &str) -> Result<bool, RawUnexpected> {
+ self.password.check_password(password)
+ }
+}
diff --git a/src/resources/languages.rs b/src/resources/languages.rs index 8ef7553..b01daf9 100644 --- a/src/resources/languages.rs +++ b/src/resources/languages.rs @@ -1,67 +1,67 @@ -use std::collections::HashMap; -use std::path::PathBuf; - -use actix_web::{get, web, HttpResponse, Scope}; -use exun::RawUnexpected; -use ini::{Ini, Properties}; -use raise::yeet; -use unic_langid::subtags::Language; - -#[derive(Debug, Clone, PartialEq)] -pub struct Translations { - languages: HashMap<Language, Properties>, -} - -pub fn initialize() -> Result<Translations, RawUnexpected> { - let mut translations = Translations { - languages: HashMap::new(), - }; - translations.refresh()?; - Ok(translations) -} - -impl Translations { - pub fn languages(&self) -> Box<[Language]> { - self.languages.keys().cloned().collect() - } - - pub fn get_message(&self, language: Language, key: &str) -> Option<String> { - Some(self.languages.get(&language)?.get(key)?.to_owned()) - } - - pub fn refresh(&mut self) -> Result<(), RawUnexpected> { - let mut languages = HashMap::with_capacity(1); - for entry in PathBuf::from("static/languages").read_dir()? { - let entry = entry?; - if entry.file_type()?.is_dir() { - continue; - } - - let path = entry.path(); - let path = path.to_string_lossy(); - let Some(language) = path.as_bytes().get(0..2) else { yeet!(RawUnexpected::msg(format!("{} not long enough to be a language name", path))) }; - let language = Language::from_bytes(language)?; - let messages = Ini::load_from_file(entry.path())?.general_section().clone(); - - languages.insert(language, messages); - } - - self.languages = languages; - Ok(()) - } -} - -#[get("")] -pub async fn all_languages(translations: web::Data<Translations>) -> HttpResponse { - HttpResponse::Ok().json( - translations - .languages() - .into_iter() - .map(|l| l.as_str()) - .collect::<Box<[&str]>>(), - ) -} - -pub fn languages() -> Scope { - web::scope("/languages").service(all_languages) -} +use std::collections::HashMap;
+use std::path::PathBuf;
+
+use actix_web::{get, web, HttpResponse, Scope};
+use exun::RawUnexpected;
+use ini::{Ini, Properties};
+use raise::yeet;
+use unic_langid::subtags::Language;
+
+#[derive(Debug, Clone, PartialEq)]
+pub struct Translations {
+ languages: HashMap<Language, Properties>,
+}
+
+pub fn initialize() -> Result<Translations, RawUnexpected> {
+ let mut translations = Translations {
+ languages: HashMap::new(),
+ };
+ translations.refresh()?;
+ Ok(translations)
+}
+
+impl Translations {
+ pub fn languages(&self) -> Box<[Language]> {
+ self.languages.keys().cloned().collect()
+ }
+
+ pub fn get_message(&self, language: Language, key: &str) -> Option<String> {
+ Some(self.languages.get(&language)?.get(key)?.to_owned())
+ }
+
+ pub fn refresh(&mut self) -> Result<(), RawUnexpected> {
+ let mut languages = HashMap::with_capacity(1);
+ for entry in PathBuf::from("static/languages").read_dir()? {
+ let entry = entry?;
+ if entry.file_type()?.is_dir() {
+ continue;
+ }
+
+ let path = entry.path();
+ let path = path.to_string_lossy();
+ let Some(language) = path.as_bytes().get(0..2) else { yeet!(RawUnexpected::msg(format!("{} not long enough to be a language name", path))) };
+ let language = Language::from_bytes(language)?;
+ let messages = Ini::load_from_file(entry.path())?.general_section().clone();
+
+ languages.insert(language, messages);
+ }
+
+ self.languages = languages;
+ Ok(())
+ }
+}
+
+#[get("")]
+pub async fn all_languages(translations: web::Data<Translations>) -> HttpResponse {
+ HttpResponse::Ok().json(
+ translations
+ .languages()
+ .into_iter()
+ .map(|l| l.as_str())
+ .collect::<Box<[&str]>>(),
+ )
+}
+
+pub fn languages() -> Scope {
+ web::scope("/languages").service(all_languages)
+}
diff --git a/src/resources/mod.rs b/src/resources/mod.rs index 9251d2c..d9f14ba 100644 --- a/src/resources/mod.rs +++ b/src/resources/mod.rs @@ -1,4 +1,4 @@ -pub mod languages; -pub mod scripts; -pub mod style; -pub mod templates; +pub mod languages;
+pub mod scripts;
+pub mod style;
+pub mod templates;
diff --git a/src/resources/scripts.rs b/src/resources/scripts.rs index 1b27859..66b9693 100644 --- a/src/resources/scripts.rs +++ b/src/resources/scripts.rs @@ -1,38 +1,38 @@ -use std::path::Path; - -use actix_web::{get, http::StatusCode, web, HttpResponse, ResponseError}; -use exun::{Expect, ResultErrorExt}; -use path_clean::clean; -use raise::yeet; -use serde::Serialize; -use thiserror::Error; - -#[derive(Debug, Clone, Error, Serialize)] -pub enum LoadScriptError { - #[error("The requested script does not exist")] - FileNotFound(Box<Path>), -} - -impl ResponseError for LoadScriptError { - fn status_code(&self) -> StatusCode { - match self { - Self::FileNotFound(..) => StatusCode::NOT_FOUND, - } - } -} - -fn load(script: &str) -> Result<String, Expect<LoadScriptError>> { - let path = clean(format!("static/scripts/{}.js", script)); - if !path.exists() { - yeet!(LoadScriptError::FileNotFound(path.into()).into()); - } - let js = std::fs::read_to_string(format!("static/scripts/{}.js", script)).unexpect()?; - Ok(js) -} - -#[get("/{script}.js")] -pub async fn get_js(script: web::Path<Box<str>>) -> Result<HttpResponse, LoadScriptError> { - let js = load(&script).map_err(|e| e.unwrap())?; - let response = HttpResponse::Ok().content_type("text/javascript").body(js); - Ok(response) -} +use std::path::Path;
+
+use actix_web::{get, http::StatusCode, web, HttpResponse, ResponseError};
+use exun::{Expect, ResultErrorExt};
+use path_clean::clean;
+use raise::yeet;
+use serde::Serialize;
+use thiserror::Error;
+
+#[derive(Debug, Clone, Error, Serialize)]
+pub enum LoadScriptError {
+ #[error("The requested script does not exist")]
+ FileNotFound(Box<Path>),
+}
+
+impl ResponseError for LoadScriptError {
+ fn status_code(&self) -> StatusCode {
+ match self {
+ Self::FileNotFound(..) => StatusCode::NOT_FOUND,
+ }
+ }
+}
+
+fn load(script: &str) -> Result<String, Expect<LoadScriptError>> {
+ let path = clean(format!("static/scripts/{}.js", script));
+ if !path.exists() {
+ yeet!(LoadScriptError::FileNotFound(path.into()).into());
+ }
+ let js = std::fs::read_to_string(format!("static/scripts/{}.js", script)).unexpect()?;
+ Ok(js)
+}
+
+#[get("/{script}.js")]
+pub async fn get_js(script: web::Path<Box<str>>) -> Result<HttpResponse, LoadScriptError> {
+ let js = load(&script).map_err(|e| e.unwrap())?;
+ let response = HttpResponse::Ok().content_type("text/javascript").body(js);
+ Ok(response)
+}
diff --git a/src/resources/style.rs b/src/resources/style.rs index 3ea56d2..8b21dc4 100644 --- a/src/resources/style.rs +++ b/src/resources/style.rs @@ -1,54 +1,54 @@ -use std::path::Path; - -use actix_web::{get, http::StatusCode, web, HttpResponse, ResponseError}; -use exun::{Expect, ResultErrorExt}; -use grass::OutputStyle; -use path_clean::clean; -use raise::yeet; -use serde::Serialize; -use thiserror::Error; - -fn output_style() -> OutputStyle { - if cfg!(debug_assertions) { - OutputStyle::Expanded - } else { - OutputStyle::Compressed - } -} - -fn options() -> grass::Options<'static> { - grass::Options::default() - .load_path("static/style") - .style(output_style()) -} - -#[derive(Debug, Clone, Error, Serialize)] -pub enum LoadStyleError { - #[error("The requested stylesheet was not found")] - FileNotFound(Box<Path>), -} - -impl ResponseError for LoadStyleError { - fn status_code(&self) -> StatusCode { - match self { - Self::FileNotFound(..) => StatusCode::NOT_FOUND, - } - } -} - -pub fn load(stylesheet: &str) -> Result<String, Expect<LoadStyleError>> { - let options = options(); - let path = clean(format!("static/style/{}.scss", stylesheet)); - if !path.exists() { - yeet!(LoadStyleError::FileNotFound(path.into()).into()); - } - let css = grass::from_path(format!("static/style/{}.scss", stylesheet), &options).unexpect()?; - Ok(css) -} - -#[get("/{stylesheet}.css")] -pub async fn get_css(stylesheet: web::Path<Box<str>>) -> Result<HttpResponse, LoadStyleError> { - let css = load(&stylesheet).map_err(|e| e.unwrap())?; - let response = HttpResponse::Ok().content_type("text/css").body(css); - Ok(response) -} +use std::path::Path;
+
+use actix_web::{get, http::StatusCode, web, HttpResponse, ResponseError};
+use exun::{Expect, ResultErrorExt};
+use grass::OutputStyle;
+use path_clean::clean;
+use raise::yeet;
+use serde::Serialize;
+use thiserror::Error;
+
+fn output_style() -> OutputStyle {
+ if cfg!(debug_assertions) {
+ OutputStyle::Expanded
+ } else {
+ OutputStyle::Compressed
+ }
+}
+
+fn options() -> grass::Options<'static> {
+ grass::Options::default()
+ .load_path("static/style")
+ .style(output_style())
+}
+
+#[derive(Debug, Clone, Error, Serialize)]
+pub enum LoadStyleError {
+ #[error("The requested stylesheet was not found")]
+ FileNotFound(Box<Path>),
+}
+
+impl ResponseError for LoadStyleError {
+ fn status_code(&self) -> StatusCode {
+ match self {
+ Self::FileNotFound(..) => StatusCode::NOT_FOUND,
+ }
+ }
+}
+
+pub fn load(stylesheet: &str) -> Result<String, Expect<LoadStyleError>> {
+ let options = options();
+ let path = clean(format!("static/style/{}.scss", stylesheet));
+ if !path.exists() {
+ yeet!(LoadStyleError::FileNotFound(path.into()).into());
+ }
+ let css = grass::from_path(format!("static/style/{}.scss", stylesheet), &options).unexpect()?;
+ Ok(css)
+}
+
+#[get("/{stylesheet}.css")]
+pub async fn get_css(stylesheet: web::Path<Box<str>>) -> Result<HttpResponse, LoadStyleError> {
+ let css = load(&stylesheet).map_err(|e| e.unwrap())?;
+ let response = HttpResponse::Ok().content_type("text/css").body(css);
+ Ok(response)
+}
diff --git a/src/resources/templates.rs b/src/resources/templates.rs index 9168fb9..baf2ee8 100644 --- a/src/resources/templates.rs +++ b/src/resources/templates.rs @@ -1,101 +1,101 @@ -use std::collections::HashMap; - -use exun::{RawUnexpected, ResultErrorExt}; -use raise::yeet; -use serde::Serialize; -use tera::{Function, Tera, Value}; -use unic_langid::subtags::Language; - -use crate::api::AuthorizationParameters; - -use super::languages; - -fn make_msg(language: Language, translations: languages::Translations) -> impl Function { - Box::new( - move |args: &HashMap<String, Value>| -> tera::Result<Value> { - let Some(key) = args.get("key") else { yeet!("No parameter 'key' provided".into()) }; - let Some(key) = key.as_str() else { yeet!(format!("{} is not a string", key).into()) }; - let Some(value) = translations.get_message(language, key) else { yeet!(format!("{} does not exist", key).into()) }; - Ok(Value::String(value)) - }, - ) -} - -fn extend_tera( - tera: &Tera, - language: Language, - translations: languages::Translations, -) -> Result<Tera, RawUnexpected> { - let mut new_tera = initialize()?; - new_tera.extend(tera)?; - new_tera.register_function("msg", make_msg(language, translations)); - Ok(new_tera) -} - -pub fn initialize() -> tera::Result<Tera> { - let tera = Tera::new("static/templates/*")?; - Ok(tera) -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)] -#[serde(rename_all = "camelCase")] -pub enum ErrorPage { - InvalidRequest, - ClientNotFound, - MissingRedirectUri, - InvalidRedirectUri, - InternalServerError, -} - -pub fn error_page( - tera: &Tera, - language: Language, - mut translations: languages::Translations, - error: ErrorPage, -) -> Result<String, RawUnexpected> { - translations.refresh()?; - let mut tera = extend_tera(tera, language, translations)?; - tera.full_reload()?; - - let error = serde_variant::to_variant_name(&error)?; - let header = format!("errorHeader_{error}"); - let message = format!("errorMessage_{error}"); - - let mut context = tera::Context::new(); - context.insert("lang", language.as_str()); - context.insert("errorHeader", &header); - context.insert("errormessage", &message); - - tera.render("error.html", &context).unexpect() -} - -pub fn login_page( - tera: &Tera, - params: &AuthorizationParameters, - language: Language, - mut translations: languages::Translations, -) -> Result<String, RawUnexpected> { - translations.refresh()?; - let mut tera = extend_tera(tera, language, translations)?; - tera.full_reload()?; - let mut context = tera::Context::new(); - context.insert("lang", language.as_str()); - context.insert("params", &serde_urlencoded::to_string(params)?); - tera.render("login.html", &context).unexpect() -} - -pub fn login_error_page( - tera: &Tera, - params: &AuthorizationParameters, - language: Language, - mut translations: languages::Translations, -) -> Result<String, RawUnexpected> { - translations.refresh()?; - let mut tera = extend_tera(tera, language, translations)?; - tera.full_reload()?; - let mut context = tera::Context::new(); - context.insert("lang", language.as_str()); - context.insert("params", &serde_urlencoded::to_string(params)?); - context.insert("errorMessage", "loginErrorMessage"); - tera.render("login.html", &context).unexpect() -} +use std::collections::HashMap;
+
+use exun::{RawUnexpected, ResultErrorExt};
+use raise::yeet;
+use serde::Serialize;
+use tera::{Function, Tera, Value};
+use unic_langid::subtags::Language;
+
+use crate::api::AuthorizationParameters;
+
+use super::languages;
+
+fn make_msg(language: Language, translations: languages::Translations) -> impl Function {
+ Box::new(
+ move |args: &HashMap<String, Value>| -> tera::Result<Value> {
+ let Some(key) = args.get("key") else { yeet!("No parameter 'key' provided".into()) };
+ let Some(key) = key.as_str() else { yeet!(format!("{} is not a string", key).into()) };
+ let Some(value) = translations.get_message(language, key) else { yeet!(format!("{} does not exist", key).into()) };
+ Ok(Value::String(value))
+ },
+ )
+}
+
+fn extend_tera(
+ tera: &Tera,
+ language: Language,
+ translations: languages::Translations,
+) -> Result<Tera, RawUnexpected> {
+ let mut new_tera = initialize()?;
+ new_tera.extend(tera)?;
+ new_tera.register_function("msg", make_msg(language, translations));
+ Ok(new_tera)
+}
+
+pub fn initialize() -> tera::Result<Tera> {
+ let tera = Tera::new("static/templates/*")?;
+ Ok(tera)
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)]
+#[serde(rename_all = "camelCase")]
+pub enum ErrorPage {
+ InvalidRequest,
+ ClientNotFound,
+ MissingRedirectUri,
+ InvalidRedirectUri,
+ InternalServerError,
+}
+
+pub fn error_page(
+ tera: &Tera,
+ language: Language,
+ mut translations: languages::Translations,
+ error: ErrorPage,
+) -> Result<String, RawUnexpected> {
+ translations.refresh()?;
+ let mut tera = extend_tera(tera, language, translations)?;
+ tera.full_reload()?;
+
+ let error = serde_variant::to_variant_name(&error)?;
+ let header = format!("errorHeader_{error}");
+ let message = format!("errorMessage_{error}");
+
+ let mut context = tera::Context::new();
+ context.insert("lang", language.as_str());
+ context.insert("errorHeader", &header);
+ context.insert("errormessage", &message);
+
+ tera.render("error.html", &context).unexpect()
+}
+
+pub fn login_page(
+ tera: &Tera,
+ params: &AuthorizationParameters,
+ language: Language,
+ mut translations: languages::Translations,
+) -> Result<String, RawUnexpected> {
+ translations.refresh()?;
+ let mut tera = extend_tera(tera, language, translations)?;
+ tera.full_reload()?;
+ let mut context = tera::Context::new();
+ context.insert("lang", language.as_str());
+ context.insert("params", &serde_urlencoded::to_string(params)?);
+ tera.render("login.html", &context).unexpect()
+}
+
+pub fn login_error_page(
+ tera: &Tera,
+ params: &AuthorizationParameters,
+ language: Language,
+ mut translations: languages::Translations,
+) -> Result<String, RawUnexpected> {
+ translations.refresh()?;
+ let mut tera = extend_tera(tera, language, translations)?;
+ tera.full_reload()?;
+ let mut context = tera::Context::new();
+ context.insert("lang", language.as_str());
+ context.insert("params", &serde_urlencoded::to_string(params)?);
+ context.insert("errorMessage", "loginErrorMessage");
+ tera.render("login.html", &context).unexpect()
+}
diff --git a/src/scopes/admin.rs b/src/scopes/admin.rs index 1e13b85..31e7880 100644 --- a/src/scopes/admin.rs +++ b/src/scopes/admin.rs @@ -1,28 +1,28 @@ -use std::fmt::{self, Display}; - -use crate::models::{client::Client, user::User}; - -use super::{Action, Scope}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct Admin; - -impl Display for Admin { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("admin") - } -} - -impl Scope for Admin { - fn parse_modifiers(_modifiers: &str) -> Result<Self, Box<str>> { - Ok(Self) - } - - fn has_user_permission(&self, _: &User, _: &Action<User>) -> bool { - true - } - - fn has_client_permission(&self, _: &User, _: &Action<Client>) -> bool { - true - } -} +use std::fmt::{self, Display};
+
+use crate::models::{client::Client, user::User};
+
+use super::{Action, Scope};
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub struct Admin;
+
+impl Display for Admin {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.write_str("admin")
+ }
+}
+
+impl Scope for Admin {
+ fn parse_modifiers(_modifiers: &str) -> Result<Self, Box<str>> {
+ Ok(Self)
+ }
+
+ fn has_user_permission(&self, _: &User, _: &Action<User>) -> bool {
+ true
+ }
+
+ fn has_client_permission(&self, _: &User, _: &Action<Client>) -> bool {
+ true
+ }
+}
diff --git a/src/scopes/mod.rs b/src/scopes/mod.rs index fb7780f..25296fd 100644 --- a/src/scopes/mod.rs +++ b/src/scopes/mod.rs @@ -1,128 +1,128 @@ -use std::collections::HashSet; - -use self::admin::Admin; -use crate::models::{client::Client, user::User}; - -mod admin; - -/// The action which was attempted on a resource -pub enum Action<T> { - Create(T), - Read(T), - Update(T, T), - Delete(T), -} - -trait ScopeSuperSet { - fn is_superset_of(&self, other: &Self) -> bool; -} - -trait Scope: ToString { - /// Parse a scope of the format: `{Scope::NAME}:{modifiers}` - fn parse_modifiers(modifiers: &str) -> Result<Self, Box<str>> - where - Self: Sized; - - /// Returns `true` if and only if the given `user` is allowed to take the - /// given `action` with this scope - fn has_user_permission(&self, user: &User, action: &Action<User>) -> bool; - - // Returns `true` if and only if the given `user` is allowed to take the - /// given `action` with this scope - fn has_client_permission(&self, user: &User, action: &Action<Client>) -> bool; -} - -pub struct ParseScopeError { - scope: Box<str>, - error: ParseScopeErrorType, -} - -impl ParseScopeError { - fn invalid_type(scope: &str, scope_type: &str) -> Self { - let scope = scope.into(); - let error = ParseScopeErrorType::InvalidType(scope_type.into()); - Self { scope, error } - } -} - -pub enum ParseScopeErrorType { - InvalidType(Box<str>), - InvalidModifiers(Box<str>), -} - -fn parse_scope(scope: &str) -> Result<Box<dyn Scope>, ParseScopeError> { - let mut split = scope.split(':'); - let scope_type = split.next().unwrap(); - let _modifiers: String = split.collect(); - - match scope_type { - "admin" => Ok(Box::new(Admin)), - _ => Err(ParseScopeError::invalid_type(scope, scope_type)), - } -} - -fn parse_scopes(scopes: &str) -> Result<Vec<Box<dyn Scope>>, ParseScopeError> { - scopes - .split_whitespace() - .map(|scope| parse_scope(scope)) - .collect() -} - -fn parse_scopes_errors( - results: &[Result<Box<dyn Scope>, ParseScopeError>], -) -> Vec<&ParseScopeError> { - let mut errors = Vec::with_capacity(results.len()); - for result in results { - if let Err(pse) = result { - errors.push(pse) - } - } - - errors -} - -/// Returns `true` if and only if all values in `left_scopes` are contained in -/// `right_scopes`. -pub fn is_subset_of(left_scopes: &str, right_scopes: &str) -> bool { - let right_scopes: HashSet<&str> = right_scopes.split_whitespace().collect(); - - for scope in left_scopes.split_whitespace() { - if !right_scopes.contains(scope) { - return false; - } - } - - true -} - -pub fn has_user_permission( - user: User, - action: Action<User>, - client_scopes: &str, -) -> Result<bool, ParseScopeError> { - let scopes = parse_scopes(client_scopes)?; - - for scope in scopes { - if scope.has_user_permission(&user, &action) { - return Ok(true); - } - } - - Ok(false) -} - -pub fn has_client_permission( - user: User, - action: Action<Client>, - client_scopes: &str, -) -> Result<bool, ParseScopeError> { - let scopes = parse_scopes(client_scopes)?; - - for scope in scopes { - if scope.has_client_permission(&user, &action) { - return Ok(true); - } - } - - Ok(false) -} +use std::collections::HashSet;
+
+use self::admin::Admin;
+use crate::models::{client::Client, user::User};
+
+mod admin;
+
+/// The action which was attempted on a resource
+pub enum Action<T> {
+ Create(T),
+ Read(T),
+ Update(T, T),
+ Delete(T),
+}
+
+trait ScopeSuperSet {
+ fn is_superset_of(&self, other: &Self) -> bool;
+}
+
+trait Scope: ToString {
+ /// Parse a scope of the format: `{Scope::NAME}:{modifiers}`
+ fn parse_modifiers(modifiers: &str) -> Result<Self, Box<str>>
+ where
+ Self: Sized;
+
+ /// Returns `true` if and only if the given `user` is allowed to take the
+ /// given `action` with this scope
+ fn has_user_permission(&self, user: &User, action: &Action<User>) -> bool;
+
+ // Returns `true` if and only if the given `user` is allowed to take the
+ /// given `action` with this scope
+ fn has_client_permission(&self, user: &User, action: &Action<Client>) -> bool;
+}
+
+pub struct ParseScopeError {
+ scope: Box<str>,
+ error: ParseScopeErrorType,
+}
+
+impl ParseScopeError {
+ fn invalid_type(scope: &str, scope_type: &str) -> Self {
+ let scope = scope.into();
+ let error = ParseScopeErrorType::InvalidType(scope_type.into());
+ Self { scope, error }
+ }
+}
+
+pub enum ParseScopeErrorType {
+ InvalidType(Box<str>),
+ InvalidModifiers(Box<str>),
+}
+
+fn parse_scope(scope: &str) -> Result<Box<dyn Scope>, ParseScopeError> {
+ let mut split = scope.split(':');
+ let scope_type = split.next().unwrap();
+ let _modifiers: String = split.collect();
+
+ match scope_type {
+ "admin" => Ok(Box::new(Admin)),
+ _ => Err(ParseScopeError::invalid_type(scope, scope_type)),
+ }
+}
+
+fn parse_scopes(scopes: &str) -> Result<Vec<Box<dyn Scope>>, ParseScopeError> {
+ scopes
+ .split_whitespace()
+ .map(|scope| parse_scope(scope))
+ .collect()
+}
+
+fn parse_scopes_errors(
+ results: &[Result<Box<dyn Scope>, ParseScopeError>],
+) -> Vec<&ParseScopeError> {
+ let mut errors = Vec::with_capacity(results.len());
+ for result in results {
+ if let Err(pse) = result {
+ errors.push(pse)
+ }
+ }
+
+ errors
+}
+
+/// Returns `true` if and only if all values in `left_scopes` are contained in
+/// `right_scopes`.
+pub fn is_subset_of(left_scopes: &str, right_scopes: &str) -> bool {
+ let right_scopes: HashSet<&str> = right_scopes.split_whitespace().collect();
+
+ for scope in left_scopes.split_whitespace() {
+ if !right_scopes.contains(scope) {
+ return false;
+ }
+ }
+
+ true
+}
+
+pub fn has_user_permission(
+ user: User,
+ action: Action<User>,
+ client_scopes: &str,
+) -> Result<bool, ParseScopeError> {
+ let scopes = parse_scopes(client_scopes)?;
+
+ for scope in scopes {
+ if scope.has_user_permission(&user, &action) {
+ return Ok(true);
+ }
+ }
+
+ Ok(false)
+}
+
+pub fn has_client_permission(
+ user: User,
+ action: Action<Client>,
+ client_scopes: &str,
+) -> Result<bool, ParseScopeError> {
+ let scopes = parse_scopes(client_scopes)?;
+
+ for scope in scopes {
+ if scope.has_client_permission(&user, &action) {
+ return Ok(true);
+ }
+ }
+
+ Ok(false)
+}
diff --git a/src/services/authorization.rs b/src/services/authorization.rs index bfbbb5a..4e6ef35 100644 --- a/src/services/authorization.rs +++ b/src/services/authorization.rs @@ -1,82 +1,82 @@ -use actix_web::{ - error::ParseError, - http::header::{self, Header, HeaderName, HeaderValue, InvalidHeaderValue, TryIntoHeaderValue}, -}; -use base64::Engine; -use raise::yeet; - -#[derive(Clone)] -pub struct BasicAuthorization { - username: Box<str>, - password: Box<str>, -} - -impl TryIntoHeaderValue for BasicAuthorization { - type Error = InvalidHeaderValue; - - fn try_into_value(self) -> Result<HeaderValue, Self::Error> { - let username = self.username; - let password = self.password; - let utf8 = format!("{username}:{password}"); - let b64 = base64::engine::general_purpose::STANDARD.encode(utf8); - let value = format!("Basic {b64}"); - HeaderValue::from_str(&value) - } -} - -impl Header for BasicAuthorization { - fn name() -> HeaderName { - header::AUTHORIZATION - } - - fn parse<M: actix_web::HttpMessage>(msg: &M) -> Result<Self, actix_web::error::ParseError> { - let Some(value) = msg.headers().get(Self::name()) else { - yeet!(ParseError::Header) - }; - - let Ok(value) = value.to_str() else { - yeet!(ParseError::Header) - }; - - if !value.starts_with("Basic") { - yeet!(ParseError::Header); - } - - let value: String = value - .chars() - .skip(5) - .skip_while(|ch| ch.is_whitespace()) - .collect(); - - if value.is_empty() { - yeet!(ParseError::Header); - } - - let Ok(bytes) = base64::engine::general_purpose::STANDARD.decode(value) else { - yeet!(ParseError::Header) - }; - - let Ok(value) = String::from_utf8(bytes) else { - yeet!(ParseError::Header) - }; - - let mut parts = value.split(':'); - let username = Box::from(parts.next().unwrap()); - let Some(password) = parts.next() else { - yeet!(ParseError::Header) - }; - let password = Box::from(password); - - Ok(Self { username, password }) - } -} - -impl BasicAuthorization { - pub fn username(&self) -> &str { - &self.username - } - - pub fn password(&self) -> &str { - &self.password - } -} +use actix_web::{
+ error::ParseError,
+ http::header::{self, Header, HeaderName, HeaderValue, InvalidHeaderValue, TryIntoHeaderValue},
+};
+use base64::Engine;
+use raise::yeet;
+
+#[derive(Clone)]
+pub struct BasicAuthorization {
+ username: Box<str>,
+ password: Box<str>,
+}
+
+impl TryIntoHeaderValue for BasicAuthorization {
+ type Error = InvalidHeaderValue;
+
+ fn try_into_value(self) -> Result<HeaderValue, Self::Error> {
+ let username = self.username;
+ let password = self.password;
+ let utf8 = format!("{username}:{password}");
+ let b64 = base64::engine::general_purpose::STANDARD.encode(utf8);
+ let value = format!("Basic {b64}");
+ HeaderValue::from_str(&value)
+ }
+}
+
+impl Header for BasicAuthorization {
+ fn name() -> HeaderName {
+ header::AUTHORIZATION
+ }
+
+ fn parse<M: actix_web::HttpMessage>(msg: &M) -> Result<Self, actix_web::error::ParseError> {
+ let Some(value) = msg.headers().get(Self::name()) else {
+ yeet!(ParseError::Header)
+ };
+
+ let Ok(value) = value.to_str() else {
+ yeet!(ParseError::Header)
+ };
+
+ if !value.starts_with("Basic") {
+ yeet!(ParseError::Header);
+ }
+
+ let value: String = value
+ .chars()
+ .skip(5)
+ .skip_while(|ch| ch.is_whitespace())
+ .collect();
+
+ if value.is_empty() {
+ yeet!(ParseError::Header);
+ }
+
+ let Ok(bytes) = base64::engine::general_purpose::STANDARD.decode(value) else {
+ yeet!(ParseError::Header)
+ };
+
+ let Ok(value) = String::from_utf8(bytes) else {
+ yeet!(ParseError::Header)
+ };
+
+ let mut parts = value.split(':');
+ let username = Box::from(parts.next().unwrap());
+ let Some(password) = parts.next() else {
+ yeet!(ParseError::Header)
+ };
+ let password = Box::from(password);
+
+ Ok(Self { username, password })
+ }
+}
+
+impl BasicAuthorization {
+ pub fn username(&self) -> &str {
+ &self.username
+ }
+
+ pub fn password(&self) -> &str {
+ &self.password
+ }
+}
diff --git a/src/services/config.rs b/src/services/config.rs index 6468126..932f38f 100644 --- a/src/services/config.rs +++ b/src/services/config.rs @@ -1,74 +1,74 @@ -use std::{ - fmt::{self, Display}, - str::FromStr, -}; - -use exun::RawUnexpected; -use parking_lot::RwLock; -use serde::Deserialize; -use thiserror::Error; -use url::Url; - -static ENVIRONMENT: RwLock<Environment> = RwLock::new(Environment::Local); - -#[derive(Debug, Clone, Deserialize)] -pub struct Config { - pub id: Box<str>, - pub url: Url, -} - -pub fn get_config() -> Result<Config, RawUnexpected> { - let env = get_environment(); - let path = format!("static/config/{env}.toml"); - let string = std::fs::read_to_string(path)?; - let config = toml::from_str(&string)?; - Ok(config) -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum Environment { - Local, - Dev, - Staging, - Production, -} - -impl Display for Environment { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Local => f.write_str("local"), - Self::Dev => f.write_str("dev"), - Self::Staging => f.write_str("staging"), - Self::Production => f.write_str("prod"), - } - } -} - -#[derive(Debug, Clone, Error)] -#[error("Expected one of the following environments: local, dev, staging, prod. Found {string}")] -pub struct ParseEnvironmentError { - string: Box<str>, -} - -impl FromStr for Environment { - type Err = ParseEnvironmentError; - - fn from_str(s: &str) -> Result<Self, Self::Err> { - match s { - "local" => Ok(Self::Local), - "dev" => Ok(Self::Dev), - "staging" => Ok(Self::Staging), - "prod" => Ok(Self::Production), - _ => Err(ParseEnvironmentError { string: s.into() }), - } - } -} - -pub fn set_environment(env: Environment) { - let mut env_ptr = ENVIRONMENT.write(); - *env_ptr = env; -} - -fn get_environment() -> Environment { - ENVIRONMENT.read().clone() -} +use std::{
+ fmt::{self, Display},
+ str::FromStr,
+};
+
+use exun::RawUnexpected;
+use parking_lot::RwLock;
+use serde::Deserialize;
+use thiserror::Error;
+use url::Url;
+
+static ENVIRONMENT: RwLock<Environment> = RwLock::new(Environment::Local);
+
+#[derive(Debug, Clone, Deserialize)]
+pub struct Config {
+ pub id: Box<str>,
+ pub url: Url,
+}
+
+pub fn get_config() -> Result<Config, RawUnexpected> {
+ let env = get_environment();
+ let path = format!("static/config/{env}.toml");
+ let string = std::fs::read_to_string(path)?;
+ let config = toml::from_str(&string)?;
+ Ok(config)
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub enum Environment {
+ Local,
+ Dev,
+ Staging,
+ Production,
+}
+
+impl Display for Environment {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ Self::Local => f.write_str("local"),
+ Self::Dev => f.write_str("dev"),
+ Self::Staging => f.write_str("staging"),
+ Self::Production => f.write_str("prod"),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Error)]
+#[error("Expected one of the following environments: local, dev, staging, prod. Found {string}")]
+pub struct ParseEnvironmentError {
+ string: Box<str>,
+}
+
+impl FromStr for Environment {
+ type Err = ParseEnvironmentError;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ match s {
+ "local" => Ok(Self::Local),
+ "dev" => Ok(Self::Dev),
+ "staging" => Ok(Self::Staging),
+ "prod" => Ok(Self::Production),
+ _ => Err(ParseEnvironmentError { string: s.into() }),
+ }
+ }
+}
+
+pub fn set_environment(env: Environment) {
+ let mut env_ptr = ENVIRONMENT.write();
+ *env_ptr = env;
+}
+
+fn get_environment() -> Environment {
+ ENVIRONMENT.read().clone()
+}
diff --git a/src/services/crypto.rs b/src/services/crypto.rs index 5fce403..0107374 100644 --- a/src/services/crypto.rs +++ b/src/services/crypto.rs @@ -1,97 +1,97 @@ -use std::hash::Hash; - -use argon2::{hash_raw, verify_raw}; -use exun::RawUnexpected; - -use crate::services::secrets::pepper; - -/// The configuration used for hashing and verifying passwords -/// -/// # Example -/// -/// ``` -/// use crate::services::secrets; -/// -/// let pepper = secrets::pepper(); -/// let config = config(&pepper); -/// ``` -fn config<'a>(pepper: &'a [u8]) -> argon2::Config<'a> { - argon2::Config { - hash_length: 32, - lanes: 4, - mem_cost: 5333, - time_cost: 4, - secret: pepper, - - ad: &[], - thread_mode: argon2::ThreadMode::Sequential, - variant: argon2::Variant::Argon2i, - version: argon2::Version::Version13, - } -} - -/// A password hash and salt for a user -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct PasswordHash { - hash: Box<[u8]>, - salt: Box<[u8]>, - version: u8, -} - -impl Hash for PasswordHash { - fn hash<H: std::hash::Hasher>(&self, state: &mut H) { - state.write(&self.hash) - } -} - -impl PasswordHash { - /// Hash a password using Argon2 - pub fn new(password: &str) -> Result<Self, RawUnexpected> { - let password = password.as_bytes(); - - let salt: [u8; 32] = rand::random(); - let salt = Box::from(salt); - let pepper = pepper()?; - let hash = hash_raw(password, &salt, &config(&pepper))?.into_boxed_slice(); - - Ok(Self { - hash, - salt, - version: 0, - }) - } - - /// Create this structure from a given hash and salt - pub fn from_fields(hash: &[u8], salt: &[u8], version: u8) -> Self { - Self { - hash: Box::from(hash), - salt: Box::from(salt), - version, - } - } - - /// Get the password hash - pub fn hash(&self) -> &[u8] { - &self.hash - } - - /// Get the salt used for the hash - pub fn salt(&self) -> &[u8] { - &self.salt - } - - pub fn version(&self) -> u8 { - self.version - } - - /// Check if the given password is the one that was hashed - pub fn check_password(&self, password: &str) -> Result<bool, RawUnexpected> { - let pepper = pepper()?; - Ok(verify_raw( - password.as_bytes(), - &self.salt, - &self.hash, - &config(&pepper), - )?) - } -} +use std::hash::Hash;
+
+use argon2::{hash_raw, verify_raw};
+use exun::RawUnexpected;
+
+use crate::services::secrets::pepper;
+
+/// The configuration used for hashing and verifying passwords
+///
+/// # Example
+///
+/// ```
+/// use crate::services::secrets;
+///
+/// let pepper = secrets::pepper();
+/// let config = config(&pepper);
+/// ```
+fn config<'a>(pepper: &'a [u8]) -> argon2::Config<'a> {
+ argon2::Config {
+ hash_length: 32,
+ lanes: 4,
+ mem_cost: 5333,
+ time_cost: 4,
+ secret: pepper,
+
+ ad: &[],
+ thread_mode: argon2::ThreadMode::Sequential,
+ variant: argon2::Variant::Argon2i,
+ version: argon2::Version::Version13,
+ }
+}
+
+/// A password hash and salt for a user
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct PasswordHash {
+ hash: Box<[u8]>,
+ salt: Box<[u8]>,
+ version: u8,
+}
+
+impl Hash for PasswordHash {
+ fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
+ state.write(&self.hash)
+ }
+}
+
+impl PasswordHash {
+ /// Hash a password using Argon2
+ pub fn new(password: &str) -> Result<Self, RawUnexpected> {
+ let password = password.as_bytes();
+
+ let salt: [u8; 32] = rand::random();
+ let salt = Box::from(salt);
+ let pepper = pepper()?;
+ let hash = hash_raw(password, &salt, &config(&pepper))?.into_boxed_slice();
+
+ Ok(Self {
+ hash,
+ salt,
+ version: 0,
+ })
+ }
+
+ /// Create this structure from a given hash and salt
+ pub fn from_fields(hash: &[u8], salt: &[u8], version: u8) -> Self {
+ Self {
+ hash: Box::from(hash),
+ salt: Box::from(salt),
+ version,
+ }
+ }
+
+ /// Get the password hash
+ pub fn hash(&self) -> &[u8] {
+ &self.hash
+ }
+
+ /// Get the salt used for the hash
+ pub fn salt(&self) -> &[u8] {
+ &self.salt
+ }
+
+ pub fn version(&self) -> u8 {
+ self.version
+ }
+
+ /// Check if the given password is the one that was hashed
+ pub fn check_password(&self, password: &str) -> Result<bool, RawUnexpected> {
+ let pepper = pepper()?;
+ Ok(verify_raw(
+ password.as_bytes(),
+ &self.salt,
+ &self.hash,
+ &config(&pepper),
+ )?)
+ }
+}
diff --git a/src/services/db.rs b/src/services/db.rs index f811d79..e3cb48b 100644 --- a/src/services/db.rs +++ b/src/services/db.rs @@ -1,15 +1,15 @@ -use exun::{RawUnexpected, ResultErrorExt}; -use sqlx::MySqlPool; - -mod client; -mod jwt; -mod user; - -pub use self::jwt::*; -pub use client::*; -pub use user::*; - -/// Intialize the connection pool -pub async fn initialize(db_url: &str) -> Result<MySqlPool, RawUnexpected> { - MySqlPool::connect(db_url).await.unexpect() -} +use exun::{RawUnexpected, ResultErrorExt};
+use sqlx::MySqlPool;
+
+mod client;
+mod jwt;
+mod user;
+
+pub use self::jwt::*;
+pub use client::*;
+pub use user::*;
+
+/// Intialize the connection pool
+pub async fn initialize(db_url: &str) -> Result<MySqlPool, RawUnexpected> {
+ MySqlPool::connect(db_url).await.unexpect()
+}
diff --git a/src/services/db/client.rs b/src/services/db/client.rs index b8942e9..1ad97b1 100644 --- a/src/services/db/client.rs +++ b/src/services/db/client.rs @@ -1,392 +1,392 @@ -use std::str::FromStr; - -use exun::{RawUnexpected, ResultErrorExt}; -use sqlx::{ - mysql::MySqlQueryResult, query, query_as, query_scalar, Executor, FromRow, MySql, Transaction, -}; -use url::Url; -use uuid::Uuid; - -use crate::{ - models::client::{Client, ClientType}, - services::crypto::PasswordHash, -}; - -#[derive(Debug, Clone, FromRow)] -pub struct ClientRow { - pub id: Uuid, - pub alias: String, - pub client_type: ClientType, - pub allowed_scopes: String, - pub default_scopes: Option<String>, - pub is_trusted: bool, -} - -#[derive(Clone, FromRow)] -struct HashRow { - secret_hash: Option<Vec<u8>>, - secret_salt: Option<Vec<u8>>, - secret_version: Option<u32>, -} - -pub async fn client_id_exists<'c>( - executor: impl Executor<'c, Database = MySql>, - id: Uuid, -) -> Result<bool, RawUnexpected> { - query_scalar!( - r"SELECT EXISTS(SELECT id FROM clients WHERE id = ?) as `e: bool`", - id - ) - .fetch_one(executor) - .await - .unexpect() -} - -pub async fn client_alias_exists<'c>( - executor: impl Executor<'c, Database = MySql>, - alias: &str, -) -> Result<bool, RawUnexpected> { - query_scalar!( - "SELECT EXISTS(SELECT alias FROM clients WHERE alias = ?) as `e: bool`", - alias - ) - .fetch_one(executor) - .await - .unexpect() -} - -pub async fn get_client_id_by_alias<'c>( - executor: impl Executor<'c, Database = MySql>, - alias: &str, -) -> Result<Option<Uuid>, RawUnexpected> { - query_scalar!( - "SELECT id as `id: Uuid` FROM clients WHERE alias = ?", - alias - ) - .fetch_optional(executor) - .await - .unexpect() -} - -pub async fn get_client_response<'c>( - executor: impl Executor<'c, Database = MySql>, - id: Uuid, -) -> Result<Option<ClientRow>, RawUnexpected> { - let record = query_as!( - ClientRow, - r"SELECT id as `id: Uuid`, - alias, - type as `client_type: ClientType`, - allowed_scopes, - default_scopes, - trusted as `is_trusted: bool` - FROM clients WHERE id = ?", - id - ) - .fetch_optional(executor) - .await?; - - Ok(record) -} - -pub async fn get_client_alias<'c>( - executor: impl Executor<'c, Database = MySql>, - id: Uuid, -) -> Result<Option<Box<str>>, RawUnexpected> { - let alias = query_scalar!("SELECT alias FROM clients WHERE id = ?", id) - .fetch_optional(executor) - .await - .unexpect()?; - - Ok(alias.map(String::into_boxed_str)) -} - -pub async fn get_client_type<'c>( - executor: impl Executor<'c, Database = MySql>, - id: Uuid, -) -> Result<Option<ClientType>, RawUnexpected> { - let ty = query_scalar!( - "SELECT type as `type: ClientType` FROM clients WHERE id = ?", - id - ) - .fetch_optional(executor) - .await - .unexpect()?; - - Ok(ty) -} - -pub async fn get_client_allowed_scopes<'c>( - executor: impl Executor<'c, Database = MySql>, - id: Uuid, -) -> Result<Option<Box<str>>, RawUnexpected> { - let scopes = query_scalar!("SELECT allowed_scopes FROM clients WHERE id = ?", id) - .fetch_optional(executor) - .await?; - - Ok(scopes.map(Box::from)) -} - -pub async fn get_client_default_scopes<'c>( - executor: impl Executor<'c, Database = MySql>, - id: Uuid, -) -> Result<Option<Option<Box<str>>>, RawUnexpected> { - let scopes = query_scalar!("SELECT default_scopes FROM clients WHERE id = ?", id) - .fetch_optional(executor) - .await?; - - Ok(scopes.map(|s| s.map(Box::from))) -} - -pub async fn get_client_secret<'c>( - executor: impl Executor<'c, Database = MySql>, - id: Uuid, -) -> Result<Option<PasswordHash>, RawUnexpected> { - let hash = query_as!( - HashRow, - r"SELECT secret_hash, secret_salt, secret_version - FROM clients WHERE id = ?", - id - ) - .fetch_optional(executor) - .await?; - - let Some(hash) = hash else { return Ok(None) }; - let Some(version) = hash.secret_version else { return Ok(None) }; - let Some(salt) = hash.secret_hash else { return Ok(None) }; - let Some(hash) = hash.secret_salt else { return Ok(None) }; - - let hash = PasswordHash::from_fields(&hash, &salt, version as u8); - Ok(Some(hash)) -} - -pub async fn is_client_trusted<'c>( - executor: impl Executor<'c, Database = MySql>, - id: Uuid, -) -> Result<Option<bool>, RawUnexpected> { - query_scalar!("SELECT trusted as `t: bool` FROM clients WHERE id = ?", id) - .fetch_optional(executor) - .await - .unexpect() -} - -pub async fn get_client_redirect_uris<'c>( - executor: impl Executor<'c, Database = MySql>, - id: Uuid, -) -> Result<Box<[Url]>, RawUnexpected> { - let uris = query_scalar!( - "SELECT redirect_uri FROM client_redirect_uris WHERE client_id = ?", - id - ) - .fetch_all(executor) - .await - .unexpect()?; - - uris.into_iter() - .map(|s| Url::from_str(&s).unexpect()) - .collect() -} - -pub async fn client_has_redirect_uri<'c>( - executor: impl Executor<'c, Database = MySql>, - id: Uuid, - url: &Url, -) -> Result<bool, RawUnexpected> { - query_scalar!( - r"SELECT EXISTS( - SELECT redirect_uri - FROM client_redirect_uris - WHERE client_id = ? AND redirect_uri = ? - ) as `e: bool`", - id, - url.to_string() - ) - .fetch_one(executor) - .await - .unexpect() -} - -async fn delete_client_redirect_uris<'c>( - executor: impl Executor<'c, Database = MySql>, - id: Uuid, -) -> Result<(), sqlx::Error> { - query!("DELETE FROM client_redirect_uris WHERE client_id = ?", id) - .execute(executor) - .await?; - Ok(()) -} - -async fn create_client_redirect_uris<'c>( - mut transaction: Transaction<'c, MySql>, - client_id: Uuid, - uris: &[Url], -) -> Result<(), sqlx::Error> { - for uri in uris { - query!( - r"INSERT INTO client_redirect_uris (client_id, redirect_uri) - VALUES ( ?, ?)", - client_id, - uri.to_string() - ) - .execute(&mut transaction) - .await?; - } - - transaction.commit().await?; - - Ok(()) -} - -pub async fn create_client<'c>( - mut transaction: Transaction<'c, MySql>, - client: &Client, -) -> Result<(), sqlx::Error> { - query!( - r"INSERT INTO clients (id, alias, type, secret_hash, secret_salt, secret_version, allowed_scopes, default_scopes) - VALUES ( ?, ?, ?, ?, ?, ?, ?, ?)", - client.id(), - client.alias(), - client.client_type(), - client.secret_hash(), - client.secret_salt(), - client.secret_version(), - client.allowed_scopes(), - client.default_scopes() - ) - .execute(&mut transaction) - .await?; - - create_client_redirect_uris(transaction, client.id(), client.redirect_uris()).await?; - - Ok(()) -} - -pub async fn update_client<'c>( - mut transaction: Transaction<'c, MySql>, - client: &Client, -) -> Result<(), sqlx::Error> { - query!( - r"UPDATE clients SET - alias = ?, - type = ?, - secret_hash = ?, - secret_salt = ?, - secret_version = ?, - allowed_scopes = ?, - default_scopes = ? - WHERE id = ?", - client.client_type(), - client.alias(), - client.secret_hash(), - client.secret_salt(), - client.secret_version(), - client.allowed_scopes(), - client.default_scopes(), - client.id() - ) - .execute(&mut transaction) - .await?; - - update_client_redirect_uris(transaction, client.id(), client.redirect_uris()).await?; - - Ok(()) -} - -pub async fn update_client_alias<'c>( - executor: impl Executor<'c, Database = MySql>, - id: Uuid, - alias: &str, -) -> Result<MySqlQueryResult, sqlx::Error> { - query!("UPDATE clients SET alias = ? WHERE id = ?", alias, id) - .execute(executor) - .await -} - -pub async fn update_client_type<'c>( - executor: impl Executor<'c, Database = MySql>, - id: Uuid, - ty: ClientType, -) -> Result<MySqlQueryResult, sqlx::Error> { - query!("UPDATE clients SET type = ? WHERE id = ?", ty, id) - .execute(executor) - .await -} - -pub async fn update_client_allowed_scopes<'c>( - executor: impl Executor<'c, Database = MySql>, - id: Uuid, - allowed_scopes: &str, -) -> Result<MySqlQueryResult, sqlx::Error> { - query!( - "UPDATE clients SET allowed_scopes = ? WHERE id = ?", - allowed_scopes, - id - ) - .execute(executor) - .await -} - -pub async fn update_client_default_scopes<'c>( - executor: impl Executor<'c, Database = MySql>, - id: Uuid, - default_scopes: Option<String>, -) -> Result<MySqlQueryResult, sqlx::Error> { - query!( - "UPDATE clients SET default_scopes = ? WHERE id = ?", - default_scopes, - id - ) - .execute(executor) - .await -} - -pub async fn update_client_trusted<'c>( - executor: impl Executor<'c, Database = MySql>, - id: Uuid, - is_trusted: bool, -) -> Result<MySqlQueryResult, sqlx::Error> { - query!( - "UPDATE clients SET trusted = ? WHERE id = ?", - is_trusted, - id - ) - .execute(executor) - .await -} - -pub async fn update_client_redirect_uris<'c>( - mut transaction: Transaction<'c, MySql>, - id: Uuid, - uris: &[Url], -) -> Result<(), sqlx::Error> { - delete_client_redirect_uris(&mut transaction, id).await?; - create_client_redirect_uris(transaction, id, uris).await?; - Ok(()) -} - -pub async fn update_client_secret<'c>( - executor: impl Executor<'c, Database = MySql>, - id: Uuid, - secret: Option<PasswordHash>, -) -> Result<MySqlQueryResult, sqlx::Error> { - if let Some(secret) = secret { - query!( - "UPDATE clients SET secret_hash = ?, secret_salt = ?, secret_version = ? WHERE id = ?", - secret.hash(), - secret.salt(), - secret.version(), - id - ) - .execute(executor) - .await - } else { - query!( - r"UPDATE clients - SET secret_hash = NULL, secret_salt = NULL, secret_version = NULL - WHERE id = ?", - id - ) - .execute(executor) - .await - } -} +use std::str::FromStr;
+
+use exun::{RawUnexpected, ResultErrorExt};
+use sqlx::{
+ mysql::MySqlQueryResult, query, query_as, query_scalar, Executor, FromRow, MySql, Transaction,
+};
+use url::Url;
+use uuid::Uuid;
+
+use crate::{
+ models::client::{Client, ClientType},
+ services::crypto::PasswordHash,
+};
+
+#[derive(Debug, Clone, FromRow)]
+pub struct ClientRow {
+ pub id: Uuid,
+ pub alias: String,
+ pub client_type: ClientType,
+ pub allowed_scopes: String,
+ pub default_scopes: Option<String>,
+ pub is_trusted: bool,
+}
+
+#[derive(Clone, FromRow)]
+struct HashRow {
+ secret_hash: Option<Vec<u8>>,
+ secret_salt: Option<Vec<u8>>,
+ secret_version: Option<u32>,
+}
+
+pub async fn client_id_exists<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<bool, RawUnexpected> {
+ query_scalar!(
+ r"SELECT EXISTS(SELECT id FROM clients WHERE id = ?) as `e: bool`",
+ id
+ )
+ .fetch_one(executor)
+ .await
+ .unexpect()
+}
+
+pub async fn client_alias_exists<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ alias: &str,
+) -> Result<bool, RawUnexpected> {
+ query_scalar!(
+ "SELECT EXISTS(SELECT alias FROM clients WHERE alias = ?) as `e: bool`",
+ alias
+ )
+ .fetch_one(executor)
+ .await
+ .unexpect()
+}
+
+pub async fn get_client_id_by_alias<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ alias: &str,
+) -> Result<Option<Uuid>, RawUnexpected> {
+ query_scalar!(
+ "SELECT id as `id: Uuid` FROM clients WHERE alias = ?",
+ alias
+ )
+ .fetch_optional(executor)
+ .await
+ .unexpect()
+}
+
+pub async fn get_client_response<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Option<ClientRow>, RawUnexpected> {
+ let record = query_as!(
+ ClientRow,
+ r"SELECT id as `id: Uuid`,
+ alias,
+ type as `client_type: ClientType`,
+ allowed_scopes,
+ default_scopes,
+ trusted as `is_trusted: bool`
+ FROM clients WHERE id = ?",
+ id
+ )
+ .fetch_optional(executor)
+ .await?;
+
+ Ok(record)
+}
+
+pub async fn get_client_alias<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Option<Box<str>>, RawUnexpected> {
+ let alias = query_scalar!("SELECT alias FROM clients WHERE id = ?", id)
+ .fetch_optional(executor)
+ .await
+ .unexpect()?;
+
+ Ok(alias.map(String::into_boxed_str))
+}
+
+pub async fn get_client_type<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Option<ClientType>, RawUnexpected> {
+ let ty = query_scalar!(
+ "SELECT type as `type: ClientType` FROM clients WHERE id = ?",
+ id
+ )
+ .fetch_optional(executor)
+ .await
+ .unexpect()?;
+
+ Ok(ty)
+}
+
+pub async fn get_client_allowed_scopes<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Option<Box<str>>, RawUnexpected> {
+ let scopes = query_scalar!("SELECT allowed_scopes FROM clients WHERE id = ?", id)
+ .fetch_optional(executor)
+ .await?;
+
+ Ok(scopes.map(Box::from))
+}
+
+pub async fn get_client_default_scopes<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Option<Option<Box<str>>>, RawUnexpected> {
+ let scopes = query_scalar!("SELECT default_scopes FROM clients WHERE id = ?", id)
+ .fetch_optional(executor)
+ .await?;
+
+ Ok(scopes.map(|s| s.map(Box::from)))
+}
+
+pub async fn get_client_secret<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Option<PasswordHash>, RawUnexpected> {
+ let hash = query_as!(
+ HashRow,
+ r"SELECT secret_hash, secret_salt, secret_version
+ FROM clients WHERE id = ?",
+ id
+ )
+ .fetch_optional(executor)
+ .await?;
+
+ let Some(hash) = hash else { return Ok(None) };
+ let Some(version) = hash.secret_version else { return Ok(None) };
+ let Some(salt) = hash.secret_hash else { return Ok(None) };
+ let Some(hash) = hash.secret_salt else { return Ok(None) };
+
+ let hash = PasswordHash::from_fields(&hash, &salt, version as u8);
+ Ok(Some(hash))
+}
+
+pub async fn is_client_trusted<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Option<bool>, RawUnexpected> {
+ query_scalar!("SELECT trusted as `t: bool` FROM clients WHERE id = ?", id)
+ .fetch_optional(executor)
+ .await
+ .unexpect()
+}
+
+pub async fn get_client_redirect_uris<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<Box<[Url]>, RawUnexpected> {
+ let uris = query_scalar!(
+ "SELECT redirect_uri FROM client_redirect_uris WHERE client_id = ?",
+ id
+ )
+ .fetch_all(executor)
+ .await
+ .unexpect()?;
+
+ uris.into_iter()
+ .map(|s| Url::from_str(&s).unexpect())
+ .collect()
+}
+
+pub async fn client_has_redirect_uri<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+ url: &Url,
+) -> Result<bool, RawUnexpected> {
+ query_scalar!(
+ r"SELECT EXISTS(
+ SELECT redirect_uri
+ FROM client_redirect_uris
+ WHERE client_id = ? AND redirect_uri = ?
+ ) as `e: bool`",
+ id,
+ url.to_string()
+ )
+ .fetch_one(executor)
+ .await
+ .unexpect()
+}
+
+async fn delete_client_redirect_uris<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<(), sqlx::Error> {
+ query!("DELETE FROM client_redirect_uris WHERE client_id = ?", id)
+ .execute(executor)
+ .await?;
+ Ok(())
+}
+
+async fn create_client_redirect_uris<'c>(
+ mut transaction: Transaction<'c, MySql>,
+ client_id: Uuid,
+ uris: &[Url],
+) -> Result<(), sqlx::Error> {
+ for uri in uris {
+ query!(
+ r"INSERT INTO client_redirect_uris (client_id, redirect_uri)
+ VALUES ( ?, ?)",
+ client_id,
+ uri.to_string()
+ )
+ .execute(&mut transaction)
+ .await?;
+ }
+
+ transaction.commit().await?;
+
+ Ok(())
+}
+
+pub async fn create_client<'c>(
+ mut transaction: Transaction<'c, MySql>,
+ client: &Client,
+) -> Result<(), sqlx::Error> {
+ query!(
+ r"INSERT INTO clients (id, alias, type, secret_hash, secret_salt, secret_version, allowed_scopes, default_scopes)
+ VALUES ( ?, ?, ?, ?, ?, ?, ?, ?)",
+ client.id(),
+ client.alias(),
+ client.client_type(),
+ client.secret_hash(),
+ client.secret_salt(),
+ client.secret_version(),
+ client.allowed_scopes(),
+ client.default_scopes()
+ )
+ .execute(&mut transaction)
+ .await?;
+
+ create_client_redirect_uris(transaction, client.id(), client.redirect_uris()).await?;
+
+ Ok(())
+}
+
+pub async fn update_client<'c>(
+ mut transaction: Transaction<'c, MySql>,
+ client: &Client,
+) -> Result<(), sqlx::Error> {
+ query!(
+ r"UPDATE clients SET
+ alias = ?,
+ type = ?,
+ secret_hash = ?,
+ secret_salt = ?,
+ secret_version = ?,
+ allowed_scopes = ?,
+ default_scopes = ?
+ WHERE id = ?",
+ client.client_type(),
+ client.alias(),
+ client.secret_hash(),
+ client.secret_salt(),
+ client.secret_version(),
+ client.allowed_scopes(),
+ client.default_scopes(),
+ client.id()
+ )
+ .execute(&mut transaction)
+ .await?;
+
+ update_client_redirect_uris(transaction, client.id(), client.redirect_uris()).await?;
+
+ Ok(())
+}
+
+pub async fn update_client_alias<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+ alias: &str,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!("UPDATE clients SET alias = ? WHERE id = ?", alias, id)
+ .execute(executor)
+ .await
+}
+
+pub async fn update_client_type<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+ ty: ClientType,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!("UPDATE clients SET type = ? WHERE id = ?", ty, id)
+ .execute(executor)
+ .await
+}
+
+pub async fn update_client_allowed_scopes<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+ allowed_scopes: &str,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!(
+ "UPDATE clients SET allowed_scopes = ? WHERE id = ?",
+ allowed_scopes,
+ id
+ )
+ .execute(executor)
+ .await
+}
+
+pub async fn update_client_default_scopes<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+ default_scopes: Option<String>,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!(
+ "UPDATE clients SET default_scopes = ? WHERE id = ?",
+ default_scopes,
+ id
+ )
+ .execute(executor)
+ .await
+}
+
+pub async fn update_client_trusted<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+ is_trusted: bool,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!(
+ "UPDATE clients SET trusted = ? WHERE id = ?",
+ is_trusted,
+ id
+ )
+ .execute(executor)
+ .await
+}
+
+pub async fn update_client_redirect_uris<'c>(
+ mut transaction: Transaction<'c, MySql>,
+ id: Uuid,
+ uris: &[Url],
+) -> Result<(), sqlx::Error> {
+ delete_client_redirect_uris(&mut transaction, id).await?;
+ create_client_redirect_uris(transaction, id, uris).await?;
+ Ok(())
+}
+
+pub async fn update_client_secret<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+ secret: Option<PasswordHash>,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ if let Some(secret) = secret {
+ query!(
+ "UPDATE clients SET secret_hash = ?, secret_salt = ?, secret_version = ? WHERE id = ?",
+ secret.hash(),
+ secret.salt(),
+ secret.version(),
+ id
+ )
+ .execute(executor)
+ .await
+ } else {
+ query!(
+ r"UPDATE clients
+ SET secret_hash = NULL, secret_salt = NULL, secret_version = NULL
+ WHERE id = ?",
+ id
+ )
+ .execute(executor)
+ .await
+ }
+}
diff --git a/src/services/db/jwt.rs b/src/services/db/jwt.rs index b2f1367..73d6902 100644 --- a/src/services/db/jwt.rs +++ b/src/services/db/jwt.rs @@ -1,199 +1,199 @@ -use chrono::{DateTime, Utc}; -use exun::{RawUnexpected, ResultErrorExt}; -use sqlx::{query, query_scalar, Executor, MySql}; -use uuid::Uuid; - -use crate::services::jwt::RevokedRefreshTokenReason; - -pub async fn auth_code_exists<'c>( - executor: impl Executor<'c, Database = MySql>, - jti: Uuid, -) -> Result<bool, RawUnexpected> { - query_scalar!( - "SELECT EXISTS(SELECT jti FROM auth_codes WHERE jti = ?) as `e: bool`", - jti - ) - .fetch_one(executor) - .await - .unexpect() -} - -pub async fn access_token_exists<'c>( - executor: impl Executor<'c, Database = MySql>, - jti: Uuid, -) -> Result<bool, RawUnexpected> { - query_scalar!( - "SELECT EXISTS(SELECT jti FROM access_tokens WHERE jti = ?) as `e: bool`", - jti - ) - .fetch_one(executor) - .await - .unexpect() -} - -pub async fn refresh_token_exists<'c>( - executor: impl Executor<'c, Database = MySql>, - jti: Uuid, -) -> Result<bool, RawUnexpected> { - query_scalar!( - "SELECT EXISTS(SELECT jti FROM refresh_tokens WHERE jti = ?) as `e: bool`", - jti - ) - .fetch_one(executor) - .await - .unexpect() -} - -pub async fn refresh_token_revoked<'c>( - executor: impl Executor<'c, Database = MySql>, - jti: Uuid, -) -> Result<bool, RawUnexpected> { - let result = query_scalar!( - r"SELECT EXISTS( - SELECT revoked_reason FROM refresh_tokens WHERE jti = ? and revoked_reason IS NOT NULL - ) as `e: bool`", - jti - ) - .fetch_one(executor) - .await? - .unwrap_or(true); - - Ok(result) -} - -pub async fn create_auth_code<'c>( - executor: impl Executor<'c, Database = MySql>, - jti: Uuid, - exp: DateTime<Utc>, -) -> Result<(), sqlx::Error> { - query!( - r"INSERT INTO auth_codes (jti, exp) - VALUES ( ?, ?)", - jti, - exp - ) - .execute(executor) - .await?; - - Ok(()) -} - -pub async fn create_access_token<'c>( - executor: impl Executor<'c, Database = MySql>, - jti: Uuid, - auth_code: Option<Uuid>, - exp: DateTime<Utc>, -) -> Result<(), sqlx::Error> { - query!( - r"INSERT INTO access_tokens (jti, auth_code, exp) - VALUES ( ?, ?, ?)", - jti, - auth_code, - exp - ) - .execute(executor) - .await?; - - Ok(()) -} - -pub async fn create_refresh_token<'c>( - executor: impl Executor<'c, Database = MySql>, - jti: Uuid, - auth_code: Option<Uuid>, - exp: DateTime<Utc>, -) -> Result<(), sqlx::Error> { - query!( - r"INSERT INTO access_tokens (jti, auth_code, exp) - VALUES ( ?, ?, ?)", - jti, - auth_code, - exp - ) - .execute(executor) - .await?; - - Ok(()) -} - -pub async fn delete_auth_code<'c>( - executor: impl Executor<'c, Database = MySql>, - auth_code: Uuid, -) -> Result<bool, RawUnexpected> { - let result = query!("DELETE FROM auth_codes WHERE jti = ?", auth_code) - .execute(executor) - .await?; - - Ok(result.rows_affected() != 0) -} - -pub async fn delete_expired_auth_codes<'c>( - executor: impl Executor<'c, Database = MySql>, -) -> Result<(), RawUnexpected> { - query!("DELETE FROM auth_codes WHERE exp < ?", Utc::now()) - .execute(executor) - .await?; - - Ok(()) -} - -pub async fn delete_access_tokens_with_auth_code<'c>( - executor: impl Executor<'c, Database = MySql>, - auth_code: Uuid, -) -> Result<bool, RawUnexpected> { - let result = query!("DELETE FROM access_tokens WHERE auth_code = ?", auth_code) - .execute(executor) - .await?; - - Ok(result.rows_affected() != 0) -} - -pub async fn delete_expired_access_tokens<'c>( - executor: impl Executor<'c, Database = MySql>, -) -> Result<(), RawUnexpected> { - query!("DELETE FROM access_tokens WHERE exp < ?", Utc::now()) - .execute(executor) - .await?; - - Ok(()) -} - -pub async fn revoke_refresh_token<'c>( - executor: impl Executor<'c, Database = MySql>, - jti: Uuid, -) -> Result<bool, RawUnexpected> { - let result = query!( - "UPDATE refresh_tokens SET revoked_reason = ? WHERE jti = ?", - RevokedRefreshTokenReason::NewRefreshToken, - jti - ) - .execute(executor) - .await?; - - Ok(result.rows_affected() != 0) -} - -pub async fn revoke_refresh_tokens_with_auth_code<'c>( - executor: impl Executor<'c, Database = MySql>, - auth_code: Uuid, -) -> Result<bool, RawUnexpected> { - let result = query!( - "UPDATE refresh_tokens SET revoked_reason = ? WHERE auth_code = ?", - RevokedRefreshTokenReason::ReusedAuthorizationCode, - auth_code - ) - .execute(executor) - .await?; - - Ok(result.rows_affected() != 0) -} - -pub async fn delete_expired_refresh_tokens<'c>( - executor: impl Executor<'c, Database = MySql>, -) -> Result<(), RawUnexpected> { - query!("DELETE FROM refresh_tokens WHERE exp < ?", Utc::now()) - .execute(executor) - .await?; - - Ok(()) -} +use chrono::{DateTime, Utc};
+use exun::{RawUnexpected, ResultErrorExt};
+use sqlx::{query, query_scalar, Executor, MySql};
+use uuid::Uuid;
+
+use crate::services::jwt::RevokedRefreshTokenReason;
+
+pub async fn auth_code_exists<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ jti: Uuid,
+) -> Result<bool, RawUnexpected> {
+ query_scalar!(
+ "SELECT EXISTS(SELECT jti FROM auth_codes WHERE jti = ?) as `e: bool`",
+ jti
+ )
+ .fetch_one(executor)
+ .await
+ .unexpect()
+}
+
+pub async fn access_token_exists<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ jti: Uuid,
+) -> Result<bool, RawUnexpected> {
+ query_scalar!(
+ "SELECT EXISTS(SELECT jti FROM access_tokens WHERE jti = ?) as `e: bool`",
+ jti
+ )
+ .fetch_one(executor)
+ .await
+ .unexpect()
+}
+
+pub async fn refresh_token_exists<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ jti: Uuid,
+) -> Result<bool, RawUnexpected> {
+ query_scalar!(
+ "SELECT EXISTS(SELECT jti FROM refresh_tokens WHERE jti = ?) as `e: bool`",
+ jti
+ )
+ .fetch_one(executor)
+ .await
+ .unexpect()
+}
+
+pub async fn refresh_token_revoked<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ jti: Uuid,
+) -> Result<bool, RawUnexpected> {
+ let result = query_scalar!(
+ r"SELECT EXISTS(
+ SELECT revoked_reason FROM refresh_tokens WHERE jti = ? and revoked_reason IS NOT NULL
+ ) as `e: bool`",
+ jti
+ )
+ .fetch_one(executor)
+ .await?
+ .unwrap_or(true);
+
+ Ok(result)
+}
+
+pub async fn create_auth_code<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ jti: Uuid,
+ exp: DateTime<Utc>,
+) -> Result<(), sqlx::Error> {
+ query!(
+ r"INSERT INTO auth_codes (jti, exp)
+ VALUES ( ?, ?)",
+ jti,
+ exp
+ )
+ .execute(executor)
+ .await?;
+
+ Ok(())
+}
+
+pub async fn create_access_token<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ jti: Uuid,
+ auth_code: Option<Uuid>,
+ exp: DateTime<Utc>,
+) -> Result<(), sqlx::Error> {
+ query!(
+ r"INSERT INTO access_tokens (jti, auth_code, exp)
+ VALUES ( ?, ?, ?)",
+ jti,
+ auth_code,
+ exp
+ )
+ .execute(executor)
+ .await?;
+
+ Ok(())
+}
+
+pub async fn create_refresh_token<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ jti: Uuid,
+ auth_code: Option<Uuid>,
+ exp: DateTime<Utc>,
+) -> Result<(), sqlx::Error> {
+ query!(
+ r"INSERT INTO access_tokens (jti, auth_code, exp)
+ VALUES ( ?, ?, ?)",
+ jti,
+ auth_code,
+ exp
+ )
+ .execute(executor)
+ .await?;
+
+ Ok(())
+}
+
+pub async fn delete_auth_code<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ auth_code: Uuid,
+) -> Result<bool, RawUnexpected> {
+ let result = query!("DELETE FROM auth_codes WHERE jti = ?", auth_code)
+ .execute(executor)
+ .await?;
+
+ Ok(result.rows_affected() != 0)
+}
+
+pub async fn delete_expired_auth_codes<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+) -> Result<(), RawUnexpected> {
+ query!("DELETE FROM auth_codes WHERE exp < ?", Utc::now())
+ .execute(executor)
+ .await?;
+
+ Ok(())
+}
+
+pub async fn delete_access_tokens_with_auth_code<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ auth_code: Uuid,
+) -> Result<bool, RawUnexpected> {
+ let result = query!("DELETE FROM access_tokens WHERE auth_code = ?", auth_code)
+ .execute(executor)
+ .await?;
+
+ Ok(result.rows_affected() != 0)
+}
+
+pub async fn delete_expired_access_tokens<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+) -> Result<(), RawUnexpected> {
+ query!("DELETE FROM access_tokens WHERE exp < ?", Utc::now())
+ .execute(executor)
+ .await?;
+
+ Ok(())
+}
+
+pub async fn revoke_refresh_token<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ jti: Uuid,
+) -> Result<bool, RawUnexpected> {
+ let result = query!(
+ "UPDATE refresh_tokens SET revoked_reason = ? WHERE jti = ?",
+ RevokedRefreshTokenReason::NewRefreshToken,
+ jti
+ )
+ .execute(executor)
+ .await?;
+
+ Ok(result.rows_affected() != 0)
+}
+
+pub async fn revoke_refresh_tokens_with_auth_code<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+ auth_code: Uuid,
+) -> Result<bool, RawUnexpected> {
+ let result = query!(
+ "UPDATE refresh_tokens SET revoked_reason = ? WHERE auth_code = ?",
+ RevokedRefreshTokenReason::ReusedAuthorizationCode,
+ auth_code
+ )
+ .execute(executor)
+ .await?;
+
+ Ok(result.rows_affected() != 0)
+}
+
+pub async fn delete_expired_refresh_tokens<'c>(
+ executor: impl Executor<'c, Database = MySql>,
+) -> Result<(), RawUnexpected> {
+ query!("DELETE FROM refresh_tokens WHERE exp < ?", Utc::now())
+ .execute(executor)
+ .await?;
+
+ Ok(())
+}
diff --git a/src/services/db/user.rs b/src/services/db/user.rs index 09a09da..f85047a 100644 --- a/src/services/db/user.rs +++ b/src/services/db/user.rs @@ -1,236 +1,236 @@ -use exun::RawUnexpected; -use sqlx::{mysql::MySqlQueryResult, query, query_as, query_scalar, Executor, MySql}; -use uuid::Uuid; - -use crate::{models::user::User, services::crypto::PasswordHash}; - -struct UserRow { - id: Uuid, - username: String, - password_hash: Vec<u8>, - password_salt: Vec<u8>, - password_version: u32, -} - -impl TryFrom<UserRow> for User { - type Error = RawUnexpected; - - fn try_from(row: UserRow) -> Result<Self, Self::Error> { - let password = PasswordHash::from_fields( - &row.password_hash, - &row.password_salt, - row.password_version as u8, - ); - let user = User { - id: row.id, - username: row.username.into_boxed_str(), - password, - }; - Ok(user) - } -} - -/// Check if a user with a given user ID exists -pub async fn user_id_exists<'c>( - conn: impl Executor<'c, Database = MySql>, - id: Uuid, -) -> Result<bool, RawUnexpected> { - let exists = query_scalar!( - r#"SELECT EXISTS(SELECT id FROM users WHERE id = ?) as `e: bool`"#, - id - ) - .fetch_one(conn) - .await?; - - Ok(exists) -} - -/// Check if a given username is taken -pub async fn username_is_used<'c>( - conn: impl Executor<'c, Database = MySql>, - username: &str, -) -> Result<bool, RawUnexpected> { - let exists = query_scalar!( - r#"SELECT EXISTS(SELECT id FROM users WHERE username = ?) as "e: bool""#, - username - ) - .fetch_one(conn) - .await?; - - Ok(exists) -} - -/// Get a user from their ID -pub async fn get_user<'c>( - conn: impl Executor<'c, Database = MySql>, - user_id: Uuid, -) -> Result<Option<User>, RawUnexpected> { - let record = query_as!( - UserRow, - r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version - FROM users WHERE id = ?", - user_id - ) - .fetch_optional(conn) - .await?; - - let Some(record) = record else { return Ok(None) }; - - Ok(Some(record.try_into()?)) -} - -/// Get a user from their username -pub async fn get_user_by_username<'c>( - conn: impl Executor<'c, Database = MySql>, - username: &str, -) -> Result<Option<User>, RawUnexpected> { - let record = query_as!( - UserRow, - r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version - FROM users WHERE username = ?", - username - ) - .fetch_optional(conn) - .await?; - - let Some(record) = record else { return Ok(None) }; - - Ok(Some(record.try_into()?)) -} - -/// Search the list of users for a given username -pub async fn search_users<'c>( - conn: impl Executor<'c, Database = MySql>, - username: &str, -) -> Result<Box<[User]>, RawUnexpected> { - let records = query_as!( - UserRow, - r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version - FROM users - WHERE LOCATE(?, username) != 0", - username, - ) - .fetch_all(conn) - .await?; - - Ok(records - .into_iter() - .map(|u| u.try_into()) - .collect::<Result<Box<[User]>, RawUnexpected>>()?) -} - -/// Search the list of users, only returning a certain range of results -pub async fn search_users_limit<'c>( - conn: impl Executor<'c, Database = MySql>, - username: &str, - offset: u32, - limit: u32, -) -> Result<Box<[User]>, RawUnexpected> { - let records = query_as!( - UserRow, - r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version - FROM users - WHERE LOCATE(?, username) != 0 - LIMIT ? - OFFSET ?", - username, - offset, - limit - ) - .fetch_all(conn) - .await?; - - Ok(records - .into_iter() - .map(|u| u.try_into()) - .collect::<Result<Box<[User]>, RawUnexpected>>()?) -} - -/// Get the username of a user with a certain ID -pub async fn get_username<'c>( - conn: impl Executor<'c, Database = MySql>, - user_id: Uuid, -) -> Result<Option<Box<str>>, RawUnexpected> { - let username = query_scalar!(r"SELECT username FROM users where id = ?", user_id) - .fetch_optional(conn) - .await? - .map(String::into_boxed_str); - - Ok(username) -} - -/// Create a new user -pub async fn create_user<'c>( - conn: impl Executor<'c, Database = MySql>, - user: &User, -) -> Result<MySqlQueryResult, sqlx::Error> { - query!( - r"INSERT INTO users (id, username, password_hash, password_salt, password_version) - VALUES ( ?, ?, ?, ?, ?)", - user.id, - user.username(), - user.password_hash(), - user.password_salt(), - user.password_version() - ) - .execute(conn) - .await -} - -/// Update a user -pub async fn update_user<'c>( - conn: impl Executor<'c, Database = MySql>, - user: &User, -) -> Result<MySqlQueryResult, sqlx::Error> { - query!( - r"UPDATE users SET - username = ?, - password_hash = ?, - password_salt = ?, - password_version = ? - WHERE id = ?", - user.username(), - user.password_hash(), - user.password_salt(), - user.password_version(), - user.id - ) - .execute(conn) - .await -} - -/// Update the username of a user with the given ID -pub async fn update_username<'c>( - conn: impl Executor<'c, Database = MySql>, - user_id: Uuid, - username: &str, -) -> Result<MySqlQueryResult, sqlx::Error> { - query!( - r"UPDATE users SET username = ? WHERE id = ?", - username, - user_id - ) - .execute(conn) - .await -} - -/// Update the password of a user with the given ID -pub async fn update_password<'c>( - conn: impl Executor<'c, Database = MySql>, - user_id: Uuid, - password: &PasswordHash, -) -> Result<MySqlQueryResult, sqlx::Error> { - query!( - r"UPDATE users SET - password_hash = ?, - password_salt = ?, - password_version = ? - WHERE id = ?", - password.hash(), - password.salt(), - password.version(), - user_id - ) - .execute(conn) - .await -} +use exun::RawUnexpected;
+use sqlx::{mysql::MySqlQueryResult, query, query_as, query_scalar, Executor, MySql};
+use uuid::Uuid;
+
+use crate::{models::user::User, services::crypto::PasswordHash};
+
+struct UserRow {
+ id: Uuid,
+ username: String,
+ password_hash: Vec<u8>,
+ password_salt: Vec<u8>,
+ password_version: u32,
+}
+
+impl TryFrom<UserRow> for User {
+ type Error = RawUnexpected;
+
+ fn try_from(row: UserRow) -> Result<Self, Self::Error> {
+ let password = PasswordHash::from_fields(
+ &row.password_hash,
+ &row.password_salt,
+ row.password_version as u8,
+ );
+ let user = User {
+ id: row.id,
+ username: row.username.into_boxed_str(),
+ password,
+ };
+ Ok(user)
+ }
+}
+
+/// Check if a user with a given user ID exists
+pub async fn user_id_exists<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ id: Uuid,
+) -> Result<bool, RawUnexpected> {
+ let exists = query_scalar!(
+ r#"SELECT EXISTS(SELECT id FROM users WHERE id = ?) as `e: bool`"#,
+ id
+ )
+ .fetch_one(conn)
+ .await?;
+
+ Ok(exists)
+}
+
+/// Check if a given username is taken
+pub async fn username_is_used<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ username: &str,
+) -> Result<bool, RawUnexpected> {
+ let exists = query_scalar!(
+ r#"SELECT EXISTS(SELECT id FROM users WHERE username = ?) as "e: bool""#,
+ username
+ )
+ .fetch_one(conn)
+ .await?;
+
+ Ok(exists)
+}
+
+/// Get a user from their ID
+pub async fn get_user<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ user_id: Uuid,
+) -> Result<Option<User>, RawUnexpected> {
+ let record = query_as!(
+ UserRow,
+ r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version
+ FROM users WHERE id = ?",
+ user_id
+ )
+ .fetch_optional(conn)
+ .await?;
+
+ let Some(record) = record else { return Ok(None) };
+
+ Ok(Some(record.try_into()?))
+}
+
+/// Get a user from their username
+pub async fn get_user_by_username<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ username: &str,
+) -> Result<Option<User>, RawUnexpected> {
+ let record = query_as!(
+ UserRow,
+ r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version
+ FROM users WHERE username = ?",
+ username
+ )
+ .fetch_optional(conn)
+ .await?;
+
+ let Some(record) = record else { return Ok(None) };
+
+ Ok(Some(record.try_into()?))
+}
+
+/// Search the list of users for a given username
+pub async fn search_users<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ username: &str,
+) -> Result<Box<[User]>, RawUnexpected> {
+ let records = query_as!(
+ UserRow,
+ r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version
+ FROM users
+ WHERE LOCATE(?, username) != 0",
+ username,
+ )
+ .fetch_all(conn)
+ .await?;
+
+ Ok(records
+ .into_iter()
+ .map(|u| u.try_into())
+ .collect::<Result<Box<[User]>, RawUnexpected>>()?)
+}
+
+/// Search the list of users, only returning a certain range of results
+pub async fn search_users_limit<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ username: &str,
+ offset: u32,
+ limit: u32,
+) -> Result<Box<[User]>, RawUnexpected> {
+ let records = query_as!(
+ UserRow,
+ r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version
+ FROM users
+ WHERE LOCATE(?, username) != 0
+ LIMIT ?
+ OFFSET ?",
+ username,
+ offset,
+ limit
+ )
+ .fetch_all(conn)
+ .await?;
+
+ Ok(records
+ .into_iter()
+ .map(|u| u.try_into())
+ .collect::<Result<Box<[User]>, RawUnexpected>>()?)
+}
+
+/// Get the username of a user with a certain ID
+pub async fn get_username<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ user_id: Uuid,
+) -> Result<Option<Box<str>>, RawUnexpected> {
+ let username = query_scalar!(r"SELECT username FROM users where id = ?", user_id)
+ .fetch_optional(conn)
+ .await?
+ .map(String::into_boxed_str);
+
+ Ok(username)
+}
+
+/// Create a new user
+pub async fn create_user<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ user: &User,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!(
+ r"INSERT INTO users (id, username, password_hash, password_salt, password_version)
+ VALUES ( ?, ?, ?, ?, ?)",
+ user.id,
+ user.username(),
+ user.password_hash(),
+ user.password_salt(),
+ user.password_version()
+ )
+ .execute(conn)
+ .await
+}
+
+/// Update a user
+pub async fn update_user<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ user: &User,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!(
+ r"UPDATE users SET
+ username = ?,
+ password_hash = ?,
+ password_salt = ?,
+ password_version = ?
+ WHERE id = ?",
+ user.username(),
+ user.password_hash(),
+ user.password_salt(),
+ user.password_version(),
+ user.id
+ )
+ .execute(conn)
+ .await
+}
+
+/// Update the username of a user with the given ID
+pub async fn update_username<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ user_id: Uuid,
+ username: &str,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!(
+ r"UPDATE users SET username = ? WHERE id = ?",
+ username,
+ user_id
+ )
+ .execute(conn)
+ .await
+}
+
+/// Update the password of a user with the given ID
+pub async fn update_password<'c>(
+ conn: impl Executor<'c, Database = MySql>,
+ user_id: Uuid,
+ password: &PasswordHash,
+) -> Result<MySqlQueryResult, sqlx::Error> {
+ query!(
+ r"UPDATE users SET
+ password_hash = ?,
+ password_salt = ?,
+ password_version = ?
+ WHERE id = ?",
+ password.hash(),
+ password.salt(),
+ password.version(),
+ user_id
+ )
+ .execute(conn)
+ .await
+}
diff --git a/src/services/id.rs b/src/services/id.rs index 0c665ed..e1227e4 100644 --- a/src/services/id.rs +++ b/src/services/id.rs @@ -1,27 +1,27 @@ -use std::future::Future; - -use exun::RawUnexpected; -use sqlx::{Executor, MySql}; -use uuid::Uuid; - -/// Create a unique id, handling duplicate ID's. -/// -/// The given `unique_check` parameter returns `true` if the ID is used and -/// `false` otherwise. -pub async fn new_id< - 'c, - E: Executor<'c, Database = MySql> + Clone, - F: Future<Output = Result<bool, RawUnexpected>>, ->( - conn: E, - unique_check: impl Fn(E, Uuid) -> F, -) -> Result<Uuid, RawUnexpected> { - let uuid = loop { - let uuid = Uuid::new_v4(); - if !unique_check(conn.clone(), uuid).await? { - break uuid; - } - }; - - Ok(uuid) -} +use std::future::Future;
+
+use exun::RawUnexpected;
+use sqlx::{Executor, MySql};
+use uuid::Uuid;
+
+/// Create a unique id, handling duplicate ID's.
+///
+/// The given `unique_check` parameter returns `true` if the ID is used and
+/// `false` otherwise.
+pub async fn new_id<
+ 'c,
+ E: Executor<'c, Database = MySql> + Clone,
+ F: Future<Output = Result<bool, RawUnexpected>>,
+>(
+ conn: E,
+ unique_check: impl Fn(E, Uuid) -> F,
+) -> Result<Uuid, RawUnexpected> {
+ let uuid = loop {
+ let uuid = Uuid::new_v4();
+ if !unique_check(conn.clone(), uuid).await? {
+ break uuid;
+ }
+ };
+
+ Ok(uuid)
+}
diff --git a/src/services/jwt.rs b/src/services/jwt.rs index 16f5fa6..863eb83 100644 --- a/src/services/jwt.rs +++ b/src/services/jwt.rs @@ -1,291 +1,291 @@ -use chrono::{serde::ts_milliseconds, serde::ts_milliseconds_option, DateTime, Duration, Utc}; -use exun::{Expect, RawUnexpected, ResultErrorExt}; -use jwt::{SignWithKey, VerifyWithKey}; -use raise::yeet; -use serde::{Deserialize, Serialize}; -use sqlx::{Executor, MySql, MySqlPool}; -use thiserror::Error; -use url::Url; -use uuid::Uuid; - -use super::{db, id::new_id, secrets}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] -pub enum TokenType { - Authorization, - Access, - Refresh, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Claims { - iss: Url, - sub: Uuid, - aud: Box<[String]>, - #[serde(with = "ts_milliseconds")] - exp: DateTime<Utc>, - #[serde(with = "ts_milliseconds_option")] - nbf: Option<DateTime<Utc>>, - #[serde(with = "ts_milliseconds")] - iat: DateTime<Utc>, - jti: Uuid, - scope: Box<str>, - client_id: Uuid, - token_type: TokenType, - auth_code_id: Option<Uuid>, - redirect_uri: Option<Url>, -} - -#[derive(Debug, Clone, Copy, sqlx::Type)] -#[sqlx(rename_all = "kebab-case")] -pub enum RevokedRefreshTokenReason { - ReusedAuthorizationCode, - NewRefreshToken, -} - -impl Claims { - pub async fn auth_code<'c>( - db: &MySqlPool, - self_id: Url, - client_id: Uuid, - sub: Uuid, - scopes: &str, - redirect_uri: &Url, - ) -> Result<Self, RawUnexpected> { - let five_minutes = Duration::minutes(5); - - let id = new_id(db, db::auth_code_exists).await?; - let iat = Utc::now(); - let exp = iat + five_minutes; - - db::create_auth_code(db, id, exp).await?; - - let aud = [self_id.to_string(), client_id.to_string()].into(); - - Ok(Self { - iss: self_id, - sub, - aud, - exp, - nbf: None, - iat, - jti: id, - scope: scopes.into(), - client_id, - auth_code_id: Some(id), - token_type: TokenType::Authorization, - redirect_uri: Some(redirect_uri.clone()), - }) - } - - pub async fn access_token<'c>( - db: &MySqlPool, - auth_code_id: Option<Uuid>, - self_id: Url, - client_id: Uuid, - sub: Uuid, - duration: Duration, - scopes: &str, - ) -> Result<Self, RawUnexpected> { - let id = new_id(db, db::access_token_exists).await?; - let iat = Utc::now(); - let exp = iat + duration; - - db::create_access_token(db, id, auth_code_id, exp) - .await - .unexpect()?; - - let aud = [self_id.to_string(), client_id.to_string()].into(); - - Ok(Self { - iss: self_id, - sub, - aud, - exp, - nbf: None, - iat, - jti: id, - scope: scopes.into(), - client_id, - auth_code_id, - token_type: TokenType::Access, - redirect_uri: None, - }) - } - - pub async fn refresh_token( - db: &MySqlPool, - other_token: &Claims, - ) -> Result<Self, RawUnexpected> { - let one_day = Duration::days(1); - - let id = new_id(db, db::refresh_token_exists).await?; - let iat = Utc::now(); - let exp = other_token.exp + one_day; - - db::create_refresh_token(db, id, other_token.auth_code_id, exp).await?; - - let mut claims = other_token.clone(); - claims.exp = exp; - claims.iat = iat; - claims.jti = id; - claims.token_type = TokenType::Refresh; - - Ok(claims) - } - - pub async fn refreshed_access_token( - db: &MySqlPool, - refresh_token: &Claims, - exp_time: Duration, - ) -> Result<Self, RawUnexpected> { - let id = new_id(db, db::access_token_exists).await?; - let iat = Utc::now(); - let exp = iat + exp_time; - - db::create_access_token(db, id, refresh_token.auth_code_id, exp).await?; - - let mut claims = refresh_token.clone(); - claims.exp = exp; - claims.iat = iat; - claims.jti = id; - claims.token_type = TokenType::Access; - - Ok(claims) - } - - pub fn id(&self) -> Uuid { - self.jti - } - - pub fn subject(&self) -> Uuid { - self.sub - } - - pub fn expires_in(&self) -> i64 { - (self.exp - Utc::now()).num_seconds() - } - - pub fn scopes(&self) -> &str { - &self.scope - } - - pub fn to_jwt(&self) -> Result<Box<str>, RawUnexpected> { - let key = secrets::signing_key()?; - let jwt = self.sign_with_key(&key)?.into_boxed_str(); - Ok(jwt) - } -} - -#[derive(Debug, Error)] -pub enum VerifyJwtError { - #[error("{0}")] - ParseJwtError(#[from] jwt::Error), - #[error("The issuer for this token is incorrect")] - IncorrectIssuer, - #[error("This bearer token was intended for a different client")] - WrongClient, - #[error("The given audience parameter does not contain this issuer")] - BadAudience, - #[error("The redirect URI doesn't match what's in the token")] - IncorrectRedirectUri, - #[error("The token is expired")] - ExpiredToken, - #[error("The token cannot be used yet")] - NotYet, - #[error("The bearer token has been revoked")] - JwtRevoked, -} - -fn verify_jwt( - token: &str, - self_id: &Url, - client_id: Option<Uuid>, -) -> Result<Claims, Expect<VerifyJwtError>> { - let key = secrets::signing_key()?; - let claims: Claims = token - .verify_with_key(&key) - .map_err(|e| VerifyJwtError::from(e))?; - - if &claims.iss != self_id { - yeet!(VerifyJwtError::IncorrectIssuer.into()) - } - - if let Some(client_id) = client_id { - if claims.client_id != client_id { - yeet!(VerifyJwtError::WrongClient.into()) - } - } - - if !claims.aud.contains(&self_id.to_string()) { - yeet!(VerifyJwtError::BadAudience.into()) - } - - let now = Utc::now(); - - if now > claims.exp { - yeet!(VerifyJwtError::ExpiredToken.into()) - } - - if let Some(nbf) = claims.nbf { - if now < nbf { - yeet!(VerifyJwtError::NotYet.into()) - } - } - - Ok(claims) -} - -pub async fn verify_auth_code<'c>( - db: &MySqlPool, - token: &str, - self_id: &Url, - client_id: Uuid, - redirect_uri: Url, -) -> Result<Claims, Expect<VerifyJwtError>> { - let claims = verify_jwt(token, self_id, Some(client_id))?; - - if let Some(claimed_uri) = &claims.redirect_uri { - if claimed_uri.clone() != redirect_uri { - yeet!(VerifyJwtError::IncorrectRedirectUri.into()); - } - } - - if db::delete_auth_code(db, claims.jti).await? { - db::delete_access_tokens_with_auth_code(db, claims.jti).await?; - db::revoke_refresh_tokens_with_auth_code(db, claims.jti).await?; - yeet!(VerifyJwtError::JwtRevoked.into()); - } - - Ok(claims) -} - -pub async fn verify_access_token<'c>( - db: impl Executor<'c, Database = MySql>, - token: &str, - self_id: &Url, - client_id: Uuid, -) -> Result<Claims, Expect<VerifyJwtError>> { - let claims = verify_jwt(token, self_id, Some(client_id))?; - - if !db::access_token_exists(db, claims.jti).await? { - yeet!(VerifyJwtError::JwtRevoked.into()) - } - - Ok(claims) -} - -pub async fn verify_refresh_token<'c>( - db: impl Executor<'c, Database = MySql>, - token: &str, - self_id: &Url, - client_id: Option<Uuid>, -) -> Result<Claims, Expect<VerifyJwtError>> { - let claims = verify_jwt(token, self_id, client_id)?; - - if db::refresh_token_revoked(db, claims.jti).await? { - yeet!(VerifyJwtError::JwtRevoked.into()) - } - - Ok(claims) -} +use chrono::{serde::ts_milliseconds, serde::ts_milliseconds_option, DateTime, Duration, Utc};
+use exun::{Expect, RawUnexpected, ResultErrorExt};
+use jwt::{SignWithKey, VerifyWithKey};
+use raise::yeet;
+use serde::{Deserialize, Serialize};
+use sqlx::{Executor, MySql, MySqlPool};
+use thiserror::Error;
+use url::Url;
+use uuid::Uuid;
+
+use super::{db, id::new_id, secrets};
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
+pub enum TokenType {
+ Authorization,
+ Access,
+ Refresh,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Claims {
+ iss: Url,
+ sub: Uuid,
+ aud: Box<[String]>,
+ #[serde(with = "ts_milliseconds")]
+ exp: DateTime<Utc>,
+ #[serde(with = "ts_milliseconds_option")]
+ nbf: Option<DateTime<Utc>>,
+ #[serde(with = "ts_milliseconds")]
+ iat: DateTime<Utc>,
+ jti: Uuid,
+ scope: Box<str>,
+ client_id: Uuid,
+ token_type: TokenType,
+ auth_code_id: Option<Uuid>,
+ redirect_uri: Option<Url>,
+}
+
+#[derive(Debug, Clone, Copy, sqlx::Type)]
+#[sqlx(rename_all = "kebab-case")]
+pub enum RevokedRefreshTokenReason {
+ ReusedAuthorizationCode,
+ NewRefreshToken,
+}
+
+impl Claims {
+ pub async fn auth_code<'c>(
+ db: &MySqlPool,
+ self_id: Url,
+ client_id: Uuid,
+ sub: Uuid,
+ scopes: &str,
+ redirect_uri: &Url,
+ ) -> Result<Self, RawUnexpected> {
+ let five_minutes = Duration::minutes(5);
+
+ let id = new_id(db, db::auth_code_exists).await?;
+ let iat = Utc::now();
+ let exp = iat + five_minutes;
+
+ db::create_auth_code(db, id, exp).await?;
+
+ let aud = [self_id.to_string(), client_id.to_string()].into();
+
+ Ok(Self {
+ iss: self_id,
+ sub,
+ aud,
+ exp,
+ nbf: None,
+ iat,
+ jti: id,
+ scope: scopes.into(),
+ client_id,
+ auth_code_id: Some(id),
+ token_type: TokenType::Authorization,
+ redirect_uri: Some(redirect_uri.clone()),
+ })
+ }
+
+ pub async fn access_token<'c>(
+ db: &MySqlPool,
+ auth_code_id: Option<Uuid>,
+ self_id: Url,
+ client_id: Uuid,
+ sub: Uuid,
+ duration: Duration,
+ scopes: &str,
+ ) -> Result<Self, RawUnexpected> {
+ let id = new_id(db, db::access_token_exists).await?;
+ let iat = Utc::now();
+ let exp = iat + duration;
+
+ db::create_access_token(db, id, auth_code_id, exp)
+ .await
+ .unexpect()?;
+
+ let aud = [self_id.to_string(), client_id.to_string()].into();
+
+ Ok(Self {
+ iss: self_id,
+ sub,
+ aud,
+ exp,
+ nbf: None,
+ iat,
+ jti: id,
+ scope: scopes.into(),
+ client_id,
+ auth_code_id,
+ token_type: TokenType::Access,
+ redirect_uri: None,
+ })
+ }
+
+ pub async fn refresh_token(
+ db: &MySqlPool,
+ other_token: &Claims,
+ ) -> Result<Self, RawUnexpected> {
+ let one_day = Duration::days(1);
+
+ let id = new_id(db, db::refresh_token_exists).await?;
+ let iat = Utc::now();
+ let exp = other_token.exp + one_day;
+
+ db::create_refresh_token(db, id, other_token.auth_code_id, exp).await?;
+
+ let mut claims = other_token.clone();
+ claims.exp = exp;
+ claims.iat = iat;
+ claims.jti = id;
+ claims.token_type = TokenType::Refresh;
+
+ Ok(claims)
+ }
+
+ pub async fn refreshed_access_token(
+ db: &MySqlPool,
+ refresh_token: &Claims,
+ exp_time: Duration,
+ ) -> Result<Self, RawUnexpected> {
+ let id = new_id(db, db::access_token_exists).await?;
+ let iat = Utc::now();
+ let exp = iat + exp_time;
+
+ db::create_access_token(db, id, refresh_token.auth_code_id, exp).await?;
+
+ let mut claims = refresh_token.clone();
+ claims.exp = exp;
+ claims.iat = iat;
+ claims.jti = id;
+ claims.token_type = TokenType::Access;
+
+ Ok(claims)
+ }
+
+ pub fn id(&self) -> Uuid {
+ self.jti
+ }
+
+ pub fn subject(&self) -> Uuid {
+ self.sub
+ }
+
+ pub fn expires_in(&self) -> i64 {
+ (self.exp - Utc::now()).num_seconds()
+ }
+
+ pub fn scopes(&self) -> &str {
+ &self.scope
+ }
+
+ pub fn to_jwt(&self) -> Result<Box<str>, RawUnexpected> {
+ let key = secrets::signing_key()?;
+ let jwt = self.sign_with_key(&key)?.into_boxed_str();
+ Ok(jwt)
+ }
+}
+
+#[derive(Debug, Error)]
+pub enum VerifyJwtError {
+ #[error("{0}")]
+ ParseJwtError(#[from] jwt::Error),
+ #[error("The issuer for this token is incorrect")]
+ IncorrectIssuer,
+ #[error("This bearer token was intended for a different client")]
+ WrongClient,
+ #[error("The given audience parameter does not contain this issuer")]
+ BadAudience,
+ #[error("The redirect URI doesn't match what's in the token")]
+ IncorrectRedirectUri,
+ #[error("The token is expired")]
+ ExpiredToken,
+ #[error("The token cannot be used yet")]
+ NotYet,
+ #[error("The bearer token has been revoked")]
+ JwtRevoked,
+}
+
+fn verify_jwt(
+ token: &str,
+ self_id: &Url,
+ client_id: Option<Uuid>,
+) -> Result<Claims, Expect<VerifyJwtError>> {
+ let key = secrets::signing_key()?;
+ let claims: Claims = token
+ .verify_with_key(&key)
+ .map_err(|e| VerifyJwtError::from(e))?;
+
+ if &claims.iss != self_id {
+ yeet!(VerifyJwtError::IncorrectIssuer.into())
+ }
+
+ if let Some(client_id) = client_id {
+ if claims.client_id != client_id {
+ yeet!(VerifyJwtError::WrongClient.into())
+ }
+ }
+
+ if !claims.aud.contains(&self_id.to_string()) {
+ yeet!(VerifyJwtError::BadAudience.into())
+ }
+
+ let now = Utc::now();
+
+ if now > claims.exp {
+ yeet!(VerifyJwtError::ExpiredToken.into())
+ }
+
+ if let Some(nbf) = claims.nbf {
+ if now < nbf {
+ yeet!(VerifyJwtError::NotYet.into())
+ }
+ }
+
+ Ok(claims)
+}
+
+pub async fn verify_auth_code<'c>(
+ db: &MySqlPool,
+ token: &str,
+ self_id: &Url,
+ client_id: Uuid,
+ redirect_uri: Url,
+) -> Result<Claims, Expect<VerifyJwtError>> {
+ let claims = verify_jwt(token, self_id, Some(client_id))?;
+
+ if let Some(claimed_uri) = &claims.redirect_uri {
+ if claimed_uri.clone() != redirect_uri {
+ yeet!(VerifyJwtError::IncorrectRedirectUri.into());
+ }
+ }
+
+ if db::delete_auth_code(db, claims.jti).await? {
+ db::delete_access_tokens_with_auth_code(db, claims.jti).await?;
+ db::revoke_refresh_tokens_with_auth_code(db, claims.jti).await?;
+ yeet!(VerifyJwtError::JwtRevoked.into());
+ }
+
+ Ok(claims)
+}
+
+pub async fn verify_access_token<'c>(
+ db: impl Executor<'c, Database = MySql>,
+ token: &str,
+ self_id: &Url,
+ client_id: Uuid,
+) -> Result<Claims, Expect<VerifyJwtError>> {
+ let claims = verify_jwt(token, self_id, Some(client_id))?;
+
+ if !db::access_token_exists(db, claims.jti).await? {
+ yeet!(VerifyJwtError::JwtRevoked.into())
+ }
+
+ Ok(claims)
+}
+
+pub async fn verify_refresh_token<'c>(
+ db: impl Executor<'c, Database = MySql>,
+ token: &str,
+ self_id: &Url,
+ client_id: Option<Uuid>,
+) -> Result<Claims, Expect<VerifyJwtError>> {
+ let claims = verify_jwt(token, self_id, client_id)?;
+
+ if db::refresh_token_revoked(db, claims.jti).await? {
+ yeet!(VerifyJwtError::JwtRevoked.into())
+ }
+
+ Ok(claims)
+}
diff --git a/src/services/mod.rs b/src/services/mod.rs index de08b58..4c69367 100644 --- a/src/services/mod.rs +++ b/src/services/mod.rs @@ -1,7 +1,7 @@ -pub mod authorization; -pub mod config; -pub mod crypto; -pub mod db; -pub mod id; -pub mod jwt; -pub mod secrets; +pub mod authorization;
+pub mod config;
+pub mod crypto;
+pub mod db;
+pub mod id;
+pub mod jwt;
+pub mod secrets;
diff --git a/src/services/secrets.rs b/src/services/secrets.rs index 241b2c5..e1d4992 100644 --- a/src/services/secrets.rs +++ b/src/services/secrets.rs @@ -1,24 +1,24 @@ -use std::env; - -use exun::*; -use hmac::{Hmac, Mac}; -use sha2::Sha256; - -/// This is a secret salt, needed for creating passwords. It's used as an extra -/// layer of security, on top of the salt that's already used. -pub fn pepper() -> Result<Box<[u8]>, RawUnexpected> { - let pepper = env::var("SECRET_SALT")?; - let pepper = hex::decode(pepper)?; - Ok(pepper.into_boxed_slice()) -} - -/// The URL to the MySQL database -pub fn database_url() -> Result<String, RawUnexpected> { - env::var("DATABASE_URL").unexpect() -} - -pub fn signing_key() -> Result<Hmac<Sha256>, RawUnexpected> { - let key = env::var("PRIVATE_KEY")?; - let key = Hmac::<Sha256>::new_from_slice(key.as_bytes())?; - Ok(key) -} +use std::env;
+
+use exun::*;
+use hmac::{Hmac, Mac};
+use sha2::Sha256;
+
+/// This is a secret salt, needed for creating passwords. It's used as an extra
+/// layer of security, on top of the salt that's already used.
+pub fn pepper() -> Result<Box<[u8]>, RawUnexpected> {
+ let pepper = env::var("SECRET_SALT")?;
+ let pepper = hex::decode(pepper)?;
+ Ok(pepper.into_boxed_slice())
+}
+
+/// The URL to the MySQL database
+pub fn database_url() -> Result<String, RawUnexpected> {
+ env::var("DATABASE_URL").unexpect()
+}
+
+pub fn signing_key() -> Result<Hmac<Sha256>, RawUnexpected> {
+ let key = env::var("PRIVATE_KEY")?;
+ let key = Hmac::<Sha256>::new_from_slice(key.as_bytes())?;
+ Ok(key)
+}
|
