From 00675738a4012580c6db10d9e1115732559bdaf3 Mon Sep 17 00:00:00 2001 From: bigfoot547 Date: Wed, 29 Jan 2025 03:27:31 -0600 Subject: do device code auth --- src/auth/device_code.rs | 103 ------------------ src/auth/oauth.rs | 8 ++ src/auth/oauth/device_code.rs | 243 ++++++++++++++++++++++++++++++++++++++++++ src/auth/oauth/refresh.rs | 0 4 files changed, 251 insertions(+), 103 deletions(-) delete mode 100644 src/auth/device_code.rs create mode 100644 src/auth/oauth.rs create mode 100644 src/auth/oauth/device_code.rs create mode 100644 src/auth/oauth/refresh.rs (limited to 'src/auth') diff --git a/src/auth/device_code.rs b/src/auth/device_code.rs deleted file mode 100644 index 087ff27..0000000 --- a/src/auth/device_code.rs +++ /dev/null @@ -1,103 +0,0 @@ -use std::ops::Add; -use std::time::Duration; -use futures::TryFutureExt; -use reqwest::Client; -use serde::{Deserialize, Serialize}; -use tokio::time::{Instant, MissedTickBehavior}; -use super::AuthError; -use crate::util::USER_AGENT; - -pub struct DeviceCodeAuthBuilder { - client_id: Option, - scope: Option, - url: Option -} - -#[derive(Serialize, Debug)] -struct DeviceCodeRequest { - client_id: String, - scope: String, - response_type: String -} - -#[derive(Deserialize, Debug)] -struct DeviceCodeResponse { - device_code: String, - user_code: String, - verification_uri: String, - expires_in: u64, - interval: u64, - message: Option -} - -impl DeviceCodeAuthBuilder { - pub fn new() -> DeviceCodeAuthBuilder { - DeviceCodeAuthBuilder { - client_id: None, - scope: None, - url: None - } - } - - pub fn client_id(mut self, client_id: &str) -> Self { - self.client_id = Some(client_id.to_owned()); - self - } - - pub fn scope(mut self, scope: &str) -> Self { - self.scope = Some(scope.to_owned()); - self - } - - pub fn url(mut self, url: &str) -> Self { - self.url = Some(url.to_owned()); - self - } - - pub async fn begin(self, client: Client) -> Result { - let scope = self.scope.expect("scope is not optional"); - let client_id = self.client_id.expect("client_id is not optional"); - let url = self.url.expect("url is not optional"); - - let device_code: DeviceCodeResponse = client.post(&url) - .header(reqwest::header::USER_AGENT, USER_AGENT) - .header(reqwest::header::ACCEPT, "application/json") - .form(&DeviceCodeRequest { - client_id, - scope, - response_type: "device_code".into() - }) - .send().await - .and_then(|r| r.error_for_status()) - .map_err(|e| AuthError::Request { what: "requesting device code auth", error: e })? - .json().await.map_err(|e| AuthError::Request { what: "receiving device code auth", error: e })?; - - let now = Instant::now(); - Ok(DeviceCodeAuth { - client, - start: now, - interval: Duration::from_secs(device_code.interval + 1), - expire_time: now.add(Duration::from_secs(device_code.expires_in)), - info: dbg!(device_code) - }) - } -} - -pub struct DeviceCodeAuth { - client: Client, - start: Instant, - interval: Duration, - expire_time: Instant, - info: DeviceCodeResponse -} - -impl DeviceCodeAuth { - async fn drive(&self) { - let mut i = tokio::time::interval_at(self.start, self.interval); - i.set_missed_tick_behavior(MissedTickBehavior::Skip); - - while self.expire_time.elapsed().is_zero() { - - } - } -} diff --git a/src/auth/oauth.rs b/src/auth/oauth.rs new file mode 100644 index 0000000..6d4da77 --- /dev/null +++ b/src/auth/oauth.rs @@ -0,0 +1,8 @@ +pub mod device_code; +mod refresh; + +pub struct AccessTokenWithRefresh { + pub access_token: String, + pub refresh_token: Option +} + diff --git a/src/auth/oauth/device_code.rs b/src/auth/oauth/device_code.rs new file mode 100644 index 0000000..4462aa7 --- /dev/null +++ b/src/auth/oauth/device_code.rs @@ -0,0 +1,243 @@ +use std::collections::BTreeSet; +use std::fmt::{Display, Formatter}; +use std::ops::Add; +use std::time::Duration; +use chrono::{DateTime, Utc}; +use log::{debug, trace}; +use reqwest::{Client, StatusCode}; +use serde::{Deserialize, Serialize}; +use tokio::time::{Instant, MissedTickBehavior}; +use crate::auth::AuthError; +use crate::auth::oauth::AccessTokenWithRefresh; +use crate::util::USER_AGENT; + +pub struct DeviceCodeAuthBuilder { + client_id: Option, + scopes: Vec<(String, bool)>, + code_request_url: Option, + check_url: Option +} + +#[derive(Serialize, Debug)] +struct DeviceCodeRequest<'s> { + client_id: &'s str, + scope: &'s str, + + #[serde(skip_serializing_if = "Option::is_none")] + response_type: Option<&'s str> +} + +#[derive(Deserialize, Debug)] +struct DeviceCodeResponse { + device_code: String, + user_code: String, + verification_uri: String, + expires_in: u64, + interval: u64, + message: Option +} + +trait StringJointerton { + fn joinerton(self, sep: char) -> String; +} + +impl StringJointerton for I +where + I: Iterator, + S: AsRef +{ + fn joinerton(self, sep: char) -> String { + self.fold(String::new(), |mut acc, s| { + if !acc.is_empty() { acc.push(sep) } + acc.push_str(s.as_ref()); + acc + }) + } +} + +impl DeviceCodeAuthBuilder { + pub fn new() -> DeviceCodeAuthBuilder { + DeviceCodeAuthBuilder { + client_id: None, + scopes: Vec::new(), + code_request_url: None, + check_url: None + } + } + + pub fn client_id(mut self, client_id: &str) -> Self { + self.client_id = Some(client_id.to_owned()); + self + } + + pub fn add_scope(mut self, scope: &str, required: bool) -> Self { + assert!(!scope.contains(' '), "multiple scopes should be passed using separate calls"); + self.scopes.push((scope.to_owned(), required)); + self + } + + pub fn code_request_url(mut self, url: &str) -> Self { + self.code_request_url = Some(url.to_owned()); + self + } + + pub fn check_url(mut self, url: &str) -> Self { + self.check_url = Some(url.to_owned()); + self + } + + pub async fn begin(self, client: Client) -> Result { + let client_id = self.client_id.expect("client_id is not optional"); + let code_req_url = self.code_request_url.expect("code url is not optional"); + let check_url = self.check_url.expect("check url is not optional"); + + let device_code: DeviceCodeResponse = client.post(&code_req_url) + .header(reqwest::header::USER_AGENT, USER_AGENT) + .header(reqwest::header::ACCEPT, "application/json") + .form(&DeviceCodeRequest { + client_id: client_id.as_str(), + scope: self.scopes.iter().map(|(scope, _)| scope.as_str()).joinerton(' ').as_str(), + response_type: Some("device_code") + }) + .send().await + .and_then(|r| r.error_for_status()) + .map_err(|e| AuthError::Request { what: "requesting device code auth", error: e })? + .json().await.map_err(|e| AuthError::Request { what: "receiving device code auth", error: e })?; + + let now = Instant::now(); + Ok(DeviceCodeAuth { + client, + client_id, + check_url, + scopes: self.scopes.into_iter().map(|(scope, _)| scope).collect(), + start: now, + interval: Duration::from_secs(device_code.interval + 1), + expire_time: now.add(Duration::from_secs(device_code.expires_in)), + info: dbg!(device_code) + }) + } +} + +pub struct DeviceCodeAuth { + client: Client, + client_id: String, + check_url: String, + scopes: Vec, + start: Instant, + interval: Duration, + expire_time: Instant, + info: DeviceCodeResponse +} + +#[derive(Serialize, Debug)] +struct DeviceCodeTokenRequest<'s> { + grant_type: &'s str, + client_id: &'s str, + device_code: &'s str +} + +#[derive(Deserialize, Debug, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +enum DeviceCodeErrorKind { + AuthorizationPending, + AuthorizationDeclined, + BadVerificationCode, + ExpiredToken, + + #[serde(other)] + Unknown +} + +// https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-device-code#successful-authentication-response +#[derive(Deserialize, Debug)] +#[serde(untagged)] +enum DeviceCodeTokenResponse { + Success { + // idc about token_type (we'll just assume it's always "Bearer" or else we don't know how to handle it) + // idc about expires_in (we are not storing this token for any length of time) + scope: Option, + access_token: String, + refresh_token: Option + }, + Error { + error: DeviceCodeErrorKind, + error_description: Option + } +} + +impl DeviceCodeAuth { + pub fn get_user_code(&self) -> &str { + self.info.user_code.as_str() + } + + pub fn get_user_link(&self) -> &str { + self.info.verification_uri.as_str() + } + + pub fn get_user_message(&self) -> Option<&str> { + self.info.message.as_ref().map(String::as_str) + } + + pub async fn drive(&self) -> Result { + let mut i = tokio::time::interval_at(self.start, self.interval); + i.set_missed_tick_behavior(MissedTickBehavior::Skip); + + let req = DeviceCodeTokenRequest { + grant_type: "urn:ietf:params:oauth:grant-type:device_code", + client_id: self.client_id.as_str(), + device_code: self.info.device_code.as_str() + }; + + while self.expire_time.elapsed().is_zero() { + i.tick().await; + + let res: DeviceCodeTokenResponse = self.client.get(&self.check_url) + .header(reqwest::header::USER_AGENT, USER_AGENT) + .header(reqwest::header::ACCEPT, "application/json") + .form(&req) + .send().await + .map_err(|e| AuthError::Request { what: "device code heartbeat", error: e })? + .json().await.map_err(|e| AuthError::Request { what: "decoding device code response", error: e })?; + + match res { + DeviceCodeTokenResponse::Success { scope, access_token, refresh_token } => { + if let Some(ref scope) = scope { + let scopes_granted = scope.split(' ').collect::>(); + for scope in self.scopes.iter() { + if !scopes_granted.contains(scope.as_str()) { + return Err(AuthError::Internal(format!("not granted required scope: {}", scope))); + } + } + } + + return Ok(AccessTokenWithRefresh { + access_token, + refresh_token + }) + }, + + // the user hasn't done anything yet. continue polling + DeviceCodeTokenResponse::Error { error, .. } if error == DeviceCodeErrorKind::AuthorizationPending => { + debug!("Authorization for device code pending (it has been {:?})...", Instant::now().duration_since(self.start)); + continue; + }, + + // the device code expired :( + DeviceCodeTokenResponse::Error { error, .. } if error == DeviceCodeErrorKind::ExpiredToken => { + debug!("Device auth interaction timeout (1)"); + return Err(AuthError::Timeout); + } + + DeviceCodeTokenResponse::Error { error, error_description } => { + debug!("Authorization for device code error: {:?} ({})", + error, + error_description.as_ref().map(|s| s.as_str()).unwrap_or("no description")); + } + } + } + + debug!("Device auth interaction timeout (2)"); + + Err(AuthError::Timeout) + } +} diff --git a/src/auth/oauth/refresh.rs b/src/auth/oauth/refresh.rs new file mode 100644 index 0000000..e69de29 -- cgit v1.2.3-70-g09d2