mas_matrix/
mock.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::collections::{HashMap, HashSet};
8
9use anyhow::Context;
10use async_trait::async_trait;
11use tokio::sync::RwLock;
12
13use crate::{MatrixUser, ProvisionRequest};
14
15struct MockUser {
16    sub: String,
17    avatar_url: Option<String>,
18    displayname: Option<String>,
19    devices: HashSet<String>,
20    emails: Option<Vec<String>>,
21    cross_signing_reset_allowed: bool,
22    deactivated: bool,
23}
24
25/// A mock implementation of a [`HomeserverConnection`], which never fails and
26/// doesn't do anything.
27pub struct HomeserverConnection {
28    homeserver: String,
29    users: RwLock<HashMap<String, MockUser>>,
30    reserved_localparts: RwLock<HashSet<&'static str>>,
31}
32
33impl HomeserverConnection {
34    /// Create a new mock connection.
35    pub fn new<H>(homeserver: H) -> Self
36    where
37        H: Into<String>,
38    {
39        Self {
40            homeserver: homeserver.into(),
41            users: RwLock::new(HashMap::new()),
42            reserved_localparts: RwLock::new(HashSet::new()),
43        }
44    }
45
46    pub async fn reserve_localpart(&self, localpart: &'static str) {
47        self.reserved_localparts.write().await.insert(localpart);
48    }
49}
50
51#[async_trait]
52impl crate::HomeserverConnection for HomeserverConnection {
53    fn homeserver(&self) -> &str {
54        &self.homeserver
55    }
56
57    async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
58        let mxid = self.mxid(localpart);
59        let users = self.users.read().await;
60        let user = users.get(&mxid).context("User not found")?;
61        Ok(MatrixUser {
62            displayname: user.displayname.clone(),
63            avatar_url: user.avatar_url.clone(),
64            deactivated: user.deactivated,
65        })
66    }
67
68    async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, anyhow::Error> {
69        let mut users = self.users.write().await;
70        let mxid = self.mxid(request.localpart());
71        let inserted = !users.contains_key(&mxid);
72        let user = users.entry(mxid).or_insert(MockUser {
73            sub: request.sub().to_owned(),
74            avatar_url: None,
75            displayname: None,
76            devices: HashSet::new(),
77            emails: None,
78            cross_signing_reset_allowed: false,
79            deactivated: false,
80        });
81
82        anyhow::ensure!(
83            user.sub == request.sub(),
84            "User already provisioned with different sub"
85        );
86
87        request.on_emails(|emails| {
88            user.emails = emails.map(ToOwned::to_owned);
89        });
90
91        request.on_displayname(|displayname| {
92            user.displayname = displayname.map(ToOwned::to_owned);
93        });
94
95        request.on_avatar_url(|avatar_url| {
96            user.avatar_url = avatar_url.map(ToOwned::to_owned);
97        });
98
99        Ok(inserted)
100    }
101
102    async fn is_localpart_available(&self, localpart: &str) -> Result<bool, anyhow::Error> {
103        if self.reserved_localparts.read().await.contains(localpart) {
104            return Ok(false);
105        }
106
107        let mxid = self.mxid(localpart);
108        let users = self.users.read().await;
109        Ok(!users.contains_key(&mxid))
110    }
111
112    async fn upsert_device(
113        &self,
114        localpart: &str,
115        device_id: &str,
116        _initial_display_name: Option<&str>,
117    ) -> Result<(), anyhow::Error> {
118        let mxid = self.mxid(localpart);
119        let mut users = self.users.write().await;
120        let user = users.get_mut(&mxid).context("User not found")?;
121        user.devices.insert(device_id.to_owned());
122        Ok(())
123    }
124
125    async fn update_device_display_name(
126        &self,
127        localpart: &str,
128        device_id: &str,
129        _display_name: &str,
130    ) -> Result<(), anyhow::Error> {
131        let mxid = self.mxid(localpart);
132        let mut users = self.users.write().await;
133        let user = users.get_mut(&mxid).context("User not found")?;
134        user.devices.get(device_id).context("Device not found")?;
135        Ok(())
136    }
137
138    async fn delete_device(&self, localpart: &str, device_id: &str) -> Result<(), anyhow::Error> {
139        let mxid = self.mxid(localpart);
140        let mut users = self.users.write().await;
141        let user = users.get_mut(&mxid).context("User not found")?;
142        user.devices.remove(device_id);
143        Ok(())
144    }
145
146    async fn sync_devices(
147        &self,
148        localpart: &str,
149        devices: HashSet<String>,
150    ) -> Result<(), anyhow::Error> {
151        let mxid = self.mxid(localpart);
152        let mut users = self.users.write().await;
153        let user = users.get_mut(&mxid).context("User not found")?;
154        user.devices = devices;
155        Ok(())
156    }
157
158    async fn delete_user(&self, localpart: &str, erase: bool) -> Result<(), anyhow::Error> {
159        let mxid = self.mxid(localpart);
160        let mut users = self.users.write().await;
161        let user = users.get_mut(&mxid).context("User not found")?;
162        user.devices.clear();
163        user.emails = None;
164        user.deactivated = true;
165        if erase {
166            user.avatar_url = None;
167            user.displayname = None;
168        }
169
170        Ok(())
171    }
172
173    async fn reactivate_user(&self, localpart: &str) -> Result<(), anyhow::Error> {
174        let mxid = self.mxid(localpart);
175        let mut users = self.users.write().await;
176        let user = users.get_mut(&mxid).context("User not found")?;
177        user.deactivated = false;
178
179        Ok(())
180    }
181
182    async fn set_displayname(
183        &self,
184        localpart: &str,
185        displayname: &str,
186    ) -> Result<(), anyhow::Error> {
187        let mxid = self.mxid(localpart);
188        let mut users = self.users.write().await;
189        let user = users.get_mut(&mxid).context("User not found")?;
190        user.displayname = Some(displayname.to_owned());
191        Ok(())
192    }
193
194    async fn unset_displayname(&self, localpart: &str) -> Result<(), anyhow::Error> {
195        let mxid = self.mxid(localpart);
196        let mut users = self.users.write().await;
197        let user = users.get_mut(&mxid).context("User not found")?;
198        user.displayname = None;
199        Ok(())
200    }
201
202    async fn allow_cross_signing_reset(&self, localpart: &str) -> Result<(), anyhow::Error> {
203        let mxid = self.mxid(localpart);
204        let mut users = self.users.write().await;
205        let user = users.get_mut(&mxid).context("User not found")?;
206        user.cross_signing_reset_allowed = true;
207        Ok(())
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::HomeserverConnection as _;
215
216    #[tokio::test]
217    async fn test_mock_connection() {
218        let conn = HomeserverConnection::new("example.org");
219
220        let mxid = "@test:example.org";
221        let device = "test";
222        assert_eq!(conn.homeserver(), "example.org");
223        assert_eq!(conn.mxid("test"), mxid);
224
225        assert!(conn.query_user("test").await.is_err());
226        assert!(conn.upsert_device("test", device, None).await.is_err());
227        assert!(conn.delete_device("test", device).await.is_err());
228
229        let request = ProvisionRequest::new("test", "test")
230            .set_displayname("Test User".into())
231            .set_avatar_url("mxc://example.org/1234567890".into())
232            .set_emails(vec!["test@example.org".to_owned()]);
233
234        let inserted = conn.provision_user(&request).await.unwrap();
235        assert!(inserted);
236
237        let user = conn.query_user("test").await.unwrap();
238        assert_eq!(user.displayname, Some("Test User".into()));
239        assert_eq!(user.avatar_url, Some("mxc://example.org/1234567890".into()));
240
241        // Set the displayname again
242        assert!(conn.set_displayname("test", "John").await.is_ok());
243
244        let user = conn.query_user("test").await.unwrap();
245        assert_eq!(user.displayname, Some("John".into()));
246
247        // Unset the displayname
248        assert!(conn.unset_displayname("test").await.is_ok());
249
250        let user = conn.query_user("test").await.unwrap();
251        assert_eq!(user.displayname, None);
252
253        // Deleting a non-existent device should not fail
254        assert!(conn.delete_device("test", device).await.is_ok());
255
256        // Create the device
257        assert!(conn.upsert_device("test", device, None).await.is_ok());
258        // Create the same device again
259        assert!(conn.upsert_device("test", device, None).await.is_ok());
260
261        // XXX: there is no API to query devices yet in the trait
262        // Delete the device
263        assert!(conn.delete_device("test", device).await.is_ok());
264
265        // The user we just created should be not available
266        assert!(!conn.is_localpart_available("test").await.unwrap());
267        // But another user should be
268        assert!(conn.is_localpart_available("alice").await.unwrap());
269
270        // Reserve the localpart, it should not be available anymore
271        conn.reserve_localpart("alice").await;
272        assert!(!conn.is_localpart_available("alice").await.unwrap());
273    }
274}