diff --git a/src/human_number.rs b/src/human_number.rs new file mode 100644 index 0000000..61c81bd --- /dev/null +++ b/src/human_number.rs @@ -0,0 +1,148 @@ +use std::convert::TryInto; +use std::fmt; + +use serde::de::{Deserializer, Error, Unexpected, Visitor}; + +pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + deserializer.deserialize_any(HumanNumberVisitor) +} + +struct HumanNumberVisitor; + +impl<'de> Visitor<'de> for HumanNumberVisitor { + type Value = Option; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + formatter.write_str("unsigned number or quoted human-readable unsigned number") + } + + fn visit_i64(self, v: i64) -> Result + where + E: Error, + { + v.try_into() + .map_err(Error::custom) + .and_then(|v| self.visit_u64(v)) + } + + fn visit_u64(self, v: u64) -> Result + where + E: Error, + { + Ok(Some(v)) + } + + fn visit_str(self, v: &str) -> Result + where + E: Error, + { + let v = v.trim(); + + if v.is_empty() { + return Err(Error::invalid_value( + Unexpected::Str(v), + &"non-empty string", + )); + } + + // Using `bytes` instead of `chars` here because `split_at` takes a byte offset + match v + .bytes() + .map(|byte| byte as char) + .position(|ch| !(ch.is_numeric() || ch == '.')) + { + Some(unit_index) => { + let (value, unit) = v.split_at(unit_index); + let value: f64 = value.parse().map_err(Error::custom)?; + + let multiple: u64 = match unit.trim_start() { + "KB" | "kB" => 1000, + "MB" => 1000 * 1000, + "GB" => 1000 * 1000 * 1000, + "TB" => 1000 * 1000 * 1000 * 1000, + + "KiB" => 1024, + "MiB" => 1024 * 1024, + "GiB" => 1024 * 1024 * 1024, + "TiB" => 1024 * 1024 * 1024 * 1024, + + unit => { + return Err(Error::invalid_value( + Unexpected::Str(unit), + &"known unit of measurement", + )) + } + }; + let value = value * (multiple as f64); + + Ok(Some(value.round() as u64)) + } + None => { + // No unit found, interpret as raw number + v.parse().map(Some).map_err(Error::custom) + } + } + } + + fn visit_none(self) -> Result + where + E: Error, + { + Ok(None) + } + + fn visit_some(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_any(self) + } +} + +#[cfg(test)] +mod test { + use serde::de::value::Error; + use serde::de::IntoDeserializer; + + use super::deserialize; + + #[test] + fn test_deserialize() { + fn check_result(input: &str, value: u64) { + assert_eq!( + deserialize(input.into_deserializer()), + Ok::<_, Error>(Some(value)) + ); + } + + check_result("1234", 1234); + check_result("1234 ", 1234); + check_result(" 1234", 1234); + check_result(" 1234 ", 1234); + + check_result("1234kB", 1234 * 1000); + check_result("1234KB", 1234 * 1000); + check_result("1234MB", 1234 * 1000 * 1000); + check_result("1234GB", 1234 * 1000 * 1000 * 1000); + check_result("1234TB", 1234 * 1000 * 1000 * 1000 * 1000); + + check_result("1234KiB", 1234 * 1024); + check_result("1234MiB", 1234 * 1024 * 1024); + check_result("1234GiB", 1234 * 1024 * 1024 * 1024); + check_result("1234TiB", 1234 * 1024 * 1024 * 1024 * 1024); + + check_result("1234 kB", 1234 * 1000); + check_result("1234 KB", 1234 * 1000); + check_result("1234 MB", 1234 * 1000 * 1000); + check_result("1234 GB", 1234 * 1000 * 1000 * 1000); + check_result("1234 TB", 1234 * 1000 * 1000 * 1000 * 1000); + + check_result("1234 KiB", 1234 * 1024); + check_result("1234 MiB", 1234 * 1024 * 1024); + check_result("1234 GiB", 1234 * 1024 * 1024 * 1024); + check_result("1234 TiB", 1234 * 1024 * 1024 * 1024 * 1024); + } +} diff --git a/src/lib.rs b/src/lib.rs index 59c26d7..e606509 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,6 +26,7 @@ use serde::Deserialize; mod config; mod dump; mod format; +mod human_number; mod log; use crate::config::Config; @@ -195,8 +196,11 @@ struct VgpuProfileOverride<'a> { multi_vgpu_supported: Option, pci_id: Option, pci_device_id: Option, + #[serde(with = "human_number")] framebuffer: Option, + #[serde(with = "human_number")] mappable_video_size: Option, + #[serde(with = "human_number")] framebuffer_reservation: Option, encoder_capacity: Option, bar1_length: Option,