mas_config/sections/
database.rs1use std::{num::NonZeroU32, time::Duration};
8
9use camino::Utf8PathBuf;
10use schemars::JsonSchema;
11use serde::{Deserialize, Serialize};
12use serde_with::serde_as;
13
14use super::ConfigurationSection;
15use crate::schema;
16
17#[allow(clippy::unnecessary_wraps)]
18fn default_connection_string() -> Option<String> {
19 Some("postgresql://".to_owned())
20}
21
22fn default_max_connections() -> NonZeroU32 {
23 NonZeroU32::new(10).unwrap()
24}
25
26fn default_connect_timeout() -> Duration {
27 Duration::from_secs(30)
28}
29
30#[allow(clippy::unnecessary_wraps)]
31fn default_idle_timeout() -> Option<Duration> {
32 Some(Duration::from_secs(10 * 60))
33}
34
35#[allow(clippy::unnecessary_wraps)]
36fn default_max_lifetime() -> Option<Duration> {
37 Some(Duration::from_secs(30 * 60))
38}
39
40impl Default for DatabaseConfig {
41 fn default() -> Self {
42 Self {
43 uri: default_connection_string(),
44 host: None,
45 port: None,
46 socket: None,
47 username: None,
48 password: None,
49 database: None,
50 ssl_mode: None,
51 ssl_ca: None,
52 ssl_ca_file: None,
53 ssl_certificate: None,
54 ssl_certificate_file: None,
55 ssl_key: None,
56 ssl_key_file: None,
57 max_connections: default_max_connections(),
58 min_connections: Default::default(),
59 connect_timeout: default_connect_timeout(),
60 idle_timeout: default_idle_timeout(),
61 max_lifetime: default_max_lifetime(),
62 }
63 }
64}
65
66#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
69#[serde(rename_all = "kebab-case")]
70pub enum PgSslMode {
71 Disable,
73
74 Allow,
76
77 Prefer,
79
80 Require,
83
84 VerifyCa,
87
88 VerifyFull,
92}
93
94#[serde_as]
96#[derive(Debug, Serialize, Deserialize, JsonSchema)]
97pub struct DatabaseConfig {
98 #[serde(skip_serializing_if = "Option::is_none")]
103 #[schemars(url, default = "default_connection_string")]
104 pub uri: Option<String>,
105
106 #[serde(skip_serializing_if = "Option::is_none")]
110 #[schemars(with = "Option::<schema::Hostname>")]
111 pub host: Option<String>,
112
113 #[serde(skip_serializing_if = "Option::is_none")]
117 #[schemars(range(min = 1, max = 65535))]
118 pub port: Option<u16>,
119
120 #[serde(skip_serializing_if = "Option::is_none")]
124 #[schemars(with = "Option<String>")]
125 pub socket: Option<Utf8PathBuf>,
126
127 #[serde(skip_serializing_if = "Option::is_none")]
131 pub username: Option<String>,
132
133 #[serde(skip_serializing_if = "Option::is_none")]
137 pub password: Option<String>,
138
139 #[serde(skip_serializing_if = "Option::is_none")]
143 pub database: Option<String>,
144
145 #[serde(skip_serializing_if = "Option::is_none")]
147 pub ssl_mode: Option<PgSslMode>,
148
149 #[serde(skip_serializing_if = "Option::is_none")]
153 pub ssl_ca: Option<String>,
154
155 #[serde(skip_serializing_if = "Option::is_none")]
159 #[schemars(with = "Option<String>")]
160 pub ssl_ca_file: Option<Utf8PathBuf>,
161
162 #[serde(skip_serializing_if = "Option::is_none")]
167 pub ssl_certificate: Option<String>,
168
169 #[serde(skip_serializing_if = "Option::is_none")]
173 #[schemars(with = "Option<String>")]
174 pub ssl_certificate_file: Option<Utf8PathBuf>,
175
176 #[serde(skip_serializing_if = "Option::is_none")]
180 pub ssl_key: Option<String>,
181
182 #[serde(skip_serializing_if = "Option::is_none")]
186 #[schemars(with = "Option<String>")]
187 pub ssl_key_file: Option<Utf8PathBuf>,
188
189 #[serde(default = "default_max_connections")]
191 pub max_connections: NonZeroU32,
192
193 #[serde(default)]
195 pub min_connections: u32,
196
197 #[schemars(with = "u64")]
199 #[serde(default = "default_connect_timeout")]
200 #[serde_as(as = "serde_with::DurationSeconds<u64>")]
201 pub connect_timeout: Duration,
202
203 #[schemars(with = "Option<u64>")]
205 #[serde(
206 default = "default_idle_timeout",
207 skip_serializing_if = "Option::is_none"
208 )]
209 #[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
210 pub idle_timeout: Option<Duration>,
211
212 #[schemars(with = "u64")]
214 #[serde(
215 default = "default_max_lifetime",
216 skip_serializing_if = "Option::is_none"
217 )]
218 #[serde_as(as = "Option<serde_with::DurationSeconds<u64>>")]
219 pub max_lifetime: Option<Duration>,
220}
221
222impl ConfigurationSection for DatabaseConfig {
223 const PATH: Option<&'static str> = Some("database");
224
225 fn validate(
226 &self,
227 figment: &figment::Figment,
228 ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
229 let metadata = figment.find_metadata(Self::PATH.unwrap());
230 let annotate = |mut error: figment::Error| {
231 error.metadata = metadata.cloned();
232 error.profile = Some(figment::Profile::Default);
233 error.path = vec![Self::PATH.unwrap().to_owned()];
234 error
235 };
236
237 let has_split_options = self.host.is_some()
240 || self.port.is_some()
241 || self.socket.is_some()
242 || self.username.is_some()
243 || self.password.is_some()
244 || self.database.is_some();
245
246 if self.uri.is_some() && has_split_options {
247 return Err(annotate(figment::error::Error::from(
248 "uri must not be specified if host, port, socket, username, password, or database are specified".to_owned(),
249 )).into());
250 }
251
252 if self.ssl_ca.is_some() && self.ssl_ca_file.is_some() {
253 return Err(annotate(figment::error::Error::from(
254 "ssl_ca must not be specified if ssl_ca_file is specified".to_owned(),
255 ))
256 .into());
257 }
258
259 if self.ssl_certificate.is_some() && self.ssl_certificate_file.is_some() {
260 return Err(annotate(figment::error::Error::from(
261 "ssl_certificate must not be specified if ssl_certificate_file is specified"
262 .to_owned(),
263 ))
264 .into());
265 }
266
267 if self.ssl_key.is_some() && self.ssl_key_file.is_some() {
268 return Err(annotate(figment::error::Error::from(
269 "ssl_key must not be specified if ssl_key_file is specified".to_owned(),
270 ))
271 .into());
272 }
273
274 if (self.ssl_key.is_some() || self.ssl_key_file.is_some())
275 ^ (self.ssl_certificate.is_some() || self.ssl_certificate_file.is_some())
276 {
277 return Err(annotate(figment::error::Error::from(
278 "both a ssl_certificate and a ssl_key must be set at the same time or none of them"
279 .to_owned(),
280 ))
281 .into());
282 }
283
284 Ok(())
285 }
286}
287#[cfg(test)]
288mod tests {
289 use figment::{
290 Figment, Jail,
291 providers::{Format, Yaml},
292 };
293
294 use super::*;
295
296 #[test]
297 fn load_config() {
298 Jail::expect_with(|jail| {
299 jail.create_file(
300 "config.yaml",
301 r"
302 database:
303 uri: postgresql://user:password@host/database
304 ",
305 )?;
306
307 let config = Figment::new()
308 .merge(Yaml::file("config.yaml"))
309 .extract_inner::<DatabaseConfig>("database")?;
310
311 assert_eq!(
312 config.uri.as_deref(),
313 Some("postgresql://user:password@host/database")
314 );
315
316 Ok(())
317 });
318 }
319}