Skip to content

Commit

Permalink
Merge pull request #4862 from jeddenlea/yes
Browse files Browse the repository at this point in the history
yes: support non-UTF-8 args
  • Loading branch information
sylvestre authored May 16, 2023
2 parents 21b9cf8 + 3870ee2 commit f26bf98
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 28 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/uu/yes/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ path = "src/yes.rs"

[dependencies]
clap = { workspace=true }
itertools = { workspace=true }

[target.'cfg(unix)'.dependencies]
uucore = { workspace=true, features=["pipes", "signals"] }
Expand Down
161 changes: 136 additions & 25 deletions src/uu/yes/src/yes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@

/* last synced with: yes (GNU coreutils) 8.13 */

use clap::{Arg, ArgAction, Command};
use std::borrow::Cow;
// cSpell:ignore strs

use clap::{builder::ValueParser, Arg, ArgAction, Command};
use std::error::Error;
use std::ffi::OsString;
use std::io::{self, Write};
use uucore::error::{UResult, USimpleError};
#[cfg(unix)]
Expand All @@ -28,19 +31,11 @@ const BUF_SIZE: usize = 16 * 1024;
pub fn uumain(args: impl uucore::Args) -> UResult<()> {
let matches = uu_app().try_get_matches_from(args)?;

let string = if let Some(values) = matches.get_many::<String>("STRING") {
let mut result = values.fold(String::new(), |res, s| res + s + " ");
result.pop();
result.push('\n');
Cow::from(result)
} else {
Cow::from("y\n")
};
let mut buffer = Vec::with_capacity(BUF_SIZE);
args_into_buffer(&mut buffer, matches.get_many::<OsString>("STRING")).unwrap();
prepare_buffer(&mut buffer);

let mut buffer = [0; BUF_SIZE];
let bytes = prepare_buffer(&string, &mut buffer);

match exec(bytes) {
match exec(&buffer) {
Ok(()) => Ok(()),
Err(err) if err.kind() == io::ErrorKind::BrokenPipe => Ok(()),
Err(err) => Err(USimpleError::new(1, format!("standard output: {err}"))),
Expand All @@ -51,21 +46,73 @@ pub fn uu_app() -> Command {
Command::new(uucore::util_name())
.about(ABOUT)
.override_usage(format_usage(USAGE))
.arg(Arg::new("STRING").action(ArgAction::Append))
.arg(
Arg::new("STRING")
.value_parser(ValueParser::os_string())
.action(ArgAction::Append),
)
.infer_long_args(true)
}

fn prepare_buffer<'a>(input: &'a str, buffer: &'a mut [u8; BUF_SIZE]) -> &'a [u8] {
if input.len() < BUF_SIZE / 2 {
let mut size = 0;
while size < BUF_SIZE - input.len() {
let (_, right) = buffer.split_at_mut(size);
right[..input.len()].copy_from_slice(input.as_bytes());
size += input.len();
}
&buffer[..size]
// Copies words from `i` into `buf`, separated by spaces.
fn args_into_buffer<'a>(
buf: &mut Vec<u8>,
i: Option<impl Iterator<Item = &'a OsString>>,
) -> Result<(), Box<dyn Error>> {
// TODO: this should be replaced with let/else once available in the MSRV.
let i = if let Some(i) = i {
i
} else {
input.as_bytes()
buf.extend_from_slice(b"y\n");
return Ok(());
};

// On Unix (and wasi), OsStrs are just &[u8]'s underneath...
#[cfg(any(unix, target_os = "wasi"))]
{
#[cfg(unix)]
use std::os::unix::ffi::OsStrExt;
#[cfg(target_os = "wasi")]
use std::os::wasi::ffi::OsStrExt;

for part in itertools::intersperse(i.map(|a| a.as_bytes()), b" ") {
buf.extend_from_slice(part);
}
}

// But, on Windows, we must hop through a String.
#[cfg(not(any(unix, target_os = "wasi")))]
{
for part in itertools::intersperse(i.map(|a| a.to_str()), Some(" ")) {
let bytes = match part {
Some(part) => part.as_bytes(),
None => return Err("arguments contain invalid UTF-8".into()),
};
buf.extend_from_slice(bytes);
}
}

buf.push(b'\n');

Ok(())
}

// Assumes buf holds a single output line forged from the command line arguments, copies it
// repeatedly until the buffer holds as many copies as it can under BUF_SIZE.
fn prepare_buffer(buf: &mut Vec<u8>) {
if buf.len() * 2 > BUF_SIZE {
return;
}

assert!(!buf.is_empty());

let line_len = buf.len();
let target_size = line_len * (BUF_SIZE / line_len);

while buf.len() < target_size {
let to_copy = std::cmp::min(target_size - buf.len(), buf.len());
debug_assert_eq!(to_copy % line_len, 0);
buf.extend_from_within(..to_copy);
}
}

Expand All @@ -88,3 +135,67 @@ pub fn exec(bytes: &[u8]) -> io::Result<()> {
stdout.write_all(bytes)?;
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_prepare_buffer() {
let tests = [
(150, 16350),
(1000, 16000),
(4093, 16372),
(4099, 12297),
(4111, 12333),
(2, 16384),
(3, 16383),
(4, 16384),
(5, 16380),
(8192, 16384),
(8191, 16382),
(8193, 8193),
(10000, 10000),
(15000, 15000),
(25000, 25000),
];

for (line, final_len) in tests {
let mut v = std::iter::repeat(b'a').take(line).collect::<Vec<_>>();
prepare_buffer(&mut v);
assert_eq!(v.len(), final_len);
}
}

#[test]
fn test_args_into_buf() {
{
let mut v = Vec::with_capacity(BUF_SIZE);
args_into_buffer(&mut v, None::<std::slice::Iter<OsString>>).unwrap();
assert_eq!(String::from_utf8(v).unwrap(), "y\n");
}

{
let mut v = Vec::with_capacity(BUF_SIZE);
args_into_buffer(&mut v, Some([OsString::from("foo")].iter())).unwrap();
assert_eq!(String::from_utf8(v).unwrap(), "foo\n");
}

{
let mut v = Vec::with_capacity(BUF_SIZE);
args_into_buffer(
&mut v,
Some(
[
OsString::from("foo"),
OsString::from("bar baz"),
OsString::from("qux"),
]
.iter(),
),
)
.unwrap();
assert_eq!(String::from_utf8(v).unwrap(), "foo bar baz qux\n");
}
}
}
26 changes: 23 additions & 3 deletions tests/by-util/test_yes.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::ffi::OsStr;
use std::process::{ExitStatus, Stdio};

#[cfg(unix)]
Expand All @@ -15,8 +16,10 @@ fn check_termination(result: ExitStatus) {
assert!(result.success(), "yes did not exit successfully");
}

const NO_ARGS: &[&str] = &[];

/// Run `yes`, capture some of the output, close the pipe, and verify it.
fn run(args: &[&str], expected: &[u8]) {
fn run(args: &[impl AsRef<OsStr>], expected: &[u8]) {
let mut cmd = new_ucmd!();
let mut child = cmd.args(args).set_stdout(Stdio::piped()).run_no_wait();
let buf = child.stdout_exact_bytes(expected.len());
Expand All @@ -34,7 +37,7 @@ fn test_invalid_arg() {

#[test]
fn test_simple() {
run(&[], b"y\ny\ny\ny\n");
run(NO_ARGS, b"y\ny\ny\ny\n");
}

#[test]
Expand All @@ -44,7 +47,7 @@ fn test_args() {

#[test]
fn test_long_output() {
run(&[], "y\n".repeat(512 * 1024).as_bytes());
run(NO_ARGS, "y\n".repeat(512 * 1024).as_bytes());
}

/// Test with an output that seems likely to get mangled in case of incomplete writes.
Expand Down Expand Up @@ -88,3 +91,20 @@ fn test_piped_to_dev_full() {
}
}
}

#[test]
#[cfg(any(unix, target_os = "wasi"))]
fn test_non_utf8() {
#[cfg(unix)]
use std::os::unix::ffi::OsStrExt;
#[cfg(target_os = "wasi")]
use std::os::wasi::ffi::OsStrExt;

run(
&[
OsStr::from_bytes(b"\xbf\xff\xee"),
OsStr::from_bytes(b"bar"),
],
&b"\xbf\xff\xee bar\n".repeat(5000),
);
}

0 comments on commit f26bf98

Please sign in to comment.