diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/api/clients.rs | 311 | ||||
| -rw-r--r-- | src/api/mod.rs | 3 | ||||
| -rw-r--r-- | src/api/oauth.rs | 24 | ||||
| -rw-r--r-- | src/api/users.rs | 31 | ||||
| -rw-r--r-- | src/main.rs | 1 | ||||
| -rw-r--r-- | src/models/client.rs | 37 | ||||
| -rw-r--r-- | src/services/db.rs | 242 | ||||
| -rw-r--r-- | src/services/db/client.rs | 236 | ||||
| -rw-r--r-- | src/services/db/user.rs | 236 | ||||
| -rw-r--r-- | src/services/id.rs | 20 |
10 files changed, 878 insertions, 263 deletions
diff --git a/src/api/clients.rs b/src/api/clients.rs new file mode 100644 index 0000000..7e8ca35 --- /dev/null +++ b/src/api/clients.rs @@ -0,0 +1,311 @@ +use actix_web::http::{header, StatusCode}; +use actix_web::{get, post, put, web, HttpResponse, ResponseError, Scope}; +use raise::yeet; +use serde::Deserialize; +use sqlx::MySqlPool; +use thiserror::Error; +use url::Url; +use uuid::Uuid; + +use crate::models::client::{Client, ClientType, NoSecretError}; +use crate::services::crypto::PasswordHash; +use crate::services::{db, id}; + +#[derive(Debug, Clone, Copy, Error)] +#[error("No client with the given client ID was found")] +struct ClientNotFound { + id: Uuid, +} + +impl ResponseError for ClientNotFound { + fn status_code(&self) -> StatusCode { + StatusCode::NOT_FOUND + } +} + +impl ClientNotFound { + fn new(id: Uuid) -> Self { + Self { id } + } +} + +#[get("/{client_id}")] +async fn get_client( + client_id: web::Path<Uuid>, + db: web::Data<MySqlPool>, +) -> Result<HttpResponse, ClientNotFound> { + let db = db.as_ref(); + let id = *client_id; + + let Some(client) = db::get_client_response(db, id).await.unwrap() else { + yeet!(ClientNotFound::new(id)) + }; + + let redirect_uris_link = format!("</clients/{client_id}/redirect-uris>; rel=\"redirect-uris\""); + let response = HttpResponse::Ok() + .append_header((header::LINK, redirect_uris_link)) + .json(client); + Ok(response) +} + +#[get("/{client_id}/alias")] +async fn get_client_alias( + client_id: web::Path<Uuid>, + db: web::Data<MySqlPool>, +) -> Result<HttpResponse, ClientNotFound> { + let db = db.as_ref(); + let id = *client_id; + + let Some(alias) = db::get_client_alias(db, id).await.unwrap() else { + yeet!(ClientNotFound::new(id)) + }; + + Ok(HttpResponse::Ok().json(alias)) +} + +#[get("/{client_id}/type")] +async fn get_client_type( + client_id: web::Path<Uuid>, + db: web::Data<MySqlPool>, +) -> Result<HttpResponse, ClientNotFound> { + let db = db.as_ref(); + let id = *client_id; + + let Some(client_type) = db::get_client_type(db, id).await.unwrap() else { + yeet!(ClientNotFound::new(id)) + }; + + Ok(HttpResponse::Ok().json(client_type)) +} + +#[get("/{client_id}/redirect-uris")] +async fn get_client_redirect_uris( + client_id: web::Path<Uuid>, + db: web::Data<MySqlPool>, +) -> Result<HttpResponse, ClientNotFound> { + let db = db.as_ref(); + let id = *client_id; + + if !db::client_id_exists(db, id).await.unwrap() { + yeet!(ClientNotFound::new(id)) + }; + + let redirect_uris = db::get_client_redirect_uris(db, id).await.unwrap(); + + Ok(HttpResponse::Ok().json(redirect_uris)) +} + +#[derive(Debug, Clone, Deserialize)] +struct ClientRequest { + alias: Box<str>, + ty: ClientType, + redirect_uris: Box<[Url]>, + secret: Option<Box<str>>, +} + +#[derive(Debug, Clone, Error)] +#[error("The given client alias is already taken")] +struct AliasTakenError { + alias: Box<str>, +} + +impl ResponseError for AliasTakenError { + fn status_code(&self) -> StatusCode { + StatusCode::CONFLICT + } +} + +impl AliasTakenError { + fn new(alias: &str) -> Self { + Self { + alias: Box::from(alias), + } + } +} + +#[post("")] +async fn create_client( + body: web::Json<ClientRequest>, + db: web::Data<MySqlPool>, +) -> Result<HttpResponse, UpdateClientError> { + let db = db.get_ref(); + let alias = &body.alias; + + if db::client_alias_exists(db, &alias).await.unwrap() { + yeet!(AliasTakenError::new(&alias).into()); + } + + let id = id::new_id(db, db::client_id_exists).await.unwrap(); + let client = Client::new( + id, + &alias, + body.ty, + body.secret.as_deref(), + &body.redirect_uris, + ) + .map_err(|e| e.unwrap())?; + + let transaction = db.begin().await.unwrap(); + db::create_client(transaction, &client).await.unwrap(); + + let response = HttpResponse::Created() + .insert_header((header::LOCATION, format!("clients/{id}"))) + .finish(); + Ok(response) +} + +#[derive(Debug, Clone, Error)] +enum UpdateClientError { + #[error(transparent)] + NotFound(#[from] ClientNotFound), + #[error(transparent)] + NoSecret(#[from] NoSecretError), + #[error(transparent)] + AliasTaken(#[from] AliasTakenError), +} + +impl ResponseError for UpdateClientError { + fn status_code(&self) -> StatusCode { + match self { + Self::NotFound(e) => e.status_code(), + Self::NoSecret(e) => e.status_code(), + Self::AliasTaken(e) => e.status_code(), + } + } +} + +#[put("/{id}")] +async fn update_client( + id: web::Path<Uuid>, + body: web::Json<ClientRequest>, + db: web::Data<MySqlPool>, +) -> Result<HttpResponse, UpdateClientError> { + let db = db.get_ref(); + let id = *id; + let alias = &body.alias; + + let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else { + yeet!(ClientNotFound::new(id).into()) + }; + if old_alias != alias.clone() && db::client_alias_exists(db, &alias).await.unwrap() { + yeet!(AliasTakenError::new(&alias).into()); + } + + let client = Client::new( + id, + &alias, + body.ty, + body.secret.as_deref(), + &body.redirect_uris, + ) + .map_err(|e| e.unwrap())?; + + let transaction = db.begin().await.unwrap(); + db::update_client(transaction, &client).await.unwrap(); + + let response = HttpResponse::NoContent().finish(); + Ok(response) +} + +#[put("/{id}/alias")] +async fn update_client_alias( + id: web::Path<Uuid>, + body: web::Json<Box<str>>, + db: web::Data<MySqlPool>, +) -> Result<HttpResponse, UpdateClientError> { + let db = db.get_ref(); + let id = *id; + let alias = body.0; + + let Some(old_alias) = db::get_client_alias(db, id).await.unwrap() else { + yeet!(ClientNotFound::new(id).into()) + }; + if old_alias == alias { + return Ok(HttpResponse::NoContent().finish()); + } + if db::client_alias_exists(db, &alias).await.unwrap() { + yeet!(AliasTakenError::new(&alias).into()); + } + + db::update_client_alias(db, id, &alias).await.unwrap(); + + let response = HttpResponse::NoContent().finish(); + Ok(response) +} + +#[put("/{id}/type")] +async fn update_client_type( + id: web::Path<Uuid>, + body: web::Json<ClientType>, + db: web::Data<MySqlPool>, +) -> Result<HttpResponse, UpdateClientError> { + let db = db.get_ref(); + let id = *id; + let ty = body.0; + + if !db::client_id_exists(db, id).await.unwrap() { + yeet!(ClientNotFound::new(id).into()); + } + + db::update_client_type(db, id, ty).await.unwrap(); + + Ok(HttpResponse::NoContent().finish()) +} + +#[put("/{id}/redirect-uris")] +async fn update_client_redirect_uris( + id: web::Path<Uuid>, + body: web::Json<Box<[Url]>>, + db: web::Data<MySqlPool>, +) -> Result<HttpResponse, UpdateClientError> { + let db = db.get_ref(); + let id = *id; + + if !db::client_id_exists(db, id).await.unwrap() { + yeet!(ClientNotFound::new(id).into()); + } + + let transaction = db.begin().await.unwrap(); + db::update_client_redirect_uris(transaction, id, &body.0) + .await + .unwrap(); + + Ok(HttpResponse::NoContent().finish()) +} + +#[put("{id}/secret")] +async fn update_client_secret( + id: web::Path<Uuid>, + body: web::Json<Option<Box<str>>>, + db: web::Data<MySqlPool>, +) -> Result<HttpResponse, UpdateClientError> { + let db = db.get_ref(); + let id = *id; + + let Some(client_type) = db::get_client_type(db, id).await.unwrap() else { + yeet!(ClientNotFound::new(id).into()) + }; + + if client_type == ClientType::Confidential && body.is_none() { + yeet!(NoSecretError::new().into()) + } + + let secret = body.0.map(|s| PasswordHash::new(&s).unwrap()); + db::update_client_secret(db, id, secret).await.unwrap(); + + Ok(HttpResponse::NoContent().finish()) +} + +pub fn service() -> Scope { + web::scope("/clients") + .service(get_client) + .service(get_client_alias) + .service(get_client_type) + .service(get_client_redirect_uris) + .service(create_client) + .service(update_client) + .service(update_client_alias) + .service(update_client_type) + .service(update_client_redirect_uris) + .service(update_client_secret) +} diff --git a/src/api/mod.rs b/src/api/mod.rs index 7627a60..3d74be8 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,7 +1,10 @@ mod liveops; mod ops; mod users; +mod oauth; +mod clients; pub use liveops::service as liveops; pub use ops::service as ops; pub use users::service as users; +pub use clients::service as clients; diff --git a/src/api/oauth.rs b/src/api/oauth.rs new file mode 100644 index 0000000..9e0e5c6 --- /dev/null +++ b/src/api/oauth.rs @@ -0,0 +1,24 @@ +use std::collections::HashMap; + +use actix_web::{web, HttpResponse}; +use serde::Deserialize; +use url::Url; +use uuid::Uuid; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)] +#[serde(rename_all = "snake_case")] +enum ResponseType { + Code, + Token, +} + +#[derive(Debug, Clone, Deserialize)] +struct AuthorizationParameters { + response_type: ResponseType, + client_id: Uuid, + redirect_uri: Url, + state: Box<str>, + + #[serde(flatten)] + additional_parameters: HashMap<Box<str>, Box<str>>, +} diff --git a/src/api/users.rs b/src/api/users.rs index 2b67663..2cd70c0 100644 --- a/src/api/users.rs +++ b/src/api/users.rs @@ -12,6 +12,7 @@ use crate::services::{db, id}; /// Just a username. No password hash, because that'd be tempting fate. #[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] struct UserResponse { id: Uuid, username: Box<str>, @@ -27,6 +28,7 @@ impl From<User> for UserResponse { } #[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] struct SearchUsers { username: Option<Box<str>>, limit: Option<u32>, @@ -82,14 +84,14 @@ async fn get_user( ) -> Result<HttpResponse, UserNotFoundError> { let conn = conn.get_ref(); - let user_id = user_id.to_owned(); - let user = db::get_user(conn, user_id).await.unwrap(); + let id = user_id.to_owned(); + let username = db::get_username(conn, id).await.unwrap(); - let Some(user) = user else { - yeet!(UserNotFoundError {user_id}); + let Some(username) = username else { + yeet!(UserNotFoundError { user_id: id }); }; - let response: UserResponse = user.into(); + let response = UserResponse { id, username }; let response = HttpResponse::Ok().json(response); Ok(response) } @@ -114,6 +116,7 @@ async fn get_username( /// A request to create or update user information #[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] struct UserRequest { username: Box<str>, password: Box<str>, @@ -138,7 +141,7 @@ async fn create_user( ) -> Result<HttpResponse, UsernameTakenError> { let conn = conn.get_ref(); - let user_id = id::new_user_id(conn).await.unwrap(); + let user_id = id::new_id(conn, db::user_id_exists).await.unwrap(); let username = body.username.clone(); let password = PasswordHash::new(&body.password).unwrap(); @@ -152,7 +155,7 @@ async fn create_user( password, }; - db::new_user(conn, &user).await.unwrap(); + db::create_user(conn, &user).await.unwrap(); let response = HttpResponse::Created() .insert_header((header::LOCATION, format!("users/{user_id}"))) @@ -171,8 +174,8 @@ enum UpdateUserError { impl ResponseError for UpdateUserError { fn status_code(&self) -> StatusCode { match self { - Self::UsernameTaken(..) => StatusCode::CONFLICT, - Self::NotFound(..) => StatusCode::NOT_FOUND, + Self::UsernameTaken(e) => e.status_code(), + Self::NotFound(e) => e.status_code(), } } } @@ -206,10 +209,7 @@ async fn update_user( db::update_user(conn, &user).await.unwrap(); - let response = HttpResponse::NoContent() - .insert_header((header::LOCATION, format!("users/{user_id}"))) - .finish(); - + let response = HttpResponse::NoContent().finish(); Ok(response) } @@ -235,10 +235,7 @@ async fn update_username( db::update_username(conn, user_id, &body).await.unwrap(); - let response = HttpResponse::NoContent() - .insert_header((header::LOCATION, format!("users/{user_id}/username"))) - .finish(); - + let response = HttpResponse::NoContent().finish(); Ok(response) } diff --git a/src/main.rs b/src/main.rs index 7b25dd1..aca5977 100644 --- a/src/main.rs +++ b/src/main.rs @@ -56,6 +56,7 @@ async fn main() -> Result<(), RawUnexpected> { // api services .service(api::liveops()) .service(api::users()) + .service(api::clients()) .service(api::ops()) }) .shutdown_timeout(1) diff --git a/src/models/client.rs b/src/models/client.rs index a7df936..44079de 100644 --- a/src/models/client.rs +++ b/src/models/client.rs @@ -1,7 +1,10 @@ use std::{hash::Hash, marker::PhantomData}; +use actix_web::{http::StatusCode, ResponseError}; use exun::{Expect, RawUnexpected}; use raise::yeet; +use serde::{Deserialize, Serialize}; +use sqlx::FromRow; use thiserror::Error; use url::Url; use uuid::Uuid; @@ -10,8 +13,9 @@ use crate::services::crypto::PasswordHash; /// There are two types of clients, based on their ability to maintain the /// security of their client credentials. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, sqlx::Type)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, sqlx::Type)] #[sqlx(rename_all = "lowercase")] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum ClientType { /// A client that is capable of maintaining the confidentiality of their /// credentials, or capable of secure client authentication using other @@ -26,12 +30,21 @@ pub enum ClientType { #[derive(Debug, Clone)] pub struct Client { - ty: ClientType, id: Uuid, + ty: ClientType, + alias: Box<str>, secret: Option<PasswordHash>, redirect_uris: Box<[Url]>, } +#[derive(Debug, Clone, Serialize, FromRow)] +#[serde(rename_all = "camelCase")] +pub struct ClientResponse { + pub id: Uuid, + pub alias: String, + pub client_type: ClientType, +} + impl PartialEq for Client { fn eq(&self, other: &Self) -> bool { self.id == other.id @@ -52,8 +65,14 @@ pub struct NoSecretError { _phantom: PhantomData<()>, } +impl ResponseError for NoSecretError { + fn status_code(&self) -> StatusCode { + StatusCode::BAD_REQUEST + } +} + impl NoSecretError { - fn new() -> Self { + pub(crate) fn new() -> Self { Self { _phantom: PhantomData, } @@ -61,8 +80,9 @@ impl NoSecretError { } impl Client { - pub fn new_public( + pub fn new( id: Uuid, + alias: &str, ty: ClientType, secret: Option<&str>, redirect_uris: &[Url], @@ -79,6 +99,7 @@ impl Client { Ok(Self { id, + alias: Box::from(alias), ty: ClientType::Public, secret, redirect_uris: redirect_uris.into_iter().cloned().collect(), @@ -89,10 +110,18 @@ impl Client { self.id } + pub fn alias(&self) -> &str { + &self.alias + } + pub fn client_type(&self) -> ClientType { self.ty } + pub fn redirect_uris(&self) -> &[Url] { + &self.redirect_uris + } + pub fn secret_hash(&self) -> Option<&[u8]> { self.secret.as_ref().map(|s| s.hash()) } diff --git a/src/services/db.rs b/src/services/db.rs index 79df260..9789e51 100644 --- a/src/services/db.rs +++ b/src/services/db.rs @@ -1,243 +1,13 @@ -use exun::*; -use sqlx::{mysql::MySqlQueryResult, query, query_as, query_scalar, Executor, MySql, MySqlPool}; -use uuid::Uuid; +use exun::{RawUnexpected, ResultErrorExt}; +use sqlx::MySqlPool; -use crate::models::user::User; +mod client; +mod user; -use super::crypto::PasswordHash; - -struct UserRow { - id: Vec<u8>, - username: String, - password_hash: Vec<u8>, - password_salt: Vec<u8>, - password_version: u32, -} - -impl TryFrom<UserRow> for User { - type Error = RawUnexpected; - - fn try_from(row: UserRow) -> Result<Self, Self::Error> { - let password = PasswordHash::from_fields( - &row.password_hash, - &row.password_salt, - row.password_version as u8, - ); - let user = User { - id: Uuid::from_slice(&row.id)?, - username: row.username.into_boxed_str(), - password, - }; - Ok(user) - } -} +pub use client::*; +pub use user::*; /// Intialize the connection pool pub async fn initialize(db_url: &str) -> Result<MySqlPool, RawUnexpected> { MySqlPool::connect(db_url).await.unexpect() } - -/// Check if a user with a given user ID exists -pub async fn user_id_exists<'c>( - conn: impl Executor<'c, Database = MySql>, - id: Uuid, -) -> Result<bool, RawUnexpected> { - let exists = query_scalar!( - r#"SELECT EXISTS(SELECT id FROM users WHERE id = ?) as "e: bool""#, - id - ) - .fetch_one(conn) - .await?; - - Ok(exists) -} - -/// Check if a given username is taken -pub async fn username_is_used<'c>( - conn: impl Executor<'c, Database = MySql>, - username: &str, -) -> Result<bool, RawUnexpected> { - let exists = query_scalar!( - r#"SELECT EXISTS(SELECT id FROM users WHERE username = ?) as "e: bool""#, - username - ) - .fetch_one(conn) - .await?; - - Ok(exists) -} - -/// Get a user from their ID -pub async fn get_user<'c>( - conn: impl Executor<'c, Database = MySql>, - user_id: Uuid, -) -> Result<Option<User>, RawUnexpected> { - let record = query_as!( - UserRow, - r"SELECT id, username, password_hash, password_salt, password_version - FROM users WHERE id = ?", - user_id - ) - .fetch_optional(conn) - .await?; - - let Some(record) = record else { return Ok(None) }; - - Ok(Some(record.try_into()?)) -} - -/// Get a user from their username -pub async fn get_user_by_username<'c>( - conn: impl Executor<'c, Database = MySql>, - username: &str, -) -> Result<Option<User>, RawUnexpected> { - let record = query_as!( - UserRow, - r"SELECT id, username, password_hash, password_salt, password_version - FROM users WHERE username = ?", - username - ) - .fetch_optional(conn) - .await?; - - let Some(record) = record else { return Ok(None) }; - - Ok(Some(record.try_into()?)) -} - -/// Search the list of users for a given username -pub async fn search_users<'c>( - conn: impl Executor<'c, Database = MySql>, - username: &str, -) -> Result<Box<[User]>, RawUnexpected> { - let records = query_as!( - UserRow, - r"SELECT id, username, password_hash, password_salt, password_version - FROM users - WHERE LOCATE(?, username) != 0", - username, - ) - .fetch_all(conn) - .await?; - - Ok(records - .into_iter() - .map(|u| u.try_into()) - .collect::<Result<Box<[User]>, RawUnexpected>>()?) -} - -/// Search the list of users, only returning a certain range of results -pub async fn search_users_limit<'c>( - conn: impl Executor<'c, Database = MySql>, - username: &str, - offset: u32, - limit: u32, -) -> Result<Box<[User]>, RawUnexpected> { - let records = query_as!( - UserRow, - r"SELECT id, username, password_hash, password_salt, password_version - FROM users - WHERE LOCATE(?, username) != 0 - LIMIT ? - OFFSET ?", - username, - offset, - limit - ) - .fetch_all(conn) - .await?; - - Ok(records - .into_iter() - .map(|u| u.try_into()) - .collect::<Result<Box<[User]>, RawUnexpected>>()?) -} - -/// Get the username of a user with a certain ID -pub async fn get_username<'c>( - conn: impl Executor<'c, Database = MySql>, - user_id: Uuid, -) -> Result<Option<Box<str>>, RawUnexpected> { - let username = query_scalar!(r"SELECT username FROM users where id = ?", user_id) - .fetch_optional(conn) - .await? - .map(String::into_boxed_str); - - Ok(username) -} - -/// Create a new user -pub async fn new_user<'c>( - conn: impl Executor<'c, Database = MySql>, - user: &User, -) -> Result<MySqlQueryResult, sqlx::Error> { - query!( - r"INSERT INTO users (id, username, password_hash, password_salt, password_version) - VALUES (?, ?, ?, ?, ?)", - user.id, - user.username(), - user.password_hash(), - user.password_salt(), - user.password_version() - ) - .execute(conn) - .await -} - -/// Update a user -pub async fn update_user<'c>( - conn: impl Executor<'c, Database = MySql>, - user: &User, -) -> Result<MySqlQueryResult, sqlx::Error> { - query!( - r"UPDATE users SET - username = ?, - password_hash = ?, - password_salt = ?, - password_version = ? - WHERE id = ?", - user.username(), - user.password_hash(), - user.password_salt(), - user.password_version(), - user.id - ) - .execute(conn) - .await -} - -/// Update the username of a user with the given ID -pub async fn update_username<'c>( - conn: impl Executor<'c, Database = MySql>, - user_id: Uuid, - username: &str, -) -> Result<MySqlQueryResult, sqlx::Error> { - query!( - r"UPDATE users SET username = ? WHERE id = ?", - username, - user_id - ) - .execute(conn) - .await -} - -/// Update the password of a user with the given ID -pub async fn update_password<'c>( - conn: impl Executor<'c, Database = MySql>, - user_id: Uuid, - password: &PasswordHash, -) -> Result<MySqlQueryResult, sqlx::Error> { - query!( - r"UPDATE users SET - password_hash = ?, - password_salt = ?, - password_version = ? - WHERE id = ?", - password.hash(), - password.salt(), - password.version(), - user_id - ) - .execute(conn) - .await -} diff --git a/src/services/db/client.rs b/src/services/db/client.rs new file mode 100644 index 0000000..d1531be --- /dev/null +++ b/src/services/db/client.rs @@ -0,0 +1,236 @@ +use std::str::FromStr; + +use exun::{RawUnexpected, ResultErrorExt}; +use sqlx::{mysql::MySqlQueryResult, query, query_as, query_scalar, Executor, MySql, Transaction}; +use url::Url; +use uuid::Uuid; + +use crate::{ + models::client::{Client, ClientResponse, ClientType}, + services::crypto::PasswordHash, +}; + +pub async fn client_id_exists<'c>( + executor: impl Executor<'c, Database = MySql>, + id: Uuid, +) -> Result<bool, RawUnexpected> { + query_scalar!( + r"SELECT EXISTS(SELECT id FROM clients WHERE id = ?) as `e: bool`", + id + ) + .fetch_one(executor) + .await + .unexpect() +} + +pub async fn client_alias_exists<'c>( + executor: impl Executor<'c, Database = MySql>, + alias: &str, +) -> Result<bool, RawUnexpected> { + query_scalar!( + "SELECT EXISTS(SELECT alias FROM clients WHERE alias = ?) as `e: bool`", + alias + ) + .fetch_one(executor) + .await + .unexpect() +} + +pub async fn get_client_response<'c>( + executor: impl Executor<'c, Database = MySql>, + id: Uuid, +) -> Result<Option<ClientResponse>, RawUnexpected> { + let record = query_as!( + ClientResponse, + r"SELECT id as `id: Uuid`, + alias, + type as `client_type: ClientType` + FROM clients WHERE id = ?", + id + ) + .fetch_optional(executor) + .await?; + + Ok(record) +} + +pub async fn get_client_alias<'c>( + executor: impl Executor<'c, Database = MySql>, + id: Uuid, +) -> Result<Option<Box<str>>, RawUnexpected> { + let alias = query_scalar!("SELECT alias FROM clients WHERE id = ?", id) + .fetch_optional(executor) + .await + .unexpect()?; + + Ok(alias.map(String::into_boxed_str)) +} + +pub async fn get_client_type<'c>( + executor: impl Executor<'c, Database = MySql>, + id: Uuid, +) -> Result<Option<ClientType>, RawUnexpected> { + let ty = query_scalar!( + "SELECT type as `type: ClientType` FROM clients WHERE id = ?", + id + ) + .fetch_optional(executor) + .await + .unexpect()?; + + Ok(ty) +} + +pub async fn get_client_redirect_uris<'c>( + executor: impl Executor<'c, Database = MySql>, + id: Uuid, +) -> Result<Box<[Url]>, RawUnexpected> { + let uris = query_scalar!( + "SELECT redirect_uri FROM client_redirect_uris WHERE client_id = ?", + id + ) + .fetch_all(executor) + .await + .unexpect()?; + + uris.into_iter() + .map(|s| Url::from_str(&s).unexpect()) + .collect() +} + +async fn delete_client_redirect_uris<'c>( + executor: impl Executor<'c, Database = MySql>, + id: Uuid, +) -> Result<(), sqlx::Error> { + query!("DELETE FROM client_redirect_uris WHERE client_id = ?", id) + .execute(executor) + .await?; + Ok(()) +} + +async fn create_client_redirect_uris<'c>( + mut transaction: Transaction<'c, MySql>, + client_id: Uuid, + uris: &[Url], +) -> Result<(), sqlx::Error> { + for uri in uris { + query!( + r"INSERT INTO client_redirect_uris (client_id, redirect_uri) + VALUES ( ?, ?)", + client_id, + uri.to_string() + ) + .execute(&mut transaction) + .await?; + } + + transaction.commit().await?; + + Ok(()) +} + +pub async fn create_client<'c>( + mut transaction: Transaction<'c, MySql>, + client: &Client, +) -> Result<(), sqlx::Error> { + query!( + r"INSERT INTO clients (id, alias, type, secret_hash, secret_salt, secret_version) + VALUES ( ?, ?, ?, ?, ?, ?)", + client.id(), + client.alias(), + client.client_type(), + client.secret_hash(), + client.secret_salt(), + client.secret_version(), + ) + .execute(&mut transaction) + .await?; + + create_client_redirect_uris(transaction, client.id(), client.redirect_uris()).await?; + + Ok(()) +} + +pub async fn update_client<'c>( + mut transaction: Transaction<'c, MySql>, + client: &Client, +) -> Result<(), sqlx::Error> { + query!( + r"UPDATE clients SET + alias = ?, + type = ?, + secret_hash = ?, + secret_salt = ?, + secret_version = ? + WHERE id = ?", + client.client_type(), + client.alias(), + client.secret_hash(), + client.secret_salt(), + client.secret_version(), + client.id() + ) + .execute(&mut transaction) + .await?; + + update_client_redirect_uris(transaction, client.id(), client.redirect_uris()).await?; + + Ok(()) +} + +pub async fn update_client_alias<'c>( + executor: impl Executor<'c, Database = MySql>, + id: Uuid, + alias: &str, +) -> Result<MySqlQueryResult, sqlx::Error> { + query!("UPDATE clients SET alias = ? WHERE id = ?", alias, id) + .execute(executor) + .await +} + +pub async fn update_client_type<'c>( + executor: impl Executor<'c, Database = MySql>, + id: Uuid, + ty: ClientType, +) -> Result<MySqlQueryResult, sqlx::Error> { + query!("UPDATE clients SET type = ? WHERE id = ?", ty, id) + .execute(executor) + .await +} + +pub async fn update_client_redirect_uris<'c>( + mut transaction: Transaction<'c, MySql>, + id: Uuid, + uris: &[Url], +) -> Result<(), sqlx::Error> { + delete_client_redirect_uris(&mut transaction, id).await?; + create_client_redirect_uris(transaction, id, uris).await?; + Ok(()) +} + +pub async fn update_client_secret<'c>( + executor: impl Executor<'c, Database = MySql>, + id: Uuid, + secret: Option<PasswordHash>, +) -> Result<MySqlQueryResult, sqlx::Error> { + if let Some(secret) = secret { + query!( + "UPDATE clients SET secret_hash = ?, secret_salt = ?, secret_version = ? WHERE id = ?", + secret.hash(), + secret.salt(), + secret.version(), + id + ) + .execute(executor) + .await + } else { + query!( + r"UPDATE clients + SET secret_hash = NULL, secret_salt = NULL, secret_version = NULL + WHERE id = ?", + id + ) + .execute(executor) + .await + } +} diff --git a/src/services/db/user.rs b/src/services/db/user.rs new file mode 100644 index 0000000..09a09da --- /dev/null +++ b/src/services/db/user.rs @@ -0,0 +1,236 @@ +use exun::RawUnexpected; +use sqlx::{mysql::MySqlQueryResult, query, query_as, query_scalar, Executor, MySql}; +use uuid::Uuid; + +use crate::{models::user::User, services::crypto::PasswordHash}; + +struct UserRow { + id: Uuid, + username: String, + password_hash: Vec<u8>, + password_salt: Vec<u8>, + password_version: u32, +} + +impl TryFrom<UserRow> for User { + type Error = RawUnexpected; + + fn try_from(row: UserRow) -> Result<Self, Self::Error> { + let password = PasswordHash::from_fields( + &row.password_hash, + &row.password_salt, + row.password_version as u8, + ); + let user = User { + id: row.id, + username: row.username.into_boxed_str(), + password, + }; + Ok(user) + } +} + +/// Check if a user with a given user ID exists +pub async fn user_id_exists<'c>( + conn: impl Executor<'c, Database = MySql>, + id: Uuid, +) -> Result<bool, RawUnexpected> { + let exists = query_scalar!( + r#"SELECT EXISTS(SELECT id FROM users WHERE id = ?) as `e: bool`"#, + id + ) + .fetch_one(conn) + .await?; + + Ok(exists) +} + +/// Check if a given username is taken +pub async fn username_is_used<'c>( + conn: impl Executor<'c, Database = MySql>, + username: &str, +) -> Result<bool, RawUnexpected> { + let exists = query_scalar!( + r#"SELECT EXISTS(SELECT id FROM users WHERE username = ?) as "e: bool""#, + username + ) + .fetch_one(conn) + .await?; + + Ok(exists) +} + +/// Get a user from their ID +pub async fn get_user<'c>( + conn: impl Executor<'c, Database = MySql>, + user_id: Uuid, +) -> Result<Option<User>, RawUnexpected> { + let record = query_as!( + UserRow, + r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version + FROM users WHERE id = ?", + user_id + ) + .fetch_optional(conn) + .await?; + + let Some(record) = record else { return Ok(None) }; + + Ok(Some(record.try_into()?)) +} + +/// Get a user from their username +pub async fn get_user_by_username<'c>( + conn: impl Executor<'c, Database = MySql>, + username: &str, +) -> Result<Option<User>, RawUnexpected> { + let record = query_as!( + UserRow, + r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version + FROM users WHERE username = ?", + username + ) + .fetch_optional(conn) + .await?; + + let Some(record) = record else { return Ok(None) }; + + Ok(Some(record.try_into()?)) +} + +/// Search the list of users for a given username +pub async fn search_users<'c>( + conn: impl Executor<'c, Database = MySql>, + username: &str, +) -> Result<Box<[User]>, RawUnexpected> { + let records = query_as!( + UserRow, + r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version + FROM users + WHERE LOCATE(?, username) != 0", + username, + ) + .fetch_all(conn) + .await?; + + Ok(records + .into_iter() + .map(|u| u.try_into()) + .collect::<Result<Box<[User]>, RawUnexpected>>()?) +} + +/// Search the list of users, only returning a certain range of results +pub async fn search_users_limit<'c>( + conn: impl Executor<'c, Database = MySql>, + username: &str, + offset: u32, + limit: u32, +) -> Result<Box<[User]>, RawUnexpected> { + let records = query_as!( + UserRow, + r"SELECT id as `id: Uuid`, username, password_hash, password_salt, password_version + FROM users + WHERE LOCATE(?, username) != 0 + LIMIT ? + OFFSET ?", + username, + offset, + limit + ) + .fetch_all(conn) + .await?; + + Ok(records + .into_iter() + .map(|u| u.try_into()) + .collect::<Result<Box<[User]>, RawUnexpected>>()?) +} + +/// Get the username of a user with a certain ID +pub async fn get_username<'c>( + conn: impl Executor<'c, Database = MySql>, + user_id: Uuid, +) -> Result<Option<Box<str>>, RawUnexpected> { + let username = query_scalar!(r"SELECT username FROM users where id = ?", user_id) + .fetch_optional(conn) + .await? + .map(String::into_boxed_str); + + Ok(username) +} + +/// Create a new user +pub async fn create_user<'c>( + conn: impl Executor<'c, Database = MySql>, + user: &User, +) -> Result<MySqlQueryResult, sqlx::Error> { + query!( + r"INSERT INTO users (id, username, password_hash, password_salt, password_version) + VALUES ( ?, ?, ?, ?, ?)", + user.id, + user.username(), + user.password_hash(), + user.password_salt(), + user.password_version() + ) + .execute(conn) + .await +} + +/// Update a user +pub async fn update_user<'c>( + conn: impl Executor<'c, Database = MySql>, + user: &User, +) -> Result<MySqlQueryResult, sqlx::Error> { + query!( + r"UPDATE users SET + username = ?, + password_hash = ?, + password_salt = ?, + password_version = ? + WHERE id = ?", + user.username(), + user.password_hash(), + user.password_salt(), + user.password_version(), + user.id + ) + .execute(conn) + .await +} + +/// Update the username of a user with the given ID +pub async fn update_username<'c>( + conn: impl Executor<'c, Database = MySql>, + user_id: Uuid, + username: &str, +) -> Result<MySqlQueryResult, sqlx::Error> { + query!( + r"UPDATE users SET username = ? WHERE id = ?", + username, + user_id + ) + .execute(conn) + .await +} + +/// Update the password of a user with the given ID +pub async fn update_password<'c>( + conn: impl Executor<'c, Database = MySql>, + user_id: Uuid, + password: &PasswordHash, +) -> Result<MySqlQueryResult, sqlx::Error> { + query!( + r"UPDATE users SET + password_hash = ?, + password_salt = ?, + password_version = ? + WHERE id = ?", + password.hash(), + password.salt(), + password.version(), + user_id + ) + .execute(conn) + .await +} diff --git a/src/services/id.rs b/src/services/id.rs index 7970c60..0c665ed 100644 --- a/src/services/id.rs +++ b/src/services/id.rs @@ -1,16 +1,24 @@ +use std::future::Future; + use exun::RawUnexpected; use sqlx::{Executor, MySql}; use uuid::Uuid; -use super::db; - -/// Create a unique user id, handling duplicate ID's -pub async fn new_user_id<'c>( - conn: impl Executor<'c, Database = MySql> + Clone, +/// Create a unique id, handling duplicate ID's. +/// +/// The given `unique_check` parameter returns `true` if the ID is used and +/// `false` otherwise. +pub async fn new_id< + 'c, + E: Executor<'c, Database = MySql> + Clone, + F: Future<Output = Result<bool, RawUnexpected>>, +>( + conn: E, + unique_check: impl Fn(E, Uuid) -> F, ) -> Result<Uuid, RawUnexpected> { let uuid = loop { let uuid = Uuid::new_v4(); - if !db::user_id_exists(conn.clone(), uuid).await? { + if !unique_check(conn.clone(), uuid).await? { break uuid; } }; |
