Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions src/uu/tr/src/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,14 @@ impl ChunkProcessor for DeleteOperation {
#[derive(Debug)]
pub struct TranslateOperation {
pub(crate) translation_table: [u8; 256],
pub(crate) ascii_range: Option<AsciiRangeTranslate>,
}

#[derive(Debug, Clone, Copy)]
pub(crate) struct AsciiRangeTranslate {
pub(crate) start: u8,
pub(crate) end: u8,
pub(crate) delta: u8,
}

impl TranslateOperation {
Expand All @@ -686,16 +694,60 @@ impl TranslateOperation {
translation_table[from as usize] = to;
}

Ok(Self { translation_table })
Ok(Self {
ascii_range: detect_ascii_range_translate(&translation_table),
translation_table,
})
} else if set1.is_empty() && set2.is_empty() {
// Identity mapping for empty sets
Ok(Self { translation_table })
Ok(Self {
ascii_range: None,
translation_table,
})
} else {
Err(BadSequence::EmptySet2WhenNotTruncatingSet1)
}
}
}

fn detect_ascii_range_translate(table: &[u8; 256]) -> Option<AsciiRangeTranslate> {
let mut range: Option<AsciiRangeTranslate> = None;
let mut changed_count = 0usize;
let mut finished_range = false;

for (from, &to) in table.iter().enumerate() {
let from = from as u8;
if to == from {
if range.is_some() {
finished_range = true;
}
continue;
}

if from > 0x7f || finished_range {
return None;
}

let delta = to.wrapping_sub(from);
match &mut range {
Some(range) if range.delta == delta => {
range.end = from;
}
Some(_) => return None,
None => {
range = Some(AsciiRangeTranslate {
start: from,
end: from,
delta,
});
}
}
changed_count += 1;
}

(changed_count > 1).then_some(range?)
}

