sqlx_postgres/types/
lquery.rs

1use crate::decode::Decode;
2use crate::encode::{Encode, IsNull};
3use crate::error::BoxDynError;
4use crate::types::Type;
5use crate::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
6use bitflags::bitflags;
7use std::fmt::{self, Display, Formatter};
8use std::io::Write;
9use std::ops::Deref;
10use std::str::FromStr;
11
12use crate::types::ltree::{PgLTreeLabel, PgLTreeParseError};
13
14/// Represents lquery specific errors
15#[derive(Debug, thiserror::Error)]
16#[non_exhaustive]
17pub enum PgLQueryParseError {
18    #[error("lquery cannot be empty")]
19    EmptyString,
20    #[error("unexpected character in lquery")]
21    UnexpectedCharacter,
22    #[error("error parsing integer: {0}")]
23    ParseIntError(#[from] std::num::ParseIntError),
24    #[error("error parsing integer: {0}")]
25    LTreeParrseError(#[from] PgLTreeParseError),
26    /// LQuery version not supported
27    #[error("lquery version not supported")]
28    InvalidLqueryVersion,
29}
30
31/// Container for a Label Tree Query (`lquery`) in Postgres.
32///
33/// See https://www.postgresql.org/docs/current/ltree.html
34///
35/// ### Note: Requires Postgres 13+
36///
37/// This integration requires that the `lquery` type support the binary format in the Postgres
38/// wire protocol, which only became available in Postgres 13.
39/// ([Postgres 13.0 Release Notes, Additional Modules][https://www.postgresql.org/docs/13/release-13.html#id-1.11.6.11.5.14])
40///
41/// Ideally, SQLx's Postgres driver should support falling back to text format for types
42/// which don't have `typsend` and `typrecv` entries in `pg_type`, but that work still needs
43/// to be done.
44///
45/// ### Note: Extension Required
46/// The `ltree` extension is not enabled by default in Postgres. You will need to do so explicitly:
47///
48/// ```ignore
49/// CREATE EXTENSION IF NOT EXISTS "ltree";
50/// ```
51#[derive(Clone, Debug, Default, PartialEq)]
52pub struct PgLQuery {
53    levels: Vec<PgLQueryLevel>,
54}
55
56// TODO: maybe a QueryBuilder pattern would be nice here
57impl PgLQuery {
58    /// creates default/empty lquery
59    pub fn new() -> Self {
60        Self::default()
61    }
62
63    pub fn from(levels: Vec<PgLQueryLevel>) -> Self {
64        Self { levels }
65    }
66
67    /// push a query level
68    pub fn push(&mut self, level: PgLQueryLevel) {
69        self.levels.push(level);
70    }
71
72    /// pop a query level
73    pub fn pop(&mut self) -> Option<PgLQueryLevel> {
74        self.levels.pop()
75    }
76
77    /// creates lquery from an iterator with checking labels
78    pub fn from_iter<I, S>(levels: I) -> Result<Self, PgLQueryParseError>
79    where
80        S: Into<String>,
81        I: IntoIterator<Item = S>,
82    {
83        let mut lquery = Self::default();
84        for level in levels {
85            lquery.push(PgLQueryLevel::from_str(&level.into())?);
86        }
87        Ok(lquery)
88    }
89}
90
91impl IntoIterator for PgLQuery {
92    type Item = PgLQueryLevel;
93    type IntoIter = std::vec::IntoIter<Self::Item>;
94
95    fn into_iter(self) -> Self::IntoIter {
96        self.levels.into_iter()
97    }
98}
99
100impl FromStr for PgLQuery {
101    type Err = PgLQueryParseError;
102
103    fn from_str(s: &str) -> Result<Self, Self::Err> {
104        Ok(Self {
105            levels: s
106                .split('.')
107                .map(|s| PgLQueryLevel::from_str(s))
108                .collect::<Result<_, Self::Err>>()?,
109        })
110    }
111}
112
113impl Display for PgLQuery {
114    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
115        let mut iter = self.levels.iter();
116        if let Some(label) = iter.next() {
117            write!(f, "{label}")?;
118            for label in iter {
119                write!(f, ".{label}")?;
120            }
121        }
122        Ok(())
123    }
124}
125
126impl Deref for PgLQuery {
127    type Target = [PgLQueryLevel];
128
129    fn deref(&self) -> &Self::Target {
130        &self.levels
131    }
132}
133
134impl Type<Postgres> for PgLQuery {
135    fn type_info() -> PgTypeInfo {
136        // Since `ltree` is enabled by an extension, it does not have a stable OID.
137        PgTypeInfo::with_name("lquery")
138    }
139}
140
141impl Encode<'_, Postgres> for PgLQuery {
142    fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
143        buf.extend(1i8.to_le_bytes());
144        write!(buf, "{self}")
145            .expect("Display implementation panicked while writing to PgArgumentBuffer");
146
147        IsNull::No
148    }
149}
150
151impl<'r> Decode<'r, Postgres> for PgLQuery {
152    fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
153        match value.format() {
154            PgValueFormat::Binary => {
155                let bytes = value.as_bytes()?;
156                let version = i8::from_le_bytes([bytes[0]; 1]);
157                if version != 1 {
158                    return Err(Box::new(PgLQueryParseError::InvalidLqueryVersion));
159                }
160                Ok(Self::from_str(std::str::from_utf8(&bytes[1..])?)?)
161            }
162            PgValueFormat::Text => Ok(Self::from_str(value.as_str()?)?),
163        }
164    }
165}
166
167bitflags! {
168    /// Modifiers that can be set to non-star labels
169    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
170    pub struct PgLQueryVariantFlag: u16 {
171        /// * - Match any label with this prefix, for example foo* matches foobar
172        const ANY_END = 0x01;
173        /// @ - Match case-insensitively, for example a@ matches A
174        const IN_CASE = 0x02;
175        /// % - Match initial underscore-separated words
176        const SUBLEXEME = 0x04;
177    }
178}
179
180impl Display for PgLQueryVariantFlag {
181    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
182        if self.contains(PgLQueryVariantFlag::ANY_END) {
183            write!(f, "*")?;
184        }
185        if self.contains(PgLQueryVariantFlag::IN_CASE) {
186            write!(f, "@")?;
187        }
188        if self.contains(PgLQueryVariantFlag::SUBLEXEME) {
189            write!(f, "%")?;
190        }
191
192        Ok(())
193    }
194}
195
196#[derive(Clone, Debug, PartialEq)]
197pub struct PgLQueryVariant {
198    label: PgLTreeLabel,
199    modifiers: PgLQueryVariantFlag,
200}
201
202impl Display for PgLQueryVariant {
203    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
204        write!(f, "{}{}", self.label, self.modifiers)
205    }
206}
207
208#[derive(Clone, Debug, PartialEq)]
209pub enum PgLQueryLevel {
210    /// match any label (*) with optional at least / at most numbers
211    Star(Option<u16>, Option<u16>),
212    /// match any of specified labels with optional flags
213    NonStar(Vec<PgLQueryVariant>),
214    /// match none of specified labels with optional flags
215    NotNonStar(Vec<PgLQueryVariant>),
216}
217
218impl FromStr for PgLQueryLevel {
219    type Err = PgLQueryParseError;
220
221    fn from_str(s: &str) -> Result<Self, Self::Err> {
222        let bytes = s.as_bytes();
223        if bytes.is_empty() {
224            Err(PgLQueryParseError::EmptyString)
225        } else {
226            match bytes[0] {
227                b'*' => {
228                    if bytes.len() > 1 {
229                        let parts = s[2..s.len() - 1].split(',').collect::<Vec<_>>();
230                        match parts.len() {
231                            1 => {
232                                let number = parts[0].parse()?;
233                                Ok(PgLQueryLevel::Star(Some(number), Some(number)))
234                            }
235                            2 => Ok(PgLQueryLevel::Star(
236                                Some(parts[0].parse()?),
237                                Some(parts[1].parse()?),
238                            )),
239                            _ => Err(PgLQueryParseError::UnexpectedCharacter),
240                        }
241                    } else {
242                        Ok(PgLQueryLevel::Star(None, None))
243                    }
244                }
245                b'!' => Ok(PgLQueryLevel::NotNonStar(
246                    s[1..]
247                        .split('|')
248                        .map(|s| PgLQueryVariant::from_str(s))
249                        .collect::<Result<Vec<_>, PgLQueryParseError>>()?,
250                )),
251                _ => Ok(PgLQueryLevel::NonStar(
252                    s.split('|')
253                        .map(|s| PgLQueryVariant::from_str(s))
254                        .collect::<Result<Vec<_>, PgLQueryParseError>>()?,
255                )),
256            }
257        }
258    }
259}
260
261impl FromStr for PgLQueryVariant {
262    type Err = PgLQueryParseError;
263
264    fn from_str(s: &str) -> Result<Self, Self::Err> {
265        let mut label_length = s.len();
266        let mut rev_iter = s.bytes().rev();
267        let mut modifiers = PgLQueryVariantFlag::empty();
268
269        while let Some(b) = rev_iter.next() {
270            match b {
271                b'@' => modifiers.insert(PgLQueryVariantFlag::IN_CASE),
272                b'*' => modifiers.insert(PgLQueryVariantFlag::ANY_END),
273                b'%' => modifiers.insert(PgLQueryVariantFlag::SUBLEXEME),
274                _ => break,
275            }
276            label_length -= 1;
277        }
278
279        Ok(PgLQueryVariant {
280            label: PgLTreeLabel::new(&s[0..label_length])?,
281            modifiers,
282        })
283    }
284}
285
286fn write_variants(f: &mut Formatter<'_>, variants: &[PgLQueryVariant], not: bool) -> fmt::Result {
287    let mut iter = variants.iter();
288    if let Some(variant) = iter.next() {
289        write!(f, "{}{}", if not { "!" } else { "" }, variant)?;
290        for variant in iter {
291            write!(f, ".{variant}")?;
292        }
293    }
294    Ok(())
295}
296
297impl Display for PgLQueryLevel {
298    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
299        match self {
300            PgLQueryLevel::Star(Some(at_least), Some(at_most)) => {
301                if at_least == at_most {
302                    write!(f, "*{{{at_least}}}")
303                } else {
304                    write!(f, "*{{{at_least},{at_most}}}")
305                }
306            }
307            PgLQueryLevel::Star(Some(at_least), _) => write!(f, "*{{{at_least},}}"),
308            PgLQueryLevel::Star(_, Some(at_most)) => write!(f, "*{{,{at_most}}}"),
309            PgLQueryLevel::Star(_, _) => write!(f, "*"),
310            PgLQueryLevel::NonStar(variants) => write_variants(f, &variants, false),
311            PgLQueryLevel::NotNonStar(variants) => write_variants(f, &variants, true),
312        }
313    }
314}