sqlx_postgres/types/
array.rs

1use sqlx_core::bytes::Buf;
2use sqlx_core::types::Text;
3use std::borrow::Cow;
4
5use crate::decode::Decode;
6use crate::encode::{Encode, IsNull};
7use crate::error::BoxDynError;
8use crate::type_info::PgType;
9use crate::types::Oid;
10use crate::types::Type;
11use crate::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};
12
13/// Provides information necessary to encode and decode Postgres arrays as compatible Rust types.
14///
15/// Implementing this trait for some type `T` enables relevant `Type`,`Encode` and `Decode` impls
16/// for `Vec<T>`, `&[T]` (slices), `[T; N]` (arrays), etc.
17///
18/// ### Note: `#[derive(sqlx::Type)]`
19/// If you have the `postgres` feature enabled, `#[derive(sqlx::Type)]` will also generate
20/// an impl of this trait for your type if your wrapper is marked `#[sqlx(transparent)]`:
21///
22/// ```rust,ignore
23/// #[derive(sqlx::Type)]
24/// #[sqlx(transparent)]
25/// struct UserId(i64);
26///
27/// let user_ids: Vec<UserId> = sqlx::query_scalar("select '{ 123, 456 }'::int8[]")
28///    .fetch(&mut pg_connection)
29///    .await?;
30/// ```
31///
32/// However, this may cause an error if the type being wrapped does not implement `PgHasArrayType`,
33/// e.g. `Vec` itself, because we don't currently support multidimensional arrays:
34///
35/// ```rust,ignore
36/// #[derive(sqlx::Type)] // ERROR: `Vec<i64>` does not implement `PgHasArrayType`
37/// #[sqlx(transparent)]
38/// struct UserIds(Vec<i64>);
39/// ```
40///
41/// To remedy this, add `#[sqlx(no_pg_array)]`, which disables the generation
42/// of the `PgHasArrayType` impl:
43///
44/// ```rust,ignore
45/// #[derive(sqlx::Type)]
46/// #[sqlx(transparent, no_pg_array)]
47/// struct UserIds(Vec<i64>);
48/// ```
49///
50/// See [the documentation of `Type`][Type] for more details.
51pub trait PgHasArrayType {
52    fn array_type_info() -> PgTypeInfo;
53    fn array_compatible(ty: &PgTypeInfo) -> bool {
54        *ty == Self::array_type_info()
55    }
56}
57
58impl<T> PgHasArrayType for Option<T>
59where
60    T: PgHasArrayType,
61{
62    fn array_type_info() -> PgTypeInfo {
63        T::array_type_info()
64    }
65
66    fn array_compatible(ty: &PgTypeInfo) -> bool {
67        T::array_compatible(ty)
68    }
69}
70
71impl<T> PgHasArrayType for Text<T> {
72    fn array_type_info() -> PgTypeInfo {
73        String::array_type_info()
74    }
75
76    fn array_compatible(ty: &PgTypeInfo) -> bool {
77        String::array_compatible(ty)
78    }
79}
80
81impl<T> Type<Postgres> for [T]
82where
83    T: PgHasArrayType,
84{
85    fn type_info() -> PgTypeInfo {
86        T::array_type_info()
87    }
88
89    fn compatible(ty: &PgTypeInfo) -> bool {
90        T::array_compatible(ty)
91    }
92}
93
94impl<T> Type<Postgres> for Vec<T>
95where
96    T: PgHasArrayType,
97{
98    fn type_info() -> PgTypeInfo {
99        T::array_type_info()
100    }
101
102    fn compatible(ty: &PgTypeInfo) -> bool {
103        T::array_compatible(ty)
104    }
105}
106
107impl<T, const N: usize> Type<Postgres> for [T; N]
108where
109    T: PgHasArrayType,
110{
111    fn type_info() -> PgTypeInfo {
112        T::array_type_info()
113    }
114
115    fn compatible(ty: &PgTypeInfo) -> bool {
116        T::array_compatible(ty)
117    }
118}
119
120impl<'q, T> Encode<'q, Postgres> for Vec<T>
121where
122    for<'a> &'a [T]: Encode<'q, Postgres>,
123    T: Encode<'q, Postgres>,
124{
125    #[inline]
126    fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
127        self.as_slice().encode_by_ref(buf)
128    }
129}
130
131impl<'q, T, const N: usize> Encode<'q, Postgres> for [T; N]
132where
133    for<'a> &'a [T]: Encode<'q, Postgres>,
134    T: Encode<'q, Postgres>,
135{
136    fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
137        self.as_slice().encode_by_ref(buf)
138    }
139}
140
141impl<'q, T> Encode<'q, Postgres> for &'_ [T]
142where
143    T: Encode<'q, Postgres> + Type<Postgres>,
144{
145    fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull {
146        let type_info = if self.len() < 1 {
147            T::type_info()
148        } else {
149            self[0].produces().unwrap_or_else(T::type_info)
150        };
151
152        buf.extend(&1_i32.to_be_bytes()); // number of dimensions
153        buf.extend(&0_i32.to_be_bytes()); // flags
154
155        // element type
156        match type_info.0 {
157            PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
158
159            ty => {
160                buf.extend(&ty.oid().0.to_be_bytes());
161            }
162        }
163
164        buf.extend(&(self.len() as i32).to_be_bytes()); // len
165        buf.extend(&1_i32.to_be_bytes()); // lower bound
166
167        for element in self.iter() {
168            buf.encode(element);
169        }
170
171        IsNull::No
172    }
173}
174
175impl<'r, T, const N: usize> Decode<'r, Postgres> for [T; N]
176where
177    T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
178{
179    fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
180        // This could be done more efficiently by refactoring the Vec decoding below so that it can
181        // be used for arrays and Vec.
182        let vec: Vec<T> = Decode::decode(value)?;
183        let array: [T; N] = vec.try_into().map_err(|_| "wrong number of elements")?;
184        Ok(array)
185    }
186}
187
188impl<'r, T> Decode<'r, Postgres> for Vec<T>
189where
190    T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
191{
192    fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
193        let format = value.format();
194
195        match format {
196            PgValueFormat::Binary => {
197                // https://github.com/postgres/postgres/blob/a995b371ae29de2d38c4b7881cf414b1560e9746/src/backend/utils/adt/arrayfuncs.c#L1548
198
199                let mut buf = value.as_bytes()?;
200
201                // number of dimensions in the array
202                let ndim = buf.get_i32();
203
204                if ndim == 0 {
205                    // zero dimensions is an empty array
206                    return Ok(Vec::new());
207                }
208
209                if ndim != 1 {
210                    return Err(format!("encountered an array of {ndim} dimensions; only one-dimensional arrays are supported").into());
211                }
212
213                // appears to have been used in the past to communicate potential NULLS
214                // but reading source code back through our supported postgres versions (9.5+)
215                // this is never used for anything
216                let _flags = buf.get_i32();
217
218                // the OID of the element
219                let element_type_oid = Oid(buf.get_u32());
220                let element_type_info: PgTypeInfo = PgTypeInfo::try_from_oid(element_type_oid)
221                    .or_else(|| value.type_info.try_array_element().map(Cow::into_owned))
222                    .ok_or_else(|| {
223                        BoxDynError::from(format!(
224                            "failed to resolve array element type for oid {}",
225                            element_type_oid.0
226                        ))
227                    })?;
228
229                // length of the array axis
230                let len = buf.get_i32();
231
232                // the lower bound, we only support arrays starting from "1"
233                let lower = buf.get_i32();
234
235                if lower != 1 {
236                    return Err(format!("encountered an array with a lower bound of {lower} in the first dimension; only arrays starting at one are supported").into());
237                }
238
239                let mut elements = Vec::with_capacity(len as usize);
240
241                for _ in 0..len {
242                    elements.push(T::decode(PgValueRef::get(
243                        &mut buf,
244                        format,
245                        element_type_info.clone(),
246                    ))?)
247                }
248
249                Ok(elements)
250            }
251
252            PgValueFormat::Text => {
253                // no type is provided from the database for the element
254                let element_type_info = T::type_info();
255
256                let s = value.as_str()?;
257
258                // https://github.com/postgres/postgres/blob/a995b371ae29de2d38c4b7881cf414b1560e9746/src/backend/utils/adt/arrayfuncs.c#L718
259
260                // trim the wrapping braces
261                let s = &s[1..(s.len() - 1)];
262
263                if s.is_empty() {
264                    // short-circuit empty arrays up here
265                    return Ok(Vec::new());
266                }
267
268                // NOTE: Nearly *all* types use ',' as the sequence delimiter. Yes, there is one
269                //       that does not. The BOX (not PostGIS) type uses ';' as a delimiter.
270
271                // TODO: When we add support for BOX we need to figure out some way to make the
272                //       delimiter selection
273
274                let delimiter = ',';
275                let mut done = false;
276                let mut in_quotes = false;
277                let mut in_escape = false;
278                let mut value = String::with_capacity(10);
279                let mut chars = s.chars();
280                let mut elements = Vec::with_capacity(4);
281
282                while !done {
283                    loop {
284                        match chars.next() {
285                            Some(ch) => match ch {
286                                _ if in_escape => {
287                                    value.push(ch);
288                                    in_escape = false;
289                                }
290
291                                '"' => {
292                                    in_quotes = !in_quotes;
293                                }
294
295                                '\\' => {
296                                    in_escape = true;
297                                }
298
299                                _ if ch == delimiter && !in_quotes => {
300                                    break;
301                                }
302
303                                _ => {
304                                    value.push(ch);
305                                }
306                            },
307
308                            None => {
309                                done = true;
310                                break;
311                            }
312                        }
313                    }
314
315                    let value_opt = if value == "NULL" {
316                        None
317                    } else {
318                        Some(value.as_bytes())
319                    };
320
321                    elements.push(T::decode(PgValueRef {
322                        value: value_opt,
323                        row: None,
324                        type_info: element_type_info.clone(),
325                        format,
326                    })?);
327
328                    value.clear();
329                }
330
331                Ok(elements)
332            }
333        }
334    }
335}