feat: verify state string in callback

This commit is contained in:
gak 2023-08-29 11:57:36 +10:00
parent db80e29d26
commit e4191c75f3
No known key found for this signature in database
2 changed files with 42 additions and 9 deletions

View file

@ -1,3 +1,4 @@
use crate::error::TeslatteError::{CouldNotFindCallbackCode, CouldNotFindState};
use crate::{Api, TeslatteError};
use derive_more::{Display, FromStr};
use rand::Rng;
@ -22,6 +23,11 @@ pub struct Credentials {
pub refresh_token: Option<RefreshToken>,
}
struct Callback {
code: String,
state: String,
}
impl Api {
/// Currently the only way to "authenticate" to an access token for this library.
pub async fn from_interactive_url() -> Result<Api, TeslatteError> {
@ -29,8 +35,16 @@ impl Api {
dbg!(&login_form);
let callback_url =
ask_input("Enter the URL of the 404 error page after you've logged in: ");
let callback_code = Self::extract_callback_code_from_url(&callback_url)?;
let bearer = Self::exchange_auth_for_bearer(&login_form.code, &callback_code).await?;
let callback = Self::extract_callback_from_url(&callback_url)?;
if callback.state != login_form.state {
return Err(TeslatteError::StateMismatch {
request: login_form.state,
callback: callback.state,
});
}
let bearer = Self::exchange_auth_for_bearer(&login_form.code, &callback.code).await?;
let access_token = AccessToken(bearer.access_token);
let refresh_token = RefreshToken(bearer.refresh_token);
Ok(Api::new(access_token, Some(refresh_token)))
@ -151,14 +165,24 @@ impl Api {
url.to_string()
}
fn extract_callback_code_from_url(callback_url: &str) -> Result<String, TeslatteError> {
Ok(Url::parse(callback_url)
.map_err(TeslatteError::UserDidNotSupplyValidCallbackUrl)?
.query_pairs()
fn extract_callback_from_url(callback_url: &str) -> Result<Callback, TeslatteError> {
let url =
Url::parse(callback_url).map_err(TeslatteError::UserDidNotSupplyValidCallbackUrl)?;
let pairs = url.query_pairs().collect::<Vec<_>>();
let code = pairs
.iter()
.find(|(k, _)| k == "code")
.map(|kv| kv.1)
.ok_or(TeslatteError::CouldNotFindCallbackCode)?
.to_string())
.map(|(_, v)| v.to_string())
.ok_or(CouldNotFindCallbackCode)?;
let state = pairs
.iter()
.find(|(k, _)| k == "state")
.map(|(_, v)| v.to_string())
.ok_or(CouldNotFindState)?;
Ok(Callback { code, state })
}
}
@ -181,6 +205,7 @@ pub struct RefreshTokenResponse {
#[derive(Debug, Default)]
pub struct LoginForm {
#[allow(dead_code)]
url: String,
code: Code,
state: String,

View file

@ -36,6 +36,14 @@ pub enum TeslatteError {
#[error("Callback URL did not contain a callback code.")]
CouldNotFindCallbackCode,
#[error("Callback URL did not contain the state.")]
CouldNotFindState,
#[error(
"State in the callback URL did not match the state in the request: {request} != {callback}"
)]
StateMismatch { request: String, callback: String },
#[error("Could not convert \"{0}\" to an EnergySiteId.")]
DecodeEnergySiteIdError(String),