impl SymbolTranslator for TranslateOperation {
fn translate(&mut self, current: u8) -> Option<u8> {
Some(self.translation_table[current as usize])
Expand All @@ -704,14 +756,16 @@ impl SymbolTranslator for TranslateOperation {

impl ChunkProcessor for TranslateOperation {
fn process_chunk(&self, input: &[u8], output: &mut Vec<u8>) {
use crate::simd::{find_single_change, process_single_char_replace};
use crate::simd::{find_single_change, process_single_char_replace, translate_ascii_range};

// Check if this is a simple single-character translation
if let Some((source, target)) =
find_single_change(&self.translation_table, |i, &val| val != i as u8)
{
// Use SIMD-optimized single character replacement
process_single_char_replace(input, output, source, target);
} else if let Some(range) = self.ascii_range {
translate_ascii_range(input, output, range);
} else {
// Standard translation using table lookup
output.extend(input.iter().map(|&b| self.translation_table[b as usize]));
Expand Down
87 changes: 86 additions & 1 deletion src/uu/tr/src/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
// For the full copyright and license information, please view the LICENSE
// file that was distributed with this source code.

// spell-checker:ignore (intrinsics) blendv cmpgt loadu storeu

//! I/O processing infrastructure for tr operations with SIMD optimizations

use crate::operation::ChunkProcessor;
use crate::operation::{AsciiRangeTranslate, ChunkProcessor};
use std::io::{BufRead, Write};
use uucore::error::{FromIo, UResult};
use uucore::translate;
Expand Down Expand Up @@ -59,6 +61,89 @@ pub fn process_single_delete(input: &[u8], output: &mut Vec<u8>, delete_char: u8
// If count == input.len(), all deleted, output nothing
}

/// Translate a contiguous ASCII byte range by a constant wrapping delta.
#[inline]
pub(crate) fn translate_ascii_range(
input: &[u8],
output: &mut Vec<u8>,
range: AsciiRangeTranslate,
) {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
if input.len() >= 32 && std::is_x86_feature_detected!("avx2") {
unsafe {
translate_ascii_range_avx2(input, output, range);
}
return;
}

translate_ascii_range_scalar(input, output, range);
}

#[inline]
fn translate_ascii_range_scalar(input: &[u8], output: &mut Vec<u8>, range: AsciiRangeTranslate) {
output.extend(input.iter().map(|&byte| {
if (range.start..=range.end).contains(&byte) {
byte.wrapping_add(range.delta)
} else {
byte
}
}));
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
/// # Safety
///
/// Callers must only call this function when AVX2 is available on the current CPU.
unsafe fn translate_ascii_range_avx2(
input: &[u8],
output: &mut Vec<u8>,
range: AsciiRangeTranslate,
) {
#[cfg(target_arch = "x86")]
use std::arch::x86::{
_mm256_add_epi8, _mm256_and_si256, _mm256_blendv_epi8, _mm256_cmpgt_epi8,
_mm256_loadu_si256, _mm256_set1_epi8, _mm256_storeu_si256, _mm256_xor_si256,
};
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::{
_mm256_add_epi8, _mm256_and_si256, _mm256_blendv_epi8, _mm256_cmpgt_epi8,
_mm256_loadu_si256, _mm256_set1_epi8, _mm256_storeu_si256, _mm256_xor_si256,
};

let start_len = output.len();
output.resize(start_len + input.len(), 0);

let start_minus_one = (range.start.wrapping_sub(1)) as i8;
let start = _mm256_set1_epi8(start_minus_one);
let end = _mm256_set1_epi8(range.end as i8);
let delta = _mm256_set1_epi8(range.delta as i8);
let all_bits = _mm256_set1_epi8(-1);

let mut offset = 0usize;
while offset + 32 <= input.len() {
let bytes = unsafe { _mm256_loadu_si256(input.as_ptr().add(offset).cast()) };
let greater_equal_start = _mm256_cmpgt_epi8(bytes, start);
let greater_than_end = _mm256_cmpgt_epi8(bytes, end);
let less_equal_end = _mm256_xor_si256(greater_than_end, all_bits);
let in_range = _mm256_and_si256(greater_equal_start, less_equal_end);
let translated = _mm256_add_epi8(bytes, delta);
let blended = _mm256_blendv_epi8(bytes, translated, in_range);
unsafe {
_mm256_storeu_si256(output.as_mut_ptr().add(start_len + offset).cast(), blended);
}
offset += 32;
}

for (index, &byte) in input[offset..].iter().enumerate() {
output[start_len + offset + index] = if (range.start..=range.end).contains(&byte) {
byte.wrapping_add(range.delta)
} else {
byte
};
}
}

/// Unified I/O processing for all operations
pub fn process_input<R, W, P>(input: &mut R, output: &mut W, processor: &P) -> UResult<()>
where
Expand Down
63 changes: 63 additions & 0 deletions tests/by-util/test_tr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,69 @@ fn test_to_upper() {
.stdout_is("!ABCD!");
}

#[test]
fn test_ascii_range_translate_alignment_boundaries() {
let mut cases = vec![
Vec::new(),
b"a".to_vec(),
(0..31).map(|i| b'a' + (i % 26) as u8).collect(),
(0..32).map(|i| b'a' + (i % 26) as u8).collect(),
(0..33).map(|i| b'a' + (i % 26) as u8).collect(),
b"zZ!\xc3\xa9a".to_vec(),
];
cases.push((0u8..=255).collect());

#[cfg(unix)]
let gnu_tr = ["tr", "gtr"].into_iter().find(|candidate| {
std::process::Command::new(candidate)
.arg("--version")
.output()
.is_ok_and(|output| {
output.status.success()
&& String::from_utf8_lossy(&output.stdout).contains("(GNU coreutils)")
})
});

for input in cases {
let expected: Vec<u8> = input
.iter()
.map(|&byte| {
if byte.is_ascii_lowercase() {
byte.to_ascii_uppercase()
} else {
byte
}
})
.collect();

new_ucmd!()
.args(&["a-z", "A-Z"])
.pipe_in(input.clone())
.succeeds()
.stdout_is_bytes(&expected)
.no_stderr();

#[cfg(unix)]
if let Some(gnu_tr) = gnu_tr {
use std::io::Write as _;

let mut child = std::process::Command::new(gnu_tr)
.args(["a-z", "A-Z"])
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.spawn()
.unwrap();
child.stdin.as_mut().unwrap().write_all(&input).unwrap();
let output = child.wait_with_output().unwrap();

assert!(output.status.success());
assert_eq!(output.stdout, expected);
assert!(output.stderr.is_empty());
}
}
}

#[test]
fn test_small_set2() {
new_ucmd!()
Expand Down
Loading