Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functionality to list connected gpu(s) #140

Merged
merged 8 commits into from
Mar 20, 2023
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ build = "build.rs"
cfg-if = "1.0.0"
libc = "0.2.131"
home = "0.5.3"
pciid-parser = "0.6.2"

[build-dependencies.vergen]
version = "7.3.2"
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ cfg_if! {
pub type ProductReadout = linux::LinuxProductReadout;
pub type PackageReadout = linux::LinuxPackageReadout;
pub type NetworkReadout = linux::LinuxNetworkReadout;
pub type GpuReadout = linux::LinuxGpuReadout;
Rolv-Apneseth marked this conversation as resolved.
Show resolved Hide resolved
} else if #[cfg(target_os = "macos")] {
mod extra;
mod macos;
Expand Down Expand Up @@ -94,6 +95,7 @@ pub struct Readouts {
pub product: ProductReadout,
pub packages: PackageReadout,
pub network: PackageReadout,
pub gpu: GpuReadout,
}

#[cfg(feature = "version")]
Expand Down
39 changes: 32 additions & 7 deletions src/linux/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
#![allow(clippy::unnecessary_cast)]
mod pci_devices;
mod sysinfo_ffi;

use self::pci_devices::get_pci_devices;
use crate::extra;
use crate::extra::get_entries;
use crate::extra::path_extension;
use crate::shared;
use crate::traits::*;
use itertools::Itertools;
use pciid_parser::Database;
use regex::Regex;
use std::fs;
use std::fs::read_dir;
Expand Down Expand Up @@ -41,6 +44,7 @@ pub struct LinuxBatteryReadout;
pub struct LinuxProductReadout;
pub struct LinuxPackageReadout;
pub struct LinuxNetworkReadout;
pub struct LinuxGpuReadout;

impl BatteryReadout for LinuxBatteryReadout {
fn new() -> Self {
Expand Down Expand Up @@ -530,11 +534,7 @@ impl GeneralReadout for LinuxGeneralReadout {
if family == product && family == version {
return Ok(family);
} else if version.is_empty() || version.len() <= 22 {
return Ok(new_product
.split_whitespace()
.into_iter()
.unique()
.join(" "));
return Ok(new_product.split_whitespace().unique().join(" "));
}

Ok(version)
Expand Down Expand Up @@ -784,7 +784,6 @@ impl LinuxPackageReadout {
entries
.iter()
.filter(|x| extra::path_extension(x).unwrap_or_default() == "list")
.into_iter()
.count()
})
}
Expand Down Expand Up @@ -893,11 +892,37 @@ impl LinuxPackageReadout {
entries
.iter()
.filter(|&x| path_extension(x).unwrap_or_default() == "snap")
.into_iter()
.count(),
);
}

None
}
}

impl GpuReadout for LinuxGpuReadout {
fn new() -> Self {
LinuxGpuReadout
}

fn list_gpus(&self) -> Result<Vec<String>, ReadoutError> {
let db = match Database::read() {
Ok(db) => db,
_ => panic!("Could not read pci.ids file"),
};
let devices_path = Path::new("/sys/bus/pci/devices/");

let devices = get_pci_devices(devices_path)?;
let mut gpus = vec![];

for device in devices {
if !device.is_gpu(&db) {
continue;
};

gpus.push(device.get_sub_device_name(&db));
}

Ok(gpus)
}
}
112 changes: 112 additions & 0 deletions src/linux/pci_devices.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
use std::{
fs::{read_dir, read_to_string},
io,
path::{Path, PathBuf},
};

use pciid_parser::{schema::SubDeviceId, Database};

use crate::extra::pop_newline;

fn parse_device_hex(hex_str: &str) -> String {
pop_newline(hex_str).chars().skip(2).collect::<String>()
}

pub enum PciDeviceReadableValues {
Class,
Vendor,
Device,
SubVendor,
SubDevice,
}

impl PciDeviceReadableValues {
fn as_str(&self) -> &'static str {
match self {
PciDeviceReadableValues::Class => "class",
PciDeviceReadableValues::Vendor => "vendor",
PciDeviceReadableValues::Device => "device",
PciDeviceReadableValues::SubVendor => "subsystem_vendor",
PciDeviceReadableValues::SubDevice => "subsystem_device",
}
}
}

#[derive(Debug)]
pub struct PciDevice {
base_path: PathBuf,
}

