use std::collections::BTreeSet; use std::fmt::{Display, Formatter}; use std::ops::Add; use std::time::Duration; use chrono::{DateTime, Utc}; use log::{debug, trace}; use reqwest::{Client, StatusCode}; use serde::{Deserialize, Serialize}; use tokio::time::{Instant, MissedTickBehavior}; use crate::auth::AuthError; use crate::auth::oauth::AccessTokenWithRefresh; use crate::util::USER_AGENT; pub struct DeviceCodeAuthBuilder { client_id: Option, scopes: Vec<(String, bool)>, code_request_url: Option, check_url: Option } #[derive(Serialize, Debug)] struct DeviceCodeRequest<'s> { client_id: &'s str, scope: &'s str, #[serde(skip_serializing_if = "Option::is_none")] response_type: Option<&'s str> } #[derive(Deserialize, Debug)] struct DeviceCodeResponse { device_code: String, user_code: String, verification_uri: String, expires_in: u64, interval: u64, message: Option } trait StringJointerton { fn joinerton(self, sep: char) -> String; } impl StringJointerton for I where I: Iterator, S: AsRef { fn joinerton(self, sep: char) -> String { self.fold(String::new(), |mut acc, s| { if !acc.is_empty() { acc.push(sep) } acc.push_str(s.as_ref()); acc }) } } impl DeviceCodeAuthBuilder { pub fn new() -> DeviceCodeAuthBuilder { DeviceCodeAuthBuilder { client_id: None, scopes: Vec::new(), code_request_url: None, check_url: None } } pub fn client_id(mut self, client_id: &str) -> Self { self.client_id = Some(client_id.to_owned()); self } pub fn add_scope(mut self, scope: &str, required: bool) -> Self { assert!(!scope.contains(' '), "multiple scopes should be passed using separate calls"); self.scopes.push((scope.to_owned(), required)); self } pub fn code_request_url(mut self, url: &str) -> Self { self.code_request_url = Some(url.to_owned()); self } pub fn check_url(mut self, url: &str) -> Self { self.check_url = Some(url.to_owned()); self } pub async fn begin(self, client: Client) -> Result { let client_id = self.client_id.expect("client_id is not optional"); let code_req_url = self.code_request_url.expect("code url is not optional"); let check_url = self.check_url.expect("check url is not optional"); let device_code: DeviceCodeResponse = client.post(&code_req_url) .header(reqwest::header::USER_AGENT, USER_AGENT) .header(reqwest::header::ACCEPT, "application/json") .form(&DeviceCodeRequest { client_id: client_id.as_str(), scope: self.scopes.iter().map(|(scope, _)| scope.as_str()).joinerton(' ').as_str(), response_type: Some("device_code") }) .send().await .and_then(|r| r.error_for_status()) .map_err(|e| AuthError::Request { what: "requesting device code auth", error: e })? .json().await.map_err(|e| AuthError::Request { what: "receiving device code auth", error: e })?; let now = Instant::now(); Ok(DeviceCodeAuth { client, client_id, check_url, scopes: self.scopes.into_iter().map(|(scope, _)| scope).collect(), start: now, interval: Duration::from_secs(device_code.interval + 1), expire_time: now.add(Duration::from_secs(device_code.expires_in)), info: dbg!(device_code) }) } } pub struct DeviceCodeAuth { client: Client, client_id: String, check_url: String, scopes: Vec, start: Instant, interval: Duration, expire_time: Instant, info: DeviceCodeResponse } #[derive(Serialize, Debug)] struct DeviceCodeTokenRequest<'s> { grant_type: &'s str, client_id: &'s str, device_code: &'s str } #[derive(Deserialize, Debug, PartialEq, Eq)] #[serde(rename_all = "snake_case")] enum DeviceCodeErrorKind { AuthorizationPending, AuthorizationDeclined, BadVerificationCode, ExpiredToken, #[serde(other)] Unknown } // https://learn.microsoft.com/en-us/entra/identity-platform/v2-oauth2-device-code#successful-authentication-response #[derive(Deserialize, Debug)] #[serde(untagged)] enum DeviceCodeTokenResponse { Success { // idc about token_type (we'll just assume it's always "Bearer" or else we don't know how to handle it) // idc about expires_in (we are not storing this token for any length of time) scope: Option, access_token: String, refresh_token: Option }, Error { error: DeviceCodeErrorKind, error_description: Option } } impl DeviceCodeAuth { pub fn get_user_code(&self) -> &str { self.info.user_code.as_str() } pub fn get_user_link(&self) -> &str { self.info.verification_uri.as_str() } pub fn get_user_message(&self) -> Option<&str> { self.info.message.as_ref().map(String::as_str) } pub async fn drive(&self) -> Result { let mut i = tokio::time::interval_at(self.start, self.interval); i.set_missed_tick_behavior(MissedTickBehavior::Skip); let req = DeviceCodeTokenRequest { grant_type: "urn:ietf:params:oauth:grant-type:device_code", client_id: self.client_id.as_str(), device_code: self.info.device_code.as_str() }; while self.expire_time.elapsed().is_zero() { i.tick().await; let res: DeviceCodeTokenResponse = self.client.get(&self.check_url) .header(reqwest::header::USER_AGENT, USER_AGENT) .header(reqwest::header::ACCEPT, "application/json") .form(&req) .send().await .map_err(|e| AuthError::Request { what: "device code heartbeat", error: e })? .json().await.map_err(|e| AuthError::Request { what: "decoding device code response", error: e })?; match res { DeviceCodeTokenResponse::Success { scope, access_token, refresh_token } => { if let Some(ref scope) = scope { let scopes_granted = scope.split(' ').collect::>(); for scope in self.scopes.iter() { if !scopes_granted.contains(scope.as_str()) { return Err(AuthError::Internal(format!("not granted required scope: {}", scope))); } } } return Ok(AccessTokenWithRefresh { access_token, refresh_token }) }, // the user hasn't done anything yet. continue polling DeviceCodeTokenResponse::Error { error, .. } if error == DeviceCodeErrorKind::AuthorizationPending => { debug!("Authorization for device code pending (it has been {:?})...", Instant::now().duration_since(self.start)); continue; }, // the device code expired :( DeviceCodeTokenResponse::Error { error, .. } if error == DeviceCodeErrorKind::ExpiredToken => { debug!("Device auth interaction timeout (1)"); return Err(AuthError::Timeout); } DeviceCodeTokenResponse::Error { error, error_description } => { debug!("Authorization for device code error: {:?} ({})", error, error_description.as_ref().map(|s| s.as_str()).unwrap_or("no description")); } } } debug!("Device auth interaction timeout (2)"); Err(AuthError::Timeout) } }