sqlx_core/migrate/
migrator.rs

1use crate::acquire::Acquire;
2use crate::migrate::{AppliedMigration, Migrate, MigrateError, Migration, MigrationSource};
3use std::borrow::Cow;
4use std::collections::{HashMap, HashSet};
5use std::ops::Deref;
6use std::slice;
7
8/// A resolved set of migrations, ready to be run.
9///
10/// Can be constructed statically using `migrate!()` or at runtime using [`Migrator::new()`].
11#[derive(Debug)]
12// Forbids `migrate!()` from constructing this:
13// #[non_exhaustive]
14pub struct Migrator {
15    // NOTE: these fields are semver-exempt and may be changed or removed in any future version.
16    // These have to be public for `migrate!()` to be able to initialize them in an implicitly
17    // const-promotable context. A `const fn` constructor isn't implicitly const-promotable.
18    #[doc(hidden)]
19    pub migrations: Cow<'static, [Migration]>,
20    #[doc(hidden)]
21    pub ignore_missing: bool,
22    #[doc(hidden)]
23    pub locking: bool,
24}
25
26fn validate_applied_migrations(
27    applied_migrations: &[AppliedMigration],
28    migrator: &Migrator,
29) -> Result<(), MigrateError> {
30    if migrator.ignore_missing {
31        return Ok(());
32    }
33
34    let migrations: HashSet<_> = migrator.iter().map(|m| m.version).collect();
35
36    for applied_migration in applied_migrations {
37        if !migrations.contains(&applied_migration.version) {
38            return Err(MigrateError::VersionMissing(applied_migration.version));
39        }
40    }
41
42    Ok(())
43}
44
45impl Migrator {
46    #[doc(hidden)]
47    pub const DEFAULT: Migrator = Migrator {
48        migrations: Cow::Borrowed(&[]),
49        ignore_missing: false,
50        locking: true,
51    };
52
53    /// Creates a new instance with the given source.
54    ///
55    /// # Examples
56    ///
57    /// ```rust,no_run
58    /// # use sqlx_core::migrate::MigrateError;
59    /// # fn main() -> Result<(), MigrateError> {
60    /// # sqlx::__rt::test_block_on(async move {
61    /// # use sqlx_core::migrate::Migrator;
62    /// use std::path::Path;
63    ///
64    /// // Read migrations from a local folder: ./migrations
65    /// let m = Migrator::new(Path::new("./migrations")).await?;
66    /// # Ok(())
67    /// # })
68    /// # }
69    /// ```
70    /// See [MigrationSource] for details on structure of the `./migrations` directory.
71    pub async fn new<'s, S>(source: S) -> Result<Self, MigrateError>
72    where
73        S: MigrationSource<'s>,
74    {
75        Ok(Self {
76            migrations: Cow::Owned(source.resolve().await.map_err(MigrateError::Source)?),
77            ..Self::DEFAULT
78        })
79    }
80
81    /// Specify whether applied migrations that are missing from the resolved migrations should be ignored.
82    pub fn set_ignore_missing(&mut self, ignore_missing: bool) -> &Self {
83        self.ignore_missing = ignore_missing;
84        self
85    }
86
87    /// Specify whether or not to lock the database during migration. Defaults to `true`.
88    ///
89    /// ### Warning
90    /// Disabling locking can lead to errors or data loss if multiple clients attempt to apply migrations simultaneously
91    /// without some sort of mutual exclusion.
92    ///
93    /// This should only be used if the database does not support locking, e.g. CockroachDB which talks the Postgres
94    /// protocol but does not support advisory locks used by SQLx's migrations support for Postgres.
95    pub fn set_locking(&mut self, locking: bool) -> &Self {
96        self.locking = locking;
97        self
98    }
99
100    /// Get an iterator over all known migrations.
101    pub fn iter(&self) -> slice::Iter<'_, Migration> {
102        self.migrations.iter()
103    }
104
105    /// Check if a migration version exists.
106    pub fn version_exists(&self, version: i64) -> bool {
107        self.iter().any(|m| m.version == version)
108    }
109
110    /// Run any pending migrations against the database; and, validate previously applied migrations
111    /// against the current migration source to detect accidental changes in previously-applied migrations.
112    ///
113    /// # Examples
114    ///
115    /// ```rust,no_run
116    /// # use sqlx::migrate::MigrateError;
117    /// # fn main() -> Result<(), MigrateError> {
118    /// #     sqlx::__rt::test_block_on(async move {
119    /// use sqlx::migrate::Migrator;
120    /// use sqlx::sqlite::SqlitePoolOptions;
121    ///
122    /// let m = Migrator::new(std::path::Path::new("./migrations")).await?;
123    /// let pool = SqlitePoolOptions::new().connect("sqlite::memory:").await?;
124    /// m.run(&pool).await
125    /// #     })
126    /// # }
127    /// ```
128    pub async fn run<'a, A>(&self, migrator: A) -> Result<(), MigrateError>
129    where
130        A: Acquire<'a>,
131        <A::Connection as Deref>::Target: Migrate,
132    {
133        let mut conn = migrator.acquire().await?;
134        self.run_direct(&mut *conn).await
135    }
136
137    // Getting around the annoying "implementation of `Acquire` is not general enough" error
138    #[doc(hidden)]
139    pub async fn run_direct<C>(&self, conn: &mut C) -> Result<(), MigrateError>
140    where
141        C: Migrate,
142    {
143        // lock the database for exclusive access by the migrator
144        if self.locking {
145            conn.lock().await?;
146        }
147
148        // creates [_migrations] table only if needed
149        // eventually this will likely migrate previous versions of the table
150        conn.ensure_migrations_table().await?;
151
152        let version = conn.dirty_version().await?;
153        if let Some(version) = version {
154            return Err(MigrateError::Dirty(version));
155        }
156
157        let applied_migrations = conn.list_applied_migrations().await?;
158        validate_applied_migrations(&applied_migrations, self)?;
159
160        let applied_migrations: HashMap<_, _> = applied_migrations
161            .into_iter()
162            .map(|m| (m.version, m))
163            .collect();
164
165        for migration in self.iter() {
166            if migration.migration_type.is_down_migration() {
167                continue;
168            }
169
170            match applied_migrations.get(&migration.version) {
171                Some(applied_migration) => {
172                    if migration.checksum != applied_migration.checksum {
173                        return Err(MigrateError::VersionMismatch(migration.version));
174                    }
175                }
176                None => {
177                    conn.apply(migration).await?;
178                }
179            }
180        }
181
182        // unlock the migrator to allow other migrators to run
183        // but do nothing as we already migrated
184        if self.locking {
185            conn.unlock().await?;
186        }
187
188        Ok(())
189    }
190
191    /// Run down migrations against the database until a specific version.
192    ///
193    /// # Examples
194    ///
195    /// ```rust,no_run
196    /// # use sqlx::migrate::MigrateError;
197    /// # fn main() -> Result<(), MigrateError> {
198    /// #     sqlx::__rt::test_block_on(async move {
199    /// use sqlx::migrate::Migrator;
200    /// use sqlx::sqlite::SqlitePoolOptions;
201    ///
202    /// let m = Migrator::new(std::path::Path::new("./migrations")).await?;
203    /// let pool = SqlitePoolOptions::new().connect("sqlite::memory:").await?;
204    /// m.undo(&pool, 4).await
205    /// #     })
206    /// # }
207    /// ```
208    pub async fn undo<'a, A>(&self, migrator: A, target: i64) -> Result<(), MigrateError>
209    where
210        A: Acquire<'a>,
211        <A::Connection as Deref>::Target: Migrate,
212    {
213        let mut conn = migrator.acquire().await?;
214
215        // lock the database for exclusive access by the migrator
216        if self.locking {
217            conn.lock().await?;
218        }
219
220        // creates [_migrations] table only if needed
221        // eventually this will likely migrate previous versions of the table
222        conn.ensure_migrations_table().await?;
223
224        let version = conn.dirty_version().await?;
225        if let Some(version) = version {
226            return Err(MigrateError::Dirty(version));
227        }
228
229        let applied_migrations = conn.list_applied_migrations().await?;
230        validate_applied_migrations(&applied_migrations, self)?;
231
232        let applied_migrations: HashMap<_, _> = applied_migrations
233            .into_iter()
234            .map(|m| (m.version, m))
235            .collect();
236
237        for migration in self
238            .iter()
239            .rev()
240            .filter(|m| m.migration_type.is_down_migration())
241            .filter(|m| applied_migrations.contains_key(&m.version))
242            .filter(|m| m.version > target)
243        {
244            conn.revert(migration).await?;
245        }
246
247        // unlock the migrator to allow other migrators to run
248        // but do nothing as we already migrated
249        if self.locking {
250            conn.unlock().await?;
251        }
252
253        Ok(())
254    }
255}