summaryrefslogtreecommitdiff
path: root/src/api/oauth.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/api/oauth.rs')
-rw-r--r--src/api/oauth.rs90
1 files changed, 51 insertions, 39 deletions
diff --git a/src/api/oauth.rs b/src/api/oauth.rs
index aee0ed4..f1aa012 100644
--- a/src/api/oauth.rs
+++ b/src/api/oauth.rs
@@ -235,21 +235,20 @@ async fn authorize(
credentials: web::Json<AuthorizeCredentials>,
tera: web::Data<Tera>,
translations: web::Data<languages::Translations>,
-) -> HttpResponse {
- // TODO handle internal server error
+) -> Result<HttpResponse, AuthorizeError> {
// TODO protect against brute force attacks
let db = db.get_ref();
let Ok(client_id) = db::get_client_id_by_alias(db, &req.client_id).await else {
let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return HttpResponse::InternalServerError().content_type("text/html").body(page);
+ return Ok(HttpResponse::InternalServerError().content_type("text/html").body(page));
};
let Some(client_id) = client_id else {
let page = error_page(&tera, &translations, templates::ErrorPage::ClientNotFound).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return HttpResponse::NotFound().content_type("text/html").body(page);
+ return Ok(HttpResponse::NotFound().content_type("text/html").body(page));
};
let Ok(config) = config::get_config() else {
let page = error_page(&tera, &translations, templates::ErrorPage::InternalServerError).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return HttpResponse::InternalServerError().content_type("text/html").body(page);
+ return Ok(HttpResponse::InternalServerError().content_type("text/html").body(page));
};
let self_id = config.url;
@@ -264,9 +263,9 @@ async fn authorize(
.unwrap_or(templates::ErrorPage::InternalServerError);
let page = error_page(&tera, &translations, e)
.unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return HttpResponse::BadRequest()
+ return Ok(HttpResponse::BadRequest()
.content_type("text/html")
- .body(page);
+ .body(page));
}
};
@@ -277,16 +276,19 @@ async fn authorize(
{
let language = Language::from_str("en").unwrap();
let translations = translations.get_ref().clone();
- let page = templates::login_error_page(&tera, &req, language, translations).unwrap();
- return HttpResponse::Ok().content_type("text/html").body(page);
+ let page = templates::login_error_page(&tera, &req, language, translations).unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
+ return Ok(HttpResponse::Ok().content_type("text/html").body(page));
};
+ let internal_server_error =
+ AuthorizeError::internal_server_error(redirect_uri.clone(), state.clone());
+
// get scope
let scope = match get_scope(&req.scope, db, client_id, &redirect_uri, &state).await {
Ok(scope) => scope,
Err(e) => {
- let e = e.unwrap();
- return e.error_response();
+ let e = e.expected().unwrap_or(internal_server_error);
+ return Err(e);
}
};
@@ -296,17 +298,18 @@ async fn authorize(
let code =
jwt::Claims::auth_code(db, self_id, client_id, user_id, &scope, &redirect_uri)
.await
- .unwrap();
- let code = code.to_jwt().unwrap();
+ .map_err(|_| internal_server_error.clone())?;
+ let code = code.to_jwt().map_err(|_| internal_server_error.clone())?;
let response = AuthCodeResponse { code, state };
- let query = Some(serde_urlencoded::to_string(response).unwrap());
+ let query =
+ Some(serde_urlencoded::to_string(response).map_err(|_| internal_server_error)?);
let query = query.as_deref();
redirect_uri.set_query(query);
- HttpResponse::Found()
+ Ok(HttpResponse::Found()
.append_header((header::LOCATION, redirect_uri.as_str()))
- .finish()
+ .finish())
}
ResponseType::Token => {
// create access token
@@ -314,9 +317,11 @@ async fn authorize(
let access_token =
jwt::Claims::access_token(db, None, self_id, client_id, user_id, duration, &scope)
.await
- .unwrap();
+ .map_err(|_| internal_server_error.clone())?;
- let access_token = access_token.to_jwt().unwrap();
+ let access_token = access_token
+ .to_jwt()
+ .map_err(|_| internal_server_error.clone())?;
let expires_in = duration.num_seconds();
let token_type = "bearer";
let response = AuthTokenResponse {
@@ -327,15 +332,17 @@ async fn authorize(
state,
};
- let fragment = Some(serde_urlencoded::to_string(response).unwrap());
+ let fragment = Some(
+ serde_urlencoded::to_string(response).map_err(|_| internal_server_error.clone())?,
+ );
let fragment = fragment.as_deref();
redirect_uri.set_fragment(fragment);
- HttpResponse::Found()
+ Ok(HttpResponse::Found()
.append_header((header::LOCATION, redirect_uri.as_str()))
- .finish()
+ .finish())
}
- _ => AuthorizeError::invalid_scope(redirect_uri, state).error_response(),
+ _ => Err(AuthorizeError::invalid_scope(redirect_uri, state)),
}
}
@@ -345,13 +352,12 @@ async fn authorize_page(
tera: web::Data<Tera>,
translations: web::Data<languages::Translations>,
request: HttpRequest,
-) -> HttpResponse {
- // TODO handle internal server error
+) -> Result<HttpResponse, AuthorizeError> {
let Ok(language) = Language::from_str("en") else {
let page = String::from(REALLY_BAD_ERROR_PAGE);
- return HttpResponse::InternalServerError()
+ return Ok(HttpResponse::InternalServerError()
.content_type("text/html")
- .body(page);
+ .body(page));
};
let translations = translations.get_ref().clone();
@@ -364,9 +370,9 @@ async fn authorize_page(
templates::ErrorPage::InvalidRequest,
)
.unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return HttpResponse::BadRequest()
+ return Ok(HttpResponse::BadRequest()
.content_type("text/html")
- .body(page);
+ .body(page));
};
let db = db.get_ref();
@@ -378,9 +384,9 @@ async fn authorize_page(
templates::ErrorPage::InternalServerError,
)
.unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return HttpResponse::InternalServerError()
+ return Ok(HttpResponse::InternalServerError()
.content_type("text/html")
- .body(page);
+ .body(page));
};
let Some(client_id) = client_id else {
let page = templates::error_page(
@@ -390,9 +396,9 @@ async fn authorize_page(
templates::ErrorPage::ClientNotFound,
)
.unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return HttpResponse::NotFound()
+ return Ok(HttpResponse::NotFound()
.content_type("text/html")
- .body(page);
+ .body(page));
};
// verify redirect uri
@@ -404,31 +410,37 @@ async fn authorize_page(
.unwrap_or(templates::ErrorPage::InternalServerError);
let page = error_page(&tera, &translations, e)
.unwrap_or_else(|_| String::from(REALLY_BAD_ERROR_PAGE));
- return HttpResponse::BadRequest()
+ return Ok(HttpResponse::BadRequest()
.content_type("text/html")
- .body(page);
+ .body(page));
}
};
+ let state = &params.state;
+ let internal_server_error =
+ AuthorizeError::internal_server_error(redirect_uri.clone(), state.clone());
+
// verify scope
let _ = match get_scope(&params.scope, db, client_id, &redirect_uri, &params.state).await {
Ok(scope) => scope,
Err(e) => {
- let e = e.unwrap();
- return e.error_response();
+ let e = e.expected().unwrap_or(internal_server_error);
+ return Err(e);
}
};
// verify response type
if params.response_type == ResponseType::Unsupported {
- return AuthorizeError::unsupported_response_type(redirect_uri, params.state)
- .error_response();
+ return Err(AuthorizeError::unsupported_response_type(
+ redirect_uri,
+ params.state,
+ ));
}
// TODO find a better way of doing languages
let language = Language::from_str("en").unwrap();
let page = templates::login_page(&tera, &params, language, translations).unwrap();
- HttpResponse::Ok().content_type("text/html").body(page)
+ Ok(HttpResponse::Ok().content_type("text/html").body(page))
}
#[derive(Clone, Deserialize)]