feat: verify state string in callback
This commit is contained in:
parent
db80e29d26
commit
e4191c75f3
43
src/auth.rs
43
src/auth.rs
|
@ -1,3 +1,4 @@
|
||||||
|
use crate::error::TeslatteError::{CouldNotFindCallbackCode, CouldNotFindState};
|
||||||
use crate::{Api, TeslatteError};
|
use crate::{Api, TeslatteError};
|
||||||
use derive_more::{Display, FromStr};
|
use derive_more::{Display, FromStr};
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
|
@ -22,6 +23,11 @@ pub struct Credentials {
|
||||||
pub refresh_token: Option<RefreshToken>,
|
pub refresh_token: Option<RefreshToken>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Callback {
|
||||||
|
code: String,
|
||||||
|
state: String,
|
||||||
|
}
|
||||||
|
|
||||||
impl Api {
|
impl Api {
|
||||||
/// Currently the only way to "authenticate" to an access token for this library.
|
/// Currently the only way to "authenticate" to an access token for this library.
|
||||||
pub async fn from_interactive_url() -> Result<Api, TeslatteError> {
|
pub async fn from_interactive_url() -> Result<Api, TeslatteError> {
|
||||||
|
@ -29,8 +35,16 @@ impl Api {
|
||||||
dbg!(&login_form);
|
dbg!(&login_form);
|
||||||
let callback_url =
|
let callback_url =
|
||||||
ask_input("Enter the URL of the 404 error page after you've logged in: ");
|
ask_input("Enter the URL of the 404 error page after you've logged in: ");
|
||||||
let callback_code = Self::extract_callback_code_from_url(&callback_url)?;
|
|
||||||
let bearer = Self::exchange_auth_for_bearer(&login_form.code, &callback_code).await?;
|
let callback = Self::extract_callback_from_url(&callback_url)?;
|
||||||
|
if callback.state != login_form.state {
|
||||||
|
return Err(TeslatteError::StateMismatch {
|
||||||
|
request: login_form.state,
|
||||||
|
callback: callback.state,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let bearer = Self::exchange_auth_for_bearer(&login_form.code, &callback.code).await?;
|
||||||
let access_token = AccessToken(bearer.access_token);
|
let access_token = AccessToken(bearer.access_token);
|
||||||
let refresh_token = RefreshToken(bearer.refresh_token);
|
let refresh_token = RefreshToken(bearer.refresh_token);
|
||||||
Ok(Api::new(access_token, Some(refresh_token)))
|
Ok(Api::new(access_token, Some(refresh_token)))
|
||||||
|
@ -151,14 +165,24 @@ impl Api {
|
||||||
url.to_string()
|
url.to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn extract_callback_code_from_url(callback_url: &str) -> Result<String, TeslatteError> {
|
fn extract_callback_from_url(callback_url: &str) -> Result<Callback, TeslatteError> {
|
||||||
Ok(Url::parse(callback_url)
|
let url =
|
||||||
.map_err(TeslatteError::UserDidNotSupplyValidCallbackUrl)?
|
Url::parse(callback_url).map_err(TeslatteError::UserDidNotSupplyValidCallbackUrl)?;
|
||||||
.query_pairs()
|
let pairs = url.query_pairs().collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let code = pairs
|
||||||
|
.iter()
|
||||||
.find(|(k, _)| k == "code")
|
.find(|(k, _)| k == "code")
|
||||||
.map(|kv| kv.1)
|
.map(|(_, v)| v.to_string())
|
||||||
.ok_or(TeslatteError::CouldNotFindCallbackCode)?
|
.ok_or(CouldNotFindCallbackCode)?;
|
||||||
.to_string())
|
|
||||||
|
let state = pairs
|
||||||
|
.iter()
|
||||||
|
.find(|(k, _)| k == "state")
|
||||||
|
.map(|(_, v)| v.to_string())
|
||||||
|
.ok_or(CouldNotFindState)?;
|
||||||
|
|
||||||
|
Ok(Callback { code, state })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -181,6 +205,7 @@ pub struct RefreshTokenResponse {
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
pub struct LoginForm {
|
pub struct LoginForm {
|
||||||
|
#[allow(dead_code)]
|
||||||
url: String,
|
url: String,
|
||||||
code: Code,
|
code: Code,
|
||||||
state: String,
|
state: String,
|
||||||
|
|
|
@ -36,6 +36,14 @@ pub enum TeslatteError {
|
||||||
#[error("Callback URL did not contain a callback code.")]
|
#[error("Callback URL did not contain a callback code.")]
|
||||||
CouldNotFindCallbackCode,
|
CouldNotFindCallbackCode,
|
||||||
|
|
||||||
|
#[error("Callback URL did not contain the state.")]
|
||||||
|
CouldNotFindState,
|
||||||
|
|
||||||
|
#[error(
|
||||||
|
"State in the callback URL did not match the state in the request: {request} != {callback}"
|
||||||
|
)]
|
||||||
|
StateMismatch { request: String, callback: String },
|
||||||
|
|
||||||
#[error("Could not convert \"{0}\" to an EnergySiteId.")]
|
#[error("Could not convert \"{0}\" to an EnergySiteId.")]
|
||||||
DecodeEnergySiteIdError(String),
|
DecodeEnergySiteIdError(String),
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue