From 14bee2493c1b096a1197dbfb72897b2b0bcc544a Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Wed, 2 Oct 2024 11:50:33 +0200 Subject: [PATCH] fix loading of multi-channel sound --- stimuli/audio/_base.py | 3 +-- stimuli/audio/am.py | 1 + stimuli/audio/noise.py | 1 + stimuli/audio/tests/test_base.py | 35 ++++++++++++++++++++++++++++++- stimuli/audio/tests/test_sound.py | 4 ++-- stimuli/audio/tone.py | 1 + 6 files changed, 40 insertions(+), 5 deletions(-) diff --git a/stimuli/audio/_base.py b/stimuli/audio/_base.py index 299c36a..c89b8c2 100644 --- a/stimuli/audio/_base.py +++ b/stimuli/audio/_base.py @@ -86,10 +86,9 @@ def _set_times(self) -> None: @abstractmethod def _set_signal(self, signal: NDArray) -> None: """Set the signal array.""" - signal = np.vstack([signal] * self._n_channels).T if self._window is not None: assert self._window.size == signal.shape[0] # sanity-check - signal = np.multiply(signal, self._window[:, np.newaxis]) + signal = signal * self._window[:, np.newaxis] assert self._volume.ndim == 1 # sanity-check assert self._volume.size == self._n_channels # sanity-check signal = np.ascontiguousarray(signal * self._volume / 100, dtype=np.float32) diff --git a/stimuli/audio/am.py b/stimuli/audio/am.py index c17f63d..b2e4cca 100644 --- a/stimuli/audio/am.py +++ b/stimuli/audio/am.py @@ -105,6 +105,7 @@ def _set_signal(self) -> None: 2 * np.pi * self._frequency_carrier * self._times ) signal /= np.max(np.abs(signal)) # normalize + signal = np.vstack([signal] * self._n_channels).T super()._set_signal(signal) @property diff --git a/stimuli/audio/noise.py b/stimuli/audio/noise.py index 2edfd75..a6c083f 100644 --- a/stimuli/audio/noise.py +++ b/stimuli/audio/noise.py @@ -85,6 +85,7 @@ def _set_signal(self) -> None: S /= np.sqrt(np.mean(S**2)) signal = np.fft.irfft(dft * S) signal /= np.max(np.abs(signal)) # normalize + signal = np.vstack([signal] * self._n_channels).T super()._set_signal(signal) # make sure we have the correct times as the rFFT and irFFT could get us off if self._times.size != self._signal.size: diff --git a/stimuli/audio/tests/test_base.py b/stimuli/audio/tests/test_base.py index 23f21c2..c633449 100644 --- a/stimuli/audio/tests/test_base.py +++ b/stimuli/audio/tests/test_base.py @@ -2,9 +2,11 @@ import numpy as np import pytest +import sounddevice as sd from numpy.testing import assert_allclose +from scipy.io import wavfile -from stimuli.audio import Noise, SoundAM, Tone +from stimuli.audio import Noise, Sound, SoundAM, Tone from stimuli.audio._base import _check_duration, _ensure_volume, _ensure_window @@ -127,3 +129,34 @@ def test_ensure_window(): window = _ensure_window(None, 101) assert window is None + + +@pytest.mark.parametrize( + ("sound"), + [ + (Tone, dict(frequency=440)), + (Noise, dict(color="white")), + ( + SoundAM, + dict(frequency_carrier=1000, frequency_modulation=40, method="dsbsc"), + ), + ], +) +def test_invalid_n_channels(sound): + """Test invalid number of channels.""" + with pytest.raises(ValueError, match="The number of channels must be"): + sound = sound[0](volume=10, duration=1, n_channels=0, **sound[1]) + + +def test_window(tmp_path): + """Test applying a window the the sound.""" + sfreq = int(sd.query_devices()[sd.default.device["output"]]["default_samplerate"]) + data = np.ones((sfreq, 2)) # stereo + fname = tmp_path / "test.wav" + wavfile.write(fname, sfreq, data) + sound = Sound(fname) + window = np.zeros(sound.times.size, dtype=np.float32) + window[::10] = 1 + sound.window = window + assert_allclose(sound.signal[::10, :], 1) + assert_allclose(sound.signal, data * window[:, np.newaxis]) diff --git a/stimuli/audio/tests/test_sound.py b/stimuli/audio/tests/test_sound.py index 901cf49..9d9b6e1 100644 --- a/stimuli/audio/tests/test_sound.py +++ b/stimuli/audio/tests/test_sound.py @@ -9,13 +9,13 @@ def test_sound_io_mono(tmp_path): sound = Noise("pink", volume=100, duration=0.5) sound.save(tmp_path / "test.wav") sound_loaded = Sound(tmp_path / "test.wav") - assert_allclose(sound.signal, sound_loaded.signal) + assert_allclose(sound.signal.squeeze(), sound_loaded.signal) sound = Noise("pink", volume=50, duration=0.5) sound.save(tmp_path / "test.wav", overwrite=True) sound_loaded = Sound(tmp_path / "test.wav") sound_loaded.volume = 50 - assert_allclose(sound.signal, sound_loaded.signal) + assert_allclose(sound.signal.squeeze(), sound_loaded.signal) # test representation assert str(tmp_path / "test.wav") in repr(sound_loaded) # test duration setter diff --git a/stimuli/audio/tone.py b/stimuli/audio/tone.py index 0355482..dd5f99f 100644 --- a/stimuli/audio/tone.py +++ b/stimuli/audio/tone.py @@ -68,6 +68,7 @@ def __repr__(self) -> str: def _set_signal(self) -> None: signal = np.sin(2 * np.pi * self._frequency * self._times, dtype=np.float32) signal /= np.max(np.abs(signal)) # normalize + signal = np.vstack([signal] * self._n_channels).T super()._set_signal(signal) @property