From 178bd567b8b82921dc2cc1c48d053181af85f3b6 Mon Sep 17 00:00:00 2001 From: Ahmed Charles Date: Thu, 12 Oct 2023 13:58:46 -0700 Subject: [PATCH 1/2] Use const generics. --- examples/gravity.rs | 31 +++--- src/kalman.rs | 86 ++++++++-------- src/matrix.rs | 240 ++++++++++++++------------------------------ src/matrix_owned.rs | 1 + src/measurement.rs | 74 +++++++------- 5 files changed, 180 insertions(+), 252 deletions(-) diff --git a/examples/gravity.rs b/examples/gravity.rs index 7937857..2ca1929 100644 --- a/examples/gravity.rs +++ b/examples/gravity.rs @@ -14,7 +14,6 @@ use minikalman::{ create_buffer_temp_S_inv, create_buffer_temp_x, create_buffer_u, create_buffer_x, create_buffer_y, create_buffer_z, matrix_data_t, Kalman, Measurement, }; -use stdint::uint_fast8_t; /// Measurements. /// @@ -39,12 +38,12 @@ const MEASUREMENT_ERROR: [matrix_data_t; 15] = [ -0.33747, 0.75873, 0.18135, -0.015764, 0.17869, ]; +const NUM_STATES: usize = 3; +const NUM_INPUTS: usize = 0; +const NUM_MEASUREMENTS: usize = 1; + #[allow(non_snake_case)] fn main() { - const NUM_STATES: uint_fast8_t = 3; - const NUM_INPUTS: uint_fast8_t = 0; - const NUM_MEASUREMENTS: uint_fast8_t = 1; - // System buffers. let mut gravity_x = create_buffer_x!(NUM_STATES); let mut gravity_A = create_buffer_A!(NUM_STATES); @@ -75,8 +74,8 @@ fn main() { let mut gravity_temp_KHP = create_buffer_temp_KHP!(NUM_STATES); let mut filter = Kalman::new_from_buffers( - NUM_STATES, - NUM_INPUTS, + NUM_STATES as _, + NUM_INPUTS as _, &mut gravity_A, &mut gravity_x, &mut gravity_B, @@ -89,8 +88,8 @@ fn main() { ); let mut measurement = Measurement::new_direct( - NUM_STATES, - NUM_MEASUREMENTS, + NUM_STATES as _, + NUM_MEASUREMENTS as _, &mut gravity_H, &mut gravity_z, &mut gravity_R, @@ -132,7 +131,7 @@ fn main() { } /// Initializes the state vector with initial assumptions. -fn initialize_state_vector(filter: &mut Kalman) { +fn initialize_state_vector(filter: &mut Kalman<'_, NUM_STATES, NUM_INPUTS>) { filter.state_vector_apply(|state| { state[0] = 0 as _; // position state[1] = 0 as _; // velocity @@ -148,7 +147,7 @@ fn initialize_state_vector(filter: &mut Kalman) { /// v₁ = 1×v₀ + T×a₀ /// a₁ = 1×a₀ /// ``` -fn initialize_state_transition_matrix(filter: &mut Kalman) { +fn initialize_state_transition_matrix(filter: &mut Kalman<'_, NUM_STATES, NUM_INPUTS>) { filter.state_transition_apply(|a| { // Time constant. const T: matrix_data_t = 1 as _; @@ -175,7 +174,7 @@ fn initialize_state_transition_matrix(filter: &mut Kalman) { /// This defines how different states (linearly) influence each other /// over time. In this setup we claim that position, velocity and acceleration /// linearly are linearly independent. -fn initialize_state_covariance_matrix(filter: &mut Kalman) { +fn initialize_state_covariance_matrix(filter: &mut Kalman<'_, NUM_STATES, NUM_INPUTS>) { filter.system_covariance_apply(|p| { p.set(0, 0, 0.1 as _); // var(s) p.set(0, 1, 0 as _); // cov(s, v) @@ -195,7 +194,9 @@ fn initialize_state_covariance_matrix(filter: &mut Kalman) { /// ```math /// z = 1×s + 0×v + 0×a /// ``` -fn initialize_position_measurement_transformation_matrix(measurement: &mut Measurement) { +fn initialize_position_measurement_transformation_matrix( + measurement: &mut Measurement<'_, NUM_STATES, NUM_MEASUREMENTS>, +) { measurement.measurement_transformation_apply(|h| { h.set(0, 0, 1 as _); // z = 1*s h.set(0, 1, 0 as _); // + 0*v @@ -208,7 +209,9 @@ fn initialize_position_measurement_transformation_matrix(measurement: &mut Measu /// This matrix describes the measurement covariances as well as the /// individual variation components. It is the measurement counterpart /// of the state covariance matrix. -fn initialize_position_measurement_process_noise_matrix(measurement: &mut Measurement) { +fn initialize_position_measurement_process_noise_matrix( + measurement: &mut Measurement<'_, NUM_STATES, NUM_MEASUREMENTS>, +) { measurement.process_noise_apply(|r| { r.set(0, 0, 0.5 as _); // var(s) }); diff --git a/src/kalman.rs b/src/kalman.rs index e7ebd96..3b5f3fd 100644 --- a/src/kalman.rs +++ b/src/kalman.rs @@ -4,51 +4,51 @@ use stdint::uint_fast8_t; /// Kalman Filter structure. #[allow(non_snake_case, unused)] -pub struct Kalman<'a> { +pub struct Kalman<'a, const STATES: usize, const INPUTS: usize> { /// The number of states. num_states: uint_fast8_t, /// The number of inputs. num_inputs: uint_fast8_t, /// State vector. - x: Matrix<'a>, + x: Matrix<'a, STATES, 1>, /// System matrix. /// /// See also [`P`]. - A: Matrix<'a>, + A: Matrix<'a, STATES, STATES>, /// System covariance matrix. /// /// See also [`A`]. - P: Matrix<'a>, + P: Matrix<'a, STATES, STATES>, /// Input vector. - u: Matrix<'a>, + u: Matrix<'a, INPUTS, 1>, /// Input matrix. /// /// See also [`Q`]. - B: Matrix<'a>, + B: Matrix<'a, STATES, INPUTS>, /// Input covariance matrix. /// /// See also [`B`]. - Q: Matrix<'a>, + Q: Matrix<'a, INPUTS, INPUTS>, /// Temporary storage. - temporary: KalmanTemporary<'a>, + temporary: KalmanTemporary<'a, STATES, INPUTS>, } #[allow(non_snake_case)] -struct KalmanTemporary<'a> { +struct KalmanTemporary<'a, const STATES: usize, const INPUTS: usize> { /// x-sized temporary vector. - predicted_x: Matrix<'a>, + predicted_x: Matrix<'a, STATES, 1>, /// P-Sized temporary matrix (number of states × number of states). /// /// The backing field for this temporary MAY be aliased with temporary BQ. - P: Matrix<'a>, + P: Matrix<'a, STATES, STATES>, /// B×Q-sized temporary matrix (number of states × number of inputs). /// /// The backing field for this temporary MAY be aliased with temporary P. - BQ: Matrix<'a>, + BQ: Matrix<'a, STATES, INPUTS>, } -impl<'a> Kalman<'a> { +impl<'a, const STATES: usize, const INPUTS: usize> Kalman<'a, STATES, INPUTS> { /// Initializes a Kalman filter instance. /// /// ## Arguments @@ -78,6 +78,8 @@ impl<'a> Kalman<'a> { temp_P: &'a mut [matrix_data_t], temp_BQ: &'a mut [matrix_data_t], ) -> Self { + debug_assert_eq!(STATES, num_states.into()); + debug_assert_eq!(INPUTS, num_inputs.into()); Self { num_states, num_inputs, @@ -113,16 +115,18 @@ impl<'a> Kalman<'a> { pub fn new( num_states: uint_fast8_t, num_inputs: uint_fast8_t, - A: Matrix<'a>, - x: Matrix<'a>, - B: Matrix<'a>, - u: Matrix<'a>, - P: Matrix<'a>, - Q: Matrix<'a>, - predictedX: Matrix<'a>, - temp_P: Matrix<'a>, - temp_BQ: Matrix<'a>, + A: Matrix<'a, STATES, STATES>, + x: Matrix<'a, STATES, 1>, + B: Matrix<'a, STATES, INPUTS>, + u: Matrix<'a, INPUTS, 1>, + P: Matrix<'a, STATES, STATES>, + Q: Matrix<'a, INPUTS, INPUTS>, + predictedX: Matrix<'a, STATES, 1>, + temp_P: Matrix<'a, STATES, STATES>, + temp_BQ: Matrix<'a, STATES, INPUTS>, ) -> Self { + debug_assert_eq!(STATES, num_states.into()); + debug_assert_eq!(INPUTS, num_inputs.into()); debug_assert_eq!( A.rows, num_states, "The state transition matrix A requires {} rows and {} columns (i.e. states × states)", @@ -241,14 +245,14 @@ impl<'a> Kalman<'a> { /// Gets a reference to the state vector x. #[inline(always)] - pub fn state_vector_ref(&self) -> &Matrix { + pub fn state_vector_ref(&self) -> &Matrix<'_, STATES, 1> { &self.x } /// Gets a reference to the state vector x. #[inline(always)] #[doc(alias = "kalman_get_state_vector")] - pub fn state_vector_mut<'b: 'a>(&'b mut self) -> &'b mut Matrix<'a> { + pub fn state_vector_mut<'b: 'a>(&'b mut self) -> &'b mut Matrix<'a, STATES, 1> { &mut self.x } @@ -256,21 +260,21 @@ impl<'a> Kalman<'a> { #[inline(always)] pub fn state_vector_apply(&mut self, mut f: F) where - F: FnMut(&mut Matrix<'a>) -> (), + F: FnMut(&mut Matrix<'a, STATES, 1>) -> (), { f(&mut self.x) } /// Gets a reference to the state transition matrix A. #[inline(always)] - pub fn state_transition_ref(&self) -> &Matrix { + pub fn state_transition_ref(&self) -> &Matrix<'_, STATES, STATES> { &self.A } /// Gets a reference to the state transition matrix A. #[inline(always)] #[doc(alias = "kalman_get_state_transition")] - pub fn state_transition_mut(&'a mut self) -> &mut Matrix { + pub fn state_transition_mut(&'a mut self) -> &mut Matrix<'_, STATES, STATES> { &mut self.A } @@ -278,21 +282,21 @@ impl<'a> Kalman<'a> { #[inline(always)] pub fn state_transition_apply(&mut self, mut f: F) where - F: FnMut(&mut Matrix<'a>) -> (), + F: FnMut(&mut Matrix<'a, STATES, STATES>) -> (), { f(&mut self.A) } /// Gets a reference to the system covariance matrix P. #[inline(always)] - pub fn system_covariance_ref(&self) -> &Matrix { + pub fn system_covariance_ref(&self) -> &Matrix<'_, STATES, STATES> { &self.P } /// Gets a mutable reference to the system covariance matrix P. #[inline(always)] #[doc(alias = "kalman_get_system_covariance")] - pub fn system_covariance_mut(&'a mut self) -> &'a mut Matrix { + pub fn system_covariance_mut(&'a mut self) -> &'a mut Matrix<'_, STATES, STATES> { &mut self.P } @@ -300,21 +304,21 @@ impl<'a> Kalman<'a> { #[inline(always)] pub fn system_covariance_apply(&mut self, mut f: F) where - F: FnMut(&mut Matrix<'a>) -> (), + F: FnMut(&mut Matrix<'a, STATES, STATES>) -> (), { f(&mut self.P) } /// Gets a reference to the input vector u. #[inline(always)] - pub fn input_vector_ref(&self) -> &Matrix { + pub fn input_vector_ref(&self) -> &Matrix<'_, INPUTS, 1> { &self.u } /// Gets a mutable reference to the input vector u. #[inline(always)] #[doc(alias = "kalman_get_input_vector")] - pub fn input_vector_mut(&'a mut self) -> &'a mut Matrix { + pub fn input_vector_mut(&'a mut self) -> &'a mut Matrix<'_, INPUTS, 1> { &mut self.u } @@ -322,21 +326,21 @@ impl<'a> Kalman<'a> { #[inline(always)] pub fn input_vector_apply(&mut self, mut f: F) where - F: FnMut(&mut Matrix<'a>) -> (), + F: FnMut(&mut Matrix<'a, INPUTS, 1>) -> (), { f(&mut self.u) } /// Gets a reference to the input transition matrix B. #[inline(always)] - pub fn input_transition_ref(&self) -> &Matrix { + pub fn input_transition_ref(&self) -> &Matrix<'a, STATES, INPUTS> { &self.B } /// Gets a mutable reference to the input transition matrix B. #[inline(always)] #[doc(alias = "kalman_get_input_transition")] - pub fn input_transition_mut(&'a mut self) -> &'a mut Matrix { + pub fn input_transition_mut(&'a mut self) -> &'a mut Matrix<'_, STATES, INPUTS> { &mut self.B } @@ -344,21 +348,21 @@ impl<'a> Kalman<'a> { #[inline(always)] pub fn input_transition_apply(&mut self, mut f: F) where - F: FnMut(&mut Matrix<'a>) -> (), + F: FnMut(&mut Matrix<'a, STATES, INPUTS>) -> (), { f(&mut self.B) } /// Gets a reference to the input covariance matrix Q. #[inline(always)] - pub fn input_covariance_ref(&self) -> &Matrix { + pub fn input_covariance_ref(&self) -> &Matrix<'_, INPUTS, INPUTS> { &self.Q } /// Gets a mutable reference to the input covariance matrix Q. #[inline(always)] #[doc(alias = "kalman_get_input_covariance")] - pub fn input_covariance_mut(&'a mut self) -> &'a mut Matrix { + pub fn input_covariance_mut(&'a mut self) -> &'a mut Matrix<'_, INPUTS, INPUTS> { &mut self.Q } @@ -367,7 +371,7 @@ impl<'a> Kalman<'a> { #[doc(alias = "kalman_get_input_covariance")] pub fn input_covariance_apply(&mut self, mut f: F) where - F: FnMut(&mut Matrix<'a>) -> (), + F: FnMut(&mut Matrix<'a, INPUTS, INPUTS>) -> (), { f(&mut self.Q) } @@ -489,7 +493,7 @@ impl<'a> Kalman<'a> { /// * `kfm` - The measurement. #[allow(non_snake_case)] #[doc(alias = "kalman_predict_Q")] - pub fn correct(&mut self, kfm: &mut Measurement<'a>) { + pub fn correct(&mut self, kfm: &mut Measurement<'a, STATES, M>) { // matrices and vectors let P = &mut self.P; let x = &mut self.x; diff --git a/src/matrix.rs b/src/matrix.rs index 5ecb1b7..009ce7f 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -13,7 +13,7 @@ use micromath::F32Ext; pub type matrix_data_t = f32; /// A matrix wrapping a data buffer. -pub struct Matrix<'a> { +pub struct Matrix<'a, const R: usize, const C: usize> { pub rows: uint_fast8_t, pub cols: uint_fast8_t, pub data: &'a mut [matrix_data_t], @@ -26,7 +26,7 @@ macro_rules! idx { }; } -impl<'a> Matrix<'a> { +impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { /// Initializes a matrix structure. /// /// ## Arguments @@ -42,6 +42,8 @@ impl<'a> Matrix<'a> { cols, rows * cols ); + debug_assert_eq!(R, rows.into()); + debug_assert_eq!(C, cols.into()); Self { rows, cols, @@ -84,6 +86,8 @@ impl<'a> Matrix<'a> { cols, rows * cols ); + debug_assert_eq!(R, rows.into()); + debug_assert_eq!(C, cols.into()); Self { rows, cols, @@ -100,7 +104,9 @@ impl<'a> Matrix<'a> { pub const fn is_empty(&self) -> bool { self.len() == 0 } +} +impl<'a, const N: usize> Matrix<'a, N, N> { /// Inverts a square lower triangular matrix. Meant to be used with /// [`Matrix::cholesky_decompose_lower`]. /// @@ -124,11 +130,11 @@ impl<'a> Matrix<'a> { /// 1.0, 0.5, 0.0, /// 0.5, 1.0, 0.0, /// 0.0, 0.0, 1.0]; - /// let mut m = Matrix::new(3, 3, &mut d); + /// let mut m = Matrix::<3, 3>::new(3, 3, &mut d); /// /// // data buffer for the inverted matrix /// let mut di = [0.0; 3 * 3]; - /// let mut mi = Matrix::new(3, 3, &mut di); + /// let mut mi = Matrix::<3, 3>::new(3, 3, &mut di); /// /// // Decompose matrix to lower triangular. /// m.cholesky_decompose_lower(); @@ -203,7 +209,9 @@ impl<'a> Matrix<'a> { } } } +} +impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { /// Performs a matrix multiplication such that `C = A * B`. This method /// uses an auxiliary buffer for keeping one row of `B` cached. This might /// improve performance on very wide matrices but is generally slower than @@ -222,16 +230,16 @@ impl<'a> Matrix<'a> { /// let mut a_buf = [ /// 1.0, 2.0, 3.0, /// 4.0, 5.0, 6.0]; - /// let a = Matrix::new(2, 3, &mut a_buf); + /// let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); /// /// let mut b_buf = [ /// 10.0, 11.0, /// 20.0, 21.0, /// 30.0, 31.0]; - /// let b = Matrix::new(3, 2, &mut b_buf); + /// let b = Matrix::<3, 2>::new(3, 2, &mut b_buf); /// /// let mut c_buf = [0f32; 2 * 2]; - /// let mut c = Matrix::new(2, 2, &mut c_buf); + /// let mut c = Matrix::<2, 2>::new(2, 2, &mut c_buf); /// /// let mut aux = [0f32; 3 * 1]; /// a.mult_buffered(&b, &mut c, &mut aux); @@ -244,7 +252,12 @@ impl<'a> Matrix<'a> { /// /// Kudos: https://code.google.com/p/efficient-java-matrix-library #[doc(alias = "matrix_mult_buffered")] - pub fn mult_buffered(&self, b: &Self, c: &mut Self, baux: &mut [matrix_data_t]) { + pub fn mult_buffered( + &self, + b: &Matrix<'_, C, U>, + c: &mut Matrix<'_, R, U>, + baux: &mut [matrix_data_t], + ) { let bcols = b.cols; let ccols = c.cols; let brows = b.rows; @@ -295,16 +308,16 @@ impl<'a> Matrix<'a> { /// let mut a_buf = [ /// 1.0, 2.0, 3.0, /// 4.0, 5.0, 6.0]; - /// let a = Matrix::new(2, 3, &mut a_buf); + /// let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); /// /// let mut b_buf = [ /// 10.0, 11.0, /// 20.0, 21.0, /// 30.0, 31.0]; - /// let b = Matrix::new(3, 2, &mut b_buf); + /// let b = Matrix::<3, 2>::new(3, 2, &mut b_buf); /// /// let mut c_buf = [0f32; 2 * 2]; - /// let mut c = Matrix::new(2, 2, &mut c_buf); + /// let mut c = Matrix::<2, 2>::new(2, 2, &mut c_buf); /// /// a.mult(&b, &mut c); /// @@ -316,7 +329,7 @@ impl<'a> Matrix<'a> { /// /// Kudos: https://code.google.com/p/efficient-java-matrix-library #[doc(alias = "matrix_mult")] - pub fn mult(&self, b: &Self, c: &mut Self) { + pub fn mult(&self, b: &Matrix<'_, C, U>, c: &mut Matrix<'_, R, U>) { let bcols = b.cols; let ccols = c.cols; let brows = b.rows; @@ -355,7 +368,7 @@ impl<'a> Matrix<'a> { /// /// Kudos: https://code.google.com/p/efficient-java-matrix-library #[doc(alias = "matrix_mult_rowvector")] - pub fn mult_rowvector(&self, x: &Self, c: &mut Self) { + pub fn mult_rowvector(&self, x: &Matrix<'_, C, 1>, c: &mut Matrix<'_, R, 1>) { let arows = self.rows; let acols = self.cols; @@ -396,7 +409,7 @@ impl<'a> Matrix<'a> { /// /// Kudos: https://code.google.com/p/efficient-java-matrix-library #[doc(alias = "matrix_multadd_rowvector")] - pub fn multadd_rowvector(&self, x: &Self, c: &mut Self) { + pub fn multadd_rowvector(&self, x: &Matrix<'_, C, 1>, c: &mut Matrix<'_, R, 1>) { let arows = self.rows; let acols = self.cols; @@ -437,7 +450,7 @@ impl<'a> Matrix<'a> { /// /// Kudos: https://code.google.com/p/efficient-java-matrix-library #[doc(alias = "matrix_mult_transb")] - pub fn mult_transb(&self, b: &Self, c: &mut Self) { + pub fn mult_transb(&self, b: &Matrix<'_, U, C>, c: &mut Matrix<'_, R, U>) { let bcols = b.cols; let brows = b.rows; let arows = self.rows; @@ -486,7 +499,7 @@ impl<'a> Matrix<'a> { /// /// Kudos: https://code.google.com/p/efficient-java-matrix-library #[doc(alias = "matrix_multadd_transb")] - pub fn multadd_transb(&self, b: &Self, c: &mut Self) { + pub fn multadd_transb(&self, b: &Matrix<'_, U, C>, c: &mut Matrix<'_, R, U>) { let bcols = b.cols; let brows = b.rows; let arows = self.rows; @@ -536,7 +549,12 @@ impl<'a> Matrix<'a> { /// /// Kudos: https://code.google.com/p/efficient-java-matrix-library #[doc(alias = "matrix_multscale_transb")] - pub fn multscale_transb(&self, b: &Self, scale: matrix_data_t, c: &mut Self) { + pub fn multscale_transb( + &self, + b: &Matrix<'_, U, C>, + scale: matrix_data_t, + c: &mut Matrix<'_, R, U>, + ) { let bcols = b.cols; let brows = b.rows; let arows = self.rows; @@ -822,7 +840,7 @@ impl<'a> Matrix<'a> { /// 0.5, 1.0, 0.0, /// 0.0, 0.0, 1.0]; /// - /// let mut m = Matrix::new(3, 3, &mut d); + /// let mut m = Matrix::<3, 3>::new(3, 3, &mut d); /// /// // Decompose matrix to lower triangular. /// m.cholesky_decompose_lower(); @@ -889,7 +907,7 @@ impl<'a> Matrix<'a> { } } -impl<'a> Index for Matrix<'a> { +impl<'a, const R: usize, const C: usize> Index for Matrix<'a, R, C> { type Output = matrix_data_t; #[inline(always)] @@ -898,26 +916,26 @@ impl<'a> Index for Matrix<'a> { } } -impl<'a> IndexMut for Matrix<'a> { +impl<'a, const R: usize, const C: usize> IndexMut for Matrix<'a, R, C> { #[inline(always)] fn index_mut(&mut self, index: usize) -> &mut Self::Output { &mut self.data[index] } } -impl<'a> AsRef<[matrix_data_t]> for Matrix<'a> { +impl<'a, const R: usize, const C: usize> AsRef<[matrix_data_t]> for Matrix<'a, R, C> { fn as_ref(&self) -> &[matrix_data_t] { &self.data } } -impl<'a> AsMut<[matrix_data_t]> for Matrix<'a> { +impl<'a, const R: usize, const C: usize> AsMut<[matrix_data_t]> for Matrix<'a, R, C> { fn as_mut(&mut self) -> &mut [matrix_data_t] { &mut self.data } } -impl<'a> MatrixBase for Matrix<'a> { +impl<'a, const R: usize, const C: usize> MatrixBase for Matrix<'a, R, C> { fn rows(&self) -> uint_fast8_t { self.rows } @@ -939,110 +957,6 @@ impl<'a> MatrixBase for Matrix<'a> { } } -impl<'a> MatrixOps for Matrix<'a> { - type Target = Matrix<'a>; - - #[inline(always)] - fn invert_l_cholesky(&self, inverse: &mut Self::Target) { - self.invert_l_cholesky(inverse) - } - - #[inline(always)] - fn mult_buffered(&self, b: &Self::Target, c: &mut Self::Target, baux: &mut [matrix_data_t]) { - self.mult_buffered(b, c, baux) - } - - #[inline(always)] - fn mult(&self, b: &Self::Target, c: &mut Self::Target) { - self.mult(b, c) - } - - #[inline(always)] - fn mult_rowvector(&self, x: &Self::Target, c: &mut Self::Target) { - self.mult_rowvector(x, c) - } - - #[inline(always)] - fn multadd_rowvector(&self, x: &Self::Target, c: &mut Self::Target) { - self.multadd_rowvector(x, c) - } - - #[inline(always)] - fn mult_transb(&self, b: &Self::Target, c: &mut Self::Target) { - self.mult_transb(b, c) - } - - #[inline(always)] - fn multadd_transb(&self, b: &Self::Target, c: &mut Self::Target) { - self.multadd_transb(b, c) - } - - #[inline(always)] - fn multscale_transb(&self, b: &Self::Target, scale: matrix_data_t, c: &mut Self::Target) { - self.multscale_transb(b, scale, c) - } - - #[inline(always)] - fn get(&self, row: uint_fast8_t, column: uint_fast8_t) -> matrix_data_t { - self.get(row, column) - } - - #[inline(always)] - fn set(&mut self, row: uint_fast8_t, column: uint_fast8_t, value: matrix_data_t) { - self.set(row, column, value) - } - - #[inline(always)] - fn set_symmetric(&mut self, row: uint_fast8_t, column: uint_fast8_t, value: matrix_data_t) { - self.set_symmetric(row, column, value) - } - - #[inline(always)] - fn get_column_copy(&self, column: uint_fast8_t, col_data: &mut [matrix_data_t]) { - self.get_column_copy(column, col_data) - } - - #[inline(always)] - fn get_row_copy(&self, row: uint_fast8_t, row_data: &mut [matrix_data_t]) { - self.get_row_copy(row, row_data) - } - - #[inline(always)] - fn copy(&self, target: &mut Self::Target) { - self.copy(target) - } - - #[inline(always)] - fn sub(&self, b: &Self::Target, c: &mut Self::Target) { - self.sub(b, c) - } - - #[inline(always)] - fn sub_inplace_a(&mut self, b: &Self::Target) { - self.sub_inplace_a(b) - } - - #[inline(always)] - fn sub_inplace_b(&self, b: &mut Self::Target) { - self.sub_inplace_b(b) - } - - #[inline(always)] - fn add_inplace_a(&mut self, b: &Self::Target) { - self.add_inplace_a(b) - } - - #[inline(always)] - fn add_inplace_b(&self, b: &mut Self::Target) { - self.add_inplace_b(b) - } - - #[inline(always)] - fn cholesky_decompose_lower(&mut self) -> bool { - self.cholesky_decompose_lower() - } -} - #[cfg(test)] mod tests { use crate::matrix::Matrix; @@ -1058,11 +972,11 @@ mod tests { 10.0, 11.0, 20.0, 21.0, 30.0, 31.0]; - let a = Matrix::new(2, 3, &mut a_buf); - let b = Matrix::new(3, 2, &mut b_buf); + let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); + let b = Matrix::<3, 2>::new(3, 2, &mut b_buf); let mut c_buf = [0f32; 2 * 2]; - let mut c = Matrix::new(2, 2, &mut c_buf); + let mut c = Matrix::<2, 2>::new(2, 2, &mut c_buf); let mut aux = [0f32; 3 * 1]; a.mult_buffered(&b, &mut c, &mut aux); @@ -1082,11 +996,11 @@ mod tests { 10.0, 11.0, 20.0, 21.0, 30.0, 31.0]; - let a = Matrix::new(2, 3, &mut a_buf); - let b = Matrix::new(3, 2, &mut b_buf); + let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); + let b = Matrix::<3, 2>::new(3, 2, &mut b_buf); let mut c_buf = [0f32; 2 * 2]; - let mut c = Matrix::new(2, 2, &mut c_buf); + let mut c = Matrix::<2, 2>::new(2, 2, &mut c_buf); a.mult(&b, &mut c); assert_f32_near!(c_buf[0], 1. * 10. + 2. * 20. + 3. * 30.); // 140 @@ -1104,11 +1018,11 @@ mod tests { let mut b_buf = [ 10.0, 20.0, 30.0, 11.0, 21.0, 31.0]; - let a = Matrix::new(2, 3, &mut a_buf); - let b = Matrix::new(2, 3, &mut b_buf); + let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); + let b = Matrix::<2, 3>::new(2, 3, &mut b_buf); let mut c_buf = [0f32; 2 * 2]; - let mut c = Matrix::new(2, 2, &mut c_buf); + let mut c = Matrix::<2, 2>::new(2, 2, &mut c_buf); Matrix::mult_transb(&a, &b, &mut c); assert_f32_near!(c_buf[0], 1. * 10. + 2. * 20. + 3. * 30.); // 140 @@ -1124,19 +1038,19 @@ mod tests { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, -9.0]; - let a = Matrix::new(3, 3, &mut a_buf); + let a = Matrix::<3, 3>::new(3, 3, &mut a_buf); let mut b_buf = [ -4.0, -1.0, 0.0, 2.0, 3.0, 4.0, 5.0, 9.0, -10.0]; - let b = Matrix::new(3, 3, &mut b_buf); + let b = Matrix::<3, 3>::new(3, 3, &mut b_buf); let mut c_buf = [0f32; 3 * 3]; - let mut c = Matrix::new(3, 3, &mut c_buf); + let mut c = Matrix::<3, 3>::new(3, 3, &mut c_buf); let mut d_buf = [0f32; 3 * 3]; - let mut d = Matrix::new(3, 3, &mut d_buf); + let mut d = Matrix::<3, 3>::new(3, 3, &mut d_buf); // Example P = A*P*A' a.mult(&b, &mut c); // temp = A*P @@ -1162,13 +1076,13 @@ mod tests { let mut b_buf = [ 10.0, 20.0, 30.0, 11.0, 21.0, 31.0]; - let a = Matrix::new(2, 3, &mut a_buf); - let b = Matrix::new(2, 3, &mut b_buf); + let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); + let b = Matrix::<2, 3>::new(2, 3, &mut b_buf); let mut c_buf = [ 1000., 2000., 3000., 4000.]; - let mut c = Matrix::new(2, 2, &mut c_buf); + let mut c = Matrix::<2, 2>::new(2, 2, &mut c_buf); Matrix::multadd_transb(&a, &b, &mut c); assert_f32_near!(c.get(0, 0), 1000. + 1. * 10. + 2. * 20. + 3. * 30.); // 1140 @@ -1186,11 +1100,11 @@ mod tests { let mut b_buf = [ 10.0, 20.0, 30.0, 11.0, 21.0, 31.0]; - let a = Matrix::new(2, 3, &mut a_buf); - let b = Matrix::new(2, 3, &mut b_buf); + let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); + let b = Matrix::<2, 3>::new(2, 3, &mut b_buf); let mut c_buf = [0f32; 2 * 2]; - let mut c = Matrix::new(2, 2, &mut c_buf); + let mut c = Matrix::<2, 2>::new(2, 2, &mut c_buf); Matrix::multscale_transb(&a, &b, 2.0, &mut c); assert_f32_near!(c_buf[0], 2.0 * (1. * 10. + 2. * 20. + 3. * 30.)); // 280 @@ -1209,11 +1123,11 @@ mod tests { 10.0, 20.0, 30.0]; - let a = Matrix::new(2, 3, &mut a_buf); - let b = Matrix::new(3, 1, &mut b_buf); + let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); + let b = Matrix::<3, 1>::new(3, 1, &mut b_buf); let mut c_buf = [0f32; 2 * 1]; - let mut c = Matrix::new(2, 1, &mut c_buf); + let mut c = Matrix::<2, 1>::new(2, 1, &mut c_buf); Matrix::mult_rowvector(&a, &b, &mut c); assert_f32_near!(c_buf[0], 1. * 10. + 2. * 20. + 3. * 30.); // 140 @@ -1230,11 +1144,11 @@ mod tests { 10.0, 20.0, 30.0]; - let a = Matrix::new(2, 3, &mut a_buf); - let b = Matrix::new(3, 1, &mut b_buf); + let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); + let b = Matrix::<3, 1>::new(3, 1, &mut b_buf); let mut c_buf = [1000., 2000.]; - let mut c = Matrix::new(2, 1, &mut c_buf); + let mut c = Matrix::<2, 1>::new(2, 1, &mut c_buf); Matrix::multadd_rowvector(&a, &b, &mut c); assert_f32_near!(c.get(0, 0), 1000. + 1. * 10. + 2. * 20. + 3. * 30.); // 1140 @@ -1247,7 +1161,7 @@ mod tests { let mut a_buf = [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; - let a = Matrix::new(2, 3, &mut a_buf); + let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); let mut a_out = [0.0; 3].as_slice(); a.get_row_pointer(0, &mut a_out); @@ -1266,11 +1180,11 @@ mod tests { let mut b_buf = [ 10.0, 20.0, 30.0, 11.0, 21.0, 31.0]; - let a = Matrix::new(2, 3, &mut a_buf); - let b = Matrix::new(2, 3, &mut b_buf); + let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); + let b = Matrix::<2, 3>::new(2, 3, &mut b_buf); let mut c_buf = [0f32; 2 * 3]; - let mut c = Matrix::new(2, 3, &mut c_buf); + let mut c = Matrix::<2, 3>::new(2, 3, &mut c_buf); Matrix::sub(&a, &b, &mut c); assert_eq!(c_buf, [ @@ -1287,8 +1201,8 @@ mod tests { let mut b_buf = [ 10.0, 20.0, 30.0, 11.0, 21.0, 31.0]; - let a = Matrix::new(2, 3, &mut a_buf); - let mut b = Matrix::new(2, 3, &mut b_buf); + let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); + let mut b = Matrix::<2, 3>::new(2, 3, &mut b_buf); Matrix::sub_inplace_b(&a, &mut b); assert_eq!(b_buf, [ @@ -1305,8 +1219,8 @@ mod tests { let mut b_buf = [ 10.0, 20.0, 30.0, 11.0, 21.0, 31.0]; - let a = Matrix::new(2, 3, &mut a_buf); - let mut b = Matrix::new(2, 3, &mut b_buf); + let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); + let mut b = Matrix::<2, 3>::new(2, 3, &mut b_buf); Matrix::add_inplace_b(&a, &mut b); assert_eq!(b_buf, [ @@ -1324,7 +1238,7 @@ mod tests { 0.5, 1.0, 0.0, 0.0, 0.0, 1.0]; - let mut m = Matrix::new(3, 3, &mut d); + let mut m = Matrix::<3, 3>::new(3, 3, &mut d); // Decompose matrix to lower triangular. m.cholesky_decompose_lower(); @@ -1361,11 +1275,11 @@ mod tests { 1.0, 0.5, 0.0, 0.5, 1.0, 0.0, 0.0, 0.0, 1.0]; - let mut m = Matrix::new(3, 3, &mut d); + let mut m = Matrix::<3, 3>::new(3, 3, &mut d); // data buffer for the inverted matrix let mut di = [0.0; 3 * 3]; - let mut mi = Matrix::new(3, 3, &mut di); + let mut mi = Matrix::<3, 3>::new(3, 3, &mut di); // Decompose matrix to lower triangular. m.cholesky_decompose_lower(); diff --git a/src/matrix_owned.rs b/src/matrix_owned.rs index e69de29..8b13789 100644 --- a/src/matrix_owned.rs +++ b/src/matrix_owned.rs @@ -0,0 +1 @@ + diff --git a/src/measurement.rs b/src/measurement.rs index ca40984..bb3bb86 100644 --- a/src/measurement.rs +++ b/src/measurement.rs @@ -3,61 +3,61 @@ use stdint::uint_fast8_t; /// Kalman Filter measurement structure. #[allow(non_snake_case, unused)] -pub struct Measurement<'a> { +pub struct Measurement<'a, const STATES: usize, const MEASUREMENTS: usize> { /// The number of states. pub num_states: uint_fast8_t, /// The number of measurements. pub num_measurements: uint_fast8_t, /// Measurement vector. - pub(crate) z: Matrix<'a>, + pub(crate) z: Matrix<'a, MEASUREMENTS, 1>, /// Measurement transformation matrix. /// /// See also [`R`]. - pub(crate) H: Matrix<'a>, + pub(crate) H: Matrix<'a, MEASUREMENTS, STATES>, /// Process noise covariance matrix. /// /// See also [`A`]. - pub(crate) R: Matrix<'a>, + pub(crate) R: Matrix<'a, MEASUREMENTS, MEASUREMENTS>, /// Innovation vector. - pub(crate) y: Matrix<'a>, + pub(crate) y: Matrix<'a, MEASUREMENTS, 1>, /// Residual covariance matrix. - pub(crate) S: Matrix<'a>, + pub(crate) S: Matrix<'a, MEASUREMENTS, MEASUREMENTS>, /// Kalman gain matrix. - pub(crate) K: Matrix<'a>, + pub(crate) K: Matrix<'a, STATES, MEASUREMENTS>, /// Temporary storage. - pub(crate) temporary: MeasurementTemporary<'a>, + pub(crate) temporary: MeasurementTemporary<'a, STATES, MEASUREMENTS>, } #[allow(non_snake_case)] -pub(crate) struct MeasurementTemporary<'a> { +pub(crate) struct MeasurementTemporary<'a, const STATES: usize, const MEASUREMENTS: usize> { /// S-Sized temporary matrix (number of measurements × number of measurements). /// /// - The backing field for this temporary MAY be aliased with temporary [`KHP`]. /// - The backing field for this temporary MAY be aliased with temporary [`HP`] (if it is not aliased with [`PHt`]). /// - The backing field for this temporary MUST NOT be aliased with temporary [`PHt`]. - pub(crate) S_inv: Matrix<'a>, + pub(crate) S_inv: Matrix<'a, MEASUREMENTS, MEASUREMENTS>, /// H-Sized temporary matrix (number of measurements × number of states). /// /// - The backing field for this temporary MAY be aliased with temporary [`S_inv`]. /// - The backing field for this temporary MAY be aliased with temporary [`PHt`]. /// - The backing field for this temporary MUST NOT be aliased with temporary [`KHP`]. - pub(crate) HP: Matrix<'a>, + pub(crate) HP: Matrix<'a, MEASUREMENTS, STATES>, /// P-Sized temporary matrix (number of states × number of states). /// /// - The backing field for this temporary MAY be aliased with temporary [`S_inv`]. /// - The backing field for this temporary MAY be aliased with temporary [`PHt`]. /// - The backing field for this temporary MUST NOT be aliased with temporary [`HP`]. - pub(crate) KHP: Matrix<'a>, + pub(crate) KHP: Matrix<'a, STATES, STATES>, /// P×H'-Sized (H'-Sized) temporary matrix (number of states × number of measurements). /// /// - The backing field for this temporary MAY be aliased with temporary [`HP`]. /// - The backing field for this temporary MAY be aliased with temporary [`KHP`]. /// - The backing field for this temporary MUST NOT be aliased with temporary [`S_inv`]. - pub(crate) PHt: Matrix<'a>, + pub(crate) PHt: Matrix<'a, STATES, MEASUREMENTS>, } -impl<'a> Measurement<'a> { +impl<'a, const STATES: usize, const MEASUREMENTS: usize> Measurement<'a, STATES, MEASUREMENTS> { /// Initializes a measurement. /// /// ## Arguments @@ -66,7 +66,7 @@ impl<'a> Measurement<'a> { /// * `H` - The measurement transformation matrix (`num_measurements` × `num_states`). /// * `z` - The measurement vector (`num_measurements` × `1`). /// * `R` - The process noise / measurement uncertainty (`num_measurements` × `num_measurements`). - /// * `v` - The innovation (`num_measurements` × `1`). + /// * `y` - The innovation (`num_measurements` × `1`). /// * `S` - The residual covariance (`num_measurements` × `num_measurements`). /// * `K` - The Kalman gain (`num_states` × `num_measurements`). /// * `S_inv` - The temporary vector for predicted states (`num_states` × `1`). @@ -89,6 +89,8 @@ impl<'a> Measurement<'a> { temp_PHt: &'a mut [matrix_data_t], temp_KHP: &'a mut [matrix_data_t], ) -> Self { + debug_assert_eq!(STATES, num_states.into()); + debug_assert_eq!(MEASUREMENTS, num_measurements.into()); Self { num_states, num_measurements, @@ -127,17 +129,19 @@ impl<'a> Measurement<'a> { pub fn new( num_states: uint_fast8_t, num_measurements: uint_fast8_t, - H: Matrix<'a>, - z: Matrix<'a>, - R: Matrix<'a>, - y: Matrix<'a>, - S: Matrix<'a>, - K: Matrix<'a>, - S_inv: Matrix<'a>, - temp_HP: Matrix<'a>, - temp_PHt: Matrix<'a>, - temp_KHP: Matrix<'a>, + H: Matrix<'a, MEASUREMENTS, STATES>, + z: Matrix<'a, MEASUREMENTS, 1>, + R: Matrix<'a, MEASUREMENTS, MEASUREMENTS>, + y: Matrix<'a, MEASUREMENTS, 1>, + S: Matrix<'a, MEASUREMENTS, MEASUREMENTS>, + K: Matrix<'a, STATES, MEASUREMENTS>, + S_inv: Matrix<'a, MEASUREMENTS, MEASUREMENTS>, + temp_HP: Matrix<'a, MEASUREMENTS, STATES>, + temp_PHt: Matrix<'a, STATES, MEASUREMENTS>, + temp_KHP: Matrix<'a, STATES, STATES>, ) -> Self { + debug_assert_eq!(STATES, num_states.into()); + debug_assert_eq!(MEASUREMENTS, num_measurements.into()); debug_assert_eq!( H.rows, num_measurements, "The measurement transformation matrix H requires {} rows and {} columns (i.e. measurements × states)", @@ -268,14 +272,14 @@ impl<'a> Measurement<'a> { /// Gets a reference to the measurement vector z. #[inline(always)] - pub fn measurement_vector_ref(&self) -> &Matrix { + pub fn measurement_vector_ref(&self) -> &Matrix<'_, MEASUREMENTS, 1> { &self.z } /// Gets a mutable reference to the measurement vector z. #[inline(always)] #[doc(alias = "kalman_get_measurement_vector")] - pub fn measurement_vector_mut(&'a mut self) -> &'a mut Matrix { + pub fn measurement_vector_mut(&'a mut self) -> &'a mut Matrix<'_, MEASUREMENTS, 1> { &mut self.z } @@ -283,21 +287,23 @@ impl<'a> Measurement<'a> { #[inline(always)] pub fn measurement_vector_apply(&mut self, mut f: F) where - F: FnMut(&mut Matrix<'a>) -> (), + F: FnMut(&mut Matrix<'a, MEASUREMENTS, 1>) -> (), { f(&mut self.z) } /// Gets a reference to the measurement transformation matrix H. #[inline(always)] - pub fn measurement_transformation_ref(&self) -> &Matrix { + pub fn measurement_transformation_ref(&self) -> &Matrix<'_, MEASUREMENTS, STATES> { &self.H } /// Gets a mutable reference to the measurement transformation matrix H. #[inline(always)] #[doc(alias = "kalman_get_measurement_transformation")] - pub fn measurement_transformation_mut(&'a mut self) -> &'a mut Matrix { + pub fn measurement_transformation_mut( + &'a mut self, + ) -> &'a mut Matrix<'_, MEASUREMENTS, STATES> { &mut self.H } @@ -305,21 +311,21 @@ impl<'a> Measurement<'a> { #[inline(always)] pub fn measurement_transformation_apply(&mut self, mut f: F) where - F: FnMut(&mut Matrix<'a>) -> (), + F: FnMut(&mut Matrix<'a, MEASUREMENTS, STATES>) -> (), { f(&mut self.H) } /// Gets a reference to the process noise matrix R. #[inline(always)] - pub fn process_noise_ref(&self) -> &Matrix { + pub fn process_noise_ref(&self) -> &Matrix<'_, MEASUREMENTS, MEASUREMENTS> { &self.R } /// Gets a mutable reference to the process noise matrix R. #[inline(always)] #[doc(alias = "kalman_get_process_noise")] - pub fn process_noise_mut(&'a mut self) -> &'a mut Matrix { + pub fn process_noise_mut(&'a mut self) -> &'a mut Matrix<'_, MEASUREMENTS, MEASUREMENTS> { &mut self.R } @@ -327,7 +333,7 @@ impl<'a> Measurement<'a> { #[inline(always)] pub fn process_noise_apply(&mut self, mut f: F) where - F: FnMut(&mut Matrix<'a>) -> (), + F: FnMut(&mut Matrix<'a, MEASUREMENTS, MEASUREMENTS>) -> (), { f(&mut self.R) } From 27f955112806978dbbdd3ddc0685a8b0a3f98db8 Mon Sep 17 00:00:00 2001 From: Markus Mayer Date: Mon, 6 Nov 2023 14:22:59 +0100 Subject: [PATCH 2/2] Remove redundant row/column count parameters --- examples/gravity.rs | 8 +- src/kalman.rs | 169 +++++++++++--------- src/matrix.rs | 373 +++++++++++++++++++++++--------------------- src/measurement.rs | 94 +++++------ 4 files changed, 339 insertions(+), 305 deletions(-) diff --git a/examples/gravity.rs b/examples/gravity.rs index 2ca1929..2d12f31 100644 --- a/examples/gravity.rs +++ b/examples/gravity.rs @@ -73,9 +73,7 @@ fn main() { let mut gravity_temp_PHt = create_buffer_temp_PHt!(NUM_STATES, NUM_MEASUREMENTS); let mut gravity_temp_KHP = create_buffer_temp_KHP!(NUM_STATES); - let mut filter = Kalman::new_from_buffers( - NUM_STATES as _, - NUM_INPUTS as _, + let mut filter = Kalman::::new_from_buffers( &mut gravity_A, &mut gravity_x, &mut gravity_B, @@ -87,9 +85,7 @@ fn main() { &mut gravity_temp_BQ, ); - let mut measurement = Measurement::new_direct( - NUM_STATES as _, - NUM_MEASUREMENTS as _, + let mut measurement = Measurement::::new_direct( &mut gravity_H, &mut gravity_z, &mut gravity_R, diff --git a/src/kalman.rs b/src/kalman.rs index 3b5f3fd..a4289e7 100644 --- a/src/kalman.rs +++ b/src/kalman.rs @@ -1,14 +1,9 @@ use crate::measurement::Measurement; use crate::{matrix_data_t, Matrix}; -use stdint::uint_fast8_t; /// Kalman Filter structure. #[allow(non_snake_case, unused)] pub struct Kalman<'a, const STATES: usize, const INPUTS: usize> { - /// The number of states. - num_states: uint_fast8_t, - /// The number of inputs. - num_inputs: uint_fast8_t, /// State vector. x: Matrix<'a, STATES, 1>, /// System matrix. @@ -49,25 +44,27 @@ struct KalmanTemporary<'a, const STATES: usize, const INPUTS: usize> { } impl<'a, const STATES: usize, const INPUTS: usize> Kalman<'a, STATES, INPUTS> { + /// The number of states. + const NUM_STATES: usize = STATES; + + /// The number of inputs. + const NUM_INPUTS: usize = INPUTS; + /// Initializes a Kalman filter instance. /// /// ## Arguments - /// * `num_states` - The number of states tracked by this filter. - /// * `num_inputs` - The number of inputs available to the filter. - /// * `A` - The state transition matrix (`num_states` × `num_states`). - /// * `x` - The state vector (`num_states` × `1`). - /// * `B` - The input transition matrix (`num_states` × `num_inputs`). - /// * `u` - The input vector (`num_inputs` × `1`). - /// * `P` - The state covariance matrix (`num_states` × `num_states`). - /// * `Q` - The input covariance matrix (`num_inputs` × `num_inputs`). - /// * `predictedX` - The temporary vector for predicted states (`num_states` × `1`). - /// * `temp_P` - The temporary vector for P calculation (`num_states` × `num_states`). - /// * `temp_BQ` - The temporary vector for B×Q calculation (`num_states` × `num_inputs`). + /// * `A` - The state transition matrix (`STATES` × `STATES`). + /// * `x` - The state vector (`STATES` × `1`). + /// * `B` - The input transition matrix (`STATES` × `INPUTS`). + /// * `u` - The input vector (`INPUTS` × `1`). + /// * `P` - The state covariance matrix (`STATES` × `STATES`). + /// * `Q` - The input covariance matrix (`INPUTS` × `INPUTS`). + /// * `predictedX` - The temporary vector for predicted states (`STATES` × `1`). + /// * `temp_P` - The temporary vector for P calculation (`STATES` × `STATES`). + /// * `temp_BQ` - The temporary vector for B×Q calculation (`STATES` × `INPUTS`). #[allow(non_snake_case)] #[doc(alias = "kalman_filter_initialize")] pub fn new_from_buffers( - num_states: uint_fast8_t, - num_inputs: uint_fast8_t, A: &'a mut [matrix_data_t], x: &'a mut [matrix_data_t], B: &'a mut [matrix_data_t], @@ -78,21 +75,17 @@ impl<'a, const STATES: usize, const INPUTS: usize> Kalman<'a, STATES, INPUTS> { temp_P: &'a mut [matrix_data_t], temp_BQ: &'a mut [matrix_data_t], ) -> Self { - debug_assert_eq!(STATES, num_states.into()); - debug_assert_eq!(INPUTS, num_inputs.into()); Self { - num_states, - num_inputs, - A: Matrix::new(num_states, num_states, A), - P: Matrix::new(num_states, num_states, P), - x: Matrix::new(num_states, 1, x), - B: Matrix::new(num_states, num_inputs, B), - Q: Matrix::new(num_inputs, num_inputs, Q), - u: Matrix::new(num_inputs, 1, u), + A: Matrix::::new(A), + P: Matrix::::new(P), + x: Matrix::::new(x), + B: Matrix::::new(B), + Q: Matrix::::new(Q), + u: Matrix::::new(u), temporary: KalmanTemporary { - predicted_x: Matrix::new(num_states, 1, predictedX), - P: Matrix::new(num_states, num_states, temp_P), - BQ: Matrix::new(num_states, num_inputs, temp_BQ), + predicted_x: Matrix::::new(predictedX), + P: Matrix::::new(temp_P), + BQ: Matrix::::new(temp_BQ), }, } } @@ -113,8 +106,6 @@ impl<'a, const STATES: usize, const INPUTS: usize> Kalman<'a, STATES, INPUTS> { /// * `temp_BQ` - The temporary vector for B×Q calculation (`num_states` × `num_inputs`). #[allow(non_snake_case)] pub fn new( - num_states: uint_fast8_t, - num_inputs: uint_fast8_t, A: Matrix<'a, STATES, STATES>, x: Matrix<'a, STATES, 1>, B: Matrix<'a, STATES, INPUTS>, @@ -125,110 +116,132 @@ impl<'a, const STATES: usize, const INPUTS: usize> Kalman<'a, STATES, INPUTS> { temp_P: Matrix<'a, STATES, STATES>, temp_BQ: Matrix<'a, STATES, INPUTS>, ) -> Self { - debug_assert_eq!(STATES, num_states.into()); - debug_assert_eq!(INPUTS, num_inputs.into()); debug_assert_eq!( - A.rows, num_states, + A.rows(), + STATES as _, "The state transition matrix A requires {} rows and {} columns (i.e. states × states)", - num_states, num_states + STATES, + STATES ); debug_assert_eq!( - A.cols, num_states, + A.cols(), + STATES as _, "The state transition matrix A requires {} rows and {} columns (i.e. states × states)", - num_states, num_states + STATES, + STATES ); debug_assert_eq!( - P.rows, num_states, + P.rows(), + STATES as _, "The system covariance matrix P requires {} rows and {} columns (i.e. states × states)", - num_states, num_states + STATES, + STATES ); debug_assert_eq!( - P.cols, num_states, + P.cols(), + STATES as _, "The system covariance matrix P requires {} rows and {} columns (i.e. states × states)", - num_states, num_states + STATES, + STATES ); debug_assert_eq!( - x.rows, num_states, + x.rows(), + STATES as _, "The state vector x requires {} rows and 1 column (i.e. states × 1)", - num_states + STATES ); debug_assert_eq!( - x.cols, 1, + x.cols(), + 1, "The state vector x requires {} rows and 1 column (i.e. states × 1)", - num_states + STATES ); debug_assert_eq!( - B.rows, num_states, + B.rows(), + STATES as _, "The input transition matrix B requires {} rows and {} columns (i.e. states × inputs)", - num_states, num_inputs + STATES, + INPUTS ); debug_assert_eq!( - B.cols, num_inputs, + B.cols(), + INPUTS as _, "The input transition matrix B requires {} rows and {} columns (i.e. states × inputs)", - num_states, num_inputs + STATES, + INPUTS ); debug_assert_eq!( - Q.rows, num_inputs, + Q.rows(), + INPUTS as _, "The input covariance matrix Q requires {} rows and {} columns (i.e. inputs × inputs)", - num_inputs, num_inputs + INPUTS, + INPUTS ); debug_assert_eq!( - Q.cols, num_inputs, + Q.cols(), + INPUTS as _, "The input covariance matrix Q requires {} rows and {} columns (i.e. inputs × inputs)", - num_inputs, num_inputs + INPUTS, + INPUTS ); debug_assert_eq!( - u.rows, num_inputs, + u.rows(), + INPUTS as _, "The input vector u requires {} rows and 1 column (i.e. inputs × 1)", - num_inputs + INPUTS ); debug_assert_eq!( - u.cols, 1, + u.cols(), + 1, "The input vector u requires {} rows and 1 column (i.e. inputs × 1)", - num_inputs + INPUTS ); debug_assert_eq!( - predictedX.rows, num_states, + predictedX.rows(), + STATES as _, "The temporary state prediction vector requires {} rows and 1 column (i.e. states × 1)", - num_states + STATES ); debug_assert_eq!( - predictedX.cols, 1, + predictedX.cols(), + 1, "The temporary state prediction vector requires {} rows and 1 column (i.e. states × 1)", - num_states + STATES ); debug_assert_eq!( - temp_P.rows, num_states, + temp_P.rows(), STATES as _, "The temporary system covariance matrix requires {} rows and {} columns (i.e. states × states)", - num_states, num_states + STATES, STATES ); debug_assert_eq!( - temp_P.cols, num_states, + temp_P.cols(), STATES as _, "The temporary system covariance matrix requires {} rows and {} columns (i.e. states × states)", - num_states, num_states + STATES, STATES ); debug_assert_eq!( - temp_BQ.rows, num_states, + temp_BQ.rows(), + STATES as _, "The temporary B×Q matrix requires {} rows and {} columns (i.e. states × inputs)", - num_states, num_inputs + STATES, + INPUTS ); debug_assert_eq!( - temp_BQ.cols, num_inputs, + temp_BQ.cols(), + INPUTS as _, "The temporary B×Q matrix requires {} rows and {} columns (i.e. states × inputs)", - num_states, num_inputs + STATES, + INPUTS ); Self { - num_states, - num_inputs, A, P, x, @@ -243,6 +256,16 @@ impl<'a, const STATES: usize, const INPUTS: usize> Kalman<'a, STATES, INPUTS> { } } + /// Returns the number of states. + pub const fn states(&self) -> usize { + Self::NUM_STATES + } + + /// Returns the number of inputs. + pub const fn inputs(&self) -> usize { + Self::NUM_INPUTS + } + /// Gets a reference to the state vector x. #[inline(always)] pub fn state_vector_ref(&self) -> &Matrix<'_, STATES, 1> { diff --git a/src/matrix.rs b/src/matrix.rs index 009ce7f..636853d 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -13,9 +13,7 @@ use micromath::F32Ext; pub type matrix_data_t = f32; /// A matrix wrapping a data buffer. -pub struct Matrix<'a, const R: usize, const C: usize> { - pub rows: uint_fast8_t, - pub cols: uint_fast8_t, +pub struct Matrix<'a, const ROWS: usize, const COLS: usize> { pub data: &'a mut [matrix_data_t], } @@ -26,29 +24,30 @@ macro_rules! idx { }; } -impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { +impl<'a, const ROWS: usize, const COLS: usize> Matrix<'a, ROWS, COLS> { /// Initializes a matrix structure. /// /// ## Arguments - /// * `mat` - The matrix to initialize - /// * `rows` - The number of rows - /// * `cols` - The number of columns /// * `buffer` - The data buffer (of size `rows` x `cols`). - pub fn new(rows: uint_fast8_t, cols: uint_fast8_t, buffer: &'a mut [matrix_data_t]) -> Self { + pub fn new(buffer: &'a mut [matrix_data_t]) -> Self { debug_assert!( - buffer.len() >= (rows * cols) as _, + buffer.len() >= (ROWS * COLS) as _, "Buffer needs to be large enough to keep at least {} × {} = {} elements", - rows, - cols, - rows * cols + ROWS, + COLS, + ROWS * COLS ); - debug_assert_eq!(R, rows.into()); - debug_assert_eq!(C, cols.into()); - Self { - rows, - cols, - data: buffer, - } + Self { data: buffer } + } + + /// Returns the number of rows of this matrix. + pub const fn rows(&self) -> uint_fast8_t { + ROWS as _ + } + + /// Returns the number of columns of this matrix. + pub const fn cols(&self) -> uint_fast8_t { + COLS as _ } /// Initializes a matrix structure from a pointer to a buffer. @@ -63,41 +62,27 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { /// * `buffer` - The data buffer (of size `rows` x `cols`). #[cfg_attr(docsrs, doc(cfg(feature = "unsafe")))] #[cfg(feature = "unsafe")] - pub unsafe fn new_unchecked( - rows: uint_fast8_t, - cols: uint_fast8_t, - ptr: *mut [matrix_data_t], - ) -> Self { + pub unsafe fn new_unchecked(ptr: *mut [matrix_data_t]) -> Self { let buffer = unsafe { &mut *ptr }; if ptr.is_null() { - debug_assert_eq!(rows, 0, "For null buffers, the row count must be zero"); - debug_assert_eq!(cols, 0, "For null buffers, the column count must be zero"); - return Self { - rows, - cols, - data: buffer, - }; + debug_assert_eq!(ROWS, 0, "For null buffers, the row count must be zero"); + debug_assert_eq!(COLS, 0, "For null buffers, the column count must be zero"); + return Self { data: buffer }; } debug_assert!( - buffer.len() >= (rows * cols) as _, + buffer.len() >= (ROWS * COLS) as _, "Buffer needs to be large enough to keep at least {} × {} = {} elements", - rows, - cols, - rows * cols + ROWS, + COLS, + ROWS * COLS ); - debug_assert_eq!(R, rows.into()); - debug_assert_eq!(C, cols.into()); - Self { - rows, - cols, - data: buffer, - } + Self { data: buffer } } /// Gets the number of elements of this matrix. pub const fn len(&self) -> uint_fast16_t { - self.rows as uint_fast16_t * self.cols as uint_fast16_t + ROWS as uint_fast16_t * COLS as uint_fast16_t } /// Determines if this matrix has zero elements. @@ -130,11 +115,11 @@ impl<'a, const N: usize> Matrix<'a, N, N> { /// 1.0, 0.5, 0.0, /// 0.5, 1.0, 0.0, /// 0.0, 0.0, 1.0]; - /// let mut m = Matrix::<3, 3>::new(3, 3, &mut d); + /// let mut m = Matrix::<3, 3>::new(&mut d); /// /// // data buffer for the inverted matrix /// let mut di = [0.0; 3 * 3]; - /// let mut mi = Matrix::<3, 3>::new(3, 3, &mut di); + /// let mut mi = Matrix::<3, 3>::new(&mut di); /// /// // Decompose matrix to lower triangular. /// m.cholesky_decompose_lower(); @@ -163,9 +148,7 @@ impl<'a, const N: usize> Matrix<'a, N, N> { /// Kudos: https://code.google.com/p/efficient-java-matrix-library #[doc(alias = "matrix_invert_lower")] pub fn invert_l_cholesky(&self, inverse: &mut Self) { - debug_assert_eq!(self.rows, self.cols); - - let n = self.rows; + let n = N; let mat = self.data.as_ref(); // t let inv = inverse.data.as_mut(); // a @@ -211,7 +194,7 @@ impl<'a, const N: usize> Matrix<'a, N, N> { } } -impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { +impl<'a, const ROWS: usize, const COLS: usize> Matrix<'a, ROWS, COLS> { /// Performs a matrix multiplication such that `C = A * B`. This method /// uses an auxiliary buffer for keeping one row of `B` cached. This might /// improve performance on very wide matrices but is generally slower than @@ -230,16 +213,16 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { /// let mut a_buf = [ /// 1.0, 2.0, 3.0, /// 4.0, 5.0, 6.0]; - /// let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); + /// let a = Matrix::<2, 3>::new(&mut a_buf); /// /// let mut b_buf = [ /// 10.0, 11.0, /// 20.0, 21.0, /// 30.0, 31.0]; - /// let b = Matrix::<3, 2>::new(3, 2, &mut b_buf); + /// let b = Matrix::<3, 2>::new(&mut b_buf); /// /// let mut c_buf = [0f32; 2 * 2]; - /// let mut c = Matrix::<2, 2>::new(2, 2, &mut c_buf); + /// let mut c = Matrix::<2, 2>::new(&mut c_buf); /// /// let mut aux = [0f32; 3 * 1]; /// a.mult_buffered(&b, &mut c, &mut aux); @@ -254,32 +237,33 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { #[doc(alias = "matrix_mult_buffered")] pub fn mult_buffered( &self, - b: &Matrix<'_, C, U>, - c: &mut Matrix<'_, R, U>, + b: &Matrix<'_, COLS, U>, + c: &mut Matrix<'_, ROWS, U>, baux: &mut [matrix_data_t], ) { - let bcols = b.cols; - let ccols = c.cols; - let brows = b.rows; - let arows = self.rows; + let arows = self.rows(); + let brows = b.rows(); + let bcols = b.cols(); + let ccols = c.cols(); + let crows = c.rows(); let adata = self.data.as_ref(); let cdata = c.data.as_mut(); // test dimensions of a and b - debug_assert_eq!(self.cols, b.rows); + debug_assert_eq!(COLS, brows as _); // test dimension of c - debug_assert_eq!(self.rows, c.rows); - debug_assert_eq!(b.cols, c.cols); + debug_assert_eq!(ROWS, crows as _); + debug_assert_eq!(bcols, ccols as _); // Test aux dimensions. - debug_assert_eq!(baux.len(), self.cols as _); - debug_assert_eq!(baux.len(), b.rows as _); + debug_assert_eq!(baux.len(), COLS as _); + debug_assert_eq!(baux.len(), brows as _); for j in (0..bcols).rev() { // create a copy of the column in B to avoid cache issues - b.get_column_copy(j, baux); + b.get_column_copy(j as _, baux); let mut index_a: uint_fast16_t = 0; for i in 0..arows { @@ -308,16 +292,16 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { /// let mut a_buf = [ /// 1.0, 2.0, 3.0, /// 4.0, 5.0, 6.0]; - /// let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); + /// let a = Matrix::<2, 3>::new(&mut a_buf); /// /// let mut b_buf = [ /// 10.0, 11.0, /// 20.0, 21.0, /// 30.0, 31.0]; - /// let b = Matrix::<3, 2>::new(3, 2, &mut b_buf); + /// let b = Matrix::<3, 2>::new(&mut b_buf); /// /// let mut c_buf = [0f32; 2 * 2]; - /// let mut c = Matrix::<2, 2>::new(2, 2, &mut c_buf); + /// let mut c = Matrix::<2, 2>::new(&mut c_buf); /// /// a.mult(&b, &mut c); /// @@ -329,29 +313,30 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { /// /// Kudos: https://code.google.com/p/efficient-java-matrix-library #[doc(alias = "matrix_mult")] - pub fn mult(&self, b: &Matrix<'_, C, U>, c: &mut Matrix<'_, R, U>) { - let bcols = b.cols; - let ccols = c.cols; - let brows = b.rows; - let arows = self.rows; + pub fn mult(&self, b: &Matrix<'_, COLS, U>, c: &mut Matrix<'_, ROWS, U>) { + let arows = ROWS; + let bcols = b.cols() as usize; + let brows = b.rows() as usize; + let ccols = c.cols() as usize; + let crows = c.rows() as usize; let adata = self.data.as_ref(); let bdata = b.data.as_ref(); let cdata = c.data.as_mut(); // test dimensions of a and b - debug_assert_eq!(self.cols, b.rows); + debug_assert_eq!(COLS, brows as _); // test dimension of c - debug_assert_eq!(self.rows, c.rows); - debug_assert_eq!(b.cols, c.cols); + debug_assert_eq!(ROWS, crows as _); + debug_assert_eq!(bcols, ccols); for j in (0..bcols).rev() { let mut index_a: uint_fast16_t = 0; for i in 0..arows { let mut total = 0 as matrix_data_t; for k in 0..brows { - total += adata[idx!(index_a)] * bdata[idx!(k * b.cols + j)]; + total += adata[idx!(index_a)] * bdata[idx!(k * bcols + j)]; index_a += 1; } cdata[idx!(i * ccols + j)] = total; @@ -368,20 +353,25 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { /// /// Kudos: https://code.google.com/p/efficient-java-matrix-library #[doc(alias = "matrix_mult_rowvector")] - pub fn mult_rowvector(&self, x: &Matrix<'_, C, 1>, c: &mut Matrix<'_, R, 1>) { - let arows = self.rows; - let acols = self.cols; + pub fn mult_rowvector(&self, x: &Matrix<'_, COLS, 1>, c: &mut Matrix<'_, ROWS, 1>) { + let arows = self.rows(); + let acols = self.cols(); + + let xrows = x.rows(); + + let crows = c.rows(); + let ccols = c.cols(); let adata = self.data.as_ref(); let xdata = x.data.as_ref(); let cdata = c.data.as_mut(); // test dimensions of a and b - debug_assert_eq!(self.cols, x.rows); + debug_assert_eq!(COLS, xrows as _); // test dimension of c - debug_assert_eq!(self.rows, c.rows); - debug_assert_eq!(c.cols, 1); + debug_assert_eq!(ROWS, crows as _); + debug_assert_eq!(ccols, 1); let mut index_a: uint_fast16_t = 0; let mut index_c: uint_fast16_t = 0; @@ -409,20 +399,25 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { /// /// Kudos: https://code.google.com/p/efficient-java-matrix-library #[doc(alias = "matrix_multadd_rowvector")] - pub fn multadd_rowvector(&self, x: &Matrix<'_, C, 1>, c: &mut Matrix<'_, R, 1>) { - let arows = self.rows; - let acols = self.cols; + pub fn multadd_rowvector(&self, x: &Matrix<'_, COLS, 1>, c: &mut Matrix<'_, ROWS, 1>) { + let arows = self.rows(); + let acols = self.cols(); + + let xrows = x.rows(); + + let crows = c.rows(); + let ccols = c.cols(); let adata = self.data.as_ref(); let xdata = x.data.as_ref(); let cdata = c.data.as_mut(); // test dimensions of a and b - debug_assert_eq!(self.cols, x.rows); + debug_assert_eq!(COLS, xrows as _); // test dimension of c - debug_assert_eq!(self.rows, c.rows); - debug_assert_eq!(c.cols, 1); + debug_assert_eq!(ROWS, crows as _); + debug_assert_eq!(ccols, 1); let mut index_a: uint_fast16_t = 0; let mut index_c: uint_fast16_t = 0; @@ -450,22 +445,28 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { /// /// Kudos: https://code.google.com/p/efficient-java-matrix-library #[doc(alias = "matrix_mult_transb")] - pub fn mult_transb(&self, b: &Matrix<'_, U, C>, c: &mut Matrix<'_, R, U>) { - let bcols = b.cols; - let brows = b.rows; - let arows = self.rows; - let acols = self.cols; + pub fn mult_transb( + &self, + b: &Matrix<'_, U, COLS>, + c: &mut Matrix<'_, ROWS, U>, + ) { + let arows = self.rows(); + let acols = self.cols(); + let bcols = b.cols(); + let brows = b.rows(); + let ccols = c.cols(); + let crows = c.rows(); let adata = self.data.as_ref(); let bdata = b.data.as_ref(); let cdata = c.data.as_mut(); // test dimensions of a and b - debug_assert_eq!(self.cols, b.cols); + debug_assert_eq!(COLS, bcols as _); // test dimension of c - debug_assert_eq!(self.rows, c.rows); - debug_assert_eq!(b.rows, c.cols); + debug_assert_eq!(ROWS, crows as _); + debug_assert_eq!(b.rows(), ccols); let mut c_index: uint_fast16_t = 0; let mut a_index_start: uint_fast16_t = 0; @@ -499,22 +500,28 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { /// /// Kudos: https://code.google.com/p/efficient-java-matrix-library #[doc(alias = "matrix_multadd_transb")] - pub fn multadd_transb(&self, b: &Matrix<'_, U, C>, c: &mut Matrix<'_, R, U>) { - let bcols = b.cols; - let brows = b.rows; - let arows = self.rows; - let acols = self.cols; + pub fn multadd_transb( + &self, + b: &Matrix<'_, U, COLS>, + c: &mut Matrix<'_, ROWS, U>, + ) { + let arows = self.rows(); + let acols = self.cols(); + let bcols = b.cols(); + let brows = b.rows(); + let ccols = c.cols(); + let crows = c.rows(); let adata = self.data.as_ref(); let bdata = b.data.as_ref(); let cdata = c.data.as_mut(); // test dimensions of a and b - debug_assert_eq!(self.cols, b.cols); + debug_assert_eq!(COLS, bcols as _); // test dimension of c - debug_assert_eq!(self.rows, c.rows); - debug_assert_eq!(b.rows, c.cols); + debug_assert_eq!(ROWS, crows as _); + debug_assert_eq!(brows, ccols); let mut c_index: uint_fast16_t = 0; let mut a_index_start: uint_fast16_t = 0; @@ -551,25 +558,27 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { #[doc(alias = "matrix_multscale_transb")] pub fn multscale_transb( &self, - b: &Matrix<'_, U, C>, + b: &Matrix<'_, U, COLS>, scale: matrix_data_t, - c: &mut Matrix<'_, R, U>, + c: &mut Matrix<'_, ROWS, U>, ) { - let bcols = b.cols; - let brows = b.rows; - let arows = self.rows; - let acols = self.cols; + let arows = self.rows(); + let acols = self.cols(); + let bcols = b.cols(); + let brows = b.rows(); + let ccols = c.cols(); + let crows = c.rows(); let adata = self.data.as_ref(); let bdata = b.data.as_ref(); let cdata = c.data.as_mut(); // test dimensions of a and b - debug_assert_eq!(self.cols, b.cols); + debug_assert_eq!(COLS, bcols as _); // test dimension of c - debug_assert_eq!(self.rows, c.rows); - debug_assert_eq!(b.rows, c.cols); + debug_assert_eq!(ROWS, crows as _); + debug_assert_eq!(brows, ccols); let mut c_index: uint_fast16_t = 0; let mut a_index_start: uint_fast16_t = 0; @@ -605,7 +614,7 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { #[inline(always)] #[doc(alias = "matrix_get")] pub fn get(&self, row: uint_fast8_t, column: uint_fast8_t) -> matrix_data_t { - self.data[idx!(row * self.cols + column)] + self.data[idx!(row * self.cols() + column)] } /// Sets a matrix element @@ -617,7 +626,7 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { #[inline(always)] #[doc(alias = "matrix_set")] pub fn set(&mut self, row: uint_fast8_t, column: uint_fast8_t, value: matrix_data_t) { - self.data[idx!(row * self.cols + column)] = value; + self.data[idx!(row * self.cols() + column)] = value; } /// Sets matrix elements in a symmetric matrix @@ -642,7 +651,7 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { /// * `row_data` - A pointer to the given matrix row #[doc(alias = "matrix_get_row_pointer")] pub fn get_row_pointer<'b>(&'a self, row: uint_fast8_t, row_data: &'b mut &'a [matrix_data_t]) { - *row_data = &self.data[idx!(row * self.cols)..idx!((row + 1) * self.cols)]; + *row_data = &self.data[idx!(row * self.cols())..idx!((row + 1) * self.cols())]; } /// Gets a copy of a matrix column @@ -654,10 +663,10 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { #[doc(alias = "matrix_get_column_copy")] pub fn get_column_copy(&self, column: uint_fast8_t, col_data: &mut [matrix_data_t]) { // start from the back, so target index is equal to the index of the last row. - let mut target_index: int_fast16_t = (self.rows - 1) as _; + let mut target_index: int_fast16_t = (self.rows() - 1) as _; // also, the source index is the column..th index - let stride: int_fast16_t = self.cols as _; + let stride: int_fast16_t = self.cols() as _; let mut source_index = (target_index as int_fast16_t) * stride + (column as int_fast16_t); let src = self.data.as_ref(); @@ -680,9 +689,9 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { /// * `row_data` - Pointer to an array of the correct length to hold a row of matrix `mat`. #[doc(alias = "matrix_get_row_copy")] pub fn get_row_copy(&self, row: uint_fast8_t, row_data: &mut [matrix_data_t]) { - let mut target_index: uint_fast16_t = (self.cols - 1) as _; + let mut target_index: uint_fast16_t = (self.cols() - 1) as _; let mut source_index: uint_fast16_t = - (row as uint_fast16_t + 1) * (self.cols - 1) as uint_fast16_t; + (row as uint_fast16_t + 1) * (self.cols() - 1) as uint_fast16_t; row_data[idx!(target_index)] = self.data[idx!(source_index)]; while target_index != 0 { @@ -700,10 +709,10 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { #[inline] #[doc(alias = "matrix_copy")] pub fn copy(&self, target: &mut Self) { - debug_assert_eq!(self.rows, target.rows); - debug_assert_eq!(self.cols, target.cols); + debug_assert_eq!(self.rows(), target.rows()); + debug_assert_eq!(self.cols(), target.cols()); - let count: uint_fast16_t = (self.cols as uint_fast16_t) * (self.rows as uint_fast16_t); + let count = self.len(); let adata = self.data.as_ref(); let bdata = target.data.as_mut(); @@ -722,12 +731,12 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { #[inline] #[doc(alias = "matrix_sub")] pub fn sub(&self, b: &Self, c: &mut Self) { - debug_assert_eq!(self.rows, b.rows); - debug_assert_eq!(self.cols, b.cols); - debug_assert_eq!(self.rows, c.rows); - debug_assert_eq!(self.cols, c.cols); + debug_assert_eq!(self.rows(), b.rows()); + debug_assert_eq!(self.cols(), b.cols()); + debug_assert_eq!(self.rows(), c.rows()); + debug_assert_eq!(self.cols(), c.cols()); - let count: uint_fast16_t = (self.cols as uint_fast16_t) * (self.rows as uint_fast16_t); + let count = self.len(); let adata = self.data.as_ref(); let bdata = b.data.as_ref(); @@ -746,10 +755,10 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { #[inline] #[doc(alias = "matrix_sub_inplace_b")] pub fn sub_inplace_a(&mut self, b: &Self) { - debug_assert_eq!(self.rows, b.rows); - debug_assert_eq!(self.cols, b.cols); + debug_assert_eq!(self.rows(), b.rows()); + debug_assert_eq!(self.cols(), b.cols()); - let count: uint_fast16_t = (self.cols as uint_fast16_t) * (self.rows as uint_fast16_t); + let count = self.len(); let adata = self.data.as_mut(); let bdata = b.data.as_ref(); @@ -767,10 +776,10 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { #[inline] #[doc(alias = "matrix_sub_inplace_b")] pub fn sub_inplace_b(&self, b: &mut Self) { - debug_assert_eq!(self.rows, b.rows); - debug_assert_eq!(self.cols, b.cols); + debug_assert_eq!(self.rows(), b.rows()); + debug_assert_eq!(self.cols(), b.cols()); - let count: uint_fast16_t = (self.cols as uint_fast16_t) * (self.rows as uint_fast16_t); + let count = self.len(); let adata = self.data.as_ref(); let bdata = b.data.as_mut(); @@ -788,10 +797,10 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { #[inline] #[doc(alias = "matrix_add_inplace_b")] pub fn add_inplace_a(&mut self, b: &Self) { - debug_assert_eq!(self.rows, b.rows); - debug_assert_eq!(self.cols, b.cols); + debug_assert_eq!(self.rows(), b.rows()); + debug_assert_eq!(self.cols(), b.cols()); - let count: uint_fast16_t = (self.cols as uint_fast16_t) * (self.rows as uint_fast16_t); + let count = self.len(); let adata = self.data.as_mut(); let bdata = b.data.as_ref(); @@ -809,10 +818,10 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { #[inline] #[doc(alias = "matrix_add_inplace_b")] pub fn add_inplace_b(&self, b: &mut Self) { - debug_assert_eq!(self.rows, b.rows); - debug_assert_eq!(self.cols, b.cols); + debug_assert_eq!(self.rows(), b.rows()); + debug_assert_eq!(self.cols(), b.cols()); - let count: uint_fast16_t = (self.cols as uint_fast16_t) * (self.rows as uint_fast16_t); + let count = self.len(); let adata = self.data.as_ref(); let bdata = b.data.as_mut(); @@ -840,7 +849,7 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { /// 0.5, 1.0, 0.0, /// 0.0, 0.0, 1.0]; /// - /// let mut m = Matrix::<3, 3>::new(3, 3, &mut d); + /// let mut m = Matrix::<3, 3>::new(&mut d); /// /// // Decompose matrix to lower triangular. /// m.cholesky_decompose_lower(); @@ -860,13 +869,13 @@ impl<'a, const R: usize, const C: usize> Matrix<'a, R, C> { /// /// Kudos: https://code.google.com/p/efficient-java-matrix-library pub fn cholesky_decompose_lower(&mut self) -> bool { - let n = self.rows; + let n = self.rows(); let t: &mut [matrix_data_t] = self.data; let mut div_el_ii = 0 as matrix_data_t; - debug_assert_eq!(self.rows, self.cols); - debug_assert!(self.rows > 0); + debug_assert_eq!(ROWS, COLS); + debug_assert!(ROWS > 0); for i in 0..n { for j in i..n { @@ -937,11 +946,11 @@ impl<'a, const R: usize, const C: usize> AsMut<[matrix_data_t]> for Matrix<'a, R impl<'a, const R: usize, const C: usize> MatrixBase for Matrix<'a, R, C> { fn rows(&self) -> uint_fast8_t { - self.rows + self.rows() } fn columns(&self) -> uint_fast8_t { - self.cols + self.cols() } fn len(&self) -> uint_fast16_t { @@ -972,11 +981,11 @@ mod tests { 10.0, 11.0, 20.0, 21.0, 30.0, 31.0]; - let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); - let b = Matrix::<3, 2>::new(3, 2, &mut b_buf); + let a = Matrix::<2, 3>::new(&mut a_buf); + let b = Matrix::<3, 2>::new(&mut b_buf); let mut c_buf = [0f32; 2 * 2]; - let mut c = Matrix::<2, 2>::new(2, 2, &mut c_buf); + let mut c = Matrix::<2, 2>::new(&mut c_buf); let mut aux = [0f32; 3 * 1]; a.mult_buffered(&b, &mut c, &mut aux); @@ -996,11 +1005,11 @@ mod tests { 10.0, 11.0, 20.0, 21.0, 30.0, 31.0]; - let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); - let b = Matrix::<3, 2>::new(3, 2, &mut b_buf); + let a = Matrix::<2, 3>::new(&mut a_buf); + let b = Matrix::<3, 2>::new(&mut b_buf); let mut c_buf = [0f32; 2 * 2]; - let mut c = Matrix::<2, 2>::new(2, 2, &mut c_buf); + let mut c = Matrix::<2, 2>::new(&mut c_buf); a.mult(&b, &mut c); assert_f32_near!(c_buf[0], 1. * 10. + 2. * 20. + 3. * 30.); // 140 @@ -1018,11 +1027,11 @@ mod tests { let mut b_buf = [ 10.0, 20.0, 30.0, 11.0, 21.0, 31.0]; - let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); - let b = Matrix::<2, 3>::new(2, 3, &mut b_buf); + let a = Matrix::<2, 3>::new(&mut a_buf); + let b = Matrix::<2, 3>::new(&mut b_buf); let mut c_buf = [0f32; 2 * 2]; - let mut c = Matrix::<2, 2>::new(2, 2, &mut c_buf); + let mut c = Matrix::<2, 2>::new(&mut c_buf); Matrix::mult_transb(&a, &b, &mut c); assert_f32_near!(c_buf[0], 1. * 10. + 2. * 20. + 3. * 30.); // 140 @@ -1038,19 +1047,19 @@ mod tests { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, -9.0]; - let a = Matrix::<3, 3>::new(3, 3, &mut a_buf); + let a = Matrix::<3, 3>::new(&mut a_buf); let mut b_buf = [ -4.0, -1.0, 0.0, 2.0, 3.0, 4.0, 5.0, 9.0, -10.0]; - let b = Matrix::<3, 3>::new(3, 3, &mut b_buf); + let b = Matrix::<3, 3>::new(&mut b_buf); let mut c_buf = [0f32; 3 * 3]; - let mut c = Matrix::<3, 3>::new(3, 3, &mut c_buf); + let mut c = Matrix::<3, 3>::new(&mut c_buf); let mut d_buf = [0f32; 3 * 3]; - let mut d = Matrix::<3, 3>::new(3, 3, &mut d_buf); + let mut d = Matrix::<3, 3>::new(&mut d_buf); // Example P = A*P*A' a.mult(&b, &mut c); // temp = A*P @@ -1076,13 +1085,13 @@ mod tests { let mut b_buf = [ 10.0, 20.0, 30.0, 11.0, 21.0, 31.0]; - let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); - let b = Matrix::<2, 3>::new(2, 3, &mut b_buf); + let a = Matrix::<2, 3>::new(&mut a_buf); + let b = Matrix::<2, 3>::new(&mut b_buf); let mut c_buf = [ 1000., 2000., 3000., 4000.]; - let mut c = Matrix::<2, 2>::new(2, 2, &mut c_buf); + let mut c = Matrix::<2, 2>::new(&mut c_buf); Matrix::multadd_transb(&a, &b, &mut c); assert_f32_near!(c.get(0, 0), 1000. + 1. * 10. + 2. * 20. + 3. * 30.); // 1140 @@ -1100,11 +1109,11 @@ mod tests { let mut b_buf = [ 10.0, 20.0, 30.0, 11.0, 21.0, 31.0]; - let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); - let b = Matrix::<2, 3>::new(2, 3, &mut b_buf); + let a = Matrix::<2, 3>::new(&mut a_buf); + let b = Matrix::<2, 3>::new(&mut b_buf); let mut c_buf = [0f32; 2 * 2]; - let mut c = Matrix::<2, 2>::new(2, 2, &mut c_buf); + let mut c = Matrix::<2, 2>::new(&mut c_buf); Matrix::multscale_transb(&a, &b, 2.0, &mut c); assert_f32_near!(c_buf[0], 2.0 * (1. * 10. + 2. * 20. + 3. * 30.)); // 280 @@ -1123,11 +1132,11 @@ mod tests { 10.0, 20.0, 30.0]; - let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); - let b = Matrix::<3, 1>::new(3, 1, &mut b_buf); + let a = Matrix::<2, 3>::new(&mut a_buf); + let b = Matrix::<3, 1>::new(&mut b_buf); let mut c_buf = [0f32; 2 * 1]; - let mut c = Matrix::<2, 1>::new(2, 1, &mut c_buf); + let mut c = Matrix::<2, 1>::new(&mut c_buf); Matrix::mult_rowvector(&a, &b, &mut c); assert_f32_near!(c_buf[0], 1. * 10. + 2. * 20. + 3. * 30.); // 140 @@ -1144,11 +1153,11 @@ mod tests { 10.0, 20.0, 30.0]; - let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); - let b = Matrix::<3, 1>::new(3, 1, &mut b_buf); + let a = Matrix::<2, 3>::new(&mut a_buf); + let b = Matrix::<3, 1>::new(&mut b_buf); let mut c_buf = [1000., 2000.]; - let mut c = Matrix::<2, 1>::new(2, 1, &mut c_buf); + let mut c = Matrix::<2, 1>::new(&mut c_buf); Matrix::multadd_rowvector(&a, &b, &mut c); assert_f32_near!(c.get(0, 0), 1000. + 1. * 10. + 2. * 20. + 3. * 30.); // 1140 @@ -1161,7 +1170,7 @@ mod tests { let mut a_buf = [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; - let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); + let a = Matrix::<2, 3>::new(&mut a_buf); let mut a_out = [0.0; 3].as_slice(); a.get_row_pointer(0, &mut a_out); @@ -1180,11 +1189,11 @@ mod tests { let mut b_buf = [ 10.0, 20.0, 30.0, 11.0, 21.0, 31.0]; - let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); - let b = Matrix::<2, 3>::new(2, 3, &mut b_buf); + let a = Matrix::<2, 3>::new(&mut a_buf); + let b = Matrix::<2, 3>::new(&mut b_buf); let mut c_buf = [0f32; 2 * 3]; - let mut c = Matrix::<2, 3>::new(2, 3, &mut c_buf); + let mut c = Matrix::<2, 3>::new(&mut c_buf); Matrix::sub(&a, &b, &mut c); assert_eq!(c_buf, [ @@ -1201,8 +1210,8 @@ mod tests { let mut b_buf = [ 10.0, 20.0, 30.0, 11.0, 21.0, 31.0]; - let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); - let mut b = Matrix::<2, 3>::new(2, 3, &mut b_buf); + let a = Matrix::<2, 3>::new(&mut a_buf); + let mut b = Matrix::<2, 3>::new(&mut b_buf); Matrix::sub_inplace_b(&a, &mut b); assert_eq!(b_buf, [ @@ -1219,8 +1228,8 @@ mod tests { let mut b_buf = [ 10.0, 20.0, 30.0, 11.0, 21.0, 31.0]; - let a = Matrix::<2, 3>::new(2, 3, &mut a_buf); - let mut b = Matrix::<2, 3>::new(2, 3, &mut b_buf); + let a = Matrix::<2, 3>::new(&mut a_buf); + let mut b = Matrix::<2, 3>::new(&mut b_buf); Matrix::add_inplace_b(&a, &mut b); assert_eq!(b_buf, [ @@ -1238,7 +1247,7 @@ mod tests { 0.5, 1.0, 0.0, 0.0, 0.0, 1.0]; - let mut m = Matrix::<3, 3>::new(3, 3, &mut d); + let mut m = Matrix::<3, 3>::new(&mut d); // Decompose matrix to lower triangular. m.cholesky_decompose_lower(); @@ -1275,11 +1284,11 @@ mod tests { 1.0, 0.5, 0.0, 0.5, 1.0, 0.0, 0.0, 0.0, 1.0]; - let mut m = Matrix::<3, 3>::new(3, 3, &mut d); + let mut m = Matrix::<3, 3>::new(&mut d); // data buffer for the inverted matrix let mut di = [0.0; 3 * 3]; - let mut mi = Matrix::<3, 3>::new(3, 3, &mut di); + let mut mi = Matrix::<3, 3>::new(&mut di); // Decompose matrix to lower triangular. m.cholesky_decompose_lower(); diff --git a/src/measurement.rs b/src/measurement.rs index bb3bb86..27cda17 100644 --- a/src/measurement.rs +++ b/src/measurement.rs @@ -4,10 +4,6 @@ use stdint::uint_fast8_t; /// Kalman Filter measurement structure. #[allow(non_snake_case, unused)] pub struct Measurement<'a, const STATES: usize, const MEASUREMENTS: usize> { - /// The number of states. - pub num_states: uint_fast8_t, - /// The number of measurements. - pub num_measurements: uint_fast8_t, /// Measurement vector. pub(crate) z: Matrix<'a, MEASUREMENTS, 1>, /// Measurement transformation matrix. @@ -76,8 +72,6 @@ impl<'a, const STATES: usize, const MEASUREMENTS: usize> Measurement<'a, STATES, #[allow(non_snake_case)] #[doc(alias = "kalman_measurement_initialize")] pub fn new_direct( - num_states: uint_fast8_t, - num_measurements: uint_fast8_t, H: &'a mut [matrix_data_t], z: &'a mut [matrix_data_t], R: &'a mut [matrix_data_t], @@ -89,22 +83,18 @@ impl<'a, const STATES: usize, const MEASUREMENTS: usize> Measurement<'a, STATES, temp_PHt: &'a mut [matrix_data_t], temp_KHP: &'a mut [matrix_data_t], ) -> Self { - debug_assert_eq!(STATES, num_states.into()); - debug_assert_eq!(MEASUREMENTS, num_measurements.into()); Self { - num_states, - num_measurements, - H: Matrix::new(num_measurements, num_states, H), - R: Matrix::new(num_measurements, num_measurements, R), - z: Matrix::new(num_measurements, 1, z), - K: Matrix::new(num_states, num_measurements, K), - S: Matrix::new(num_measurements, num_measurements, S), - y: Matrix::new(num_measurements, 1, y), + H: Matrix::::new(H), + R: Matrix::::new(R), + z: Matrix::::new(z), + K: Matrix::::new(K), + S: Matrix::::new(S), + y: Matrix::::new(y), temporary: MeasurementTemporary { - S_inv: Matrix::new(num_measurements, num_measurements, S_inv), - HP: Matrix::new(num_measurements, num_states, temp_HP), - PHt: Matrix::new(num_states, num_measurements, temp_PHt), - KHP: Matrix::new(num_states, num_states, temp_KHP), + S_inv: Matrix::::new(S_inv), + HP: Matrix::::new(temp_HP), + PHt: Matrix::::new(temp_PHt), + KHP: Matrix::::new(temp_KHP), }, } } @@ -143,118 +133,124 @@ impl<'a, const STATES: usize, const MEASUREMENTS: usize> Measurement<'a, STATES, debug_assert_eq!(STATES, num_states.into()); debug_assert_eq!(MEASUREMENTS, num_measurements.into()); debug_assert_eq!( - H.rows, num_measurements, + H.rows(), num_measurements, "The measurement transformation matrix H requires {} rows and {} columns (i.e. measurements × states)", num_measurements, num_states ); debug_assert_eq!( - H.cols, num_states, + H.cols(), num_states, "The measurement transformation matrix H requires {} rows and {} columns (i.e. measurements × states)", num_measurements, num_states ); debug_assert_eq!( - z.rows, num_measurements, + z.rows(), + num_measurements, "The measurement vector z requires {} rows and 1 column (i.e. measurements × 1)", num_measurements ); debug_assert_eq!( - z.cols, 1, + z.cols(), + 1, "The measurement vector z requires {} rows and 1 column (i.e. measurements × 1)", num_measurements ); debug_assert_eq!( - R.rows, num_measurements, + R.rows(), num_measurements, "The process noise / measurement uncertainty matrix R requires {} rows and {} columns (i.e. measurements × measurements)", num_measurements, num_measurements ); debug_assert_eq!( - R.cols, num_measurements, + R.cols(), num_measurements, "The process noise / measurement uncertainty matrix R requires {} rows and {} columns (i.e. measurements × measurements)", num_measurements, num_measurements ); debug_assert_eq!( - y.rows, num_measurements, + y.rows(), + num_measurements, "The innovation vector y requires {} rows and 1 column (i.e. measurements × 1)", num_measurements ); debug_assert_eq!( - y.cols, 1, + y.cols(), + 1, "The innovation vector y requires {} rows and 1 column (i.e. measurements × 1)", num_measurements ); debug_assert_eq!( - S.rows, num_measurements, + S.rows(), num_measurements, "The residual covariance matrix S requires {} rows and {} columns (i.e. measurements × measurements)", num_measurements, num_measurements ); debug_assert_eq!( - S.cols, num_measurements, + S.cols(), num_measurements, "The residual covariance S requires {} rows and {} columns (i.e. measurements × measurements)", num_measurements, num_measurements ); debug_assert_eq!( - K.rows, num_states, + K.rows(), + num_states, "The Kalman gain matrix S requires {} rows and {} columns (i.e. states × measurements)", - num_states, num_measurements + num_states, + num_measurements ); debug_assert_eq!( - K.cols, num_measurements, + K.cols(), + num_measurements, "The Kalman gain matrix K requires {} rows and {} columns (i.e. states × measurements)", - num_states, num_measurements + num_states, + num_measurements ); debug_assert_eq!( - S_inv.rows, num_measurements, + S_inv.rows(), num_measurements, "The temporary S-inverted matrix requires {} rows and {} columns (i.e. measurements × measurements)", num_measurements, num_measurements ); debug_assert_eq!( - S_inv.cols, num_measurements, + S_inv.cols(), num_measurements, "The temporary S-inverted matrix requires {} rows and {} columns (i.e. measurements × measurements)", num_measurements, num_measurements ); debug_assert_eq!( - temp_HP.rows, num_measurements, + temp_HP.rows(), num_measurements, "The temporary H×P calculation matrix requires {} rows and {} columns (i.e. measurements × measurements)", num_measurements, num_states ); debug_assert_eq!( - temp_HP.cols, num_states, + temp_HP.cols(), num_states, "The temporary H×P calculation matrix requires {} rows and {} columns (i.e. measurements × measurements)", num_measurements, num_states ); debug_assert_eq!( - temp_PHt.rows, num_states, + temp_PHt.rows(), num_states, "The temporary P×H' calculation matrix requires {} rows and {} columns (i.e. states × measurements)", num_states, num_measurements ); debug_assert_eq!( - temp_PHt.cols, num_measurements, + temp_PHt.cols(), num_measurements, "The temporary P×H' calculation matrix requires {} rows and {} columns (i.e. states × measurements)", num_states, num_measurements ); debug_assert_eq!( - temp_KHP.rows, num_states, + temp_KHP.rows(), num_states, "The temporary K×H×P calculation matrix requires {} rows and {} columns (i.e. states × states)", num_states, num_states ); debug_assert_eq!( - temp_KHP.cols, num_states, + temp_KHP.cols(), num_states, "The temporary K×H×P calculation matrix requires {} rows and {} columns (i.e. states × states)", num_states, num_states ); Self { - num_states, - num_measurements, H, R, z, @@ -270,6 +266,16 @@ impl<'a, const STATES: usize, const MEASUREMENTS: usize> Measurement<'a, STATES, } } + /// Returns then number of measurements. + pub const fn measurements() -> uint_fast8_t { + MEASUREMENTS as _ + } + + /// Returns then number of states. + pub const fn states() -> uint_fast8_t { + STATES as _ + } + /// Gets a reference to the measurement vector z. #[inline(always)] pub fn measurement_vector_ref(&self) -> &Matrix<'_, MEASUREMENTS, 1> {