1use crate::database::{Database, HasArguments};
4
5use crate::query_builder::QueryBuilder;
6
7use indexmap::set::IndexSet;
8use std::cmp;
9use std::collections::{BTreeMap, HashMap};
10use std::marker::PhantomData;
11use std::sync::Arc;
12
13pub type Result<T, E = FixtureError> = std::result::Result<T, E>;
14
15pub struct FixtureSnapshot<DB> {
21 tables: BTreeMap<TableName, Table>,
22 db: PhantomData<DB>,
23}
24
25#[derive(Debug, thiserror::Error)]
26#[error("could not create fixture: {0}")]
27pub struct FixtureError(String);
28
29pub struct Fixture<DB> {
30 ops: Vec<FixtureOp>,
31 db: PhantomData<DB>,
32}
33
34enum FixtureOp {
35 Insert {
36 table: TableName,
37 columns: Vec<ColumnName>,
38 rows: Vec<Vec<Value>>,
39 },
40 }
42
43type TableName = Arc<str>;
44type ColumnName = Arc<str>;
45type Value = String;
46
47struct Table {
48 name: TableName,
49 columns: IndexSet<ColumnName>,
50 rows: Vec<Vec<Value>>,
51 foreign_keys: HashMap<ColumnName, (TableName, ColumnName)>,
52}
53
54macro_rules! fixture_assert (
55 ($cond:expr, $msg:literal $($arg:tt)*) => {
56 if !($cond) {
57 return Err(FixtureError(format!($msg $($arg)*)))
58 }
59 }
60);
61
62impl<DB: Database> FixtureSnapshot<DB> {
63 pub fn additive_fixture(&self) -> Result<Fixture<DB>> {
74 let visit_order = self.calculate_visit_order()?;
75
76 let mut ops = Vec::new();
77
78 for table_name in visit_order {
79 let table = self.tables.get(&table_name).unwrap();
80
81 ops.push(FixtureOp::Insert {
82 table: table_name,
83 columns: table.columns.iter().cloned().collect(),
84 rows: table.rows.clone(),
85 });
86 }
87
88 Ok(Fixture { ops, db: self.db })
89 }
90
91 fn calculate_visit_order(&self) -> Result<Vec<TableName>> {
96 let mut table_depths = HashMap::with_capacity(self.tables.len());
97 let mut visited_set = IndexSet::with_capacity(self.tables.len());
98
99 for table in self.tables.values() {
100 foreign_key_depth(&self.tables, table, &mut table_depths, &mut visited_set)?;
101 visited_set.clear();
102 }
103
104 let mut table_names: Vec<TableName> = table_depths.keys().cloned().collect();
105 table_names.sort_by_key(|name| table_depths.get(name).unwrap());
106 Ok(table_names)
107 }
108}
109
110impl<DB: Database> ToString for Fixture<DB>
113where
114 for<'a> <DB as HasArguments<'a>>::Arguments: Default,
115{
116 fn to_string(&self) -> String {
117 let mut query = QueryBuilder::<DB>::new("");
118
119 for op in &self.ops {
120 match op {
121 FixtureOp::Insert {
122 table,
123 columns,
124 rows,
125 } => {
126 if columns.is_empty() || rows.is_empty() {
128 continue;
129 }
130
131 query.push(format_args!("INSERT INTO {table} ("));
132
133 let mut separated = query.separated(", ");
134
135 for column in columns {
136 separated.push(column);
137 }
138
139 query.push(")\n");
140
141 query.push_values(rows, |mut separated, row| {
142 for value in row {
143 separated.push(value);
144 }
145 });
146
147 query.push(";\n");
148 }
149 }
150 }
151
152 query.into_sql()
153 }
154}
155
156fn foreign_key_depth(
157 tables: &BTreeMap<TableName, Table>,
158 table: &Table,
159 depths: &mut HashMap<TableName, usize>,
160 visited_set: &mut IndexSet<TableName>,
161) -> Result<usize> {
162 if let Some(&depth) = depths.get(&table.name) {
163 return Ok(depth);
164 }
165
166 fixture_assert!(
168 visited_set.insert(table.name.clone()),
169 "foreign key cycle detected: {:?} -> {:?}",
170 visited_set,
171 table.name
172 );
173
174 let mut refdepth = 0;
175
176 for (colname, (refname, refcol)) in &table.foreign_keys {
177 let referenced = tables.get(refname).ok_or_else(|| {
178 FixtureError(format!(
179 "table {:?} in foreign key `{}.{} references {}.{}` does not exist",
180 refname, table.name, colname, refname, refcol
181 ))
182 })?;
183
184 refdepth = cmp::max(
185 refdepth,
186 foreign_key_depth(tables, referenced, depths, visited_set)?,
187 );
188 }
189
190 let depth = refdepth + 1;
191
192 depths.insert(table.name.clone(), depth);
193
194 Ok(depth)
195}
196
197#[test]
198#[cfg(feature = "postgres")]
199fn test_additive_fixture() -> Result<()> {
200 use crate::postgres::Postgres;
201
202 let mut snapshot = FixtureSnapshot {
203 tables: BTreeMap::new(),
204 db: PhantomData::<Postgres>,
205 };
206
207 snapshot.tables.insert(
208 "foo".into(),
209 Table {
210 name: "foo".into(),
211 columns: ["foo_id", "foo_a", "foo_b"]
212 .into_iter()
213 .map(Arc::<str>::from)
214 .collect(),
215 rows: vec![vec!["1".into(), "'asdf'".into(), "true".into()]],
216 foreign_keys: HashMap::new(),
217 },
218 );
219
220 snapshot.tables.insert(
223 "bar".into(),
224 Table {
225 name: "bar".into(),
226 columns: ["bar_id", "foo_id", "bar_a", "bar_b"]
227 .into_iter()
228 .map(Arc::<str>::from)
229 .collect(),
230 rows: vec![vec![
231 "1234".into(),
232 "1".into(),
233 "'2022-07-22 23:27:48.775113301+00:00'".into(),
234 "3.14".into(),
235 ]],
236 foreign_keys: [("foo_id".into(), ("foo".into(), "foo_id".into()))]
237 .into_iter()
238 .collect(),
239 },
240 );
241
242 snapshot.tables.insert(
244 "baz".into(),
245 Table {
246 name: "baz".into(),
247 columns: ["baz_id", "bar_id", "foo_id", "baz_a", "baz_b"]
248 .into_iter()
249 .map(Arc::<str>::from)
250 .collect(),
251 rows: vec![vec![
252 "5678".into(),
253 "1234".into(),
254 "1".into(),
255 "'2022-07-22 23:27:48.775113301+00:00'".into(),
256 "3.14".into(),
257 ]],
258 foreign_keys: [
259 ("foo_id".into(), ("foo".into(), "foo_id".into())),
260 ("bar_id".into(), ("bar".into(), "bar_id".into())),
261 ]
262 .into_iter()
263 .collect(),
264 },
265 );
266
267 let fixture = snapshot.additive_fixture()?;
268
269 assert_eq!(
270 fixture.to_string(),
271 "INSERT INTO foo (foo_id, foo_a, foo_b)\n\
272 VALUES (1, 'asdf', true);\n\
273 INSERT INTO bar (bar_id, foo_id, bar_a, bar_b)\n\
274 VALUES (1234, 1, '2022-07-22 23:27:48.775113301+00:00', 3.14);\n\
275 INSERT INTO baz (baz_id, bar_id, foo_id, baz_a, baz_b)\n\
276 VALUES (5678, 1234, 1, '2022-07-22 23:27:48.775113301+00:00', 3.14);\n"
277 );
278
279 Ok(())
280}