mas_oidc_client/requests/
authorization_code.rs1use std::{collections::HashSet, num::NonZeroU32};
12
13use base64ct::{Base64UrlUnpadded, Encoding};
14use chrono::{DateTime, Utc};
15use language_tags::LanguageTag;
16use mas_iana::oauth::{OAuthAuthorizationEndpointResponseType, PkceCodeChallengeMethod};
17use mas_jose::claims::{self, TokenHash};
18use oauth2_types::{
19    pkce,
20    prelude::CodeChallengeMethodExt,
21    requests::{
22        AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, AuthorizationRequest,
23        Display, Prompt, ResponseMode,
24    },
25    scope::{OPENID, Scope},
26};
27use rand::{
28    Rng,
29    distributions::{Alphanumeric, DistString},
30};
31use serde::Serialize;
32use url::Url;
33
34use super::jose::JwtVerificationData;
35use crate::{
36    error::{AuthorizationError, IdTokenError, TokenAuthorizationCodeError},
37    requests::{jose::verify_id_token, token::request_access_token},
38    types::{IdToken, client_credentials::ClientCredentials},
39};
40
41#[derive(Debug, Clone)]
43pub struct AuthorizationRequestData {
44    pub client_id: String,
46
47    pub scope: Scope,
52
53    pub redirect_uri: Url,
57
58    pub code_challenge_methods_supported: Option<Vec<PkceCodeChallengeMethod>>,
63
64    pub display: Option<Display>,
67
68    pub prompt: Option<Vec<Prompt>>,
73
74    pub max_age: Option<NonZeroU32>,
77
78    pub ui_locales: Option<Vec<LanguageTag>>,
80
81    pub id_token_hint: Option<String>,
85
86    pub login_hint: Option<String>,
89
90    pub acr_values: Option<HashSet<String>>,
92
93    pub response_mode: Option<ResponseMode>,
95}
96
97impl AuthorizationRequestData {
98    #[must_use]
101    pub fn new(client_id: String, scope: Scope, redirect_uri: Url) -> Self {
102        Self {
103            client_id,
104            scope,
105            redirect_uri,
106            code_challenge_methods_supported: None,
107            display: None,
108            prompt: None,
109            max_age: None,
110            ui_locales: None,
111            id_token_hint: None,
112            login_hint: None,
113            acr_values: None,
114            response_mode: None,
115        }
116    }
117
118    #[must_use]
121    pub fn with_code_challenge_methods_supported(
122        mut self,
123        code_challenge_methods_supported: Vec<PkceCodeChallengeMethod>,
124    ) -> Self {
125        self.code_challenge_methods_supported = Some(code_challenge_methods_supported);
126        self
127    }
128
129    #[must_use]
131    pub fn with_display(mut self, display: Display) -> Self {
132        self.display = Some(display);
133        self
134    }
135
136    #[must_use]
138    pub fn with_prompt(mut self, prompt: Vec<Prompt>) -> Self {
139        self.prompt = Some(prompt);
140        self
141    }
142
143    #[must_use]
145    pub fn with_max_age(mut self, max_age: NonZeroU32) -> Self {
146        self.max_age = Some(max_age);
147        self
148    }
149
150    #[must_use]
152    pub fn with_ui_locales(mut self, ui_locales: Vec<LanguageTag>) -> Self {
153        self.ui_locales = Some(ui_locales);
154        self
155    }
156
157    #[must_use]
159    pub fn with_id_token_hint(mut self, id_token_hint: String) -> Self {
160        self.id_token_hint = Some(id_token_hint);
161        self
162    }
163
164    #[must_use]
166    pub fn with_login_hint(mut self, login_hint: String) -> Self {
167        self.login_hint = Some(login_hint);
168        self
169    }
170
171    #[must_use]
173    pub fn with_acr_values(mut self, acr_values: HashSet<String>) -> Self {
174        self.acr_values = Some(acr_values);
175        self
176    }
177
178    #[must_use]
180    pub fn with_response_mode(mut self, response_mode: ResponseMode) -> Self {
181        self.response_mode = Some(response_mode);
182        self
183    }
184}
185
186#[derive(Debug, Clone, PartialEq, Eq)]
189pub struct AuthorizationValidationData {
190    pub state: String,
192
193    pub nonce: String,
195
196    pub redirect_uri: Url,
198
199    pub code_challenge_verifier: Option<String>,
201}
202
203#[derive(Clone, Serialize)]
204struct FullAuthorizationRequest {
205    #[serde(flatten)]
206    inner: AuthorizationRequest,
207
208    #[serde(flatten, skip_serializing_if = "Option::is_none")]
209    pkce: Option<pkce::AuthorizationRequest>,
210}
211
212fn build_authorization_request(
214    authorization_data: AuthorizationRequestData,
215    rng: &mut impl Rng,
216) -> Result<(FullAuthorizationRequest, AuthorizationValidationData), AuthorizationError> {
217    let AuthorizationRequestData {
218        client_id,
219        mut scope,
220        redirect_uri,
221        code_challenge_methods_supported,
222        display,
223        prompt,
224        max_age,
225        ui_locales,
226        id_token_hint,
227        login_hint,
228        acr_values,
229        response_mode,
230    } = authorization_data;
231
232    let state = Alphanumeric.sample_string(rng, 16);
234    let nonce = Alphanumeric.sample_string(rng, 16);
235
236    let (pkce, code_challenge_verifier) = if code_challenge_methods_supported
238        .iter()
239        .any(|methods| methods.contains(&PkceCodeChallengeMethod::S256))
240    {
241        let mut verifier = [0u8; 32];
242        rng.fill(&mut verifier);
243
244        let method = PkceCodeChallengeMethod::S256;
245        let verifier = Base64UrlUnpadded::encode_string(&verifier);
246        let code_challenge = method.compute_challenge(&verifier)?.into();
247
248        let pkce = pkce::AuthorizationRequest {
249            code_challenge_method: method,
250            code_challenge,
251        };
252
253        (Some(pkce), Some(verifier))
254    } else {
255        (None, None)
256    };
257
258    scope.insert(OPENID);
259
260    let auth_request = FullAuthorizationRequest {
261        inner: AuthorizationRequest {
262            response_type: OAuthAuthorizationEndpointResponseType::Code.into(),
263            client_id,
264            redirect_uri: Some(redirect_uri.clone()),
265            scope,
266            state: Some(state.clone()),
267            response_mode,
268            nonce: Some(nonce.clone()),
269            display,
270            prompt,
271            max_age,
272            ui_locales,
273            id_token_hint,
274            login_hint,
275            acr_values,
276            request: None,
277            request_uri: None,
278            registration: None,
279        },
280        pkce,
281    };
282
283    let auth_data = AuthorizationValidationData {
284        state,
285        nonce,
286        redirect_uri,
287        code_challenge_verifier,
288    };
289
290    Ok((auth_request, auth_data))
291}
292
293pub fn build_authorization_url(
324    authorization_endpoint: Url,
325    authorization_data: AuthorizationRequestData,
326    rng: &mut impl Rng,
327) -> Result<(Url, AuthorizationValidationData), AuthorizationError> {
328    tracing::debug!(
329        scope = ?authorization_data.scope,
330        "Authorizing..."
331    );
332
333    let (authorization_request, validation_data) =
334        build_authorization_request(authorization_data, rng)?;
335
336    let authorization_query = serde_urlencoded::to_string(authorization_request)?;
337
338    let mut authorization_url = authorization_endpoint;
339
340    let mut full_query = authorization_url
342        .query()
343        .map(ToOwned::to_owned)
344        .unwrap_or_default();
345    if !full_query.is_empty() {
346        full_query.push('&');
347    }
348    full_query.push_str(&authorization_query);
349
350    authorization_url.set_query(Some(&full_query));
351
352    Ok((authorization_url, validation_data))
353}
354
355#[allow(clippy::too_many_arguments)]
393#[tracing::instrument(skip_all, fields(token_endpoint))]
394pub async fn access_token_with_authorization_code(
395    http_client: &reqwest::Client,
396    client_credentials: ClientCredentials,
397    token_endpoint: &Url,
398    code: String,
399    validation_data: AuthorizationValidationData,
400    id_token_verification_data: Option<JwtVerificationData<'_>>,
401    now: DateTime<Utc>,
402    rng: &mut impl Rng,
403) -> Result<(AccessTokenResponse, Option<IdToken<'static>>), TokenAuthorizationCodeError> {
404    tracing::debug!("Exchanging authorization code for access token...");
405
406    let token_response = request_access_token(
407        http_client,
408        client_credentials,
409        token_endpoint,
410        AccessTokenRequest::AuthorizationCode(AuthorizationCodeGrant {
411            code: code.clone(),
412            redirect_uri: Some(validation_data.redirect_uri),
413            code_verifier: validation_data.code_challenge_verifier,
414        }),
415        now,
416        rng,
417    )
418    .await?;
419
420    let id_token = if let Some(verification_data) = id_token_verification_data {
421        let signing_alg = verification_data.signing_algorithm;
422
423        let id_token = token_response
424            .id_token
425            .as_deref()
426            .ok_or(IdTokenError::MissingIdToken)?;
427
428        let id_token = verify_id_token(id_token, verification_data, None, now)?;
429
430        let mut claims = id_token.payload().clone();
431
432        claims::AT_HASH
434            .extract_optional_with_options(
435                &mut claims,
436                TokenHash::new(signing_alg, &token_response.access_token),
437            )
438            .map_err(IdTokenError::from)?;
439
440        claims::C_HASH
442            .extract_optional_with_options(&mut claims, TokenHash::new(signing_alg, &code))
443            .map_err(IdTokenError::from)?;
444
445        claims::NONCE
447            .extract_required_with_options(&mut claims, validation_data.nonce.as_str())
448            .map_err(IdTokenError::from)?;
449
450        Some(id_token.into_owned())
451    } else {
452        None
453    };
454
455    Ok((token_response, id_token))
456}