impl PciDevice {
fn new(base_path: PathBuf) -> PciDevice {
PciDevice { base_path }
}

fn _read_value(&self, readable_value: PciDeviceReadableValues) -> String {
let value_path = self.base_path.join(readable_value.as_str());

match read_to_string(&value_path) {
Ok(hex_string) => parse_device_hex(&hex_string),
_ => panic!("Could not find value: {:?}", value_path),
}
}

pub fn is_gpu(&self, db: &Database) -> bool {
let class_value = self._read_value(PciDeviceReadableValues::Class);
let first_pair = class_value.chars().take(2).collect::<String>();

match db.classes.get(&first_pair) {
Some(class) => class.name == "Display controller",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is a fragile and expensive comparison. Can we instead match the classes based on their hexadecimal value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would I do that? db.classes appears to be a hash map of String to Class I think

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I wasn't aware. I'll have to read up on the documentation of this new library. It might take me some time to really digest the information and provide a decent review.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have a look at the other library I linked originally as it may be better anyway since it bundles the pci.ids

Copy link
Member

@grtcdr grtcdr Mar 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't.

A bundled database means that it could fall out of date at any point in time and that'll affect the information we provide to the caller.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok makes sense. Any suggestions on what I could try then since the hashmap is built using strings as keys? Would you prefer I tried making a PR for that library to change it before we use it here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really mind it, but I'd be interetested to know why the author opted for the corresponding value rather than the short hexadecimal values assigned to them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just used hyperfine and reading the GPU is taking quite a while.

Without showing GPU:

Time (mean ± σ):       4.2 ms ±   0.1 ms    [User: 3.0 ms, System: 1.2 ms]
Range (min … max):     4.0 ms …   4.7 ms    500 runs

With:

Time (mean ± σ):      17.8 ms ±   0.3 ms    [User: 14.6 ms, System: 3.1 ms]
Range (min … max):    17.0 ms …  19.0 ms    500 runs

Is this acceptable as is?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's reasonable and expected; it's parsing a 36000-line database after all.

We'll have to implement a caching mechanism at some point, whether that goes in libmacchina or macchina is another story, but I assume it's the client or consumer, generally speaking, that needs to cache the output and not the library.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright cool. So then is there anything else you would like from this PR? I can have a look at making a PR for that library to use maybe u8 for the keys if you think that would help? Also are there any changes to documentation / testing I should make here

_ => false,
}
}

pub fn get_sub_device_name(&self, db: &Database) -> String {
let vendor_value = self._read_value(PciDeviceReadableValues::Vendor);
let sub_vendor_value = self._read_value(PciDeviceReadableValues::SubVendor);
let device_value = self._read_value(PciDeviceReadableValues::Device);
let sub_device_value = self._read_value(PciDeviceReadableValues::SubDevice);

let vendor = match db.vendors.get(&vendor_value) {
Some(vendor) => vendor,
_ => panic!("Could not find vendor for value: {}", vendor_value),
};

let device = match vendor.devices.get(&device_value) {
Some(device) => device,
_ => panic!("Could not find device for value: {}", device_value),
};

let sub_device_id = SubDeviceId {
subvendor: sub_vendor_value,
subdevice: sub_device_value,
};

match device.subdevices.get(&sub_device_id) {
Some(sub_device) => {
let start = match sub_device.find('[') {
Some(i) => i + 1,
_ => panic!(
"Could not find opening square bracket for sub device: {}",
sub_device
),
};
let end = sub_device.len() - 1;

sub_device.chars().take(end).skip(start).collect::<String>()
}
_ => panic!("Could not find sub device for id: {:?}", sub_device_id),
}
}
}

pub fn get_pci_devices(devices_path: &Path) -> Result<Vec<PciDevice>, io::Error> {
Rolv-Apneseth marked this conversation as resolved.
Show resolved Hide resolved
let devices_dir = read_dir(devices_path)?;

let mut devices = vec![];
for device_entry in devices_dir.flatten() {
devices.push(PciDevice::new(device_entry.path()));
}

Ok(devices)
}
32 changes: 32 additions & 0 deletions src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,38 @@ pub trait ProductReadout {
fn product(&self) -> Result<String, ReadoutError>;
}

/**
This trait provides the interface for getting information about the _GPU(s)_ connected
to the host machine.

# Example

```
use libmacchina::traits::GpuReadout;
use libmacchina::traits::ReadoutError;

pub struct LinuxGpuReadout;

impl GpuReadout for LinuxGpuReadout {
fn new() -> Self {
LinuxGpuReadout
}

fn list_gpus(&self) -> Result<Vec<String>, ReadoutError> {
// Get gpu(s) from list of connected pci devices
Ok(vec!(String::from("gpu1"), String::from("gpu2"))) // Return gpu sub-device names
}
}
```
*/
pub trait GpuReadout {
/// Creates a new instance of the structure which implements this trait.
fn new() -> Self;

/// This function is used for querying the currently connected gpu devices
fn list_gpus(&self) -> Result<Vec<String>, ReadoutError>;
}

/**
This trait provides the interface for implementing functionality used for querying general
information about the running operating system and current user.
Expand Down