Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

std: Fix partial writes in LineWriter #38062

Merged
merged 1 commit into from
Dec 24, 2016
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 86 additions & 13 deletions src/libstd/io/buffered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@ impl<W> fmt::Display for IntoInnerError<W> {
#[stable(feature = "rust1", since = "1.0.0")]
pub struct LineWriter<W: Write> {
inner: BufWriter<W>,
need_flush: bool,
}

impl<W: Write> LineWriter<W> {
Expand Down Expand Up @@ -692,7 +693,10 @@ impl<W: Write> LineWriter<W> {
/// ```
#[stable(feature = "rust1", since = "1.0.0")]
pub fn with_capacity(cap: usize, inner: W) -> LineWriter<W> {
LineWriter { inner: BufWriter::with_capacity(cap, inner) }
LineWriter {
inner: BufWriter::with_capacity(cap, inner),
need_flush: false,
}
}

/// Gets a reference to the underlying writer.
Expand Down Expand Up @@ -759,28 +763,57 @@ impl<W: Write> LineWriter<W> {
#[stable(feature = "rust1", since = "1.0.0")]
pub fn into_inner(self) -> Result<W, IntoInnerError<LineWriter<W>>> {
self.inner.into_inner().map_err(|IntoInnerError(buf, e)| {
IntoInnerError(LineWriter { inner: buf }, e)
IntoInnerError(LineWriter {
inner: buf,
need_flush: false,
}, e)
})
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl<W: Write> Write for LineWriter<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match memchr::memrchr(b'\n', buf) {
Some(i) => {
let n = self.inner.write(&buf[..i + 1])?;
if n != i + 1 || self.inner.flush().is_err() {
// Do not return errors on partial writes.
return Ok(n);
}
self.inner.write(&buf[i + 1..]).map(|i| n + i)
}
None => self.inner.write(buf),
if self.need_flush {
self.flush()?;
}

// Find the last newline character in the buffer provided. If found then
// we're going to write all the data up to that point and then flush,
// otherewise we just write the whole block to the underlying writer.
let i = match memchr::memrchr(b'\n', buf) {
Some(i) => i,
None => return self.inner.write(buf),
};


// Ok, we're going to write a partial amount of the data given first
// followed by flushing the newline. After we've successfully written
// some data then we *must* report that we wrote that data, so future
// errors are ignored. We set our internal `need_flush` flag, though, in
// case flushing fails and we need to try it first next time.
let n = self.inner.write(&buf[..i + 1])?;
self.need_flush = true;
if self.flush().is_err() || n != i + 1 {
return Ok(n)
}

// At this point we successfully wrote `i + 1` bytes and flushed it out,
// meaning that the entire line is now flushed out on the screen. While
// we can attempt to finish writing the rest of the data provided.
// Remember though that we ignore errors here as we've successfully
// written data, so we need to report that.
match self.inner.write(&buf[i + 1..]) {
Ok(i) => Ok(n + i),
Err(_) => Ok(n),
}
}

fn flush(&mut self) -> io::Result<()> { self.inner.flush() }
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()?;
self.need_flush = false;
Ok(())
}
}

#[stable(feature = "rust1", since = "1.0.0")]
Expand Down Expand Up @@ -1153,4 +1186,44 @@ mod tests {
BufWriter::new(io::sink())
});
}

struct AcceptOneThenFail {
written: bool,
flushed: bool,
}

impl Write for AcceptOneThenFail {
fn write(&mut self, data: &[u8]) -> io::Result<usize> {
if !self.written {
assert_eq!(data, b"a\nb\n");
self.written = true;
Ok(data.len())
} else {
Err(io::Error::new(io::ErrorKind::NotFound, "test"))
}
}

fn flush(&mut self) -> io::Result<()> {
assert!(self.written);
assert!(!self.flushed);
self.flushed = true;
Err(io::Error::new(io::ErrorKind::Other, "test"))
}
}

#[test]
fn erroneous_flush_retried() {
let a = AcceptOneThenFail {
written: false,
flushed: false,
};

let mut l = LineWriter::new(a);
assert_eq!(l.write(b"a\nb\na").unwrap(), 4);
assert!(l.get_ref().written);
assert!(l.get_ref().flushed);
l.get_mut().flushed = false;

assert_eq!(l.write(b"a").unwrap_err().kind(), io::ErrorKind::Other)
}
}