1use bstr::BStr;
6use serde::de::Visitor;
7use serde::{Deserialize, Deserializer, Serialize, Serializer};
8use std::fmt;
9use std::fmt::Write;
10use std::ops::Deref;
11use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout};
12use zx_types::ZX_MAX_NAME_LEN;
13
14#[derive(
22 Default,
23 Eq,
24 Hash,
25 FromBytes,
26 IntoBytes,
27 Immutable,
28 KnownLayout,
29 PartialEq,
30 PartialOrd,
31 Ord,
32 Clone,
33)]
34pub struct ZXName([u8; ZX_MAX_NAME_LEN]);
35
36#[derive(Debug, PartialEq)]
37pub enum Error {
38 InvalidArgument,
39}
40
41impl ZXName {
42 pub fn as_bytes(&self) -> &[u8] {
43 match self.0.iter().position(|&b| b == 0) {
44 Some(index) => &self.0[..index],
45 None => &self.0[..],
46 }
47 }
48
49 pub fn as_bstr(&self) -> &BStr {
50 BStr::new(self.as_bytes())
51 }
52
53 pub fn buffer(&self) -> &[u8; ZX_MAX_NAME_LEN] {
55 &self.0
56 }
57
58 pub const fn try_from_bytes(b: &[u8]) -> Result<Self, Error> {
61 if b.len() >= ZX_MAX_NAME_LEN {
62 return Err(Error::InvalidArgument);
63 }
64
65 let mut inner = [0u8; ZX_MAX_NAME_LEN];
66 let mut i = 0;
67 while i < b.len() {
68 if b[i] == 0 {
69 return Err(Error::InvalidArgument);
70 }
71 inner[i] = b[i];
72 i += 1;
73 }
74
75 Ok(Self(inner))
76 }
77
78 pub fn from_string_lossy(s: &str) -> Self {
79 Self::from_bytes_lossy(s.as_bytes())
80 }
81
82 #[inline]
83 pub const fn from_bytes_lossy(b: &[u8]) -> Self {
84 let to_copy = if b.len() <= ZX_MAX_NAME_LEN - 1 { b.len() } else { ZX_MAX_NAME_LEN - 1 };
85
86 let mut inner = [0u8; ZX_MAX_NAME_LEN];
87 let mut source_idx = 0;
88 let mut dest_idx = 0;
89 while source_idx < to_copy {
90 if b[source_idx] != 0 {
91 inner[dest_idx] = b[source_idx];
92 dest_idx += 1;
93 }
94 source_idx += 1;
95 }
96
97 Self(inner)
98 }
99}
100
101impl std::fmt::Display for ZXName {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 std::fmt::Display::fmt(self.as_bstr(), f)
105 }
106}
107
108impl std::fmt::Debug for ZXName {
109 #[inline]
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 std::fmt::Debug::fmt(self.as_bstr(), f)
112 }
113}
114
115impl From<&ZXName> for ZXName {
116 fn from(name_ref: &ZXName) -> Self {
117 name_ref.clone()
118 }
119}
120
121impl Deref for ZXName {
122 type Target = [u8; ZX_MAX_NAME_LEN];
123 fn deref(&self) -> &Self::Target {
124 &self.0
125 }
126}
127
128impl Serialize for ZXName {
131 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
132 where
133 S: Serializer,
134 {
135 let mut res = String::new();
136 for chunk in self.as_bstr().utf8_chunks() {
137 for ch in chunk.valid().chars() {
138 match ch {
139 '\\' => {
140 write!(res, "\\\\").unwrap();
141 }
142 _ => res.push(ch),
143 }
144 }
145 for byte in chunk.invalid() {
146 write!(res, "\\x{:02X}", byte).unwrap();
147 }
148 }
149 serializer.serialize_str(&res)
150 }
151}
152
153impl<'de> Deserialize<'de> for ZXName {
155 fn deserialize<D>(deserializer: D) -> Result<ZXName, D::Error>
156 where
157 D: Deserializer<'de>,
158 {
159 struct ZXNameVisitor;
160
161 impl<'de> Visitor<'de> for ZXNameVisitor {
162 type Value = ZXName;
163
164 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
165 formatter.write_str("an string")
166 }
167 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
168 where
169 E: serde::de::Error,
170 {
171 let mut result = Vec::new();
172 let mut chars = v.as_bytes().iter();
173 loop {
174 match chars.next() {
175 None => break,
176 Some(b'\\') => match chars.next() {
177 None => return Err(E::custom("Character expected after '\\' escape")),
178 Some(b'x') => result.push(
179 u8::from_str_radix(
180 str::from_utf8(&[
181 *chars.next().ok_or_else(|| {
182 E::custom("Hex characters expected after '\\x'")
183 })?,
184 *chars.next().ok_or_else(|| {
185 E::custom("Hex characters expected after '\\x'")
186 })?,
187 ])
188 .map_err(|_| E::custom("Invalid utf-8 sequence after '\\x'"))?,
189 16,
190 )
191 .map_err(|_| E::custom("Invalid hex pair after '\\x'"))?,
192 ),
193 Some(v) => result.push(*v),
194 },
195 Some(u) => {
196 result.push(*u);
197 }
198 }
199 }
200 Ok(ZXName::from_bytes_lossy(&result))
201 }
202 }
203 deserializer.deserialize_str(ZXNameVisitor {})
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use rand::Rng;
211 #[test]
212 fn empty_name() {
213 for empty in [
214 &ZXName::try_from_bytes(b"").unwrap(),
215 &ZXName::from_bytes_lossy(b""),
216 &ZXName::from_string_lossy(""),
217 ZXName::ref_from_bytes(&[0u8; ZX_MAX_NAME_LEN]).unwrap(),
218 ] {
219 assert_eq!(0, empty.as_bstr().len());
220 assert_eq!(empty, empty);
221 assert_eq!("", empty.to_string());
222 assert_eq!(&[0u8; ZX_MAX_NAME_LEN], empty.buffer());
223 assert_eq!(&[0u8; ZX_MAX_NAME_LEN], &**empty);
224 }
225 }
226
227 #[test]
228 fn just_fit() {
229 let data = "abcdefghijklmnopqrstuvwxyz01234";
230 for name in [
231 &ZXName::try_from_bytes(data.as_bytes()).unwrap(),
232 &ZXName::from_bytes_lossy(data.as_bytes()),
233 &ZXName::from_string_lossy(data),
234 ZXName::ref_from_bytes(b"abcdefghijklmnopqrstuvwxyz01234\0").unwrap(),
235 ] {
236 assert_eq!("abcdefghijklmnopqrstuvwxyz01234", name.to_string());
237 assert_eq!(ZX_MAX_NAME_LEN - 1, name.to_string().len());
238 assert_eq!(b"abcdefghijklmnopqrstuvwxyz01234\0", name.buffer());
239 assert_eq!(b"abcdefghijklmnopqrstuvwxyz01234\0", &**name);
240 }
241 }
242
243 #[test]
244 fn too_long() {
245 let data = "abcdefghijklmnopqrstuvwxyz012345";
246 assert_eq!(Result::Err(Error::InvalidArgument), ZXName::try_from_bytes(data.as_bytes()));
247
248 for name in [ZXName::from_bytes_lossy(data.as_bytes()), ZXName::from_string_lossy(data)] {
249 assert_eq!("abcdefghijklmnopqrstuvwxyz01234", name.to_string());
250 assert_eq!(ZX_MAX_NAME_LEN - 1, name.to_string().len());
251 }
252 }
253
254 #[test]
255 fn zero_inside() {
256 let data = b"abc\0def\0\0\0";
257 assert_eq!(Err(Error::InvalidArgument), ZXName::try_from_bytes(data));
258 assert_eq!("abcdef", ZXName::from_bytes_lossy(data).to_string());
259 }
260
261 #[test]
262 fn not_utf8() {
263 let data: [u8; 2] = [0xff, 0xff];
264 assert_eq!("\u{FFFD}\u{FFFD}", ZXName::from_bytes_lossy(&data).to_string());
265 }
266
267 #[test]
268 fn test_serialize() {
269 assert_eq!("\"abc\"", serde_json::to_string(&ZXName::from_string_lossy("abc")).unwrap());
270 assert_eq!(
271 "\"\\n\\t\\r'\\\"\\\\\\\\\"",
272 serde_json::to_string(&ZXName::from_string_lossy("\n\t\r'\"\\")).unwrap()
273 );
274 assert_eq!(
275 r#""aĀ(\\xC3)""#,
276 serde_json::to_string(&ZXName::from_bytes_lossy(&[b'a', 0xc4, 0x80, b'(', 0xc3, b')']))
277 .unwrap()
278 );
279 }
280
281 #[test]
282 fn test_deserialize() {
283 assert!(
284 format!("{:?}", serde_json::from_str::<ZXName>(r#""\\""#))
285 .contains("Character expected after '\\\\'")
286 );
287 assert!(
288 format!("{:?}", serde_json::from_str::<ZXName>(r#""\\x""#))
289 .contains("Hex characters expected after")
290 );
291 assert!(
292 format!("{:?}", serde_json::from_str::<ZXName>(r#""\\x1""#))
293 .contains("Hex characters expected after")
294 );
295 assert!(
296 format!("{:?}", serde_json::from_str::<ZXName>(r#""\\x1x""#))
297 .contains("Invalid hex pair after")
298 );
299 }
300
301 #[test]
302 fn test_fuzz_serialize() {
303 let mut rng = rand::rng();
304 for _ in 0..100000 {
305 let byte_vec: Vec<u8> =
306 (0..rng.random_range(0..32)).map(|_| rng.random::<u8>()).collect();
307 let before = ZXName::from_bytes_lossy(&byte_vec);
308 let json = serde_json::to_string(&before).unwrap();
309 let after: ZXName = serde_json::from_str(&json).expect("deserialization works");
310 assert_eq!(before, after);
311 }
312 }
313}