added 2 new implementations from users who commented on my reddit post as a comparison

This commit is contained in:
Fabian Schmidt 2024-07-31 13:58:42 +02:00
parent 25d20169aa
commit 0ffbff4cbf
6 changed files with 636 additions and 0 deletions

View File

@ -548,6 +548,7 @@ version = "0.1.0"
dependencies = [
"bstr",
"fast-float",
"libc",
"memchr",
"memmap",
"polars",

View File

@ -13,6 +13,7 @@ memmap = "0.7.0"
polars = { version = "0.36.2", features = ["csv", "lazy", "nightly", "streaming"]}
rayon = "1.10.0"
rustc-hash = "2.0.0"
libc = "0.2.155"
[features]
json = []

View File

@ -0,0 +1,286 @@
#![feature(hash_raw_entry)]
use std::collections::HashMap;
use std::env::args;
use std::fs::File;
use std::io::{BufRead, BufReader, Read, Seek, SeekFrom};
use std::ops::{Neg, Range};
use std::os::unix::fs::FileExt;
use std::sync::mpsc::{channel, Sender};
use std::thread;
use std::thread::{available_parallelism, JoinHandle};
use std::time::Instant;
#[derive(Copy, Clone, Debug)]
struct City {
min: i64,
max: i64,
sum: i64,
occurrences: u32,
}
impl City {
pub fn add_new(&mut self, input: &[u8]) {
let mut val = 0;
let mut is_neg = false;
for &char in input {
match char {
b'0'..=b'9' => {
val *= 10;
let digit = char - b'0';
val += digit as i64;
}
b'-' => {
is_neg = true;
}
b'.' => {}
_ => {
panic!("encountered {} in value", char::from(char))
}
}
}
if is_neg {
val = val.neg();
}
self.add_new_value(val);
}
pub fn add_new_value(&mut self, new: i64) {
self.min = self.min.min(new);
self.max = self.max.max(new);
self.sum += new;
self.occurrences += 1;
}
pub fn min(&self) -> f64 {
self.min as f64 / 10.0
}
pub fn mean(&self) -> f64 {
self.sum as f64 / self.occurrences as f64 / 10.0
}
pub fn max(&self) -> f64 {
self.max as f64 / 10.0
}
pub fn add_result(&mut self, other: Self) {
self.min = self.min.min(other.min);
self.max = self.max.max(other.max);
self.sum += other.sum;
self.occurrences += other.occurrences;
}
}
impl Default for City {
fn default() -> Self {
Self {
min: i64::MAX,
max: i64::MIN,
sum: 0,
occurrences: 0,
}
}
}
#[derive(Default, Clone, Debug)]
struct Citymap {
// Length then values
pub map: HashMap<u32, (String, City)>,
}
fn hashstr(s: &str) -> u32 {
let b = s.as_bytes();
u32::from_le_bytes([s.len() as u8, b[0], b[1], b[2]])
}
impl Citymap {
pub fn lookup(&mut self, lookup: &str) -> &mut City {
let hash = hashstr(lookup);
let builder = self.map.raw_entry_mut();
&mut builder.from_key(&hash).or_insert(hash, (lookup.to_owned(), Default::default())).1.1
}
pub fn new() -> Self {
Self {
map: Default::default(),
}
}
pub fn into_key_values(self) -> Vec<(String, City)> {
self.map.into_iter().map(|(_, s)| s).collect()
}
pub fn merge_with(&mut self, rhs: Self) {
for (k, v) in rhs.map.into_iter() {
self.map
.entry(k)
.and_modify(|lhs| {
lhs.1.add_result(v.1);
})
.or_insert(v);
}
}
}
fn main() {
let mut args = args();
let start = Instant::now();
let input = "../../../measurements.txt";
let results = if args.find(|e| e == "st").is_some() {
citymap_single_thread(input)
} else {
citymap_multi_threaded(input)
};
print_results(results);
println!("{:?}", start.elapsed());
}
fn citymap_single_thread(path: &str) -> Citymap {
let f = File::open(path).unwrap();
let mut buf = BufReader::with_capacity(10_usize.pow(8), f);
citymap_naive(&mut buf)
}
fn citymap_multi_threaded(path: &str) -> Citymap {
let cpus = available_parallelism().unwrap().get();
let size = File::open(path).unwrap().metadata().unwrap().len();
let per_thread = size / cpus as u64;
let mut index = 0;
let mut threads = vec![];
let (sender, receiver) = channel();
for i in 0..cpus {
let range = index..({
index += per_thread;
index.min(size)
});
threads.push(citymap_thread(path.to_owned(), range, i, sender.clone()));
}
let mut ranges = (0..cpus)
.into_iter()
.map(|_| receiver.recv().unwrap())
.collect::<Vec<_>>();
ranges.sort_unstable_by_key(|e| e.start);
assert!(
ranges.windows(2).all(|e| {
let first = &e[0];
let second = &e[1];
first.end == second.start
}),
"Ranges overlap or have gaps: {ranges:?}"
);
let results = threads
.into_iter()
.map(|e| e.join().unwrap())
//.map(|e|dbg!(e))
.reduce(|mut left, right| {
left.merge_with(right);
left
})
.unwrap();
results
}
fn citymap_thread(
path: String,
mut range: Range<u64>,
i: usize,
range_feedback: Sender<Range<u64>>,
) -> JoinHandle<Citymap> {
thread::Builder::new()
.name(format!("process_thread id: {i} assigned: {range:?}"))
.spawn(move || {
let mut file = File::open(path).unwrap();
//println!("Before: {range:?}");
// Perform alignment of buffer/range at the start
{
// Skip head alignment for start of file
if range.start != 0 {
let mut head = vec![0; 50];
let len = file.read_at(&mut head, range.start).unwrap();
head.truncate(len);
for (i, &pos) in head.iter().enumerate() {
if pos == '\n' as u8 {
range.start += i as u64;
break;
}
}
}
// tail alignment
{
let mut head = vec![0; 50];
let len = file.read_at(&mut head, range.end).unwrap();
head.truncate(len);
for (i, &pos) in head.iter().enumerate() {
if pos == '\n' as u8 {
range.end += i as u64;
break;
}
}
}
}
// Notify main about alignment
range_feedback.send(range.clone()).unwrap();
// Ensure we remain within bounds of the designated file range
file.seek(SeekFrom::Start(range.start)).unwrap();
let limited = BufReader::with_capacity(10_usize.pow(5), file);
let mut buffered = limited.take(range.end - range.start);
citymap_naive(&mut buffered)
})
.unwrap()
}
fn citymap_naive(input: &mut impl BufRead) -> Citymap {
let mut map = Citymap::new();
let mut buf = Vec::with_capacity(50);
loop {
let read = input.read_until(b'\n', &mut buf).unwrap();
// Stream has finished
if read == 0 {
break;
}
// Skip over just newline strings that get created by the alignment process
if buf == &[b'\n'] {
continue;
}
let mut city = None;
let mut val = None;
for (i, &char) in buf.iter().enumerate() {
if char == b';' {
city = Some(&buf[0..i]);
val = Some(&buf[(i + 1)..(buf.len() - 1)]);
break;
}
}
#[cfg(not(feature = "unsafe"))]
let entry = map.lookup(std::str::from_utf8(city.unwrap()).unwrap());
#[cfg(feature = "unsafe")]
let entry = map.lookup(unsafe { std::str::from_utf8_unchecked(city.unwrap()) });
entry.add_new(val.unwrap());
buf.clear();
}
map
}
fn print_results(map: Citymap) {
let mut res = map.into_key_values();
res.sort_unstable_by(|(a, _), (b, _)| a.cmp(b));
print!("{{");
for (city, vals) in res {
let min = vals.min();
let mean = vals.mean();
let max = vals.max();
print!("{city}={min:.1}/{mean:.1}/{max:.1}, ")
}
println!("}}");
}

View File

@ -0,0 +1,201 @@
use std::collections::HashMap;
use std::env;
use std::fmt;
use std::fs::File;
use std::io;
use std::io::prelude::*;
use std::thread::{self, Scope, ScopedJoinHandle};
use onebrc::mmap::Mmap;
use onebrc::mmap::MmapChunkIterator;
// Defined in challenge spec
const MAX_STATIONS: usize = 10000;
const NUM_CONSUMERS: usize = 32;
const FIXED_POINT_DIVISOR: f64 = 10.0;
struct StationData {
min_temp: i32,
max_temp: i32,
count: i32,
temp_sum: i32,
}
impl fmt::Display for StationData {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{:.1}/{:.1}/{:.1}",
(self.min_temp as f64 / FIXED_POINT_DIVISOR),
self.get_mean(),
(self.max_temp as f64 / FIXED_POINT_DIVISOR)
)
}
}
/// Efficiently handles station statistics. Avoids using floating-point arithmetic to speed-up processing.
/// The mean is only calculated on demand, so we avoid calculating it as we read the file
impl StationData {
fn new(temp: i32) -> Self {
Self {
min_temp: temp,
max_temp: temp,
count: 1,
temp_sum: temp,
}
}
fn to_bytes(&self) -> Vec<u8> {
format!(
"{:.1}/{:.1}/{:.1}",
(self.min_temp as f64 / FIXED_POINT_DIVISOR),
self.get_mean(),
(self.max_temp as f64 / FIXED_POINT_DIVISOR)
)
.into_bytes()
}
fn get_mean(&self) -> f64 {
(self.temp_sum as f64 / self.count as f64) / FIXED_POINT_DIVISOR
}
fn update_from(&mut self, temp: i32) {
self.max_temp = self.max_temp.max(temp);
self.min_temp = self.min_temp.min(temp);
self.count += 1;
self.temp_sum += temp;
}
fn update_from_station(&mut self, src: &mut Self) {
self.max_temp = self.max_temp.max(src.max_temp);
self.min_temp = self.min_temp.min(src.min_temp);
self.temp_sum += src.temp_sum;
self.count += src.count;
}
#[inline]
fn parse_temp(bytes: &[u8]) -> i32 {
let mut result: i32 = 0;
let mut negative: bool = false;
for &b in bytes {
match b {
b'0'..=b'9' => {
result = result * 10 + (b as i32 - b'0' as i32);
}
b'.' => {}
b'-' => {
negative = true;
}
_ => panic!("wrong format for temperature"),
}
}
if negative {
return -result;
}
result
}
#[inline]
fn parse_data(line: &[u8]) -> (&[u8], i32) {
let semicolon_pos = line.iter().position(|&b| b == b';').unwrap();
let name = &line[..semicolon_pos];
let temp = &line[semicolon_pos + 1..];
(name, Self::parse_temp(temp))
}
}
fn merge_hashmaps<'a>(
mut dest: HashMap<&'a [u8], StationData>,
src: HashMap<&'a [u8], StationData>,
) -> HashMap<&'a [u8], StationData> {
for (k, mut v) in src {
dest.entry(k)
.and_modify(|e| e.update_from_station(&mut v))
.or_insert(v);
}
dest
}
/// Parses a chunk of the input as StationData values.
fn process_chunk<'a>(current_chunk_slice: &'a [u8]) -> HashMap<&'a [u8], StationData> {
let mut station_map: HashMap<&[u8], StationData> = HashMap::with_capacity(MAX_STATIONS);
let mut start = 0;
while let Some(end) = current_chunk_slice[start..].iter().position(|&b| b == b'\n') {
let line = &current_chunk_slice[start..start + end];
let (name, temp) = StationData::parse_data(line);
station_map
.entry(name)
.and_modify(|e| e.update_from(temp))
.or_insert(StationData::new(temp));
start += end + 1; // move to the start of the next line
}
// If we don't find a \n, process the remaining data
if start < current_chunk_slice.len() {
let line = &current_chunk_slice[start..];
let (name, temp) = StationData::parse_data(line);
station_map
.entry(name)
.and_modify(|e| e.update_from(temp))
.or_insert(StationData::new(temp));
}
station_map
}
fn process_mmap<'scope, 'env>(
mmap: Mmap<'env>,
s: &'scope Scope<'scope, 'env>,
) -> HashMap<&'env [u8], StationData> {
let mut handlers: Vec<ScopedJoinHandle<HashMap<&[u8], StationData>>> = Vec::new();
for chunk in MmapChunkIterator::new(mmap, NUM_CONSUMERS) {
let h = s.spawn(move || process_chunk(chunk));
handlers.push(h);
}
let mut station_map: HashMap<&[u8], StationData> = HashMap::with_capacity(MAX_STATIONS);
for h in handlers {
let inner_station = h.join().unwrap();
station_map = merge_hashmaps(station_map, inner_station);
}
station_map
}
fn write_output_to_stdout(station_map: HashMap<&[u8], StationData>) -> io::Result<()> {
let mut stdout = io::stdout().lock();
let mut buffer = Vec::new();
buffer.extend_from_slice(b"{");
let mut sorted_key_value_vec: Vec<_> = station_map.iter().collect();
sorted_key_value_vec.sort_by_key(|e| e.0);
for (i, (name, data)) in sorted_key_value_vec.iter().enumerate() {
if i > 0 {
buffer.extend_from_slice(b", ");
}
buffer.extend_from_slice(name);
buffer.extend_from_slice(b"=");
buffer.extend(data.to_bytes());
}
buffer.extend_from_slice(b"}");
stdout.write_all(&buffer)
}
fn main() -> io::Result<()> {
// won't accept non-utf-8 args
let args: Vec<String> = env::args().collect();
let file_name = match args.get(2).clone() {
Some(fname) => fname,
None => "../../../measurements.txt",
};
let f = File::open(file_name)?;
let mmap = Mmap::from_file(f);
thread::scope(|s| {
let station_map = process_mmap(mmap, s);
write_output_to_stdout(station_map).unwrap();
});
Ok(())
}

