diff options
Diffstat (limited to 'src/auth')
| -rw-r--r-- | src/auth/device_code.rs | 103 | ||||
| -rw-r--r-- | src/auth/oauth.rs | 8 | ||||
| -rw-r--r-- | src/auth/oauth/device_code.rs | 243 | ||||
| -rw-r--r-- | src/auth/oauth/refresh.rs | 0 |
4 files changed, 251 insertions, 103 deletions
diff --git a/src/auth/device_code.rs b/src/auth/device_code.rs deleted file mode 100644 index 087ff27..0000000 --- a/src/auth/device_code.rs +++ /dev/null @@ -1,103 +0,0 @@ -use std::ops::Add;
-use std::time::Duration;
-use futures::TryFutureExt;
-use reqwest::Client;
-use serde::{Deserialize, Serialize};
-use tokio::time::{Instant, MissedTickBehavior};
-use super::AuthError;
-use crate::util::USER_AGENT;
-
-pub struct DeviceCodeAuthBuilder {
- client_id: Option<String>,
- scope: Option<String>,
- url: Option<String>
-}
-
-#[derive(Serialize, Debug)]
-struct DeviceCodeRequest {
- client_id: String,
- scope: String,
- response_type: String
-}
-
-#[derive(Deserialize, Debug)]
-struct DeviceCodeResponse {
- device_code: String,
- user_code: String,
- verification_uri: String,
- expires_in: u64,
- interval: u64,
- message: Option<String>
-}
-
-impl DeviceCodeAuthBuilder {
- pub fn new() -> DeviceCodeAuthBuilder {
- DeviceCodeAuthBuilder {
- client_id: None,
- scope: None,
- url: None
- }
- }
-
- pub fn client_id(mut self, client_id: &str) -> Self {
- self.client_id = Some(client_id.to_owned());
- self
- }
-
- pub fn scope(mut self, scope: &str) -> Self {
- self.scope = Some(scope.to_owned());
- self
- }
-
- pub fn url(mut self, url: &str) -> Self {
- self.url = Some(url.to_owned());
- self
- }
-
- pub async fn begin(self, client: Client) -> Result<DeviceCodeAuth, AuthError> {
- let scope = self.scope.expect("scope is not optional");
- let client_id = self.client_id.expect("client_id is not optional");
- let url = self.url.expect("url is not optional");
-
- let device_code: DeviceCodeResponse = client.post(&url)
- .header(reqwest::header::USER_AGENT, USER_AGENT)
- .header(reqwest::header::ACCEPT, "application/json")
- .form(&DeviceCodeRequest {
- client_id,
- scope,
- response_type: "device_code".into()
- })
- .send().await
- .and_then(|r| r.error_for_status())
- .map_err(|e| AuthError::Request { what: "requesting device code auth", error: e })?
- .json().await.map_err(|e| AuthError::Request { what: "receiving device code auth", error: e })?;
-
- let now = Instant::now();
- Ok(DeviceCodeAuth {
- client,
- start: now,
- interval: Duration::from_secs(device_code.interval + 1),
- expire_time: now.add(Duration::from_secs(device_code.expires_in)),
- info: dbg!(device_code)
- })
- }
-}
-
-pub struct DeviceCodeAuth {
- client: Client,
- start: Instant,
- interval: Duration,
- expire_time: Instant,
- info: DeviceCodeResponse
-}
-
-impl DeviceCodeAuth {
- async fn drive(&self) {
- let mut i = tokio::time::interval_at(self.start, self.interval);
- i.set_missed_tick_behavior(MissedTickBehavior::Skip);
-
- while self.expire_time.elapsed().is_zero() {
-
- }
- }
-}
diff --git a/src/auth/oauth.rs b/src/auth/oauth.rs new file mode 100644 index 0000000..6d4da77 --- /dev/null +++ b/src/auth/oauth.rs @@ -0,0 +1,8 @@ +pub mod device_code; +mod refresh; + +pub struct AccessTokenWithRefresh { + pub access_token: String, + pub refresh_token: Option<String> +} + diff --git a/src/auth/oauth/device_code.rs b/src/auth/oauth/device_code.rs new file mode 100644 index 0000000..4462aa7 --- /dev/null +++ b/src/auth/oauth/device_code.rs @@ -0,0 +1,243 @@ +use std::collections::BTreeSet;
+use std::fmt::{Display, Formatter};
+use std::ops::Add;
+use std::time::Duration;
+use chrono::{DateTime, Utc};
+use log::{debug, trace};
+use reqwest::{Client, StatusCode};
+use serde::{Deserialize, Serialize};
+use tokio::time::{Instant, MissedTickBehavior};
+use crate::auth::AuthError;
+use crate::auth::oauth::AccessTokenWithRefresh;
+use crate::util::USER_AGENT;
+
+pub struct DeviceCodeAuthBuilder {
+ client_id: Option<String>,
+ scopes: Vec<(String, bool)>,
+ code_request_url: Option<String>,
+ check_url: Option<String>
+}
+
+#[derive(Serialize, Debug)]
+struct DeviceCodeRequest<'s> {
+ client_id: &'s str,
+ scope: &'s str,
+
+ #[serde(skip_serializing_if = "Option::is_none")]
+ response_type: Option<&'s str>
+}
+
+#[derive(Deserialize, Debug)]
+struct DeviceCodeResponse {
+ device_code: String,
+ user_code: String,
+ verification_uri: String,
+ expires_in: u64,
+ interval: u64,
+ message: Option<String>
+}
+
+trait StringJointerton {
+ fn joinerton(self, sep: char) -> String;
+}
+
+impl<I, S> StringJointerton for I
+where
+ I: Iterator<Item = S>,
+ S: AsRef<str>
+{
+ fn joinerton(self, sep: char) -> String {
+ self.fold(String::new(), |mut acc, s| {
+ if !acc.is_empty() { acc.push(sep) }
+ acc.push_str(s.as_ref());
+ acc
+ })
+ }
+}
+
+impl DeviceCodeAuthBuilder {
+ pub fn new() -> DeviceCodeAuthBuilder {
+ DeviceCodeAuthBuilder {
+ client_id: None,
+ scopes: Vec::new(),
+ code_request_url: None,
+ check_url: None
+ }
+ }
+
+ pub fn client_id(mut self, client_id: &str) -> Self {
+ self.client_id = Some(client_id.to_owned());
+ self
+ }
+
+ pub fn add_scope(mut self, scope: &str, required: bool) -> Self {
+ assert!(!scope.contains(' '), "multiple scopes should be passed using separate calls");
+ self.scopes.push((scope.to_owned(), required));
+ self
+ }
+
+ pub fn code_request_url(mut self, url: &str) -> Self {
+ self.code_request_url = Some(url.to_owned());
+ self
+ }
+
+ pub fn check_url(mut self, url: &str) -> Self {
+ self.check_url = Some(url.to_owned());
+ self
+ }
+
+ pub async fn begin(self, client: Client) -> Result<DeviceCodeAuth, AuthError> {
+ let client_id = self.client_id.expect("client_id is not optional");
+ let code_req_url = self.code_request_url.expect("code url is not optional");
+ let check_url = self.check_url.expect("check url is not optional");
+
+ let device_code: DeviceCodeResponse = client.post(&code_req_url)
+ .header(reqwest::header::USER_AGENT, USER_AGENT)
+ .header(reqwest::header::ACCEPT, "application/json")
+ .form(&DeviceCodeRequest {
+ client_id: client_id.as_str(),
+ scope: self.scopes.iter().map(|(scope, _)| scope.as_str()).joinerton(' ').as_str(),
+ response_type: Some("device_code")
+ })
+ .send().await
+ .and_then(|r| r.error_for_status())
+ .map_err(|e| AuthError::Request { what: "requesting device code auth", error: e })?
+ .json().await.map_err(|e| AuthError::Request { what: "receiving device code auth", error: e })?;
+
+ let now = Instant::now();
+ Ok(DeviceCodeAuth {
+ client,
+ client_id,
+ check_url,
+ scopes: self.scopes.into_iter().map(|(scope, _)| scope).collect(),
+ start: now,
+ interval: Duration::from_secs(device_code.interval + 1),
+ expire_time: now.add(Duration::from_secs(device_code.expires_in)),
+ info: dbg!(device_code)
+ })
+ }
+}
+
+pub struct DeviceCodeAuth {
+ client: Client,
+ client_id: String,
+ check_url: String,
+ scopes: Vec<String>,
+ start: Instant,
+ interval: Duration,
+ expire_time: Instant,
+ info: DeviceCodeResponse
+}
+
+#[derive(Serialize, Debug)]
+struct DeviceCodeTokenRequest<'s> {
+ grant_type: &'s str,
+ client_id: &'s str,
+ device_code: &'s str
+}
+
+#[derive(Deserialize, Debug, PartialEq, Eq)]
+#[serde(rename_all = "snake_case")]
+enum DeviceCodeErrorKind {
+ AuthorizationPending,
+ AuthorizationDeclined,
+ BadVerificationCode,
+ ExpiredToken,
+
+ #[serde(other)]
+ Unknown
+}
+
+// https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-device-code#successful-authentication-response
+#[derive(Deserialize, Debug)]
+#[serde(untagged)]
+enum DeviceCodeTokenResponse {
+ Success {
+ // idc about token_type (we'll just assume it's always "Bearer" or else we don't know how to handle it)
+ // idc about expires_in (we are not storing this token for any length of time)
+ scope: Option<String>,
+ access_token: String,
+ refresh_token: Option<String>
+ },
+ Error {
+ error: DeviceCodeErrorKind,
+ error_description: Option<String>
+ }
+}
+
+impl DeviceCodeAuth {
+ pub fn get_user_code(&self) -> &str {
+ self.info.user_code.as_str()
+ }
+
+ pub fn get_user_link(&self) -> &str {
+ self.info.verification_uri.as_str()
+ }
+
+ pub fn get_user_message(&self) -> Option<&str> {
+ self.info.message.as_ref().map(String::as_str)
+ }
+
+ pub async fn drive(&self) -> Result<AccessTokenWithRefresh, AuthError> {
+ let mut i = tokio::time::interval_at(self.start, self.interval);
+ i.set_missed_tick_behavior(MissedTickBehavior::Skip);
+
+ let req = DeviceCodeTokenRequest {
+ grant_type: "urn:ietf:params:oauth:grant-type:device_code",
+ client_id: self.client_id.as_str(),
+ device_code: self.info.device_code.as_str()
+ };
+
+ while self.expire_time.elapsed().is_zero() {
+ i.tick().await;
+
+ let res: DeviceCodeTokenResponse = self.client.get(&self.check_url)
+ .header(reqwest::header::USER_AGENT, USER_AGENT)
+ .header(reqwest::header::ACCEPT, "application/json")
+ .form(&req)
+ .send().await
+ .map_err(|e| AuthError::Request { what: "device code heartbeat", error: e })?
+ .json().await.map_err(|e| AuthError::Request { what: "decoding device code response", error: e })?;
+
+ match res {
+ DeviceCodeTokenResponse::Success { scope, access_token, refresh_token } => {
+ if let Some(ref scope) = scope {
+ let scopes_granted = scope.split(' ').collect::<BTreeSet<_>>();
+ for scope in self.scopes.iter() {
+ if !scopes_granted.contains(scope.as_str()) {
+ return Err(AuthError::Internal(format!("not granted required scope: {}", scope)));
+ }
+ }
+ }
+
+ return Ok(AccessTokenWithRefresh {
+ access_token,
+ refresh_token
+ })
+ },
+
+ // the user hasn't done anything yet. continue polling
+ DeviceCodeTokenResponse::Error { error, .. } if error == DeviceCodeErrorKind::AuthorizationPending => {
+ debug!("Authorization for device code pending (it has been {:?})...", Instant::now().duration_since(self.start));
+ continue;
+ },
+
+ // the device code expired :(
+ DeviceCodeTokenResponse::Error { error, .. } if error == DeviceCodeErrorKind::ExpiredToken => {
+ debug!("Device auth interaction timeout (1)");
+ return Err(AuthError::Timeout);
+ }
+
+ DeviceCodeTokenResponse::Error { error, error_description } => {
+ debug!("Authorization for device code error: {:?} ({})",
+ error,
+ error_description.as_ref().map(|s| s.as_str()).unwrap_or("no description"));
+ }
+ }
+ }
+
+ debug!("Device auth interaction timeout (2)");
+
+ Err(AuthError::Timeout)
+ }
+}
diff --git a/src/auth/oauth/refresh.rs b/src/auth/oauth/refresh.rs new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/auth/oauth/refresh.rs |
