diff --git a/src/tracing.rs b/src/tracing.rs index 5acae8a850..f06460d240 100644 --- a/src/tracing.rs +++ b/src/tracing.rs @@ -1,8 +1,11 @@ -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::{ + ffi::CStr, + sync::atomic::{AtomicPtr, Ordering}, +}; -use libc::c_char; +use libc::{c_char, c_int}; -use crate::{panic, raw, util::Binding}; +use crate::{panic, raw, util::Binding, Error}; /// Available tracing levels. When tracing is set to a particular level, /// callers will be provided tracing at the given level and all lower levels. @@ -57,29 +60,82 @@ impl Binding for TraceLevel { } } -//TODO: pass raw &[u8] and leave conversion to consumer (breaking API) /// Callback type used to pass tracing events to the subscriber. /// see `trace_set` to register a subscriber. -pub type TracingCb = fn(TraceLevel, &str); +pub type TracingCb = fn(TraceLevel, &[u8]); -static CALLBACK: AtomicUsize = AtomicUsize::new(0); +/// Use an atomic pointer to store the global tracing subscriber function. +static CALLBACK: AtomicPtr<()> = AtomicPtr::new(std::ptr::null_mut()); -/// -pub fn trace_set(level: TraceLevel, cb: TracingCb) -> bool { - CALLBACK.store(cb as usize, Ordering::SeqCst); +/// Set the global subscriber called when libgit2 produces a tracing message. +pub fn trace_set(level: TraceLevel, cb: TracingCb) -> Result<(), Error> { + // Store the callback in the global atomic. + CALLBACK.store(cb as *mut (), Ordering::SeqCst); - unsafe { - raw::git_trace_set(level.raw(), Some(tracing_cb_c)); - } + // git_trace_set returns 0 if there was no error. + let return_code: c_int = unsafe { raw::git_trace_set(level.raw(), Some(tracing_cb_c)) }; - return true; + if return_code != 0 { + // Unwrap here is fine since `Error::last_error` always returns `Some`. + Err(Error::last_error(return_code).unwrap()) + } else { + Ok(()) + } } +/// The tracing callback we pass to libgit2 (C ABI compatible). extern "C" fn tracing_cb_c(level: raw::git_trace_level_t, msg: *const c_char) { - let cb = CALLBACK.load(Ordering::SeqCst); - panic::wrap(|| unsafe { - let cb: TracingCb = std::mem::transmute(cb); - let msg = std::ffi::CStr::from_ptr(msg).to_string_lossy(); - cb(Binding::from_raw(level), msg.as_ref()); + // Load the callback function pointer from the global atomic. + let cb: *mut () = CALLBACK.load(Ordering::SeqCst); + + // Transmute the callback pointer into the function pointer we know it to be. + // + // SAFETY: We only ever set the callback pointer with something cast from a TracingCb + // so transmuting back to a TracingCb is safe. This is notably not an integer-to-pointer + // transmute as described in the mem::transmute documentation and is in-line with the + // example in that documentation for casing between *const () to fn pointers. + let cb: TracingCb = unsafe { std::mem::transmute(cb) }; + + // If libgit2 passes us a message that is null, drop it and do not pass it to the callback. + // This is to avoid ever exposing rust code to a null ref, which would be Undefined Behavior. + if msg.is_null() { + return; + } + + // Convert the message from a *const c_char to a &[u8] and pass it to the callback. + // + // SAFETY: We've just checked that the pointer is not null. The other safety requirements are left to + // libgit2 to enforce -- namely that it gives us a valid, nul-terminated, C string, that that string exists + // entirely in one allocation, that the string will not be mutated once passed to us, and that the nul-terminator is + // within isize::MAX bytes from the given pointers data address. + let msg: &CStr = unsafe { CStr::from_ptr(msg) }; + + // Convert from a CStr to &[u8] to pass to the rust code callback. + let msg: &[u8] = CStr::to_bytes(msg); + + // Do the remaining part of this function in a panic wrapper, to catch any panics it produces. + panic::wrap(|| { + // Convert the raw trace level into a type we can pass to the rust callback fn. + // + // SAFETY: Currently the implementation of this function (above) may panic, but is only marked as unsafe to match + // the trait definition, thus we can consider this call safe. + let level: TraceLevel = unsafe { Binding::from_raw(level) }; + + // Call the user-supplied callback (which may panic). + (cb)(level, msg); }); } + +#[cfg(test)] +mod tests { + use super::TraceLevel; + + // Test that using the above function to set a tracing callback doesn't panic. + #[test] + fn smoke() { + super::trace_set(TraceLevel::Trace, |level, msg| { + dbg!(level, msg); + }) + .expect("libgit2 can set global trace callback"); + } +}