diff --git a/src/lib.rs b/src/lib.rs index 0e76566..31960a3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,8 +9,8 @@ //! configuration structure use std::borrow::Cow; -use std::collections::HashMap; use std::cmp; +use std::collections::HashMap; use std::env; use std::fmt; use std::fs; @@ -540,196 +540,215 @@ fn apply_profile_override( vgpu_type: &str, config_override: &VgpuProfileOverride, ) -> bool { - macro_rules! handle_copy_overrides { - ($source_field:ident => $target_field:ident) => { - if let Some(value) = config_override.$source_field { - info!( - "Patching {}/{}: {} -> {}", - vgpu_type, - stringify!($target_field), - config.$target_field, - value - ); - - config.$target_field = value; - } + macro_rules! patch_msg { + ($target_field:ident, $value:expr) => { + info!( + "Patching {}/{}: {} -> {}", + vgpu_type, + stringify!($target_field), + config.$target_field, + $value + ); }; - ($field:ident) => { - handle_copy_overrides!($field => $field); - }; - ($($source_field:ident $(=> $target_field:ident)?),*$(,)?) => { - $( - handle_copy_overrides!($source_field $(=> $target_field)?); - )* + ($target_field:ident, $preprocess:ident, $value:expr) => { + info!( + "Patching {}/{}: {} -> {}", + vgpu_type, + stringify!($target_field), + $preprocess(&config.$target_field), + $value + ); }; } - macro_rules! handle_bool_overrides { - ($source_field:ident => $target_field:ident) => { - if let Some(value) = config_override.$source_field { - let value = cmp::max(cmp::min(value, 0), 1); + macro_rules! error_too_long { + ($target_field:ident, $value:expr) => { + error!( + "Patching {}/{}: value '{}' is too long", + vgpu_type, + stringify!($target_field), + $value + ); - info!( - "Patching {}/{}: {} -> {}", - vgpu_type, - stringify!($target_field), - config.$target_field, - value - ); - - config.$target_field = value; - } - }; - ($field:ident) => { - handle_bool_overrides!($field => $field); - }; - ($($source_field:ident $(=> $target_field:ident)?),*$(,)?) => { - $( - handle_bool_overrides!($source_field $(=> $target_field)?); - )* + return false; }; } - macro_rules! handle_str_overrides { - ($source_field:ident => $target_field:ident) => { + + macro_rules! handle_override { + // Override entrypoint when the same field name is used as the source and target without + // an explicit `=>`. + ( + class: $class:ident, + source_field: $field:ident, + ) => { + handle_override! { + class: $class, + source_field: $field, + target_field: $field, + } + }; + + // Override entrypoint when both the source and target field names are defined explicitly. + ( + class: $class:ident, + source_field: $source_field:ident, + target_field: $target_field:ident, + ) => { if let Some(value) = config_override.$source_field { - let value_bytes = value.as_bytes(); - - // Use `len - 1` to account for the required NULL terminator. - if value_bytes.len() > config.$target_field.len() - 1 { - error!( - "Patching {}/{}: value '{}' is too long", - vgpu_type, - stringify!($target_field), - value - ); - - return false; - } else { - info!( - "Patching {}/{}: '{}' -> '{}'", - vgpu_type, - stringify!($target_field), - from_c_str(&config.$target_field), - value - ); - - // Zero out the field first. - // (`fill` was stabilized in Rust 1.50, but Debian Bullseye ships with 1.48) - for v in config.$target_field.iter_mut() { - *v = 0; - } - - // Write the string bytes. - let _ = config.$target_field[..].as_mut().write_all(value_bytes); + handle_override! { + class: $class, + value: value, + source_field: $source_field, + target_field: $target_field, } } }; - ($field:ident) => { - handle_str_overrides!($field => $field); + + // The following are override handlers for each field class type (`bool`, `copy`, `str`, + // and `wide_str`). + ( + class: bool, + value: $value:ident, + source_field: $source_field:ident, + target_field: $target_field:ident, + ) => { + let $value = cmp::max(cmp::min($value, 0), 1); + + patch_msg!($target_field, $value); + + config.$target_field = $value; }; - ($($source_field:ident $(=> $target_field:ident)?),*$(,)?) => { - $( - handle_str_overrides!($source_field $(=> $target_field)?); - )* + ( + class: copy, + value: $value:ident, + source_field: $source_field:ident, + target_field: $target_field:ident, + ) => { + patch_msg!($target_field, $value); + + config.$target_field = $value; }; - } - macro_rules! handle_wide_str_overrides { - ($source_field:ident => $target_field:ident) => { - if let Some(value) = config_override.$source_field { - // Use `len - 1` to account for the required NULL terminator. - if value.encode_utf16().count() > config.$target_field.len() - 1 { - error!( - "Patching {}/{}: value '{}' is too long", - vgpu_type, - stringify!($target_field), - value - ); + ( + class: str, + value: $value:ident, + source_field: $source_field:ident, + target_field: $target_field:ident, + ) => { + let value_bytes = $value.as_bytes(); - return false; - } else { - info!( - "Patching {}/{}: '{}' -> '{}'", - vgpu_type, - stringify!($target_field), - WideCharFormat(&config.$target_field), - value - ); + // Use `len - 1` to account for the required NULL terminator. + if value_bytes.len() > config.$target_field.len() - 1 { + error_too_long!($target_field, $value); + } else { + patch_msg!($target_field, from_c_str, $value); - // Zero out the field first. - // (`fill` was stabilized in Rust 1.50, but Debian Bullseye ships with 1.48) - for v in config.$target_field.iter_mut() { - *v = 0; - } + // Zero out the field first. + // (`fill` was stabilized in Rust 1.50, but Debian Bullseye ships with 1.48) + for v in config.$target_field.iter_mut() { + *v = 0; + } - // Write the string bytes. - for (v, ch) in config.$target_field[..] - .iter_mut() - .zip(value.encode_utf16().chain(Some(0))) - { - *v = ch; - } + // Write the string bytes. + let _ = config.$target_field[..].as_mut().write_all(value_bytes); + } + }; + ( + class: wide_str, + value: $value:ident, + source_field: $source_field:ident, + target_field: $target_field:ident, + ) => { + // Use `len - 1` to account for the required NULL terminator. + if $value.encode_utf16().count() > config.$target_field.len() - 1 { + error_too_long!($target_field, $value); + } else { + patch_msg!($target_field, WideCharFormat, $value); + + // Zero out the field first. + // (`fill` was stabilized in Rust 1.50, but Debian Bullseye ships with 1.48) + for v in config.$target_field.iter_mut() { + *v = 0; + } + + // Write the string bytes. + for (v, ch) in config.$target_field[..] + .iter_mut() + .zip($value.encode_utf16().chain(Some(0))) + { + *v = ch; } } }; - ($field:ident) => { - handle_wide_str_overrides!($field => $field); - }; - ($($source_field:ident $(=> $target_field:ident)?),*$(,)?) => { + } + macro_rules! handle_overrides { + ( + $($class:ident: [ + $($source_field:ident $(=> $target_field:ident)?),*$(,)? + ]),*$(,)? + ) => { $( - handle_wide_str_overrides!($source_field $(=> $target_field)?); + $( + handle_override! { + class: $class, + source_field: $source_field, + $(target_field: $target_field,)? + } + )* )* }; } - // While the following could be done with two statements. I wanted the log statements to be in + // While the following could be done with fewer branches, I wanted the log statements to be in // field order. - handle_copy_overrides! { - gpu_type => vgpu_type, - } - handle_str_overrides! { - card_name => vgpu_name, - vgpu_type => vgpu_class, - features, - } - handle_copy_overrides! { - max_instances, - num_displays => num_heads, - display_width => max_resolution_x, - display_height => max_resolution_y, - max_pixels, - frl_config, - } - handle_bool_overrides! { - cuda_enabled, - ecc_supported, - } - handle_copy_overrides! { - mig_instance_size, - } - handle_bool_overrides! { - multi_vgpu_supported, - } - handle_copy_overrides! { - pci_id => vdev_id, - pci_device_id => pdev_id, - framebuffer => fb_length, - mappable_video_size, - framebuffer_reservation => fb_reservation, - encoder_capacity, - bar1_length, - } - handle_bool_overrides! { - frl_enabled => frl_enable, - } - handle_str_overrides! { - adapter_name, - } - handle_wide_str_overrides! { - adapter_name => adapter_name_unicode, - } - handle_str_overrides! { - short_gpu_name => short_gpu_name_string, - license_type => licensed_product_name, + handle_overrides! { + copy: [ + gpu_type => vgpu_type, + ], + str: [ + card_name => vgpu_name, + vgpu_type => vgpu_class, + features, + ], + copy: [ + max_instances, + num_displays => num_heads, + display_width => max_resolution_x, + display_height => max_resolution_y, + max_pixels, + frl_config, + ], + bool: [ + cuda_enabled, + ecc_supported, + ], + copy: [ + mig_instance_size, + ], + bool: [ + multi_vgpu_supported, + ], + copy: [ + pci_id => vdev_id, + pci_device_id => pdev_id, + framebuffer => fb_length, + mappable_video_size, + framebuffer_reservation => fb_reservation, + encoder_capacity, + bar1_length, + ], + bool: [ + frl_enabled => frl_enable, + ], + str: [ + adapter_name, + ], + wide_str: [ + adapter_name => adapter_name_unicode, + ], + str: [ + short_gpu_name => short_gpu_name_string, + license_type => licensed_product_name, + ], } true