Skip to content

Commit

Permalink
fix loading of multi-channel sound
Browse files Browse the repository at this point in the history
  • Loading branch information
mscheltienne committed Oct 2, 2024
1 parent 4753b98 commit 14bee24
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 5 deletions.
3 changes: 1 addition & 2 deletions stimuli/audio/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,9 @@ def _set_times(self) -> None:
@abstractmethod
def _set_signal(self, signal: NDArray) -> None:
"""Set the signal array."""
signal = np.vstack([signal] * self._n_channels).T
if self._window is not None:
assert self._window.size == signal.shape[0] # sanity-check
signal = np.multiply(signal, self._window[:, np.newaxis])
signal = signal * self._window[:, np.newaxis]
assert self._volume.ndim == 1 # sanity-check
assert self._volume.size == self._n_channels # sanity-check
signal = np.ascontiguousarray(signal * self._volume / 100, dtype=np.float32)
Expand Down
1 change: 1 addition & 0 deletions stimuli/audio/am.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def _set_signal(self) -> None:
2 * np.pi * self._frequency_carrier * self._times
)
signal /= np.max(np.abs(signal)) # normalize
signal = np.vstack([signal] * self._n_channels).T
super()._set_signal(signal)

@property
Expand Down
1 change: 1 addition & 0 deletions stimuli/audio/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def _set_signal(self) -> None:
S /= np.sqrt(np.mean(S**2))
signal = np.fft.irfft(dft * S)
signal /= np.max(np.abs(signal)) # normalize
signal = np.vstack([signal] * self._n_channels).T
super()._set_signal(signal)
# make sure we have the correct times as the rFFT and irFFT could get us off
if self._times.size != self._signal.size:
Expand Down
35 changes: 34 additions & 1 deletion stimuli/audio/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import numpy as np
import pytest
import sounddevice as sd
from numpy.testing import assert_allclose
from scipy.io import wavfile

from stimuli.audio import Noise, SoundAM, Tone
from stimuli.audio import Noise, Sound, SoundAM, Tone
from stimuli.audio._base import _check_duration, _ensure_volume, _ensure_window


Expand Down Expand Up @@ -127,3 +129,34 @@ def test_ensure_window():

window = _ensure_window(None, 101)
assert window is None


@pytest.mark.parametrize(
("sound"),
[
(Tone, dict(frequency=440)),
(Noise, dict(color="white")),
(
SoundAM,
dict(frequency_carrier=1000, frequency_modulation=40, method="dsbsc"),
),
],
)
def test_invalid_n_channels(sound):
"""Test invalid number of channels."""
with pytest.raises(ValueError, match="The number of channels must be"):
sound = sound[0](volume=10, duration=1, n_channels=0, **sound[1])


def test_window(tmp_path):
"""Test applying a window the the sound."""
sfreq = int(sd.query_devices()[sd.default.device["output"]]["default_samplerate"])
data = np.ones((sfreq, 2)) # stereo
fname = tmp_path / "test.wav"
wavfile.write(fname, sfreq, data)
sound = Sound(fname)
window = np.zeros(sound.times.size, dtype=np.float32)
window[::10] = 1
sound.window = window
assert_allclose(sound.signal[::10, :], 1)
assert_allclose(sound.signal, data * window[:, np.newaxis])
4 changes: 2 additions & 2 deletions stimuli/audio/tests/test_sound.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ def test_sound_io_mono(tmp_path):
sound = Noise("pink", volume=100, duration=0.5)
sound.save(tmp_path / "test.wav")
sound_loaded = Sound(tmp_path / "test.wav")
assert_allclose(sound.signal, sound_loaded.signal)
assert_allclose(sound.signal.squeeze(), sound_loaded.signal)

sound = Noise("pink", volume=50, duration=0.5)
sound.save(tmp_path / "test.wav", overwrite=True)
sound_loaded = Sound(tmp_path / "test.wav")
sound_loaded.volume = 50
assert_allclose(sound.signal, sound_loaded.signal)
assert_allclose(sound.signal.squeeze(), sound_loaded.signal)
# test representation
assert str(tmp_path / "test.wav") in repr(sound_loaded)
# test duration setter
Expand Down
1 change: 1 addition & 0 deletions stimuli/audio/tone.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __repr__(self) -> str:
def _set_signal(self) -> None:
signal = np.sin(2 * np.pi * self._frequency * self._times, dtype=np.float32)
signal /= np.max(np.abs(signal)) # normalize
signal = np.vstack([signal] * self._n_channels).T
super()._set_signal(signal)

@property
Expand Down

0 comments on commit 14bee24

Please sign in to comment.