1use async_trait::async_trait;
11use mas_data_model::User;
12use mas_storage::{
13 Clock,
14 user::{UserFilter, UserRepository},
15};
16use rand::RngCore;
17use sea_query::{Expr, PostgresQueryBuilder, Query};
18use sea_query_binder::SqlxBinder;
19use sqlx::PgConnection;
20use ulid::Ulid;
21use uuid::Uuid;
22
23use crate::{
24 DatabaseError,
25 filter::{Filter, StatementExt},
26 iden::Users,
27 pagination::QueryBuilderExt,
28 tracing::ExecuteExt,
29};
30
31mod email;
32mod password;
33mod recovery;
34mod registration;
35mod registration_token;
36mod session;
37mod terms;
38
39#[cfg(test)]
40mod tests;
41
42pub use self::{
43 email::PgUserEmailRepository, password::PgUserPasswordRepository,
44 recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
45 registration_token::PgUserRegistrationTokenRepository, session::PgBrowserSessionRepository,
46 terms::PgUserTermsRepository,
47};
48
49pub struct PgUserRepository<'c> {
51 conn: &'c mut PgConnection,
52}
53
54impl<'c> PgUserRepository<'c> {
55 pub fn new(conn: &'c mut PgConnection) -> Self {
57 Self { conn }
58 }
59}
60
61mod priv_ {
62 #![allow(missing_docs)]
65
66 use chrono::{DateTime, Utc};
67 use sea_query::enum_def;
68 use uuid::Uuid;
69
70 #[derive(Debug, Clone, sqlx::FromRow)]
71 #[enum_def]
72 pub(super) struct UserLookup {
73 pub(super) user_id: Uuid,
74 pub(super) username: String,
75 pub(super) created_at: DateTime<Utc>,
76 pub(super) locked_at: Option<DateTime<Utc>>,
77 pub(super) deactivated_at: Option<DateTime<Utc>>,
78 pub(super) can_request_admin: bool,
79 }
80}
81
82use priv_::{UserLookup, UserLookupIden};
83
84impl From<UserLookup> for User {
85 fn from(value: UserLookup) -> Self {
86 let id = value.user_id.into();
87 Self {
88 id,
89 username: value.username,
90 sub: id.to_string(),
91 created_at: value.created_at,
92 locked_at: value.locked_at,
93 deactivated_at: value.deactivated_at,
94 can_request_admin: value.can_request_admin,
95 }
96 }
97}
98
99impl Filter for UserFilter<'_> {
100 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
101 sea_query::Condition::all()
102 .add_option(self.state().map(|state| {
103 match state {
104 mas_storage::user::UserState::Deactivated => {
105 Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
106 }
107 mas_storage::user::UserState::Locked => {
108 Expr::col((Users::Table, Users::LockedAt)).is_not_null()
109 }
110 mas_storage::user::UserState::Active => {
111 Expr::col((Users::Table, Users::LockedAt))
112 .is_null()
113 .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
114 }
115 }
116 }))
117 .add_option(self.can_request_admin().map(|can_request_admin| {
118 Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
119 }))
120 }
121}
122
123#[async_trait]
124impl UserRepository for PgUserRepository<'_> {
125 type Error = DatabaseError;
126
127 #[tracing::instrument(
128 name = "db.user.lookup",
129 skip_all,
130 fields(
131 db.query.text,
132 user.id = %id,
133 ),
134 err,
135 )]
136 async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
137 let res = sqlx::query_as!(
138 UserLookup,
139 r#"
140 SELECT user_id
141 , username
142 , created_at
143 , locked_at
144 , deactivated_at
145 , can_request_admin
146 FROM users
147 WHERE user_id = $1
148 "#,
149 Uuid::from(id),
150 )
151 .traced()
152 .fetch_optional(&mut *self.conn)
153 .await?;
154
155 let Some(res) = res else { return Ok(None) };
156
157 Ok(Some(res.into()))
158 }
159
160 #[tracing::instrument(
161 name = "db.user.find_by_username",
162 skip_all,
163 fields(
164 db.query.text,
165 user.username = username,
166 ),
167 err,
168 )]
169 async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
170 let res = sqlx::query_as!(
174 UserLookup,
175 r#"
176 SELECT user_id
177 , username
178 , created_at
179 , locked_at
180 , deactivated_at
181 , can_request_admin
182 FROM users
183 WHERE LOWER(username) = LOWER($1)
184 "#,
185 username,
186 )
187 .traced()
188 .fetch_all(&mut *self.conn)
189 .await?;
190
191 match &res[..] {
192 [user] => Ok(Some(user.clone().into())),
194 [] => Ok(None),
196 list => {
197 if let Some(user) = list.iter().find(|user| user.username == username) {
200 Ok(Some(user.clone().into()))
201 } else {
202 Ok(None)
204 }
205 }
206 }
207 }
208
209 #[tracing::instrument(
210 name = "db.user.add",
211 skip_all,
212 fields(
213 db.query.text,
214 user.username = username,
215 user.id,
216 ),
217 err,
218 )]
219 async fn add(
220 &mut self,
221 rng: &mut (dyn RngCore + Send),
222 clock: &dyn Clock,
223 username: String,
224 ) -> Result<User, Self::Error> {
225 let created_at = clock.now();
226 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
227 tracing::Span::current().record("user.id", tracing::field::display(id));
228
229 let res = sqlx::query!(
230 r#"
231 INSERT INTO users (user_id, username, created_at)
232 VALUES ($1, $2, $3)
233 ON CONFLICT (username) DO NOTHING
234 "#,
235 Uuid::from(id),
236 username,
237 created_at,
238 )
239 .traced()
240 .execute(&mut *self.conn)
241 .await?;
242
243 DatabaseError::ensure_affected_rows(&res, 1)?;
246
247 Ok(User {
248 id,
249 username,
250 sub: id.to_string(),
251 created_at,
252 locked_at: None,
253 deactivated_at: None,
254 can_request_admin: false,
255 })
256 }
257
258 #[tracing::instrument(
259 name = "db.user.exists",
260 skip_all,
261 fields(
262 db.query.text,
263 user.username = username,
264 ),
265 err,
266 )]
267 async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
268 let exists = sqlx::query_scalar!(
269 r#"
270 SELECT EXISTS(
271 SELECT 1 FROM users WHERE LOWER(username) = LOWER($1)
272 ) AS "exists!"
273 "#,
274 username
275 )
276 .traced()
277 .fetch_one(&mut *self.conn)
278 .await?;
279
280 Ok(exists)
281 }
282
283 #[tracing::instrument(
284 name = "db.user.lock",
285 skip_all,
286 fields(
287 db.query.text,
288 %user.id,
289 ),
290 err,
291 )]
292 async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
293 if user.locked_at.is_some() {
294 return Ok(user);
295 }
296
297 let locked_at = clock.now();
298 let res = sqlx::query!(
299 r#"
300 UPDATE users
301 SET locked_at = $1
302 WHERE user_id = $2
303 "#,
304 locked_at,
305 Uuid::from(user.id),
306 )
307 .traced()
308 .execute(&mut *self.conn)
309 .await?;
310
311 DatabaseError::ensure_affected_rows(&res, 1)?;
312
313 user.locked_at = Some(locked_at);
314
315 Ok(user)
316 }
317
318 #[tracing::instrument(
319 name = "db.user.unlock",
320 skip_all,
321 fields(
322 db.query.text,
323 %user.id,
324 ),
325 err,
326 )]
327 async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
328 if user.locked_at.is_none() {
329 return Ok(user);
330 }
331
332 let res = sqlx::query!(
333 r#"
334 UPDATE users
335 SET locked_at = NULL
336 WHERE user_id = $1
337 "#,
338 Uuid::from(user.id),
339 )
340 .traced()
341 .execute(&mut *self.conn)
342 .await?;
343
344 DatabaseError::ensure_affected_rows(&res, 1)?;
345
346 user.locked_at = None;
347
348 Ok(user)
349 }
350
351 #[tracing::instrument(
352 name = "db.user.deactivate",
353 skip_all,
354 fields(
355 db.query.text,
356 %user.id,
357 ),
358 err,
359 )]
360 async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
361 if user.deactivated_at.is_some() {
362 return Ok(user);
363 }
364
365 let deactivated_at = clock.now();
366 let res = sqlx::query!(
367 r#"
368 UPDATE users
369 SET deactivated_at = $2
370 WHERE user_id = $1
371 AND deactivated_at IS NULL
372 "#,
373 Uuid::from(user.id),
374 deactivated_at,
375 )
376 .traced()
377 .execute(&mut *self.conn)
378 .await?;
379
380 DatabaseError::ensure_affected_rows(&res, 1)?;
381
382 user.deactivated_at = Some(deactivated_at);
383
384 Ok(user)
385 }
386
387 #[tracing::instrument(
388 name = "db.user.reactivate",
389 skip_all,
390 fields(
391 db.query.text,
392 %user.id,
393 ),
394 err,
395 )]
396 async fn reactivate(&mut self, mut user: User) -> Result<User, Self::Error> {
397 if user.deactivated_at.is_none() {
398 return Ok(user);
399 }
400
401 let res = sqlx::query!(
402 r#"
403 UPDATE users
404 SET deactivated_at = NULL
405 WHERE user_id = $1
406 "#,
407 Uuid::from(user.id),
408 )
409 .traced()
410 .execute(&mut *self.conn)
411 .await?;
412
413 DatabaseError::ensure_affected_rows(&res, 1)?;
414
415 user.deactivated_at = None;
416
417 Ok(user)
418 }
419
420 #[tracing::instrument(
421 name = "db.user.set_can_request_admin",
422 skip_all,
423 fields(
424 db.query.text,
425 %user.id,
426 user.can_request_admin = can_request_admin,
427 ),
428 err,
429 )]
430 async fn set_can_request_admin(
431 &mut self,
432 mut user: User,
433 can_request_admin: bool,
434 ) -> Result<User, Self::Error> {
435 let res = sqlx::query!(
436 r#"
437 UPDATE users
438 SET can_request_admin = $2
439 WHERE user_id = $1
440 "#,
441 Uuid::from(user.id),
442 can_request_admin,
443 )
444 .traced()
445 .execute(&mut *self.conn)
446 .await?;
447
448 DatabaseError::ensure_affected_rows(&res, 1)?;
449
450 user.can_request_admin = can_request_admin;
451
452 Ok(user)
453 }
454
455 #[tracing::instrument(
456 name = "db.user.list",
457 skip_all,
458 fields(
459 db.query.text,
460 ),
461 err,
462 )]
463 async fn list(
464 &mut self,
465 filter: UserFilter<'_>,
466 pagination: mas_storage::Pagination,
467 ) -> Result<mas_storage::Page<User>, Self::Error> {
468 let (sql, arguments) = Query::select()
469 .expr_as(
470 Expr::col((Users::Table, Users::UserId)),
471 UserLookupIden::UserId,
472 )
473 .expr_as(
474 Expr::col((Users::Table, Users::Username)),
475 UserLookupIden::Username,
476 )
477 .expr_as(
478 Expr::col((Users::Table, Users::CreatedAt)),
479 UserLookupIden::CreatedAt,
480 )
481 .expr_as(
482 Expr::col((Users::Table, Users::LockedAt)),
483 UserLookupIden::LockedAt,
484 )
485 .expr_as(
486 Expr::col((Users::Table, Users::DeactivatedAt)),
487 UserLookupIden::DeactivatedAt,
488 )
489 .expr_as(
490 Expr::col((Users::Table, Users::CanRequestAdmin)),
491 UserLookupIden::CanRequestAdmin,
492 )
493 .from(Users::Table)
494 .apply_filter(filter)
495 .generate_pagination((Users::Table, Users::UserId), pagination)
496 .build_sqlx(PostgresQueryBuilder);
497
498 let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
499 .traced()
500 .fetch_all(&mut *self.conn)
501 .await?;
502
503 let page = pagination.process(edges).map(User::from);
504
505 Ok(page)
506 }
507
508 #[tracing::instrument(
509 name = "db.user.count",
510 skip_all,
511 fields(
512 db.query.text,
513 ),
514 err,
515 )]
516 async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
517 let (sql, arguments) = Query::select()
518 .expr(Expr::col((Users::Table, Users::UserId)).count())
519 .from(Users::Table)
520 .apply_filter(filter)
521 .build_sqlx(PostgresQueryBuilder);
522
523 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
524 .traced()
525 .fetch_one(&mut *self.conn)
526 .await?;
527
528 count
529 .try_into()
530 .map_err(DatabaseError::to_invalid_operation)
531 }
532
533 #[tracing::instrument(
534 name = "db.user.acquire_lock_for_sync",
535 skip_all,
536 fields(
537 db.query.text,
538 user.id = %user.id,
539 ),
540 err,
541 )]
542 async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
543 let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
551
552 sqlx::query!(
555 r#"
556 SELECT pg_advisory_xact_lock($1)
557 "#,
558 lock_id,
559 )
560 .traced()
561 .execute(&mut *self.conn)
562 .await?;
563
564 Ok(())
565 }
566}