mas_data_model/upstream_oauth2/
provider.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use chrono::{DateTime, Utc};
8use mas_iana::jose::JsonWebSignatureAlg;
9use oauth2_types::scope::Scope;
10use serde::{Deserialize, Serialize};
11use thiserror::Error;
12use ulid::Ulid;
13use url::Url;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
16#[serde(rename_all = "lowercase")]
17pub enum DiscoveryMode {
18    /// Use OIDC discovery to fetch and verify the provider metadata
19    #[default]
20    Oidc,
21
22    /// Use OIDC discovery to fetch the provider metadata, but don't verify it
23    Insecure,
24
25    /// Don't fetch the provider metadata
26    Disabled,
27}
28
29impl DiscoveryMode {
30    /// Returns `true` if discovery is disabled
31    #[must_use]
32    pub fn is_disabled(&self) -> bool {
33        matches!(self, DiscoveryMode::Disabled)
34    }
35}
36
37#[derive(Debug, Clone, Error)]
38#[error("Invalid discovery mode {0:?}")]
39pub struct InvalidDiscoveryModeError(String);
40
41impl std::str::FromStr for DiscoveryMode {
42    type Err = InvalidDiscoveryModeError;
43
44    fn from_str(s: &str) -> Result<Self, Self::Err> {
45        match s {
46            "oidc" => Ok(Self::Oidc),
47            "insecure" => Ok(Self::Insecure),
48            "disabled" => Ok(Self::Disabled),
49            s => Err(InvalidDiscoveryModeError(s.to_owned())),
50        }
51    }
52}
53
54impl DiscoveryMode {
55    #[must_use]
56    pub fn as_str(self) -> &'static str {
57        match self {
58            Self::Oidc => "oidc",
59            Self::Insecure => "insecure",
60            Self::Disabled => "disabled",
61        }
62    }
63}
64
65impl std::fmt::Display for DiscoveryMode {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        f.write_str(self.as_str())
68    }
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
72#[serde(rename_all = "lowercase")]
73pub enum PkceMode {
74    /// Use PKCE if the provider supports it
75    #[default]
76    Auto,
77
78    /// Always use PKCE with the S256 method
79    S256,
80
81    /// Don't use PKCE
82    Disabled,
83}
84
85#[derive(Debug, Clone, Error)]
86#[error("Invalid PKCE mode {0:?}")]
87pub struct InvalidPkceModeError(String);
88
89impl std::str::FromStr for PkceMode {
90    type Err = InvalidPkceModeError;
91
92    fn from_str(s: &str) -> Result<Self, Self::Err> {
93        match s {
94            "auto" => Ok(Self::Auto),
95            "s256" => Ok(Self::S256),
96            "disabled" => Ok(Self::Disabled),
97            s => Err(InvalidPkceModeError(s.to_owned())),
98        }
99    }
100}
101
102impl PkceMode {
103    #[must_use]
104    pub fn as_str(self) -> &'static str {
105        match self {
106            Self::Auto => "auto",
107            Self::S256 => "s256",
108            Self::Disabled => "disabled",
109        }
110    }
111}
112
113impl std::fmt::Display for PkceMode {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.write_str(self.as_str())
116    }
117}
118
119#[derive(Debug, Clone, Error)]
120#[error("Invalid response mode {0:?}")]
121pub struct InvalidResponseModeError(String);
122
123#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
124#[serde(rename_all = "snake_case")]
125pub enum ResponseMode {
126    #[default]
127    Query,
128    FormPost,
129}
130
131impl From<ResponseMode> for oauth2_types::requests::ResponseMode {
132    fn from(value: ResponseMode) -> Self {
133        match value {
134            ResponseMode::Query => oauth2_types::requests::ResponseMode::Query,
135            ResponseMode::FormPost => oauth2_types::requests::ResponseMode::FormPost,
136        }
137    }
138}
139
140impl ResponseMode {
141    #[must_use]
142    pub fn as_str(self) -> &'static str {
143        match self {
144            Self::Query => "query",
145            Self::FormPost => "form_post",
146        }
147    }
148}
149
150impl std::fmt::Display for ResponseMode {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        f.write_str(self.as_str())
153    }
154}
155
156impl std::str::FromStr for ResponseMode {
157    type Err = InvalidResponseModeError;
158
159    fn from_str(s: &str) -> Result<Self, Self::Err> {
160        match s {
161            "query" => Ok(ResponseMode::Query),
162            "form_post" => Ok(ResponseMode::FormPost),
163            s => Err(InvalidResponseModeError(s.to_owned())),
164        }
165    }
166}
167
168#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
169#[serde(rename_all = "snake_case")]
170pub enum TokenAuthMethod {
171    None,
172    ClientSecretBasic,
173    ClientSecretPost,
174    ClientSecretJwt,
175    PrivateKeyJwt,
176    SignInWithApple,
177}
178
179impl TokenAuthMethod {
180    #[must_use]
181    pub fn as_str(self) -> &'static str {
182        match self {
183            Self::None => "none",
184            Self::ClientSecretBasic => "client_secret_basic",
185            Self::ClientSecretPost => "client_secret_post",
186            Self::ClientSecretJwt => "client_secret_jwt",
187            Self::PrivateKeyJwt => "private_key_jwt",
188            Self::SignInWithApple => "sign_in_with_apple",
189        }
190    }
191}
192
193impl std::fmt::Display for TokenAuthMethod {
194    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
195        f.write_str(self.as_str())
196    }
197}
198
199impl std::str::FromStr for TokenAuthMethod {
200    type Err = InvalidUpstreamOAuth2TokenAuthMethod;
201
202    fn from_str(s: &str) -> Result<Self, Self::Err> {
203        match s {
204            "none" => Ok(Self::None),
205            "client_secret_post" => Ok(Self::ClientSecretPost),
206            "client_secret_basic" => Ok(Self::ClientSecretBasic),
207            "client_secret_jwt" => Ok(Self::ClientSecretJwt),
208            "private_key_jwt" => Ok(Self::PrivateKeyJwt),
209            "sign_in_with_apple" => Ok(Self::SignInWithApple),
210            s => Err(InvalidUpstreamOAuth2TokenAuthMethod(s.to_owned())),
211        }
212    }
213}
214
215#[derive(Debug, Clone, Error)]
216#[error("Invalid upstream OAuth 2.0 token auth method: {0}")]
217pub struct InvalidUpstreamOAuth2TokenAuthMethod(String);
218
219#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
220#[serde(rename_all = "snake_case")]
221pub enum OnBackchannelLogout {
222    DoNothing,
223    LogoutBrowserOnly,
224    LogoutAll,
225}
226
227impl OnBackchannelLogout {
228    #[must_use]
229    pub fn as_str(self) -> &'static str {
230        match self {
231            Self::DoNothing => "do_nothing",
232            Self::LogoutBrowserOnly => "logout_browser_only",
233            Self::LogoutAll => "logout_all",
234        }
235    }
236}
237
238impl std::fmt::Display for OnBackchannelLogout {
239    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240        f.write_str(self.as_str())
241    }
242}
243
244impl std::str::FromStr for OnBackchannelLogout {
245    type Err = InvalidUpstreamOAuth2OnBackchannelLogout;
246
247    fn from_str(s: &str) -> Result<Self, Self::Err> {
248        match s {
249            "do_nothing" => Ok(Self::DoNothing),
250            "logout_browser_only" => Ok(Self::LogoutBrowserOnly),
251            "logout_all" => Ok(Self::LogoutAll),
252            s => Err(InvalidUpstreamOAuth2OnBackchannelLogout(s.to_owned())),
253        }
254    }
255}
256
257#[derive(Debug, Clone, Error)]
258#[error("Invalid upstream OAuth 2.0 'on backchannel logout': {0}")]
259pub struct InvalidUpstreamOAuth2OnBackchannelLogout(String);
260
261#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
262pub struct UpstreamOAuthProvider {
263    pub id: Ulid,
264    pub issuer: Option<String>,
265    pub human_name: Option<String>,
266    pub brand_name: Option<String>,
267    pub discovery_mode: DiscoveryMode,
268    pub pkce_mode: PkceMode,
269    pub jwks_uri_override: Option<Url>,
270    pub authorization_endpoint_override: Option<Url>,
271    pub scope: Scope,
272    pub token_endpoint_override: Option<Url>,
273    pub userinfo_endpoint_override: Option<Url>,
274    pub fetch_userinfo: bool,
275    pub userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
276    pub client_id: String,
277    pub encrypted_client_secret: Option<String>,
278    pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
279    pub token_endpoint_auth_method: TokenAuthMethod,
280    pub id_token_signed_response_alg: JsonWebSignatureAlg,
281    pub response_mode: Option<ResponseMode>,
282    pub created_at: DateTime<Utc>,
283    pub disabled_at: Option<DateTime<Utc>>,
284    pub claims_imports: ClaimsImports,
285    pub additional_authorization_parameters: Vec<(String, String)>,
286    pub forward_login_hint: bool,
287    pub on_backchannel_logout: OnBackchannelLogout,
288}
289
290impl PartialOrd for UpstreamOAuthProvider {
291    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
292        Some(self.id.cmp(&other.id))
293    }
294}
295
296impl Ord for UpstreamOAuthProvider {
297    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
298        self.id.cmp(&other.id)
299    }
300}
301
302impl UpstreamOAuthProvider {
303    /// Returns `true` if the provider is enabled
304    #[must_use]
305    pub const fn enabled(&self) -> bool {
306        self.disabled_at.is_none()
307    }
308}
309
310#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
311pub struct ClaimsImports {
312    #[serde(default)]
313    pub subject: SubjectPreference,
314
315    #[serde(default)]
316    pub localpart: LocalpartPreference,
317
318    #[serde(default)]
319    pub displayname: ImportPreference,
320
321    #[serde(default)]
322    pub email: ImportPreference,
323
324    #[serde(default)]
325    pub account_name: SubjectPreference,
326}
327
328// XXX: this should have another name
329#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
330pub struct SubjectPreference {
331    #[serde(default)]
332    pub template: Option<String>,
333}
334
335#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
336pub struct LocalpartPreference {
337    #[serde(default)]
338    pub action: ImportAction,
339
340    #[serde(default)]
341    pub template: Option<String>,
342
343    #[serde(default)]
344    pub on_conflict: OnConflict,
345}
346
347impl std::ops::Deref for LocalpartPreference {
348    type Target = ImportAction;
349
350    fn deref(&self) -> &Self::Target {
351        &self.action
352    }
353}
354
355#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
356pub struct ImportPreference {
357    #[serde(default)]
358    pub action: ImportAction,
359
360    #[serde(default)]
361    pub template: Option<String>,
362}
363
364impl std::ops::Deref for ImportPreference {
365    type Target = ImportAction;
366
367    fn deref(&self) -> &Self::Target {
368        &self.action
369    }
370}
371
372#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
373#[serde(rename_all = "lowercase")]
374pub enum ImportAction {
375    /// Ignore the claim
376    #[default]
377    Ignore,
378
379    /// Suggest the claim value, but allow the user to change it
380    Suggest,
381
382    /// Force the claim value, but don't fail if it is missing
383    Force,
384
385    /// Force the claim value, and fail if it is missing
386    Require,
387}
388
389impl ImportAction {
390    #[must_use]
391    pub fn is_forced_or_required(&self) -> bool {
392        matches!(self, Self::Force | Self::Require)
393    }
394
395    #[must_use]
396    pub fn ignore(&self) -> bool {
397        matches!(self, Self::Ignore)
398    }
399
400    #[must_use]
401    pub fn is_required(&self) -> bool {
402        matches!(self, Self::Require)
403    }
404
405    #[must_use]
406    pub fn should_import(&self, user_preference: bool) -> bool {
407        match self {
408            Self::Ignore => false,
409            Self::Suggest => user_preference,
410            Self::Force | Self::Require => true,
411        }
412    }
413}
414
415#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
416#[serde(rename_all = "lowercase")]
417pub enum OnConflict {
418    /// Fails the upstream OAuth 2.0 login
419    #[default]
420    Fail,
421
422    /// Adds the upstream account link, regardless of whether there is an
423    /// existing link or not
424    Add,
425}