sqlx_postgres/testing/
mod.rs

1use std::fmt::Write;
2use std::ops::Deref;
3use std::str::FromStr;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::time::{Duration, SystemTime};
6
7use futures_core::future::BoxFuture;
8
9use once_cell::sync::OnceCell;
10
11use crate::connection::Connection;
12
13use crate::error::Error;
14use crate::executor::Executor;
15use crate::pool::{Pool, PoolOptions};
16use crate::query::query;
17use crate::query_scalar::query_scalar;
18use crate::{PgConnectOptions, PgConnection, Postgres};
19
20pub(crate) use sqlx_core::testing::*;
21
22// Using a blocking `OnceCell` here because the critical sections are short.
23static MASTER_POOL: OnceCell<Pool<Postgres>> = OnceCell::new();
24// Automatically delete any databases created before the start of the test binary.
25static DO_CLEANUP: AtomicBool = AtomicBool::new(true);
26
27impl TestSupport for Postgres {
28    fn test_context(args: &TestArgs) -> BoxFuture<'_, Result<TestContext<Self>, Error>> {
29        Box::pin(async move {
30            let res = test_context(args).await;
31            res
32        })
33    }
34
35    fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>> {
36        Box::pin(async move {
37            let mut conn = MASTER_POOL
38                .get()
39                .expect("cleanup_test() invoked outside `#[sqlx::test]")
40                .acquire()
41                .await?;
42
43            conn.execute(&format!("drop database if exists {db_name:?};")[..])
44                .await?;
45
46            query("delete from _sqlx_test.databases where db_name = $1")
47                .bind(&db_name)
48                .execute(&mut *conn)
49                .await?;
50
51            Ok(())
52        })
53    }
54
55    fn cleanup_test_dbs() -> BoxFuture<'static, Result<Option<usize>, Error>> {
56        Box::pin(async move {
57            let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set");
58
59            let mut conn = PgConnection::connect(&url).await?;
60
61            let now = SystemTime::now()
62                .duration_since(SystemTime::UNIX_EPOCH)
63                .unwrap();
64
65            let num_deleted = do_cleanup(&mut conn, now).await?;
66            let _ = conn.close().await;
67            Ok(Some(num_deleted))
68        })
69    }
70
71    fn snapshot(
72        _conn: &mut Self::Connection,
73    ) -> BoxFuture<'_, Result<FixtureSnapshot<Self>, Error>> {
74        // TODO: I want to get the testing feature out the door so this will have to wait,
75        // but I'm keeping the code around for now because I plan to come back to it.
76        todo!()
77    }
78}
79
80async fn test_context(args: &TestArgs) -> Result<TestContext<Postgres>, Error> {
81    let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set");
82
83    let master_opts = PgConnectOptions::from_str(&url).expect("failed to parse DATABASE_URL");
84
85    let pool = PoolOptions::new()
86        // Postgres' normal connection limit is 100 plus 3 superuser connections
87        // We don't want to use the whole cap and there may be fuzziness here due to
88        // concurrently running tests anyway.
89        .max_connections(20)
90        // Immediately close master connections. Tokio's I/O streams don't like hopping runtimes.
91        .after_release(|_conn, _| Box::pin(async move { Ok(false) }))
92        .connect_lazy_with(master_opts);
93
94    let master_pool = match MASTER_POOL.try_insert(pool) {
95        Ok(inserted) => inserted,
96        Err((existing, pool)) => {
97            // Sanity checks.
98            assert_eq!(
99                existing.connect_options().host,
100                pool.connect_options().host,
101                "DATABASE_URL changed at runtime, host differs"
102            );
103
104            assert_eq!(
105                existing.connect_options().database,
106                pool.connect_options().database,
107                "DATABASE_URL changed at runtime, database differs"
108            );
109
110            existing
111        }
112    };
113
114    let mut conn = master_pool.acquire().await?;
115
116    // language=PostgreSQL
117    conn.execute(
118        // Explicit lock avoids this latent bug: https://stackoverflow.com/a/29908840
119        // I couldn't find a bug on the mailing list for `CREATE SCHEMA` specifically,
120        // but a clearly related bug with `CREATE TABLE` has been known since 2007:
121        // https://www.postgresql.org/message-id/200710222037.l9MKbCJZ098744%40wwwmaster.postgresql.org
122        r#"
123        lock table pg_catalog.pg_namespace in share row exclusive mode;
124
125        create schema if not exists _sqlx_test;
126
127        create table if not exists _sqlx_test.databases (
128            db_name text primary key,
129            test_path text not null,
130            created_at timestamptz not null default now()
131        );
132
133        create index if not exists databases_created_at 
134            on _sqlx_test.databases(created_at);
135
136        create sequence if not exists _sqlx_test.database_ids;
137    "#,
138    )
139    .await?;
140
141    // Record the current time _before_ we acquire the `DO_CLEANUP` permit. This
142    // prevents the first test thread from accidentally deleting new test dbs
143    // created by other test threads if we're a bit slow.
144    let now = SystemTime::now()
145        .duration_since(SystemTime::UNIX_EPOCH)
146        .unwrap();
147
148    // Only run cleanup if the test binary just started.
149    if DO_CLEANUP.swap(false, Ordering::SeqCst) {
150        do_cleanup(&mut conn, now).await?;
151    }
152
153    let new_db_name: String = query_scalar(
154        r#"
155            insert into _sqlx_test.databases(db_name, test_path)
156            select '_sqlx_test_' || nextval('_sqlx_test.database_ids'), $1
157            returning db_name
158        "#,
159    )
160    .bind(&args.test_path)
161    .fetch_one(&mut *conn)
162    .await?;
163
164    conn.execute(&format!("create database {new_db_name:?}")[..])
165        .await?;
166
167    Ok(TestContext {
168        pool_opts: PoolOptions::new()
169            // Don't allow a single test to take all the connections.
170            // Most tests shouldn't require more than 5 connections concurrently,
171            // or else they're likely doing too much in one test.
172            .max_connections(5)
173            // Close connections ASAP if left in the idle queue.
174            .idle_timeout(Some(Duration::from_secs(1)))
175            .parent(master_pool.clone()),
176        connect_opts: master_pool
177            .connect_options()
178            .deref()
179            .clone()
180            .database(&new_db_name),
181        db_name: new_db_name,
182    })
183}
184
185async fn do_cleanup(conn: &mut PgConnection, created_before: Duration) -> Result<usize, Error> {
186    // since SystemTime is not monotonic we added a little margin here to avoid race conditions with other threads
187    let created_before = i64::try_from(created_before.as_secs()).unwrap() - 2;
188
189    let delete_db_names: Vec<String> = query_scalar(
190        "select db_name from _sqlx_test.databases \
191            where created_at < (to_timestamp($1) at time zone 'UTC')",
192    )
193    .bind(&created_before)
194    .fetch_all(&mut *conn)
195    .await?;
196
197    if delete_db_names.is_empty() {
198        return Ok(0);
199    }
200
201    let mut deleted_db_names = Vec::with_capacity(delete_db_names.len());
202    let delete_db_names = delete_db_names.into_iter();
203
204    let mut command = String::new();
205
206    for db_name in delete_db_names {
207        command.clear();
208        writeln!(command, "drop database if exists {db_name:?};").ok();
209        match conn.execute(&*command).await {
210            Ok(_deleted) => {
211                deleted_db_names.push(db_name);
212            }
213            // Assume a database error just means the DB is still in use.
214            Err(Error::Database(dbe)) => {
215                eprintln!("could not clean test database {db_name:?}: {dbe}")
216            }
217            // Bubble up other errors
218            Err(e) => return Err(e),
219        }
220    }
221
222    query("delete from _sqlx_test.databases where db_name = any($1::text[])")
223        .bind(&deleted_db_names)
224        .execute(&mut *conn)
225        .await?;
226
227    Ok(deleted_db_names.len())
228}