From 8ce70c50953cf8c99d3a256dfaf23f0e370c25f6 Mon Sep 17 00:00:00 2001 From: Violet Date: Tue, 17 Jun 2025 15:00:10 -0700 Subject: [PATCH] Add `integrity_check` implementation to ZLUDA (#387) --- zluda/src/impl/driver.rs | 72 +++++++++++++++++++++++++++++++++++++-- zluda/src/impl/os_unix.rs | 10 ++++++ zluda/src/impl/os_win.rs | 10 ++++++ 3 files changed, 90 insertions(+), 2 deletions(-) diff --git a/zluda/src/impl/driver.rs b/zluda/src/impl/driver.rs index c413f82..662be4e 100644 --- a/zluda/src/impl/driver.rs +++ b/zluda/src/impl/driver.rs @@ -10,6 +10,10 @@ use std::{ usize, }; +#[cfg_attr(windows, path = "os_win.rs")] +#[cfg_attr(not(windows), path = "os_unix.rs")] +mod os; + pub(crate) struct GlobalState { pub devices: Vec, pub comgr: Comgr, @@ -232,7 +236,31 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { unix_seconds: u64, result: *mut [u64; 2], ) -> cuda_types::cuda::CUresult { - todo!() + let current_process = std::process::id(); + let current_thread = os::current_thread(); + + let integrity_check_table = EXPORT_TABLE.INTEGRITY_CHECK.as_ptr().cast(); + let cudart_table = EXPORT_TABLE.CUDART_INTERFACE.as_ptr().cast(); + let fn_address = EXPORT_TABLE.INTEGRITY_CHECK[1]; + + let devices = get_device_hash_info()?; + let device_count = devices.len() as u32; + let get_device = |dev| devices[dev as usize]; + + let hash = ::dark_api::integrity_check( + version, + unix_seconds, + cuda_types::cuda::CUDA_VERSION, + current_process, + current_thread, + integrity_check_table, + cudart_table, + fn_address, + device_count, + get_device, + ); + *result = hash; + Ok(()) } unsafe extern "system" fn context_check( @@ -244,10 +272,50 @@ impl ::dark_api::cuda::CudaDarkApi for DarkApi { } unsafe extern "system" fn check_fn3() -> u32 { - todo!() + 0 } } +fn get_device_hash_info() -> Result, CUerror> { + let mut device_count = 0; + device::get_count(&mut device_count)?; + + (0..device_count) + .map(|dev| { + let mut guid = CUuuid_st { bytes: [0; 16] }; + unsafe { crate::cuDeviceGetUuid(&mut guid, dev)? }; + + let mut pci_domain = 0; + device::get_attribute( + &mut pci_domain, + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID, + dev, + )?; + + let mut pci_bus = 0; + device::get_attribute( + &mut pci_bus, + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID, + dev, + )?; + + let mut pci_device = 0; + device::get_attribute( + &mut pci_device, + CUdevice_attribute::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, + dev, + )?; + + Ok(::dark_api::DeviceHashinfo { + guid, + pci_domain, + pci_bus, + pci_device, + }) + }) + .collect() +} + static EXPORT_TABLE: ::dark_api::cuda::CudaDarkApiGlobalTable = ::dark_api::cuda::CudaDarkApiGlobalTable::new::(); diff --git a/zluda/src/impl/os_unix.rs b/zluda/src/impl/os_unix.rs index 90b7011..1edeada 100644 --- a/zluda/src/impl/os_unix.rs +++ b/zluda/src/impl/os_unix.rs @@ -11,3 +11,13 @@ pub unsafe fn heap_alloc(_heap: *mut c_void, _bytes: usize) -> *mut c_void { pub unsafe fn heap_free(_heap: *mut c_void, _alloc: *mut c_void) { todo!() } + +// TODO: remove duplication with zluda_dump +#[link(name = "pthread")] +unsafe extern "C" { + fn pthread_self() -> std::os::unix::thread::RawPthread; +} + +pub(crate) fn current_thread() -> u32 { + (unsafe { pthread_self() }) as u32 +} diff --git a/zluda/src/impl/os_win.rs b/zluda/src/impl/os_win.rs index 427920e..84d46ed 100644 --- a/zluda/src/impl/os_win.rs +++ b/zluda/src/impl/os_win.rs @@ -14,3 +14,13 @@ pub unsafe fn heap_alloc(heap: *mut c_void, bytes: usize) -> *mut c_void { pub unsafe fn heap_free(heap: *mut c_void, alloc: *mut c_void) { HeapFree(heap, 0, alloc); } + +// TODO: remove duplication with zluda_dump +#[link(name = "kernel32")] +unsafe extern "system" { + fn GetCurrentThreadId() -> u32; +} + +pub(crate) fn current_thread() -> u32 { + unsafe { GetCurrentThreadId() } +}