1use bstr::BStr;
6use serde::de::Visitor;
7use serde::{Deserialize, Deserializer, Serialize, Serializer};
8use std::fmt;
9use std::fmt::Write;
10use std::ops::Deref;
11use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
12use zx_types::ZX_MAX_NAME_LEN;
13
14#[derive(
22 Default,
23 Eq,
24 Hash,
25 FromBytes,
26 IntoBytes,
27 Immutable,
28 KnownLayout,
29 PartialEq,
30 PartialOrd,
31 Ord,
32 Clone,
33)]
34pub struct ZXName([u8; ZX_MAX_NAME_LEN]);
35
36#[derive(Debug, PartialEq)]
37pub enum Error {
38 InvalidArgument,
39}
40
41impl ZXName {
42 pub fn as_bstr(&self) -> &BStr {
43 BStr::new(match self.0.iter().position(|&b| b == 0) {
44 Some(index) => &self.0[..index],
45 None => &self.0[..],
46 })
47 }
48
49 pub fn buffer(&self) -> &[u8; ZX_MAX_NAME_LEN] {
50 &self.0
51 }
52
53 pub const fn try_from_bytes(b: &[u8]) -> Result<Self, Error> {
54 if b.len() >= ZX_MAX_NAME_LEN {
55 return Err(Error::InvalidArgument);
56 }
57
58 let mut inner = [0u8; ZX_MAX_NAME_LEN];
59 let mut i = 0;
60 while i < b.len() {
61 if b[i] == 0 {
62 return Err(Error::InvalidArgument);
63 }
64 inner[i] = b[i];
65 i += 1;
66 }
67
68 Ok(Self(inner))
69 }
70
71 pub fn from_string_lossy(s: &str) -> Self {
72 Self::from_bytes_lossy(s.as_bytes())
73 }
74
75 #[inline]
76 pub const fn from_bytes_lossy(b: &[u8]) -> Self {
77 let to_copy = if b.len() <= ZX_MAX_NAME_LEN - 1 { b.len() } else { ZX_MAX_NAME_LEN - 1 };
78
79 let mut inner = [0u8; ZX_MAX_NAME_LEN];
80 let mut source_idx = 0;
81 let mut dest_idx = 0;
82 while source_idx < to_copy {
83 if b[source_idx] != 0 {
84 inner[dest_idx] = b[source_idx];
85 dest_idx += 1;
86 }
87 source_idx += 1;
88 }
89
90 Self(inner)
91 }
92}
93
94impl std::fmt::Display for ZXName {
95 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96 std::fmt::Display::fmt(self.as_bstr(), f)
98 }
99}
100
101impl std::fmt::Debug for ZXName {
102 #[inline]
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 std::fmt::Debug::fmt(self.as_bstr(), f)
105 }
106}
107
108impl From<&ZXName> for ZXName {
109 fn from(name_ref: &ZXName) -> Self {
110 name_ref.clone()
111 }
112}
113
114impl Deref for ZXName {
115 type Target = [u8; ZX_MAX_NAME_LEN];
116 fn deref(&self) -> &Self::Target {
117 &self.0
118 }
119}
120
121impl Serialize for ZXName {
124 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
125 where
126 S: Serializer,
127 {
128 let mut res = String::new();
129 for chunk in self.as_bstr().utf8_chunks() {
130 for ch in chunk.valid().chars() {
131 match ch {
132 '\\' => {
133 write!(res, "\\\\").unwrap();
134 }
135 _ => res.push(ch),
136 }
137 }
138 for byte in chunk.invalid() {
139 write!(res, "\\x{:02X}", byte).unwrap();
140 }
141 }
142 serializer.serialize_str(&res)
143 }
144}
145
146impl<'de> Deserialize<'de> for ZXName {
148 fn deserialize<D>(deserializer: D) -> Result<ZXName, D::Error>
149 where
150 D: Deserializer<'de>,
151 {
152 struct ZXNameVisitor;
153
154 impl<'de> Visitor<'de> for ZXNameVisitor {
155 type Value = ZXName;
156
157 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
158 formatter.write_str("an string")
159 }
160 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
161 where
162 E: serde::de::Error,
163 {
164 let mut result = Vec::new();
165 let mut chars = v.as_bytes().iter();
166 loop {
167 match chars.next() {
168 None => break,
169 Some(b'\\') => match chars.next() {
170 None => return Err(E::custom("Character expected after '\\' escape")),
171 Some(b'x') => result.push(
172 u8::from_str_radix(
173 str::from_utf8(&[
174 *chars.next().ok_or_else(|| {
175 E::custom("Hex characters expected after '\\x'")
176 })?,
177 *chars.next().ok_or_else(|| {
178 E::custom("Hex characters expected after '\\x'")
179 })?,
180 ])
181 .map_err(|_| E::custom("Invalid utf-8 sequence after '\\x'"))?,
182 16,
183 )
184 .map_err(|_| E::custom("Invalid hex pair after '\\x'"))?,
185 ),
186 Some(v) => result.push(*v),
187 },
188 Some(u) => {
189 result.push(*u);
190 }
191 }
192 }
193 Ok(ZXName::from_bytes_lossy(&result))
194 }
195 }
196 deserializer.deserialize_str(ZXNameVisitor {})
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use rand::Rng;
204 #[test]
205 fn empty_name() {
206 for empty in [
207 &ZXName::try_from_bytes(b"").unwrap(),
208 &ZXName::from_bytes_lossy(b""),
209 &ZXName::from_string_lossy(""),
210 ZXName::ref_from_bytes(&[0u8; ZX_MAX_NAME_LEN]).unwrap(),
211 ] {
212 assert_eq!(0, empty.as_bstr().len());
213 assert_eq!(empty, empty);
214 assert_eq!("", empty.to_string());
215 assert_eq!(&[0u8; ZX_MAX_NAME_LEN], empty.buffer());
216 assert_eq!(&[0u8; ZX_MAX_NAME_LEN], &**empty);
217 }
218 }
219
220 #[test]
221 fn just_fit() {
222 let data = "abcdefghijklmnopqrstuvwxyz01234";
223 for name in [
224 &ZXName::try_from_bytes(data.as_bytes()).unwrap(),
225 &ZXName::from_bytes_lossy(data.as_bytes()),
226 &ZXName::from_string_lossy(data),
227 ZXName::ref_from_bytes(b"abcdefghijklmnopqrstuvwxyz01234\0").unwrap(),
228 ] {
229 assert_eq!("abcdefghijklmnopqrstuvwxyz01234", name.to_string());
230 assert_eq!(ZX_MAX_NAME_LEN - 1, name.to_string().len());
231 assert_eq!(b"abcdefghijklmnopqrstuvwxyz01234\0", name.buffer());
232 assert_eq!(b"abcdefghijklmnopqrstuvwxyz01234\0", &**name);
233 }
234 }
235
236 #[test]
237 fn too_long() {
238 let data = "abcdefghijklmnopqrstuvwxyz012345";
239 assert_eq!(Result::Err(Error::InvalidArgument), ZXName::try_from_bytes(data.as_bytes()));
240
241 for name in [ZXName::from_bytes_lossy(data.as_bytes()), ZXName::from_string_lossy(data)] {
242 assert_eq!("abcdefghijklmnopqrstuvwxyz01234", name.to_string());
243 assert_eq!(ZX_MAX_NAME_LEN - 1, name.to_string().len());
244 }
245 }
246
247 #[test]
248 fn zero_inside() {
249 let data = b"abc\0def\0\0\0";
250 assert_eq!(Err(Error::InvalidArgument), ZXName::try_from_bytes(data));
251 assert_eq!("abcdef", ZXName::from_bytes_lossy(data).to_string());
252 }
253
254 #[test]
255 fn not_utf8() {
256 let data: [u8; 2] = [0xff, 0xff];
257 assert_eq!("\u{FFFD}\u{FFFD}", ZXName::from_bytes_lossy(&data).to_string());
258 }
259
260 #[test]
261 fn test_serialize() {
262 assert_eq!("\"abc\"", serde_json::to_string(&ZXName::from_string_lossy("abc")).unwrap());
263 assert_eq!(
264 "\"\\n\\t\\r'\\\"\\\\\\\\\"",
265 serde_json::to_string(&ZXName::from_string_lossy("\n\t\r'\"\\")).unwrap()
266 );
267 assert_eq!(
268 r#""aĀ(\\xC3)""#,
269 serde_json::to_string(&ZXName::from_bytes_lossy(&[b'a', 0xc4, 0x80, b'(', 0xc3, b')']))
270 .unwrap()
271 );
272 }
273
274 #[test]
275 fn test_deserialize() {
276 assert!(format!("{:?}", serde_json::from_str::<ZXName>(r#""\\""#))
277 .contains("Character expected after '\\\\'"));
278 assert!(format!("{:?}", serde_json::from_str::<ZXName>(r#""\\x""#))
279 .contains("Hex characters expected after"));
280 assert!(format!("{:?}", serde_json::from_str::<ZXName>(r#""\\x1""#))
281 .contains("Hex characters expected after"));
282 assert!(format!("{:?}", serde_json::from_str::<ZXName>(r#""\\x1x""#))
283 .contains("Invalid hex pair after"));
284 }
285
286 #[test]
287 fn test_fuzz_serialize() {
288 let mut rng = rand::thread_rng();
289 for _ in 0..100000 {
290 let byte_vec: Vec<u8> = (0..rng.gen_range(0..32)).map(|_| rng.gen::<u8>()).collect();
291 let before = ZXName::from_bytes_lossy(&byte_vec);
292 let json = serde_json::to_string(&before).unwrap();
293 let after: ZXName = serde_json::from_str(&json).expect("deserialization works");
294 assert_eq!(before, after);
295 }
296 }
297}