1use std::collections::{HashMap, HashSet};
8
9use anyhow::Context;
10use async_trait::async_trait;
11use tokio::sync::RwLock;
12
13use crate::{MatrixUser, ProvisionRequest};
14
15struct MockUser {
16 sub: String,
17 avatar_url: Option<String>,
18 displayname: Option<String>,
19 devices: HashSet<String>,
20 emails: Option<Vec<String>>,
21 cross_signing_reset_allowed: bool,
22 deactivated: bool,
23}
24
25pub struct HomeserverConnection {
28 homeserver: String,
29 users: RwLock<HashMap<String, MockUser>>,
30 reserved_localparts: RwLock<HashSet<&'static str>>,
31}
32
33impl HomeserverConnection {
34 pub fn new<H>(homeserver: H) -> Self
36 where
37 H: Into<String>,
38 {
39 Self {
40 homeserver: homeserver.into(),
41 users: RwLock::new(HashMap::new()),
42 reserved_localparts: RwLock::new(HashSet::new()),
43 }
44 }
45
46 pub async fn reserve_localpart(&self, localpart: &'static str) {
47 self.reserved_localparts.write().await.insert(localpart);
48 }
49}
50
51#[async_trait]
52impl crate::HomeserverConnection for HomeserverConnection {
53 fn homeserver(&self) -> &str {
54 &self.homeserver
55 }
56
57 async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
58 let mxid = self.mxid(localpart);
59 let users = self.users.read().await;
60 let user = users.get(&mxid).context("User not found")?;
61 Ok(MatrixUser {
62 displayname: user.displayname.clone(),
63 avatar_url: user.avatar_url.clone(),
64 deactivated: user.deactivated,
65 })
66 }
67
68 async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, anyhow::Error> {
69 let mut users = self.users.write().await;
70 let mxid = self.mxid(request.localpart());
71 let inserted = !users.contains_key(&mxid);
72 let user = users.entry(mxid).or_insert(MockUser {
73 sub: request.sub().to_owned(),
74 avatar_url: None,
75 displayname: None,
76 devices: HashSet::new(),
77 emails: None,
78 cross_signing_reset_allowed: false,
79 deactivated: false,
80 });
81
82 anyhow::ensure!(
83 user.sub == request.sub(),
84 "User already provisioned with different sub"
85 );
86
87 request.on_emails(|emails| {
88 user.emails = emails.map(ToOwned::to_owned);
89 });
90
91 request.on_displayname(|displayname| {
92 user.displayname = displayname.map(ToOwned::to_owned);
93 });
94
95 request.on_avatar_url(|avatar_url| {
96 user.avatar_url = avatar_url.map(ToOwned::to_owned);
97 });
98
99 Ok(inserted)
100 }
101
102 async fn is_localpart_available(&self, localpart: &str) -> Result<bool, anyhow::Error> {
103 if self.reserved_localparts.read().await.contains(localpart) {
104 return Ok(false);
105 }
106
107 let mxid = self.mxid(localpart);
108 let users = self.users.read().await;
109 Ok(!users.contains_key(&mxid))
110 }
111
112 async fn upsert_device(
113 &self,
114 localpart: &str,
115 device_id: &str,
116 _initial_display_name: Option<&str>,
117 ) -> Result<(), anyhow::Error> {
118 let mxid = self.mxid(localpart);
119 let mut users = self.users.write().await;
120 let user = users.get_mut(&mxid).context("User not found")?;
121 user.devices.insert(device_id.to_owned());
122 Ok(())
123 }
124
125 async fn update_device_display_name(
126 &self,
127 localpart: &str,
128 device_id: &str,
129 _display_name: &str,
130 ) -> Result<(), anyhow::Error> {
131 let mxid = self.mxid(localpart);
132 let mut users = self.users.write().await;
133 let user = users.get_mut(&mxid).context("User not found")?;
134 user.devices.get(device_id).context("Device not found")?;
135 Ok(())
136 }
137
138 async fn delete_device(&self, localpart: &str, device_id: &str) -> Result<(), anyhow::Error> {
139 let mxid = self.mxid(localpart);
140 let mut users = self.users.write().await;
141 let user = users.get_mut(&mxid).context("User not found")?;
142 user.devices.remove(device_id);
143 Ok(())
144 }
145
146 async fn sync_devices(
147 &self,
148 localpart: &str,
149 devices: HashSet<String>,
150 ) -> Result<(), anyhow::Error> {
151 let mxid = self.mxid(localpart);
152 let mut users = self.users.write().await;
153 let user = users.get_mut(&mxid).context("User not found")?;
154 user.devices = devices;
155 Ok(())
156 }
157
158 async fn delete_user(&self, localpart: &str, erase: bool) -> Result<(), anyhow::Error> {
159 let mxid = self.mxid(localpart);
160 let mut users = self.users.write().await;
161 let user = users.get_mut(&mxid).context("User not found")?;
162 user.devices.clear();
163 user.emails = None;
164 user.deactivated = true;
165 if erase {
166 user.avatar_url = None;
167 user.displayname = None;
168 }
169
170 Ok(())
171 }
172
173 async fn reactivate_user(&self, localpart: &str) -> Result<(), anyhow::Error> {
174 let mxid = self.mxid(localpart);
175 let mut users = self.users.write().await;
176 let user = users.get_mut(&mxid).context("User not found")?;
177 user.deactivated = false;
178
179 Ok(())
180 }
181
182 async fn set_displayname(
183 &self,
184 localpart: &str,
185 displayname: &str,
186 ) -> Result<(), anyhow::Error> {
187 let mxid = self.mxid(localpart);
188 let mut users = self.users.write().await;
189 let user = users.get_mut(&mxid).context("User not found")?;
190 user.displayname = Some(displayname.to_owned());
191 Ok(())
192 }
193
194 async fn unset_displayname(&self, localpart: &str) -> Result<(), anyhow::Error> {
195 let mxid = self.mxid(localpart);
196 let mut users = self.users.write().await;
197 let user = users.get_mut(&mxid).context("User not found")?;
198 user.displayname = None;
199 Ok(())
200 }
201
202 async fn allow_cross_signing_reset(&self, localpart: &str) -> Result<(), anyhow::Error> {
203 let mxid = self.mxid(localpart);
204 let mut users = self.users.write().await;
205 let user = users.get_mut(&mxid).context("User not found")?;
206 user.cross_signing_reset_allowed = true;
207 Ok(())
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::HomeserverConnection as _;
215
216 #[tokio::test]
217 async fn test_mock_connection() {
218 let conn = HomeserverConnection::new("example.org");
219
220 let mxid = "@test:example.org";
221 let device = "test";
222 assert_eq!(conn.homeserver(), "example.org");
223 assert_eq!(conn.mxid("test"), mxid);
224
225 assert!(conn.query_user("test").await.is_err());
226 assert!(conn.upsert_device("test", device, None).await.is_err());
227 assert!(conn.delete_device("test", device).await.is_err());
228
229 let request = ProvisionRequest::new("test", "test")
230 .set_displayname("Test User".into())
231 .set_avatar_url("mxc://example.org/1234567890".into())
232 .set_emails(vec!["test@example.org".to_owned()]);
233
234 let inserted = conn.provision_user(&request).await.unwrap();
235 assert!(inserted);
236
237 let user = conn.query_user("test").await.unwrap();
238 assert_eq!(user.displayname, Some("Test User".into()));
239 assert_eq!(user.avatar_url, Some("mxc://example.org/1234567890".into()));
240
241 assert!(conn.set_displayname("test", "John").await.is_ok());
243
244 let user = conn.query_user("test").await.unwrap();
245 assert_eq!(user.displayname, Some("John".into()));
246
247 assert!(conn.unset_displayname("test").await.is_ok());
249
250 let user = conn.query_user("test").await.unwrap();
251 assert_eq!(user.displayname, None);
252
253 assert!(conn.delete_device("test", device).await.is_ok());
255
256 assert!(conn.upsert_device("test", device, None).await.is_ok());
258 assert!(conn.upsert_device("test", device, None).await.is_ok());
260
261 assert!(conn.delete_device("test", device).await.is_ok());
264
265 assert!(!conn.is_localpart_available("test").await.unwrap());
267 assert!(conn.is_localpart_available("alice").await.unwrap());
269
270 conn.reserve_localpart("alice").await;
272 assert!(!conn.is_localpart_available("alice").await.unwrap());
273 }
274}