diff --git a/src/input.rs b/src/input.rs index e232c13..5afb09d 100644 --- a/src/input.rs +++ b/src/input.rs @@ -6,6 +6,7 @@ use midir::ConnectErrorKind; // XXX: do we expose this? pub use midir::{Ignore, MidiInputPort}; use std::error::Error; use std::fmt::Display; +use std::future::Future; use MidiInputError::{ConnectionError, PortRefreshError}; pub struct MidiInputPlugin; @@ -171,7 +172,13 @@ fn setup(mut commands: Commands, settings: Res) { let thread_pool = IoTaskPool::get(); thread_pool - .spawn(midi_input(m_receiver, r_sender, settings.clone())) + .spawn(MidiInputTask { + receiver: m_receiver, + sender: r_sender, + settings: settings.clone(), + input: None, + connection: None, + }) .detach(); commands.insert_resource(MidiInput { @@ -195,74 +202,44 @@ enum Reply { Midi(MidiData), } -async fn midi_input( +struct MidiInputTask { receiver: Receiver, sender: Sender, settings: MidiInputSettings, -) -> Result<(), crossbeam_channel::SendError> { - use Message::{ConnectToPort, DisconnectFromPort, RefreshPorts}; - - let input = midir::MidiInput::new(settings.client_name).unwrap(); - sender.send(get_available_ports(&input))?; // Invariant: exactly one of `input` or `connection` is Some - let mut input: Option = Some(input); - let mut connection: Option<(midir::MidiInputConnection<()>, MidiInputPort)> = None; + input: Option, + connection: Option<(midir::MidiInputConnection<()>, MidiInputPort)>, +} - while let Ok(msg) = receiver.recv() { - match msg { - ConnectToPort(port) => { - let was_connected = input.is_none(); - let s = sender.clone(); - let i = input.unwrap_or_else(|| connection.unwrap().0.close().0); - let conn = i.connect( - &port, - settings.port_name, - move |stamp, message, _| { - let _ = s.send(Reply::Midi(MidiData { - stamp, - message: [message[0], message[1], message[2]].into(), - })); - }, - (), - ); - match conn { - Ok(conn) => { - sender.send(Reply::Connected)?; - connection = Some((conn, port)); - input = None; - } - Err(conn_err) => { - sender.send(Reply::Error(ConnectionError(conn_err.kind())))?; - if was_connected { - sender.send(Reply::Disconnected)?; - } - connection = None; - input = Some(conn_err.into_inner()); - } - } - } - DisconnectFromPort => { - if let Some((conn, _)) = connection { - input = Some(conn.close().0); - connection = None; - sender.send(Reply::Disconnected)?; - } - } - RefreshPorts => match &input { - Some(i) => { - sender.send(get_available_ports(i))?; - } - None => { - let (conn, port) = connection.unwrap(); - let i = conn.close().0; +impl Future for MidiInputTask { + type Output = (); - sender.send(get_available_ports(&i))?; + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + if self.input.is_none() && self.connection.is_none() { + self.input = midir::MidiInput::new(self.settings.client_name).ok(); + self.sender + .send(get_available_ports(self.input.as_ref().unwrap())) + .unwrap(); + } + + if let Ok(msg) = self.receiver.recv() { + use Message::{ConnectToPort, DisconnectFromPort, RefreshPorts}; - let s = sender.clone(); + match msg { + ConnectToPort(port) => { + let was_connected = self.input.is_none(); + let s = self.sender.clone(); + let i = self + .input + .take() + .unwrap_or_else(|| self.connection.take().unwrap().0.close().0); let conn = i.connect( &port, - settings.port_name, + self.settings.port_name, move |stamp, message, _| { let _ = s.send(Reply::Midi(MidiData { stamp, @@ -273,21 +250,72 @@ async fn midi_input( ); match conn { Ok(conn) => { - connection = Some((conn, port)); - input = None; + self.sender.send(Reply::Connected).unwrap(); + self.connection = Some((conn, port)); + self.input = None; } Err(conn_err) => { - sender.send(Reply::Error(ConnectionError(conn_err.kind())))?; - sender.send(Reply::Disconnected)?; - connection = None; - input = Some(conn_err.into_inner()); + self.sender + .send(Reply::Error(ConnectionError(conn_err.kind()))) + .unwrap(); + if was_connected { + self.sender.send(Reply::Disconnected).unwrap(); + } + self.connection = None; + self.input = Some(conn_err.into_inner()); } } } - }, + DisconnectFromPort => { + if let Some((conn, _)) = self.connection.take() { + self.input = Some(conn.close().0); + self.connection = None; + self.sender.send(Reply::Disconnected).unwrap(); + } + } + RefreshPorts => match &self.input { + Some(i) => { + self.sender.send(get_available_ports(i)).unwrap(); + } + None => { + let (conn, port) = self.connection.take().unwrap(); + let i = conn.close().0; + + self.sender.send(get_available_ports(&i)).unwrap(); + + let s = self.sender.clone(); + let conn = i.connect( + &port, + self.settings.port_name, + move |stamp, message, _| { + let _ = s.send(Reply::Midi(MidiData { + stamp, + message: [message[0], message[1], message[2]].into(), + })); + }, + (), + ); + match conn { + Ok(conn) => { + self.connection = Some((conn, port)); + self.input = None; + } + Err(conn_err) => { + self.sender + .send(Reply::Error(ConnectionError(conn_err.kind()))) + .unwrap(); + self.sender.send(Reply::Disconnected).unwrap(); + self.connection = None; + self.input = Some(conn_err.into_inner()); + } + } + } + }, + } } + cx.waker().wake_by_ref(); + std::task::Poll::Pending } - Ok(()) } // Helper for above. diff --git a/src/output.rs b/src/output.rs index c164edb..7f0774f 100644 --- a/src/output.rs +++ b/src/output.rs @@ -3,8 +3,8 @@ use bevy::{prelude::*, tasks::IoTaskPool}; use crossbeam_channel::{Receiver, Sender}; use midir::ConnectErrorKind; pub use midir::MidiOutputPort; -use std::error::Error; use std::fmt::Display; +use std::{error::Error, future::Future}; use MidiOutputError::{ConnectionError, PortRefreshError, SendDisconnectedError, SendError}; pub struct MidiOutputPlugin; @@ -136,7 +136,13 @@ fn setup(mut commands: Commands, settings: Res) { let thread_pool = IoTaskPool::get(); thread_pool - .spawn(midi_output(m_receiver, r_sender, settings.port_name)) + .spawn(MidiOutputTask { + receiver: m_receiver, + sender: r_sender, + settings: settings.clone(), + output: None, + connection: None, + }) .detach(); commands.insert_resource(MidiOutput { @@ -184,84 +190,108 @@ enum Reply { Disconnected, } -async fn midi_output( +struct MidiOutputTask { receiver: Receiver, sender: Sender, - name: &str, -) -> Result<(), crossbeam_channel::SendError> { - use Message::{ConnectToPort, DisconnectFromPort, Midi, RefreshPorts}; - - let output = midir::MidiOutput::new(name).unwrap(); - sender.send(get_available_ports(&output))?; + settings: MidiOutputSettings, // Invariant: exactly one of `output` or `connection` is Some - let mut output: Option = Some(output); - let mut connection: Option<(midir::MidiOutputConnection, MidiOutputPort)> = None; + output: Option, + connection: Option<(midir::MidiOutputConnection, MidiOutputPort)>, +} - while let Ok(msg) = receiver.recv() { - match msg { - ConnectToPort(port) => { - let was_connected = output.is_none(); - let out = output.unwrap_or_else(|| connection.unwrap().0.close()); - match out.connect(&port, name) { - Ok(conn) => { - connection = Some((conn, port)); - output = None; - sender.send(Reply::Connected)?; - } - Err(conn_err) => { - sender.send(Reply::Error(ConnectionError(conn_err.kind())))?; - if was_connected { - sender.send(Reply::Disconnected)?; - } - connection = None; - output = Some(conn_err.into_inner()); - } - } - } - DisconnectFromPort => { - if let Some((conn, _)) = connection { - output = Some(conn.close()); - connection = None; - sender.send(Reply::Disconnected)?; - } - } - RefreshPorts => match &output { - Some(out) => { - sender.send(get_available_ports(out))?; - } - None => { - let (conn, port) = connection.unwrap(); - let out = conn.close(); +impl Future for MidiOutputTask { + type Output = (); - sender.send(get_available_ports(&out))?; + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + if self.output.is_none() && self.connection.is_none() { + self.output = midir::MidiOutput::new(self.settings.port_name).ok(); + self.sender + .send(get_available_ports(self.output.as_ref().unwrap())) + .unwrap(); + } - match out.connect(&port, name) { + if let Ok(msg) = self.receiver.recv() { + use Message::{ConnectToPort, DisconnectFromPort, Midi, RefreshPorts}; + + match msg { + ConnectToPort(port) => { + let was_connected = self.output.is_none(); + let out = self + .output + .take() + .unwrap_or_else(|| self.connection.take().unwrap().0.close()); + match out.connect(&port, self.settings.port_name) { Ok(conn) => { - connection = Some((conn, port)); - output = None; + self.connection = Some((conn, port)); + self.output = None; + self.sender.send(Reply::Connected).unwrap(); } Err(conn_err) => { - sender.send(Reply::Error(ConnectionError(conn_err.kind())))?; - sender.send(Reply::Disconnected)?; - connection = None; - output = Some(conn_err.into_inner()); + self.sender + .send(Reply::Error(ConnectionError(conn_err.kind()))) + .unwrap(); + if was_connected { + self.sender.send(Reply::Disconnected).unwrap(); + } + self.connection = None; + self.output = Some(conn_err.into_inner()); } } } - }, - Midi(message) => { - if let Some((conn, _)) = &mut connection { - if let Err(e) = conn.send(&message.msg) { - sender.send(Reply::Error(SendError(e)))?; + DisconnectFromPort => { + if let Some((conn, _)) = self.connection.take() { + self.output = Some(conn.close()); + self.connection = None; + self.sender.send(Reply::Disconnected).unwrap(); + } + } + RefreshPorts => match &self.output { + Some(out) => { + self.sender.send(get_available_ports(out)).unwrap(); + } + None => { + let (conn, port) = self.connection.take().unwrap(); + let out = conn.close(); + + self.sender.send(get_available_ports(&out)).unwrap(); + + match out.connect(&port, self.settings.port_name) { + Ok(conn) => { + self.connection = Some((conn, port)); + self.output = None; + } + Err(conn_err) => { + self.sender + .send(Reply::Error(ConnectionError(conn_err.kind()))) + .unwrap(); + self.sender.send(Reply::Disconnected).unwrap(); + self.connection = None; + self.output = Some(conn_err.into_inner()); + } + } + } + }, + Midi(message) => { + if let Some((conn, _)) = &mut self.connection { + if let Err(e) = conn.send(&message.msg) { + self.sender.send(Reply::Error(SendError(e))).unwrap(); + } + } else { + self.sender + .send(Reply::Error(SendDisconnectedError(message))) + .unwrap(); } - } else { - sender.send(Reply::Error(SendDisconnectedError(message)))?; } } } + + cx.waker().wake_by_ref(); + std::task::Poll::Pending } - Ok(()) } // Helper for above.