Skip to content

Commit

Permalink
🚧 fix reinstall of profiles
Browse files Browse the repository at this point in the history
  • Loading branch information
vnepogodin committed Oct 3, 2024
1 parent 7c06e60 commit 2729c20
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 27 deletions.
112 changes: 85 additions & 27 deletions src/profile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use comfy_table::modifiers::UTF8_ROUND_CORNERS;
use comfy_table::presets::UTF8_FULL;
use comfy_table::*;

use std::collections::BTreeMap;
use std::fs;

#[derive(Debug, Default, Clone, PartialEq)]
Expand All @@ -34,7 +33,7 @@ pub struct HardwareID {
pub blacklisted_device_ids: Vec<String>,
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq)]
pub struct Profile {
pub is_ai_sdk: bool,

Expand Down Expand Up @@ -284,50 +283,68 @@ fn merge_table_left(lhs: &mut toml::Table, rhs: &toml::Table) {
}

pub fn write_profile_to_file(file_path: &str, profile: &Profile) -> bool {
// lets check manually if it does exist already in the profiles map
// NOTE: instead of trying to overwrite profile, we return error
if std::path::Path::new(file_path).exists() {
let profiles = parse_profiles(file_path).expect("Failed to parse profiles");

// Check if profile exists in file and remove it
if profiles.iter().any(|x| x.name == profile.name) {
return false;
}
}

let mut profiles = if std::path::Path::new(file_path).exists() {
parse_profiles_merged(file_path).expect("Failed to parse profiles")
fs::read_to_string(file_path)
.expect("Failed to read profiles")
.parse::<toml::Table>()
.expect("Failed to parse profiles")
} else {
vec![]
toml::Table::new()
};

if let Some(index) = profiles.iter().position(|x| x.name == profile.name) {
profiles[index] = profile.clone();
} else {
profiles.push(profile.clone());
}
let table_item = toml::Value::Table(profile_into_toml(&profile));

// convert Vec into expected toml structure
let profiles =
profiles.iter().map(|x| (x.name.clone(), profile_into_toml(x))).collect::<BTreeMap<_, _>>();
profiles.insert(profile.name.clone(), table_item);

let toml_string = replace_escaping_toml(&profiles);
fs::write(file_path, toml_string).is_ok()
}

pub fn remove_profile_from_file(file_path: &str, profile_name: &str) -> bool {
let mut profiles = parse_profiles_merged(file_path).expect("Failed to parse profiles");
// we cannot remove profile from file, if the file doesn't exist and therefore nothing to be
// removed
if !std::path::Path::new(file_path).exists() {
return false;
}

let mut profiles = parse_profiles(file_path).expect("Failed to parse profiles");

// Check if profile exists in file and remove it
if let Some(index) = profiles.iter().position(|x| x.name == profile_name) {
profiles.remove(index);
if let Some(found_idx) = profiles.iter().position(|x| x.name == profile_name) {
// remove
profiles.remove(found_idx);

// convert Vec into expected toml structure
let profiles = profiles
.iter()
.map(|x| (x.name.clone(), profile_into_toml(x)))
.collect::<BTreeMap<_, _>>();
let mut profiles_doc = toml::Table::new();

let toml_string = replace_escaping_toml(&profiles);
// insert all profiles back to the map
for profile in profiles {
let table_item = toml::Value::Table(profile_into_toml(&profile));
profiles_doc.insert(profile.name, table_item);
}

let toml_string = replace_escaping_toml(&profiles_doc);
fs::write(file_path, toml_string).is_ok()
} else {
log::error!("Profile '{profile_name}' was not found");
false
}
}

fn replace_escaping_toml(profiles: &BTreeMap<String, toml::Table>) -> String {
let mut toml_string = toml::to_string_pretty(profiles).unwrap();
fn replace_escaping_toml(profiles: &toml::Table) -> String {
let mut toml_string = profiles.to_string();

for profile_name in profiles.keys() {
for (profile_name, _) in profiles.iter() {
// Find escaped table name and replace with unescaped table name
toml_string =
toml_string.replace(&format!("[\"{profile_name}\"]"), &format!("[{profile_name}]"));
Expand Down Expand Up @@ -450,6 +467,43 @@ mod tests {
assert!(!parsed_profiles[1].post_remove.is_empty());
}

#[test]
fn profile_extra_check_parse_test() {
let prof_path = "tests/profiles/extra-check-root-profile.toml";
let parsed_profiles = parse_profiles(prof_path);
assert!(parsed_profiles.is_ok());
let parsed_profiles = parsed_profiles.unwrap();

let hwd_ids = vec![HardwareID {
class_ids: vec!["0300".to_owned(), "0302".to_owned(), "0380".to_owned()],
vendor_ids: vec!["10de".to_owned()],
device_ids: vec!["*".to_owned()],
blacklisted_class_ids: vec![],
blacklisted_vendor_ids: vec![],
blacklisted_device_ids: vec![],
}];

assert_eq!(parsed_profiles.len(), 1);
assert_eq!(parsed_profiles[0].name, "nvidia-dkms");
assert_eq!(parsed_profiles[0].desc, "Closed source NVIDIA drivers for Linux (Latest)");
assert_eq!(parsed_profiles[0].priority, 12);
assert_eq!(parsed_profiles[0].is_ai_sdk, false);
assert_eq!(
parsed_profiles[0].packages,
"nvidia-utils egl-wayland nvidia-settings opencl-nvidia lib32-opencl-nvidia \
lib32-nvidia-utils libva-nvidia-driver vulkan-icd-loader lib32-vulkan-icd-loader"
);
assert_eq!(
parsed_profiles[0].device_name_pattern,
Some("((GM|GP)+[0-9]+[^M]*\\s.*)".to_owned())
);
assert_eq!(parsed_profiles[0].hwd_product_name_pattern, None);
assert_eq!(parsed_profiles[0].hwd_ids, hwd_ids);
assert_eq!(parsed_profiles[0].gc_versions, None);
assert!(!parsed_profiles[0].post_install.is_empty());
assert!(!parsed_profiles[0].post_remove.is_empty());
}

#[test]
fn graphics_profiles_invalid() {
let prof_path = "tests/profiles/graphic_drivers-invalid-profiles-test.toml";
Expand Down Expand Up @@ -539,19 +593,23 @@ mod tests {
assert!(!crate::profile::remove_profile_from_file(&filepath, &parsed_profiles[0].name));
assert!(!crate::profile::remove_profile_from_file(&filepath, &parsed_profiles[1].name));

// clean this up
assert!(crate::profile::remove_profile_from_file(&filepath, &parsed_profiles[2].name));

// insert same profiles again
assert!(crate::profile::write_profile_to_file(&filepath, &parsed_profiles[0]));
assert!(crate::profile::write_profile_to_file(&filepath, &parsed_profiles[1]));

// insert same profiles again
assert!(crate::profile::write_profile_to_file(&filepath, &parsed_profiles[0]));
assert!(crate::profile::write_profile_to_file(&filepath, &parsed_profiles[1]));
assert!(!crate::profile::write_profile_to_file(&filepath, &parsed_profiles[0]));
assert!(!crate::profile::write_profile_to_file(&filepath, &parsed_profiles[1]));

let orig_content = fs::read_to_string(&filepath).unwrap();
let expected_output = fs::read_to_string(prof_parsed_path).unwrap();

// cleanup
assert!(fs::remove_file(&filepath).is_ok());

assert_eq!(orig_content, fs::read_to_string(prof_parsed_path).unwrap());
assert_eq!(orig_content, expected_output);
}
}
24 changes: 24 additions & 0 deletions tests/profiles/extra-check-root-profile.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[nvidia-dkms]
ai_sdk = false
class_ids = "0300 0302 0380"
desc = "Closed source NVIDIA drivers for Linux (Latest)"
device_ids = "*"
device_name_pattern = '((GM|GP)+[0-9]+[^M]*\s.*)'
packages = "nvidia-utils egl-wayland nvidia-settings opencl-nvidia lib32-opencl-nvidia lib32-nvidia-utils libva-nvidia-driver vulkan-icd-loader lib32-vulkan-icd-loader"
post_install = '''
cat <<EOF >/etc/mkinitcpio.conf.d/10-chwd.conf
# This file is automatically generated by chwd. PLEASE DO NOT EDIT IT.
MODULES+=(nvidia nvidia_modeset nvidia_uvm nvidia_drm)
EOF
mkinitcpio -P
# Add libva-nvidia-driver to profile
echo "export LIBVA_DRIVER_NAME=nvidia" > /etc/profile.d/nvidia-vaapi.sh
'''
post_remove = """
rm -f /etc/mkinitcpio.conf.d/10-chwd.conf
rm -f /etc/profile.d/nvidia-vaapi.sh
mkinitcpio -P
"""
priority = 12
vendor_ids = "10de"

0 comments on commit 2729c20

Please sign in to comment.