View File

@ -1,5 +1,7 @@
#![feature(slice_as_chunks)]
pub mod mmap;
use std::fmt::Display;
#[derive(Copy, Clone)]

145
src/main/rust/src/mmap.rs Normal file
View File

@ -0,0 +1,145 @@
use std::ops::Deref;
use std::os::fd::AsRawFd;
use std::ptr::null_mut;
use std::{fs::File, os::raw::c_void};
use libc::{madvise, mmap, munmap, size_t, MADV_WILLNEED, MAP_FAILED, MAP_SHARED, PROT_READ};
/// Smart pointer type for a mmap. Handles munmap call.
pub struct Mmap<'a> {
mmap_slice: &'a [u8],
}
/// To properly dispose of the mmap we have to manually call munmap.
/// So implementing drop for this smart-pointer type is necessary.
impl<'a> Drop for Mmap<'a> {
fn drop(&mut self) {
unsafe {
munmap(
self.mmap_slice.as_ptr() as *mut c_void,
self.mmap_slice.len(),
);
}
}
}
// anti-pattern for non-smart pointer types.
// ref: https://rust-unofficial.github.io/patterns/anti_patterns/deref.html
impl<'a> Deref for Mmap<'a> {
type Target = [u8];
fn deref(&self) -> &Self::Target {
self.mmap_slice
}
}
impl<'a> Mmap<'a> {
fn new(data: &'a [u8]) -> Self {
Self { mmap_slice: data }
}
pub fn from_file(f: File) -> Self {
let size = f.metadata().unwrap().len() as size_t;
let prot = PROT_READ;
let flags = MAP_SHARED;
unsafe {
let m = mmap(null_mut(), size, prot, flags, f.as_raw_fd(), 0);
if m == MAP_FAILED {
panic!("mmap failed");
}
// We can advise the kernel on how we intend to use the mmap.
// But this did not improve my read performance in a meaningful way
madvise(m, size, MADV_WILLNEED);
return Self::new(std::slice::from_raw_parts(m as *const u8, size));
}
}
}
pub struct MmapChunkIterator<'a> {
data: Mmap<'a>,
chunk_size: usize,
}
impl<'a> MmapChunkIterator<'a> {
fn with_consumers(mut self, consumers: usize) -> Self {
self.chunk_size = self.data.len() / consumers;
self
}
pub fn new(data: Mmap<'a>, num_consumers: usize) -> Self {
Self {
data,
chunk_size: 1,
}
.with_consumers(num_consumers)
}
}
impl<'a> IntoIterator for Mmap<'a> {
type IntoIter = MmapChunkIterator<'a>;
type Item = &'a [u8];
fn into_iter(self) -> Self::IntoIter {
MmapChunkIterator {
data: self,
chunk_size: 1,
}
}
}
impl<'a> Iterator for MmapChunkIterator<'a> {
type Item = &'a [u8];
fn next(&mut self) -> Option<Self::Item> {
if self.data.is_empty() {
return None;
}
let chunk_end = self.chunk_size.min(self.data.len());
let chunk = &self.data[..chunk_end];
// Find the last newline in the chunk
let split_at = chunk
.iter()
.rposition(|&x| x == b'\n')
.map(|i| i + 1)
.unwrap_or(chunk_end);
let (result, rest) = self.data.mmap_slice.split_at(split_at);
self.data.mmap_slice = rest;
Some(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::io::Write;
use std::path::Path;
fn create_test_file(path: &Path, content: &[u8]) {
let mut file = File::create(path).unwrap();
file.write_all(content).unwrap();
}
fn remove_test_file(path: &Path) {
if path.exists() {
fs::remove_file(path).unwrap();
}
}
#[test]
fn test_from_file() {
let test_file_path = Path::new("test_file.txt");
let test_content = b"Hello, mmap!";
create_test_file(test_file_path, test_content);
let file = File::open(test_file_path).unwrap();
let mmap = Mmap::from_file(file);
assert_eq!(&*mmap, test_content);
remove_test_file(test_file_path);
}
}