mas_config/sections/
upstream_oauth2.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::collections::BTreeMap;
8
9use camino::Utf8PathBuf;
10use mas_iana::jose::JsonWebSignatureAlg;
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize, de::Error};
13use serde_with::skip_serializing_none;
14use ulid::Ulid;
15use url::Url;
16
17use crate::ConfigurationSection;
18
19/// Upstream OAuth 2.0 providers configuration
20#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
21pub struct UpstreamOAuth2Config {
22    /// List of OAuth 2.0 providers
23    pub providers: Vec<Provider>,
24}
25
26impl UpstreamOAuth2Config {
27    /// Returns true if the configuration is the default one
28    pub(crate) fn is_default(&self) -> bool {
29        self.providers.is_empty()
30    }
31}
32
33impl ConfigurationSection for UpstreamOAuth2Config {
34    const PATH: Option<&'static str> = Some("upstream_oauth2");
35
36    fn validate(
37        &self,
38        figment: &figment::Figment,
39    ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
40        for (index, provider) in self.providers.iter().enumerate() {
41            let annotate = |mut error: figment::Error| {
42                error.metadata = figment
43                    .find_metadata(&format!("{root}.providers", root = Self::PATH.unwrap()))
44                    .cloned();
45                error.profile = Some(figment::Profile::Default);
46                error.path = vec![
47                    Self::PATH.unwrap().to_owned(),
48                    "providers".to_owned(),
49                    index.to_string(),
50                ];
51                error
52            };
53
54            if !matches!(provider.discovery_mode, DiscoveryMode::Disabled)
55                && provider.issuer.is_none()
56            {
57                return Err(annotate(figment::Error::custom(
58                    "The `issuer` field is required when discovery is enabled",
59                ))
60                .into());
61            }
62
63            match provider.token_endpoint_auth_method {
64                TokenAuthMethod::None
65                | TokenAuthMethod::PrivateKeyJwt
66                | TokenAuthMethod::SignInWithApple => {
67                    if provider.client_secret.is_some() {
68                        return Err(annotate(figment::Error::custom(
69                            "Unexpected field `client_secret` for the selected authentication method",
70                        )).into());
71                    }
72                }
73                TokenAuthMethod::ClientSecretBasic
74                | TokenAuthMethod::ClientSecretPost
75                | TokenAuthMethod::ClientSecretJwt => {
76                    if provider.client_secret.is_none() {
77                        return Err(annotate(figment::Error::missing_field("client_secret")).into());
78                    }
79                }
80            }
81
82            match provider.token_endpoint_auth_method {
83                TokenAuthMethod::None
84                | TokenAuthMethod::ClientSecretBasic
85                | TokenAuthMethod::ClientSecretPost
86                | TokenAuthMethod::SignInWithApple => {
87                    if provider.token_endpoint_auth_signing_alg.is_some() {
88                        return Err(annotate(figment::Error::custom(
89                            "Unexpected field `token_endpoint_auth_signing_alg` for the selected authentication method",
90                        )).into());
91                    }
92                }
93                TokenAuthMethod::ClientSecretJwt | TokenAuthMethod::PrivateKeyJwt => {
94                    if provider.token_endpoint_auth_signing_alg.is_none() {
95                        return Err(annotate(figment::Error::missing_field(
96                            "token_endpoint_auth_signing_alg",
97                        ))
98                        .into());
99                    }
100                }
101            }
102
103            match provider.token_endpoint_auth_method {
104                TokenAuthMethod::SignInWithApple => {
105                    if provider.sign_in_with_apple.is_none() {
106                        return Err(
107                            annotate(figment::Error::missing_field("sign_in_with_apple")).into(),
108                        );
109                    }
110                }
111
112                _ => {
113                    if provider.sign_in_with_apple.is_some() {
114                        return Err(annotate(figment::Error::custom(
115                            "Unexpected field `sign_in_with_apple` for the selected authentication method",
116                        )).into());
117                    }
118                }
119            }
120
121            if matches!(
122                provider.claims_imports.localpart.on_conflict,
123                OnConflict::Add
124            ) && !matches!(
125                provider.claims_imports.localpart.action,
126                ImportAction::Force | ImportAction::Require
127            ) {
128                return Err(annotate(figment::Error::custom(
129                    "The field `action` must be either `force` or `require` when `on_conflict` is set to `add`",
130                )).into());
131            }
132        }
133
134        Ok(())
135    }
136}
137
138/// The response mode we ask the provider to use for the callback
139#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
140#[serde(rename_all = "snake_case")]
141pub enum ResponseMode {
142    /// `query`: The provider will send the response as a query string in the
143    /// URL search parameters
144    Query,
145
146    /// `form_post`: The provider will send the response as a POST request with
147    /// the response parameters in the request body
148    ///
149    /// <https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html>
150    FormPost,
151}
152
153/// Authentication methods used against the OAuth 2.0 provider
154#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
155#[serde(rename_all = "snake_case")]
156pub enum TokenAuthMethod {
157    /// `none`: No authentication
158    None,
159
160    /// `client_secret_basic`: `client_id` and `client_secret` used as basic
161    /// authorization credentials
162    ClientSecretBasic,
163
164    /// `client_secret_post`: `client_id` and `client_secret` sent in the
165    /// request body
166    ClientSecretPost,
167
168    /// `client_secret_jwt`: a `client_assertion` sent in the request body and
169    /// signed using the `client_secret`
170    ClientSecretJwt,
171
172    /// `private_key_jwt`: a `client_assertion` sent in the request body and
173    /// signed by an asymmetric key
174    PrivateKeyJwt,
175
176    /// `sign_in_with_apple`: a special method for Signin with Apple
177    SignInWithApple,
178}
179
180/// How to handle a claim
181#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
182#[serde(rename_all = "lowercase")]
183pub enum ImportAction {
184    /// Ignore the claim
185    #[default]
186    Ignore,
187
188    /// Suggest the claim value, but allow the user to change it
189    Suggest,
190
191    /// Force the claim value, but don't fail if it is missing
192    Force,
193
194    /// Force the claim value, and fail if it is missing
195    Require,
196}
197
198impl ImportAction {
199    #[allow(clippy::trivially_copy_pass_by_ref)]
200    const fn is_default(&self) -> bool {
201        matches!(self, ImportAction::Ignore)
202    }
203}
204
205/// How to handle an existing localpart claim
206#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
207#[serde(rename_all = "lowercase")]
208pub enum OnConflict {
209    /// Fails the sso login on conflict
210    #[default]
211    Fail,
212
213    /// Adds the oauth identity link, regardless of whether there is an existing
214    /// link or not
215    Add,
216}
217
218impl OnConflict {
219    #[allow(clippy::trivially_copy_pass_by_ref)]
220    const fn is_default(&self) -> bool {
221        matches!(self, OnConflict::Fail)
222    }
223}
224
225/// What should be done for the subject attribute
226#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
227pub struct SubjectImportPreference {
228    /// The Jinja2 template to use for the subject attribute
229    ///
230    /// If not provided, the default template is `{{ user.sub }}`
231    #[serde(default, skip_serializing_if = "Option::is_none")]
232    pub template: Option<String>,
233}
234
235impl SubjectImportPreference {
236    const fn is_default(&self) -> bool {
237        self.template.is_none()
238    }
239}
240
241/// What should be done for the localpart attribute
242#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
243pub struct LocalpartImportPreference {
244    /// How to handle the attribute
245    #[serde(default, skip_serializing_if = "ImportAction::is_default")]
246    pub action: ImportAction,
247
248    /// The Jinja2 template to use for the localpart attribute
249    ///
250    /// If not provided, the default template is `{{ user.preferred_username }}`
251    #[serde(default, skip_serializing_if = "Option::is_none")]
252    pub template: Option<String>,
253
254    /// How to handle conflicts on the claim, default value is `Fail`
255    #[serde(default, skip_serializing_if = "OnConflict::is_default")]
256    pub on_conflict: OnConflict,
257}
258
259impl LocalpartImportPreference {
260    const fn is_default(&self) -> bool {
261        self.action.is_default() && self.template.is_none()
262    }
263}
264
265/// What should be done for the displayname attribute
266#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
267pub struct DisplaynameImportPreference {
268    /// How to handle the attribute
269    #[serde(default, skip_serializing_if = "ImportAction::is_default")]
270    pub action: ImportAction,
271
272    /// The Jinja2 template to use for the displayname attribute
273    ///
274    /// If not provided, the default template is `{{ user.name }}`
275    #[serde(default, skip_serializing_if = "Option::is_none")]
276    pub template: Option<String>,
277}
278
279impl DisplaynameImportPreference {
280    const fn is_default(&self) -> bool {
281        self.action.is_default() && self.template.is_none()
282    }
283}
284
285/// What should be done with the email attribute
286#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
287pub struct EmailImportPreference {
288    /// How to handle the claim
289    #[serde(default, skip_serializing_if = "ImportAction::is_default")]
290    pub action: ImportAction,
291
292    /// The Jinja2 template to use for the email address attribute
293    ///
294    /// If not provided, the default template is `{{ user.email }}`
295    #[serde(default, skip_serializing_if = "Option::is_none")]
296    pub template: Option<String>,
297}
298
299impl EmailImportPreference {
300    const fn is_default(&self) -> bool {
301        self.action.is_default() && self.template.is_none()
302    }
303}
304
305/// What should be done for the account name attribute
306#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
307pub struct AccountNameImportPreference {
308    /// The Jinja2 template to use for the account name. This name is only used
309    /// for display purposes.
310    ///
311    /// If not provided, it will be ignored.
312    #[serde(default, skip_serializing_if = "Option::is_none")]
313    pub template: Option<String>,
314}
315
316impl AccountNameImportPreference {
317    const fn is_default(&self) -> bool {
318        self.template.is_none()
319    }
320}
321
322/// How claims should be imported
323#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)]
324pub struct ClaimsImports {
325    /// How to determine the subject of the user
326    #[serde(default, skip_serializing_if = "SubjectImportPreference::is_default")]
327    pub subject: SubjectImportPreference,
328
329    /// Import the localpart of the MXID
330    #[serde(default, skip_serializing_if = "LocalpartImportPreference::is_default")]
331    pub localpart: LocalpartImportPreference,
332
333    /// Import the displayname of the user.
334    #[serde(
335        default,
336        skip_serializing_if = "DisplaynameImportPreference::is_default"
337    )]
338    pub displayname: DisplaynameImportPreference,
339
340    /// Import the email address of the user based on the `email` and
341    /// `email_verified` claims
342    #[serde(default, skip_serializing_if = "EmailImportPreference::is_default")]
343    pub email: EmailImportPreference,
344
345    /// Set a human-readable name for the upstream account for display purposes
346    #[serde(
347        default,
348        skip_serializing_if = "AccountNameImportPreference::is_default"
349    )]
350    pub account_name: AccountNameImportPreference,
351}
352
353impl ClaimsImports {
354    const fn is_default(&self) -> bool {
355        self.subject.is_default()
356            && self.localpart.is_default()
357            && self.displayname.is_default()
358            && self.email.is_default()
359    }
360}
361
362/// How to discover the provider's configuration
363#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
364#[serde(rename_all = "snake_case")]
365pub enum DiscoveryMode {
366    /// Use OIDC discovery with strict metadata verification
367    #[default]
368    Oidc,
369
370    /// Use OIDC discovery with relaxed metadata verification
371    Insecure,
372
373    /// Use a static configuration
374    Disabled,
375}
376
377impl DiscoveryMode {
378    #[allow(clippy::trivially_copy_pass_by_ref)]
379    const fn is_default(&self) -> bool {
380        matches!(self, DiscoveryMode::Oidc)
381    }
382}
383
384/// Whether to use proof key for code exchange (PKCE) when requesting and
385/// exchanging the token.
386#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
387#[serde(rename_all = "snake_case")]
388pub enum PkceMethod {
389    /// Use PKCE if the provider supports it
390    ///
391    /// Defaults to no PKCE if provider discovery is disabled
392    #[default]
393    Auto,
394
395    /// Always use PKCE with the S256 challenge method
396    Always,
397
398    /// Never use PKCE
399    Never,
400}
401
402impl PkceMethod {
403    #[allow(clippy::trivially_copy_pass_by_ref)]
404    const fn is_default(&self) -> bool {
405        matches!(self, PkceMethod::Auto)
406    }
407}
408
409fn default_true() -> bool {
410    true
411}
412
413#[allow(clippy::trivially_copy_pass_by_ref)]
414fn is_default_true(value: &bool) -> bool {
415    *value
416}
417
418#[allow(clippy::ref_option)]
419fn is_signed_response_alg_default(signed_response_alg: &JsonWebSignatureAlg) -> bool {
420    *signed_response_alg == signed_response_alg_default()
421}
422
423#[allow(clippy::unnecessary_wraps)]
424fn signed_response_alg_default() -> JsonWebSignatureAlg {
425    JsonWebSignatureAlg::Rs256
426}
427
428#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
429pub struct SignInWithApple {
430    /// The private key file used to sign the `id_token`
431    #[serde(skip_serializing_if = "Option::is_none")]
432    #[schemars(with = "Option<String>")]
433    pub private_key_file: Option<Utf8PathBuf>,
434
435    /// The private key used to sign the `id_token`
436    #[serde(skip_serializing_if = "Option::is_none")]
437    pub private_key: Option<String>,
438
439    /// The Team ID of the Apple Developer Portal
440    pub team_id: String,
441
442    /// The key ID of the Apple Developer Portal
443    pub key_id: String,
444}
445
446fn default_scope() -> String {
447    "openid".to_owned()
448}
449
450fn is_default_scope(scope: &str) -> bool {
451    scope == default_scope()
452}
453
454/// What to do when receiving an OIDC Backchannel logout request.
455#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, Default)]
456#[serde(rename_all = "snake_case")]
457pub enum OnBackchannelLogout {
458    /// Do nothing
459    #[default]
460    DoNothing,
461
462    /// Only log out the MAS 'browser session' started by this OIDC session
463    LogoutBrowserOnly,
464
465    /// Log out all sessions started by this OIDC session, including MAS
466    /// 'browser sessions' and client sessions
467    LogoutAll,
468}
469
470impl OnBackchannelLogout {
471    #[allow(clippy::trivially_copy_pass_by_ref)]
472    const fn is_default(&self) -> bool {
473        matches!(self, OnBackchannelLogout::DoNothing)
474    }
475}
476
477/// Configuration for one upstream OAuth 2 provider.
478#[skip_serializing_none]
479#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
480pub struct Provider {
481    /// Whether this provider is enabled.
482    ///
483    /// Defaults to `true`
484    #[serde(default = "default_true", skip_serializing_if = "is_default_true")]
485    pub enabled: bool,
486
487    /// An internal unique identifier for this provider
488    #[schemars(
489        with = "String",
490        regex(pattern = r"^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$"),
491        description = "A ULID as per https://github.com/ulid/spec"
492    )]
493    pub id: Ulid,
494
495    /// The ID of the provider that was used by Synapse.
496    /// In order to perform a Synapse-to-MAS migration, this must be specified.
497    ///
498    /// ## For providers that used OAuth 2.0 or OpenID Connect in Synapse
499    ///
500    /// ### For `oidc_providers`:
501    /// This should be specified as `oidc-` followed by the ID that was
502    /// configured as `idp_id` in one of the `oidc_providers` in the Synapse
503    /// configuration.
504    /// For example, if Synapse's configuration contained `idp_id: wombat` for
505    /// this provider, then specify `oidc-wombat` here.
506    ///
507    /// ### For `oidc_config` (legacy):
508    /// Specify `oidc` here.
509    #[serde(skip_serializing_if = "Option::is_none")]
510    pub synapse_idp_id: Option<String>,
511
512    /// The OIDC issuer URL
513    ///
514    /// This is required if OIDC discovery is enabled (which is the default)
515    #[serde(skip_serializing_if = "Option::is_none")]
516    pub issuer: Option<String>,
517
518    /// A human-readable name for the provider, that will be shown to users
519    #[serde(skip_serializing_if = "Option::is_none")]
520    pub human_name: Option<String>,
521
522    /// A brand identifier used to customise the UI, e.g. `apple`, `google`,
523    /// `github`, etc.
524    ///
525    /// Values supported by the default template are:
526    ///
527    ///  - `apple`
528    ///  - `google`
529    ///  - `facebook`
530    ///  - `github`
531    ///  - `gitlab`
532    ///  - `twitter`
533    ///  - `discord`
534    #[serde(skip_serializing_if = "Option::is_none")]
535    pub brand_name: Option<String>,
536
537    /// The client ID to use when authenticating with the provider
538    pub client_id: String,
539
540    /// The client secret to use when authenticating with the provider
541    ///
542    /// Used by the `client_secret_basic`, `client_secret_post`, and
543    /// `client_secret_jwt` methods
544    #[serde(skip_serializing_if = "Option::is_none")]
545    pub client_secret: Option<String>,
546
547    /// The method to authenticate the client with the provider
548    pub token_endpoint_auth_method: TokenAuthMethod,
549
550    /// Additional parameters for the `sign_in_with_apple` method
551    #[serde(skip_serializing_if = "Option::is_none")]
552    pub sign_in_with_apple: Option<SignInWithApple>,
553
554    /// The JWS algorithm to use when authenticating the client with the
555    /// provider
556    ///
557    /// Used by the `client_secret_jwt` and `private_key_jwt` methods
558    #[serde(skip_serializing_if = "Option::is_none")]
559    pub token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
560
561    /// Expected signature for the JWT payload returned by the token
562    /// authentication endpoint.
563    ///
564    /// Defaults to `RS256`.
565    #[serde(
566        default = "signed_response_alg_default",
567        skip_serializing_if = "is_signed_response_alg_default"
568    )]
569    pub id_token_signed_response_alg: JsonWebSignatureAlg,
570
571    /// The scopes to request from the provider
572    ///
573    /// Defaults to `openid`.
574    #[serde(default = "default_scope", skip_serializing_if = "is_default_scope")]
575    pub scope: String,
576
577    /// How to discover the provider's configuration
578    ///
579    /// Defaults to `oidc`, which uses OIDC discovery with strict metadata
580    /// verification
581    #[serde(default, skip_serializing_if = "DiscoveryMode::is_default")]
582    pub discovery_mode: DiscoveryMode,
583
584    /// Whether to use proof key for code exchange (PKCE) when requesting and
585    /// exchanging the token.
586    ///
587    /// Defaults to `auto`, which uses PKCE if the provider supports it.
588    #[serde(default, skip_serializing_if = "PkceMethod::is_default")]
589    pub pkce_method: PkceMethod,
590
591    /// Whether to fetch the user profile from the userinfo endpoint,
592    /// or to rely on the data returned in the `id_token` from the
593    /// `token_endpoint`.
594    ///
595    /// Defaults to `false`.
596    #[serde(default)]
597    pub fetch_userinfo: bool,
598
599    /// Expected signature for the JWT payload returned by the userinfo
600    /// endpoint.
601    ///
602    /// If not specified, the response is expected to be an unsigned JSON
603    /// payload.
604    #[serde(skip_serializing_if = "Option::is_none")]
605    pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
606
607    /// The URL to use for the provider's authorization endpoint
608    ///
609    /// Defaults to the `authorization_endpoint` provided through discovery
610    #[serde(skip_serializing_if = "Option::is_none")]
611    pub authorization_endpoint: Option<Url>,
612
613    /// The URL to use for the provider's userinfo endpoint
614    ///
615    /// Defaults to the `userinfo_endpoint` provided through discovery
616    #[serde(skip_serializing_if = "Option::is_none")]
617    pub userinfo_endpoint: Option<Url>,
618
619    /// The URL to use for the provider's token endpoint
620    ///
621    /// Defaults to the `token_endpoint` provided through discovery
622    #[serde(skip_serializing_if = "Option::is_none")]
623    pub token_endpoint: Option<Url>,
624
625    /// The URL to use for getting the provider's public keys
626    ///
627    /// Defaults to the `jwks_uri` provided through discovery
628    #[serde(skip_serializing_if = "Option::is_none")]
629    pub jwks_uri: Option<Url>,
630
631    /// The response mode we ask the provider to use for the callback
632    #[serde(skip_serializing_if = "Option::is_none")]
633    pub response_mode: Option<ResponseMode>,
634
635    /// How claims should be imported from the `id_token` provided by the
636    /// provider
637    #[serde(default, skip_serializing_if = "ClaimsImports::is_default")]
638    pub claims_imports: ClaimsImports,
639
640    /// Additional parameters to include in the authorization request
641    ///
642    /// Orders of the keys are not preserved.
643    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
644    pub additional_authorization_parameters: BTreeMap<String, String>,
645
646    /// Whether the `login_hint` should be forwarded to the provider in the
647    /// authorization request.
648    ///
649    /// Defaults to `false`.
650    #[serde(default)]
651    pub forward_login_hint: bool,
652
653    /// What to do when receiving an OIDC Backchannel logout request.
654    ///
655    /// Defaults to "do_nothing".
656    #[serde(default, skip_serializing_if = "OnBackchannelLogout::is_default")]
657    pub on_backchannel_logout: OnBackchannelLogout,
658}