Skip to main content

csv_legacy2/
serde.rs

1#![cfg(feature = "serde")]
2
3extern crate alloc;
4
5use alloc::{
6    borrow::ToOwned,
7    collections::BTreeMap,
8    string::{String, ToString},
9    vec::Vec,
10};
11use core::fmt;
12
13use serde::de::{self, DeserializeSeed, Deserializer, Error as _, SeqAccess, Visitor};
14
15use crate::{error::ReadError, reader::Row};
16
17impl Row<'_> {
18    /// Deserialize this row into a `T`.
19    ///
20    /// If [`parse_headers`](crate::Reader::parse_headers) was called
21    /// before iterating, struct fields are matched by column name.
22    /// Otherwise, fields are mapped positionally.
23    ///
24    /// # Example
25    ///
26    /// ```no_run
27    /// use csv::Reader;
28    /// use serde::Deserialize;
29    ///
30    /// #[derive(Deserialize)]
31    /// struct Record { name: String, age: u32 }
32    ///
33    /// let mut reader = Reader::from_reader(std::io::Cursor::new(b"name,age\nAlice,30\n"));
34    /// reader.parse_headers()?;
35    /// for row in reader.rows() {
36    ///     let rec: Record = row.deserialize()?;
37    ///     println!("{} is {}", rec.name, rec.age);
38    /// }
39    /// # Ok::<_, Box<dyn std::error::Error>>(())
40    /// ```
41    pub fn deserialize<T>(&self) -> Result<T, ReadError>
42    where
43        T: serde::de::DeserializeOwned,
44    {
45        let owned: Vec<String> = self.fields()?.map(|s| s.to_string()).collect();
46
47        if let Some(header_map) = self.header_map {
48            let mut deser = HeaderRow {
49                fields: &owned,
50                header_map,
51                struct_fields: &[],
52                index: 0,
53            };
54            T::deserialize(&mut deser)
55                .map_err(|e| ReadError::new(crate::error::ReadErrorKind::Deserialize(e.msg), 0, 0))
56        } else {
57            let mut deser = PositionalRow {
58                fields: &owned,
59                index: 0,
60            };
61            T::deserialize(&mut deser)
62                .map_err(|e| ReadError::new(crate::error::ReadErrorKind::Deserialize(e.msg), 0, 0))
63        }
64    }
65}
66
67struct HeaderRow<'de> {
68    fields: &'de [String],
69    header_map: &'de BTreeMap<String, usize>,
70    struct_fields: &'static [&'static str],
71    index: usize,
72}
73
74impl<'de> Deserializer<'de> for &mut HeaderRow<'de> {
75    type Error = CsvError;
76
77    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
78    where
79        V: Visitor<'de>,
80    {
81        self.deserialize_struct("", &[], visitor)
82    }
83
84    fn deserialize_struct<V>(
85        mut self,
86        _name: &'static str,
87        struct_fields: &'static [&'static str],
88        visitor: V,
89    ) -> Result<V::Value, Self::Error>
90    where
91        V: Visitor<'de>,
92    {
93        self.struct_fields = struct_fields;
94        self.index = 0;
95        visitor.visit_seq(&mut self)
96    }
97
98    serde::forward_to_deserialize_any! {
99        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
100        bytes byte_buf option unit unit_struct newtype_struct seq tuple
101        tuple_struct map enum identifier ignored_any
102    }
103}
104
105impl<'de> SeqAccess<'de> for HeaderRow<'de> {
106    type Error = CsvError;
107
108    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
109    where
110        T: DeserializeSeed<'de>,
111    {
112        if self.index >= self.struct_fields.len() {
113            return Ok(None);
114        }
115        let field_name = self.struct_fields[self.index];
116        self.index += 1;
117        let val = self
118            .header_map
119            .get(field_name)
120            .and_then(|&idx| self.fields.get(idx).map(|s| s.as_str()))
121            .unwrap_or("");
122        seed.deserialize(FieldDeserializer(val)).map(Some)
123    }
124}
125
126struct PositionalRow<'de> {
127    fields: &'de [String],
128    index: usize,
129}
130
131impl<'de> Deserializer<'de> for &mut PositionalRow<'de> {
132    type Error = CsvError;
133
134    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
135    where
136        V: Visitor<'de>,
137    {
138        self.deserialize_struct("", &[], visitor)
139    }
140
141    fn deserialize_struct<V>(
142        mut self,
143        _name: &'static str,
144        _struct_fields: &'static [&'static str],
145        visitor: V,
146    ) -> Result<V::Value, Self::Error>
147    where
148        V: Visitor<'de>,
149    {
150        self.index = 0;
151        visitor.visit_seq(&mut self)
152    }
153
154    serde::forward_to_deserialize_any! {
155        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
156        bytes byte_buf option unit unit_struct newtype_struct seq tuple
157        tuple_struct map enum identifier ignored_any
158    }
159}
160
161impl<'de> SeqAccess<'de> for PositionalRow<'de> {
162    type Error = CsvError;
163
164    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
165    where
166        T: DeserializeSeed<'de>,
167    {
168        if self.index >= self.fields.len() {
169            return Ok(None);
170        }
171        let val = self.fields[self.index].as_str();
172        self.index += 1;
173        seed.deserialize(FieldDeserializer(val)).map(Some)
174    }
175}
176
177/// Deserializes a single CSV field value with proper type coercion.
178///
179/// Strings are parsed into the requested type so that fields like
180/// `"30"` can be deserialized into `u32`.
181struct FieldDeserializer<'a>(&'a str);
182
183macro_rules! forward_parse {
184    ($($method:ident => $visit:ident :: $ty:ty),*) => {
185        $(
186            fn $method<V>(self, visitor: V) -> Result<V::Value, Self::Error>
187            where V: Visitor<'de>,
188            {
189                let v: $ty = self.0.parse().map_err(|_| {
190                    CsvError::custom(concat!("invalid ", stringify!($ty)))
191                })?;
192                visitor.$visit(v)
193            }
194        )*
195    };
196}
197
198impl<'de, 'a: 'de> Deserializer<'de> for FieldDeserializer<'a> {
199    type Error = CsvError;
200
201    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
202    where
203        V: Visitor<'de>,
204    {
205        visitor.visit_borrowed_str(self.0)
206    }
207
208    forward_parse! {
209        deserialize_bool   => visit_bool   :: bool,
210        deserialize_i8     => visit_i8     :: i8,
211        deserialize_i16    => visit_i16    :: i16,
212        deserialize_i32    => visit_i32    :: i32,
213        deserialize_i64    => visit_i64    :: i64,
214        deserialize_u8     => visit_u8     :: u8,
215        deserialize_u16    => visit_u16    :: u16,
216        deserialize_u32    => visit_u32    :: u32,
217        deserialize_u64    => visit_u64    :: u64,
218        deserialize_f32    => visit_f32    :: f32,
219        deserialize_f64    => visit_f64    :: f64
220    }
221
222    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
223    where
224        V: Visitor<'de>,
225    {
226        let ch = self.0.chars().next().ok_or_else(|| CsvError::custom("empty char"))?;
227        visitor.visit_char(ch)
228    }
229
230    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
231    where
232        V: Visitor<'de>,
233    {
234        visitor.visit_borrowed_str(self.0)
235    }
236
237    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
238    where
239        V: Visitor<'de>,
240    {
241        visitor.visit_string(self.0.to_owned())
242    }
243
244    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
245    where
246        V: Visitor<'de>,
247    {
248        visitor.visit_borrowed_bytes(self.0.as_bytes())
249    }
250
251    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
252    where
253        V: Visitor<'de>,
254    {
255        visitor.visit_byte_buf(self.0.as_bytes().to_vec())
256    }
257
258    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
259    where
260        V: Visitor<'de>,
261    {
262        if self.0.is_empty() {
263            visitor.visit_none()
264        } else {
265            visitor.visit_some(self)
266        }
267    }
268
269    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
270    where
271        V: Visitor<'de>,
272    {
273        visitor.visit_unit()
274    }
275
276    fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value, Self::Error>
277    where
278        V: Visitor<'de>,
279    {
280        visitor.visit_unit()
281    }
282
283    fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value, Self::Error>
284    where
285        V: Visitor<'de>,
286    {
287        visitor.visit_newtype_struct(self)
288    }
289
290    fn deserialize_seq<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
291    where
292        V: Visitor<'de>,
293    {
294        Err(CsvError::custom("cannot deserialize sequence from a single field"))
295    }
296
297    fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
298    where
299        V: Visitor<'de>,
300    {
301        Err(CsvError::custom("cannot deserialize tuple from a single field"))
302    }
303
304    fn deserialize_tuple_struct<V>(self, _name: &'static str, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
305    where
306        V: Visitor<'de>,
307    {
308        Err(CsvError::custom("cannot deserialize tuple struct from a single field"))
309    }
310
311    fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
312    where
313        V: Visitor<'de>,
314    {
315        Err(CsvError::custom("cannot deserialize map from a single field"))
316    }
317
318    fn deserialize_struct<V>(
319        self,
320        _name: &'static str,
321        _fields: &'static [&'static str],
322        _visitor: V,
323    ) -> Result<V::Value, Self::Error>
324    where
325        V: Visitor<'de>,
326    {
327        Err(CsvError::custom("cannot deserialize struct from a single field"))
328    }
329
330    fn deserialize_enum<V>(
331        self,
332        _name: &'static str,
333        _variants: &'static [&'static str],
334        _visitor: V,
335    ) -> Result<V::Value, Self::Error>
336    where
337        V: Visitor<'de>,
338    {
339        Err(CsvError::custom("cannot deserialize enum from a single field"))
340    }
341
342    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
343    where
344        V: Visitor<'de>,
345    {
346        visitor.visit_borrowed_str(self.0)
347    }
348
349    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
350    where
351        V: Visitor<'de>,
352    {
353        visitor.visit_unit()
354    }
355}
356
357#[derive(Debug)]
358pub struct CsvError {
359    pub(crate) msg: String,
360}
361
362impl de::Error for CsvError {
363    fn custom<T: fmt::Display>(msg: T) -> Self {
364        CsvError {
365            msg: msg.to_string(),
366        }
367    }
368}
369
370impl fmt::Display for CsvError {
371    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
372        write!(f, "{}", self.msg)
373    }
374}
375
376#[cfg(feature = "std")]
377impl std::error::Error for CsvError {}
378
379#[cfg(all(not(feature = "std"), feature = "serde"))]
380impl core::error::Error for CsvError {}