diff --git a/src/tracker.rs b/src/tracker.rs index c7078a0..b449b92 100644 --- a/src/tracker.rs +++ b/src/tracker.rs @@ -1,4 +1,5 @@ use std; +use serde; use binascii; use server::Events; @@ -57,6 +58,38 @@ impl serde::ser::Serialize for InfoHash { } } +struct InfoHashVisitor; + +impl<'v> serde::de::Visitor<'v> for InfoHashVisitor { + type Value = InfoHash; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "a 40 character long hash") + } + + fn visit_str(self, v: &str) -> Result { + if v.len() != 40 { + return Err(serde::de::Error::invalid_value(serde::de::Unexpected::Str(v), &"expected a 40 character long string")); + } + + let mut res = InfoHash{ + info_hash: [0u8; 20], + }; + + if let Err(_) = binascii::hex2bin(v.as_bytes(), &mut res.info_hash) { + return Err(serde::de::Error::invalid_value(serde::de::Unexpected::Str(v), &"expected a hexadecimal string")); + } else { + return Ok(res); + } + } +} + +impl<'de> serde::de::Deserialize<'de> for InfoHash { + fn deserialize>(des: D) -> Result { + des.deserialize_str(InfoHashVisitor) + } +} + pub type PeerId = [u8; 20]; #[derive(Serialize, Deserialize)] @@ -283,4 +316,17 @@ mod tests { fn tracker_sync() { is_sync::(); } + + #[test] + fn test_infohash_de() { + use serde_json; + + let ih: InfoHash = [0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1].into(); + + let serialized_ih = serde_json::to_string(&ih).unwrap(); + + let de_ih: InfoHash = serde_json::from_str(serialized_ih.as_str()).unwrap(); + + assert!(de_ih == ih); + } } \ No newline at end of file