attribution_processing/
name.rs1use bstr::BStr;
6use serde::de::Visitor;
7use serde::{Deserialize, Deserializer, Serialize, Serializer};
8use std::fmt;
9use std::fmt::Write;
10use zerocopy::{FromBytes, Immutable, KnownLayout};
11use zx_types::ZX_MAX_NAME_LEN;
12
13#[derive(
21 Default, Eq, Hash, FromBytes, Immutable, KnownLayout, PartialEq, PartialOrd, Ord, Clone,
22)]
23pub struct ZXName([u8; ZX_MAX_NAME_LEN]);
24
25#[derive(Debug, PartialEq)]
26pub enum Error {
27 InvalidArgument,
28}
29
30impl ZXName {
31 pub fn as_bstr(&self) -> &BStr {
32 BStr::new(match self.0.iter().position(|&b| b == 0) {
33 Some(index) => &self.0[..index],
34 None => &self.0[..],
35 })
36 }
37
38 pub const fn try_from_bytes(b: &[u8]) -> Result<Self, Error> {
39 if b.len() >= ZX_MAX_NAME_LEN {
40 return Err(Error::InvalidArgument);
41 }
42
43 let mut inner = [0u8; ZX_MAX_NAME_LEN];
44 let mut i = 0;
45 while i < b.len() {
46 if b[i] == 0 {
47 return Err(Error::InvalidArgument);
48 }
49 inner[i] = b[i];
50 i += 1;
51 }
52
53 Ok(Self(inner))
54 }
55
56 pub fn from_string_lossy(s: &str) -> Self {
57 Self::from_bytes_lossy(s.as_bytes())
58 }
59
60 #[inline]
61 pub const fn from_bytes_lossy(b: &[u8]) -> Self {
62 let to_copy = if b.len() <= ZX_MAX_NAME_LEN - 1 { b.len() } else { ZX_MAX_NAME_LEN - 1 };
63
64 let mut inner = [0u8; ZX_MAX_NAME_LEN];
65 let mut source_idx = 0;
66 let mut dest_idx = 0;
67 while source_idx < to_copy {
68 if b[source_idx] != 0 {
69 inner[dest_idx] = b[source_idx];
70 dest_idx += 1;
71 }
72 source_idx += 1;
73 }
74
75 Self(inner)
76 }
77}
78
79impl std::fmt::Display for ZXName {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 std::fmt::Display::fmt(self.as_bstr(), f)
83 }
84}
85
86impl std::fmt::Debug for ZXName {
87 #[inline]
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 std::fmt::Debug::fmt(self.as_bstr(), f)
90 }
91}
92
93impl From<&ZXName> for ZXName {
94 fn from(name_ref: &ZXName) -> Self {
95 name_ref.clone()
96 }
97}
98
99impl Serialize for ZXName {
102 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
103 where
104 S: Serializer,
105 {
106 let mut res = String::new();
107 for chunk in self.as_bstr().utf8_chunks() {
108 for ch in chunk.valid().chars() {
109 match ch {
110 '\\' => {
111 write!(res, "\\\\").unwrap();
112 }
113 _ => res.push(ch),
114 }
115 }
116 for byte in chunk.invalid() {
117 write!(res, "\\x{:02X}", byte).unwrap();
118 }
119 }
120 serializer.serialize_str(&res)
121 }
122}
123
124impl<'de> Deserialize<'de> for ZXName {
126 fn deserialize<D>(deserializer: D) -> Result<ZXName, D::Error>
127 where
128 D: Deserializer<'de>,
129 {
130 struct ZXNameVisitor;
131
132 impl<'de> Visitor<'de> for ZXNameVisitor {
133 type Value = ZXName;
134
135 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
136 formatter.write_str("an string")
137 }
138 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
139 where
140 E: serde::de::Error,
141 {
142 let mut result = Vec::new();
143 let mut chars = v.as_bytes().iter();
144 loop {
145 match chars.next() {
146 None => break,
147 Some(b'\\') => match chars.next() {
148 None => return Err(E::custom("Character expected after '\\' escape")),
149 Some(b'x') => result.push(
150 u8::from_str_radix(
151 str::from_utf8(&[
152 *chars.next().ok_or_else(|| {
153 E::custom("Hex characters expected after '\\x'")
154 })?,
155 *chars.next().ok_or_else(|| {
156 E::custom("Hex characters expected after '\\x'")
157 })?,
158 ])
159 .map_err(|_| E::custom("Invalid utf-8 sequence after '\\x'"))?,
160 16,
161 )
162 .map_err(|_| E::custom("Invalid hex pair after '\\x'"))?,
163 ),
164 Some(v) => result.push(*v),
165 },
166 Some(u) => {
167 result.push(*u);
168 }
169 }
170 }
171 Ok(ZXName::from_bytes_lossy(&result))
172 }
173 }
174 deserializer.deserialize_str(ZXNameVisitor {})
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use rand::Rng;
182 #[test]
183 fn empty_name() {
184 for empty in [
185 &ZXName::try_from_bytes(b"").unwrap(),
186 &ZXName::from_bytes_lossy(b""),
187 &ZXName::from_string_lossy(""),
188 ZXName::ref_from_bytes(&[0u8; ZX_MAX_NAME_LEN]).unwrap(),
189 ] {
190 assert_eq!(0, empty.as_bstr().len());
191 assert_eq!(empty, empty);
192 assert_eq!("", empty.to_string());
193 }
194 }
195
196 #[test]
197 fn just_fit() {
198 let data = "abcdefghijklmnopqrstuvwxyz01234";
199 for name in [
200 &ZXName::try_from_bytes(data.as_bytes()).unwrap(),
201 &ZXName::from_bytes_lossy(data.as_bytes()),
202 &ZXName::from_string_lossy(data),
203 ZXName::ref_from_bytes(b"abcdefghijklmnopqrstuvwxyz01234\0").unwrap(),
204 ] {
205 assert_eq!("abcdefghijklmnopqrstuvwxyz01234", name.to_string());
206 assert_eq!(ZX_MAX_NAME_LEN - 1, name.to_string().len());
207 }
208 }
209
210 #[test]
211 fn too_long() {
212 let data = "abcdefghijklmnopqrstuvwxyz012345";
213 assert_eq!(Result::Err(Error::InvalidArgument), ZXName::try_from_bytes(data.as_bytes()));
214
215 for name in [ZXName::from_bytes_lossy(data.as_bytes()), ZXName::from_string_lossy(data)] {
216 assert_eq!("abcdefghijklmnopqrstuvwxyz01234", name.to_string());
217 assert_eq!(ZX_MAX_NAME_LEN - 1, name.to_string().len());
218 }
219 }
220
221 #[test]
222 fn zero_inside() {
223 let data = b"abc\0def\0\0\0";
224 assert_eq!(Err(Error::InvalidArgument), ZXName::try_from_bytes(data));
225 assert_eq!("abcdef", ZXName::from_bytes_lossy(data).to_string());
226 }
227
228 #[test]
229 fn not_utf8() {
230 let data: [u8; 2] = [0xff, 0xff];
231 assert_eq!("\u{FFFD}\u{FFFD}", ZXName::from_bytes_lossy(&data).to_string());
232 }
233
234 #[test]
235 fn test_serialize() {
236 assert_eq!("\"abc\"", serde_json::to_string(&ZXName::from_string_lossy("abc")).unwrap());
237 assert_eq!(
238 "\"\\n\\t\\r'\\\"\\\\\\\\\"",
239 serde_json::to_string(&ZXName::from_string_lossy("\n\t\r'\"\\")).unwrap()
240 );
241 assert_eq!(
242 r#""aĀ(\\xC3)""#,
243 serde_json::to_string(&ZXName::from_bytes_lossy(&[b'a', 0xc4, 0x80, b'(', 0xc3, b')']))
244 .unwrap()
245 );
246 }
247
248 #[test]
249 fn test_deserialize() {
250 assert!(format!("{:?}", serde_json::from_str::<ZXName>(r#""\\""#))
251 .contains("Character expected after '\\\\'"));
252 assert!(format!("{:?}", serde_json::from_str::<ZXName>(r#""\\x""#))
253 .contains("Hex characters expected after"));
254 assert!(format!("{:?}", serde_json::from_str::<ZXName>(r#""\\x1""#))
255 .contains("Hex characters expected after"));
256 assert!(format!("{:?}", serde_json::from_str::<ZXName>(r#""\\x1x""#))
257 .contains("Invalid hex pair after"));
258 }
259
260 #[test]
261 fn test_fuzz_serialize() {
262 let mut rng = rand::thread_rng();
263 for _ in 0..100000 {
264 let byte_vec: Vec<u8> = (0..rng.gen_range(0..32)).map(|_| rng.gen::<u8>()).collect();
265 let before = ZXName::from_bytes_lossy(&byte_vec);
266 let json = serde_json::to_string(&before).unwrap();
267 let after: ZXName = serde_json::from_str(&json).expect("deserialization works");
268 assert_eq!(before, after);
269 }
270 }
271}