diff --git a/src/format.rs b/src/format.rs index a9ebe24..53b3ae6 100644 --- a/src/format.rs +++ b/src/format.rs @@ -5,6 +5,9 @@ use std::char; use std::fmt::{self, Write}; +use crate::to_bytes::ToBytes; +use crate::utils; + pub struct CStrFormat<'a>(pub &'a [u8]); impl<'a> fmt::Debug for CStrFormat<'a> { @@ -16,7 +19,7 @@ impl<'a> fmt::Debug for CStrFormat<'a> { impl<'a> fmt::Display for CStrFormat<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let s = crate::from_c_str(self.0); + let s = utils::from_c_str(self.0); fmt::Debug::fmt(&s, f) } @@ -37,16 +40,16 @@ impl fmt::Display for HexFormat { } } -pub struct HexFormatSlice<'a>(pub &'a [u8]); +pub struct HexFormatSlice<'a, T>(pub &'a [T]); -impl<'a> fmt::Debug for HexFormatSlice<'a> { +impl<'a, T: Copy + fmt::LowerHex + ToBytes> fmt::Debug for HexFormatSlice<'a, T> { #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fmt::Display::fmt(self, f) } } -impl<'a> fmt::Display for HexFormatSlice<'a> { +impl<'a, T: Copy + fmt::LowerHex + ToBytes> fmt::Display for HexFormatSlice<'a, T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { if self.0.is_empty() { f.write_str("[]") @@ -54,7 +57,9 @@ impl<'a> fmt::Display for HexFormatSlice<'a> { f.write_str("0x")?; for v in self.0.iter() { - write!(f, "{:02x}", v)?; + for b in v.to_ne_bytes() { + write!(f, "{b:02x}")?; + } } Ok(()) diff --git a/src/lib.rs b/src/lib.rs index ff3b575..0a09a38 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,7 +8,6 @@ //! - Arc Compute for their work on Mdev-GPU and GVM documenting more field names in the vGPU //! configuration structure -use std::borrow::Cow; use std::cmp; use std::collections::HashMap; use std::env; @@ -33,6 +32,7 @@ mod human_number; mod ioctl; mod log; mod nvidia; +mod to_bytes; mod utils; mod uuid; @@ -489,12 +489,6 @@ pub unsafe extern "C" fn ioctl(fd: RawFd, request: c_ulong, argp: *mut c_void) - ret } -pub fn from_c_str(value: &[u8]) -> Cow<'_, str> { - let len = value.iter().position(|&c| c == 0).unwrap_or(value.len()); - - String::from_utf8_lossy(&value[..len]) -} - fn load_overrides() -> Result { let config_path = match env::var_os("VGPU_UNLOCK_PROFILE_OVERRIDE_CONFIG_PATH") { Some(path) => PathBuf::from(path), @@ -581,7 +575,7 @@ fn apply_profile_override( $value ); }; - ($target_field:ident, $preprocess:ident, $value:expr) => { + ($target_field:ident, $preprocess:expr, $value:expr) => { info!( "Patching {}/{}: {} -> {}", vgpu_type, @@ -670,7 +664,7 @@ fn apply_profile_override( if value_bytes.len() > config.$target_field().len() - 1 { error_too_long!($target_field, $value); } else { - patch_msg!($target_field, from_c_str, $value); + patch_msg!($target_field, utils::from_c_str, $value); // Zero out the field first. // (`fill` was stabilized in Rust 1.50, but Debian Bullseye ships with 1.48) diff --git a/src/to_bytes.rs b/src/to_bytes.rs new file mode 100644 index 0000000..0041597 --- /dev/null +++ b/src/to_bytes.rs @@ -0,0 +1,49 @@ +pub trait ToBytes { + type Bytes: Copy + AsRef<[u8]> + AsMut<[u8]> + IntoIterator + 'static; + + fn to_ne_bytes(self) -> Self::Bytes; +} + +macro_rules! impl_to_bytes { + ($ty:tt, $len:expr) => { + impl ToBytes for $ty { + type Bytes = [u8; $len]; + + fn to_ne_bytes(self) -> Self::Bytes { + $ty::to_ne_bytes(self) + } + } + }; +} + +impl ToBytes for i8 { + type Bytes = [u8; 1]; + + fn to_ne_bytes(self) -> Self::Bytes { + [self as u8] + } +} + +impl_to_bytes!(i16, 2); +impl_to_bytes!(i32, 4); +impl_to_bytes!(i64, 8); +#[cfg(target_pointer_width = "32")] +impl_to_bytes!(isize, 4); +#[cfg(target_pointer_width = "64")] +impl_to_bytes!(isize, 8); + +impl ToBytes for u8 { + type Bytes = [u8; 1]; + + fn to_ne_bytes(self) -> Self::Bytes { + [self] + } +} + +impl_to_bytes!(u16, 2); +impl_to_bytes!(u32, 4); +impl_to_bytes!(u64, 8); +#[cfg(target_pointer_width = "32")] +impl_to_bytes!(usize, 4); +#[cfg(target_pointer_width = "64")] +impl_to_bytes!(usize, 8); diff --git a/src/utils.rs b/src/utils.rs index 00d5dd8..1395e4a 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::fmt; #[cfg(feature = "proxmox")] @@ -21,6 +22,12 @@ impl fmt::LowerHex for AlignedU64 { } } +pub fn from_c_str(value: &[u8]) -> Cow<'_, str> { + let len = value.iter().position(|&c| c == 0).unwrap_or(value.len()); + + String::from_utf8_lossy(&value[..len]) +} + /// Extracts the VMID from the last segment of a mdev uuid /// /// For example, for this uuid 00000000-0000-0000-0000-000000000100