summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authormrw1593 <botahamec@outlook.com>2023-06-30 19:27:33 -0400
committermrw1593 <botahamec@outlook.com>2023-06-30 19:27:33 -0400
commit55cfb8187cb814e17a2a99d02bfd9296fc01dcc2 (patch)
treec5f7ed60c8a814addd60b1cfb843fb9a107f1458 /src
parent9058b01d6c0e3d1e9e485a537258a312ccfc841c (diff)
Added config file
Diffstat (limited to 'src')
-rw-r--r--src/api/oauth.rs28
-rw-r--r--src/main.rs17
-rw-r--r--src/services/config.rs74
-rw-r--r--src/services/jwt.rs20
-rw-r--r--src/services/mod.rs1
5 files changed, 118 insertions, 22 deletions
diff --git a/src/api/oauth.rs b/src/api/oauth.rs
index ef40637..fe1c361 100644
--- a/src/api/oauth.rs
+++ b/src/api/oauth.rs
@@ -20,7 +20,7 @@ use crate::models::client::ClientType;
use crate::resources::{languages, templates};
use crate::scopes;
use crate::services::jwt::VerifyJwtError;
-use crate::services::{authorization, db, jwt};
+use crate::services::{authorization, config, db, jwt};
const REALLY_BAD_ERROR_PAGE: &str = "<!DOCTYPE html><html><head><title>Internal Server Error</title></head><body>Internal Server Error</body></html>";
@@ -243,7 +243,12 @@ async fn authorize(
let page = error_page(&tera, &translations, templates::ErrorPage::ClientNotFound).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
return HttpResponse::NotFound().content_type("text/html").body(page);
};
- let self_id = Url::parse("www.google.com").unwrap(); // TODO find the actual value
+ let Ok(config) = config::get_config() else {
+ let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return HttpResponse::InternalServerError().content_type("text/html").body(page);
+ };
+
+ let self_id = config.id;
let state = req.state.clone();
// get redirect uri
@@ -284,7 +289,7 @@ async fn authorize(
match req.response_type {
ResponseType::Code => {
// create auth code
- let code = jwt::Claims::auth_code(db, self_id, client_id, &scope, &redirect_uri)
+ let code = jwt::Claims::auth_code(db, &self_id, client_id, &scope, &redirect_uri)
.await
.unwrap();
let code = code.to_jwt().unwrap();
@@ -302,7 +307,7 @@ async fn authorize(
// create access token
let duration = Duration::hours(1);
let access_token =
- jwt::Claims::access_token(db, None, self_id, client_id, duration, &scope)
+ jwt::Claims::access_token(db, None, &self_id, client_id, duration, &scope)
.await
.unwrap();
@@ -628,8 +633,9 @@ async fn token(
let Ok(request) = request else {
return TokenError::invalid_request().error_response();
};
+ let config = config::get_config().unwrap();
- let self_id = Url::parse("www.google.com").unwrap(); // TODO find the actual value
+ let self_id = config.id;
let duration = Duration::hours(1);
let token_type = Box::from("bearer");
let cache_control = header::CacheControl(vec![header::CacheDirective::NoStore]);
@@ -646,9 +652,7 @@ async fn token(
// validate auth code
let claims =
- match jwt::verify_auth_code(db, &code, self_id.clone(), client_id, redirect_uri)
- .await
- {
+ match jwt::verify_auth_code(db, &code, &self_id, client_id, redirect_uri).await {
Ok(claims) => claims,
Err(err) => {
let err = err.unwrap();
@@ -673,7 +677,7 @@ async fn token(
let access_token = jwt::Claims::access_token(
db,
Some(claims.id()),
- self_id,
+ &self_id,
client_id,
duration,
claims.scopes(),
@@ -751,7 +755,7 @@ async fn token(
}
let access_token =
- jwt::Claims::access_token(db, None, self_id, client_id, duration, &scope)
+ jwt::Claims::access_token(db, None, &self_id, client_id, duration, &scope)
.await
.unwrap();
let refresh_token = jwt::Claims::refresh_token(db, &access_token).await.unwrap();
@@ -815,7 +819,7 @@ async fn token(
}
let access_token =
- jwt::Claims::access_token(db, None, self_id, client_id, duration, &scope)
+ jwt::Claims::access_token(db, None, &self_id, client_id, duration, &scope)
.await
.unwrap();
@@ -851,7 +855,7 @@ async fn token(
}
let claims =
- match jwt::verify_refresh_token(db, &refresh_token, self_id, client_id).await {
+ match jwt::verify_refresh_token(db, &refresh_token, &self_id, client_id).await {
Ok(claims) => claims,
Err(e) => {
let e = e.unwrap();
diff --git a/src/main.rs b/src/main.rs
index da740be..e946161 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -5,6 +5,7 @@ use actix_web::middleware::{ErrorHandlerResponse, ErrorHandlers, Logger, Normali
use actix_web::web::Data;
use actix_web::{dev, App, HttpServer};
+use bpaf::Bpaf;
use exun::*;
mod api;
@@ -44,12 +45,28 @@ async fn delete_expired_tokens(db: MySqlPool) {
}
}
+#[derive(Debug, Clone, Bpaf)]
+#[bpaf(options, version)]
+struct Opts {
+ /// The environment that the server is running in. Must be one of: local,
+ /// dev, staging, prod.
+ #[bpaf(
+ env("LOCKDAGGER_ENVIRONMENT"),
+ fallback(config::Environment::Local),
+ display_fallback
+ )]
+ env: config::Environment,
+}
+
#[actix_web::main]
async fn main() -> Result<(), RawUnexpected> {
// load the environment file, but only in debug mode
#[cfg(debug_assertions)]
dotenv::dotenv()?;
+ let args = opts().run();
+ config::set_environment(args.env);
+
// initialize the database
let db_url = secrets::database_url()?;
let sql_pool = db::initialize(&db_url).await?;
diff --git a/src/services/config.rs b/src/services/config.rs
new file mode 100644
index 0000000..6468126
--- /dev/null
+++ b/src/services/config.rs
@@ -0,0 +1,74 @@
+use std::{
+ fmt::{self, Display},
+ str::FromStr,
+};
+
+use exun::RawUnexpected;
+use parking_lot::RwLock;
+use serde::Deserialize;
+use thiserror::Error;
+use url::Url;
+
+static ENVIRONMENT: RwLock<Environment> = RwLock::new(Environment::Local);
+
+#[derive(Debug, Clone, Deserialize)]
+pub struct Config {
+ pub id: Box<str>,
+ pub url: Url,
+}
+
+pub fn get_config() -> Result<Config, RawUnexpected> {
+ let env = get_environment();
+ let path = format!("static/config/{env}.toml");
+ let string = std::fs::read_to_string(path)?;
+ let config = toml::from_str(&string)?;
+ Ok(config)
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
+pub enum Environment {
+ Local,
+ Dev,
+ Staging,
+ Production,
+}
+
+impl Display for Environment {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match self {
+ Self::Local => f.write_str("local"),
+ Self::Dev => f.write_str("dev"),
+ Self::Staging => f.write_str("staging"),
+ Self::Production => f.write_str("prod"),
+ }
+ }
+}
+
+#[derive(Debug, Clone, Error)]
+#[error("Expected one of the following environments: local, dev, staging, prod. Found {string}")]
+pub struct ParseEnvironmentError {
+ string: Box<str>,
+}
+
+impl FromStr for Environment {
+ type Err = ParseEnvironmentError;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ match s {
+ "local" => Ok(Self::Local),
+ "dev" => Ok(Self::Dev),
+ "staging" => Ok(Self::Staging),
+ "prod" => Ok(Self::Production),
+ _ => Err(ParseEnvironmentError { string: s.into() }),
+ }
+ }
+}
+
+pub fn set_environment(env: Environment) {
+ let mut env_ptr = ENVIRONMENT.write();
+ *env_ptr = env;
+}
+
+fn get_environment() -> Environment {
+ ENVIRONMENT.read().clone()
+}
diff --git a/src/services/jwt.rs b/src/services/jwt.rs
index 488e0ac..86252c4 100644
--- a/src/services/jwt.rs
+++ b/src/services/jwt.rs
@@ -19,7 +19,7 @@ pub enum TokenType {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Claims {
- iss: Url,
+ iss: Box<str>,
aud: Option<Box<[String]>>,
#[serde(with = "ts_milliseconds")]
exp: DateTime<Utc>,
@@ -45,7 +45,7 @@ pub enum RevokedRefreshTokenReason {
impl Claims {
pub async fn auth_code<'c>(
db: &MySqlPool,
- self_id: Url,
+ self_id: &str,
client_id: Uuid,
scopes: &str,
redirect_uri: &Url,
@@ -59,7 +59,7 @@ impl Claims {
db::create_auth_code(db, id, exp).await?;
Ok(Self {
- iss: self_id,
+ iss: Box::from(self_id),
aud: None,
exp,
nbf: None,
@@ -76,7 +76,7 @@ impl Claims {
pub async fn access_token<'c>(
db: &MySqlPool,
auth_code_id: Option<Uuid>,
- self_id: Url,
+ self_id: &str,
client_id: Uuid,
duration: Duration,
scopes: &str,
@@ -90,7 +90,7 @@ impl Claims {
.unexpect()?;
Ok(Self {
- iss: self_id,
+ iss: Box::from(self_id),
aud: None,
exp,
nbf: None,
@@ -186,7 +186,7 @@ pub enum VerifyJwtError {
fn verify_jwt(
token: &str,
- self_id: Url,
+ self_id: &str,
client_id: Option<Uuid>,
) -> Result<Claims, Expect<VerifyJwtError>> {
let key = secrets::signing_key()?;
@@ -194,7 +194,7 @@ fn verify_jwt(
.verify_with_key(&key)
.map_err(|e| VerifyJwtError::from(e))?;
- if claims.iss != self_id {
+ if claims.iss != self_id.into() {
yeet!(VerifyJwtError::IncorrectIssuer.into())
}
@@ -228,7 +228,7 @@ fn verify_jwt(
pub async fn verify_auth_code<'c>(
db: &MySqlPool,
token: &str,
- self_id: Url,
+ self_id: &str,
client_id: Uuid,
redirect_uri: Url,
) -> Result<Claims, Expect<VerifyJwtError>> {
@@ -252,7 +252,7 @@ pub async fn verify_auth_code<'c>(
pub async fn verify_access_token<'c>(
db: impl Executor<'c, Database = MySql>,
token: &str,
- self_id: Url,
+ self_id: &str,
client_id: Uuid,
) -> Result<Claims, Expect<VerifyJwtError>> {
let claims = verify_jwt(token, self_id, Some(client_id))?;
@@ -267,7 +267,7 @@ pub async fn verify_access_token<'c>(
pub async fn verify_refresh_token<'c>(
db: impl Executor<'c, Database = MySql>,
token: &str,
- self_id: Url,
+ self_id: &str,
client_id: Option<Uuid>,
) -> Result<Claims, Expect<VerifyJwtError>> {
let claims = verify_jwt(token, self_id, client_id)?;
diff --git a/src/services/mod.rs b/src/services/mod.rs
index 5339594..de08b58 100644
--- a/src/services/mod.rs
+++ b/src/services/mod.rs
@@ -1,4 +1,5 @@
pub mod authorization;
+pub mod config;
pub mod crypto;
pub mod db;
pub mod id;