diff --git a/src/config.rs b/src/config.rs index 75c59b6..8aebb2c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -4,6 +4,8 @@ use std::collections::HashMap; use serde::Deserialize; +use crate::string_number::U32; + struct Defaults; impl Defaults { @@ -25,7 +27,7 @@ pub struct Config { #[serde(default = "Defaults::unlock_migration")] pub unlock_migration: bool, #[serde(default)] - pub pci_info_map: Option>, + pub pci_info_map: Option>, } #[derive(Deserialize)] diff --git a/src/lib.rs b/src/lib.rs index f61a0b7..e05f74d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -32,6 +32,7 @@ mod human_number; mod ioctl; mod log; mod nvidia; +mod string_number; mod to_bytes; mod utils; mod uuid; @@ -64,6 +65,7 @@ use crate::nvidia::error::{ NV_ERR_BUSY_RETRY, NV_ERR_NOT_SUPPORTED, NV_ERR_OBJECT_NOT_FOUND, NV_OK, }; use crate::nvidia::nvos::{Nvos54Parameters, NV_ESC_RM_CONTROL}; +use crate::string_number::U32; #[cfg(feature = "proxmox")] use crate::utils::uuid_to_vmid; use crate::uuid::Uuid; @@ -485,7 +487,7 @@ pub unsafe extern "C" fn ioctl(fd: RawFd, request: c_ulong, argp: *mut c_void) - let mapped_id = CONFIG .pci_info_map .as_ref() - .and_then(|pci_info_map| pci_info_map.get(&orig_device_id)); + .and_then(|pci_info_map| pci_info_map.get(&U32(orig_device_id))); let actual_device_id = (orig_device_id & 0xffff0000) >> 16; let actual_sub_system_id = (orig_sub_system_id & 0xffff0000) >> 16; diff --git a/src/string_number.rs b/src/string_number.rs new file mode 100644 index 0000000..ce21bd4 --- /dev/null +++ b/src/string_number.rs @@ -0,0 +1,100 @@ +use std::borrow::Cow; +use std::cmp::{Eq, PartialEq}; +use std::fmt; +use std::hash::{Hash, Hasher}; + +use serde::de::{Deserializer, Error}; +use serde::Deserialize; + +#[repr(transparent)] +pub struct U32(pub u32); + +impl<'de> Deserialize<'de> for U32 { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + match NumberString::deserialize(deserializer)? { + NumberString::Number(n) => Ok(Self(n)), + NumberString::String(s) => { + let s = s.trim(); + + // Try to maintain compatibility with older Rust versions + let (v, radix) = match (s.get(0..2), s.get(2..)) { + (Some(prefix), Some(suffix)) if prefix.eq_ignore_ascii_case("0b") => { + (suffix, 2) + } + (Some(prefix), Some(suffix)) if prefix.eq_ignore_ascii_case("0x") => { + (suffix, 16) + } + (_, _) => (s, 10), + }; + + match u32::from_str_radix(v, radix) { + Ok(n) => Ok(Self(n)), + Err(e) => Err(D::Error::custom(format!( + "Failed to parse string as base-{radix} integer: {e}" + ))), + } + } + } + } +} + +#[derive(Deserialize)] +#[serde(untagged)] +enum NumberString<'data> { + Number(u32), + #[serde(borrow)] + String(Cow<'data, str>), +} + +impl fmt::Display for U32 { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} + +impl fmt::Debug for U32 { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.0, f) + } +} + +impl Eq for U32 {} + +impl Hash for U32 { + #[inline] + fn hash(&self, state: &mut H) + where + H: Hasher, + { + Hash::hash(&self.0, state) + } +} + +impl PartialEq for U32 { + #[inline] + fn eq(&self, other: &u32) -> bool { + PartialEq::eq(&self.0, other) + } + + #[inline] + fn ne(&self, other: &u32) -> bool { + PartialEq::ne(&self.0, other) + } +} + +impl PartialEq for U32 { + #[inline] + fn eq(&self, other: &Self) -> bool { + PartialEq::eq(&self.0, &other.0) + } + + #[inline] + fn ne(&self, other: &Self) -> bool { + PartialEq::ne(&self.0, &other.0) + } +}