diff --git a/src/main/rust/Cargo.lock b/src/main/rust/Cargo.lock index 7fc8b04..e1545cb 100644 --- a/src/main/rust/Cargo.lock +++ b/src/main/rust/Cargo.lock @@ -548,6 +548,7 @@ version = "0.1.0" dependencies = [ "bstr", "fast-float", + "libc", "memchr", "memmap", "polars", diff --git a/src/main/rust/Cargo.toml b/src/main/rust/Cargo.toml index 69a6e2b..e643055 100644 --- a/src/main/rust/Cargo.toml +++ b/src/main/rust/Cargo.toml @@ -13,6 +13,7 @@ memmap = "0.7.0" polars = { version = "0.36.2", features = ["csv", "lazy", "nightly", "streaming"]} rayon = "1.10.0" rustc-hash = "2.0.0" +libc = "0.2.155" [features] json = [] diff --git a/src/main/rust/src/bin/FlareFlo.rs b/src/main/rust/src/bin/FlareFlo.rs new file mode 100644 index 0000000..fd95f9c --- /dev/null +++ b/src/main/rust/src/bin/FlareFlo.rs @@ -0,0 +1,286 @@ +#![feature(hash_raw_entry)] + +use std::collections::HashMap; +use std::env::args; +use std::fs::File; +use std::io::{BufRead, BufReader, Read, Seek, SeekFrom}; +use std::ops::{Neg, Range}; +use std::os::unix::fs::FileExt; +use std::sync::mpsc::{channel, Sender}; +use std::thread; +use std::thread::{available_parallelism, JoinHandle}; +use std::time::Instant; + +#[derive(Copy, Clone, Debug)] +struct City { + min: i64, + max: i64, + sum: i64, + occurrences: u32, +} + +impl City { + pub fn add_new(&mut self, input: &[u8]) { + let mut val = 0; + let mut is_neg = false; + for &char in input { + match char { + b'0'..=b'9' => { + val *= 10; + let digit = char - b'0'; + val += digit as i64; + } + b'-' => { + is_neg = true; + } + b'.' => {} + _ => { + panic!("encountered {} in value", char::from(char)) + } + } + } + if is_neg { + val = val.neg(); + } + self.add_new_value(val); + } + + pub fn add_new_value(&mut self, new: i64) { + self.min = self.min.min(new); + self.max = self.max.max(new); + self.sum += new; + self.occurrences += 1; + } + pub fn min(&self) -> f64 { + self.min as f64 / 10.0 + } + pub fn mean(&self) -> f64 { + self.sum as f64 / self.occurrences as f64 / 10.0 + } + pub fn max(&self) -> f64 { + self.max as f64 / 10.0 + } + + pub fn add_result(&mut self, other: Self) { + self.min = self.min.min(other.min); + self.max = self.max.max(other.max); + self.sum += other.sum; + self.occurrences += other.occurrences; + } +} + +impl Default for City { + fn default() -> Self { + Self { + min: i64::MAX, + max: i64::MIN, + sum: 0, + occurrences: 0, + } + } +} + +#[derive(Default, Clone, Debug)] +struct Citymap { + // Length then values + pub map: HashMap, +} + +fn hashstr(s: &str) -> u32 { + let b = s.as_bytes(); + u32::from_le_bytes([s.len() as u8, b[0], b[1], b[2]]) +} + +impl Citymap { + pub fn lookup(&mut self, lookup: &str) -> &mut City { + let hash = hashstr(lookup); + let builder = self.map.raw_entry_mut(); + &mut builder.from_key(&hash).or_insert(hash, (lookup.to_owned(), Default::default())).1.1 + } + pub fn new() -> Self { + Self { + map: Default::default(), + } + } + pub fn into_key_values(self) -> Vec<(String, City)> { + self.map.into_iter().map(|(_, s)| s).collect() + } + pub fn merge_with(&mut self, rhs: Self) { + for (k, v) in rhs.map.into_iter() { + self.map + .entry(k) + .and_modify(|lhs| { + lhs.1.add_result(v.1); + }) + .or_insert(v); + } + } +} + +fn main() { + let mut args = args(); + + let start = Instant::now(); + let input = "../../../measurements.txt"; + + let results = if args.find(|e| e == "st").is_some() { + citymap_single_thread(input) + } else { + citymap_multi_threaded(input) + }; + + print_results(results); + + println!("{:?}", start.elapsed()); +} + +fn citymap_single_thread(path: &str) -> Citymap { + let f = File::open(path).unwrap(); + + let mut buf = BufReader::with_capacity(10_usize.pow(8), f); + citymap_naive(&mut buf) +} + +fn citymap_multi_threaded(path: &str) -> Citymap { + let cpus = available_parallelism().unwrap().get(); + let size = File::open(path).unwrap().metadata().unwrap().len(); + let per_thread = size / cpus as u64; + + let mut index = 0; + let mut threads = vec![]; + let (sender, receiver) = channel(); + for i in 0..cpus { + let range = index..({ + index += per_thread; + index.min(size) + }); + threads.push(citymap_thread(path.to_owned(), range, i, sender.clone())); + } + let mut ranges = (0..cpus) + .into_iter() + .map(|_| receiver.recv().unwrap()) + .collect::>(); + ranges.sort_unstable_by_key(|e| e.start); + assert!( + ranges.windows(2).all(|e| { + let first = &e[0]; + let second = &e[1]; + first.end == second.start + }), + "Ranges overlap or have gaps: {ranges:?}" + ); + let results = threads + .into_iter() + .map(|e| e.join().unwrap()) + //.map(|e|dbg!(e)) + .reduce(|mut left, right| { + left.merge_with(right); + left + }) + .unwrap(); + results +} + +fn citymap_thread( + path: String, + mut range: Range, + i: usize, + range_feedback: Sender>, +) -> JoinHandle { + thread::Builder::new() + .name(format!("process_thread id: {i} assigned: {range:?}")) + .spawn(move || { + let mut file = File::open(path).unwrap(); + //println!("Before: {range:?}"); + + // Perform alignment of buffer/range at the start + { + // Skip head alignment for start of file + if range.start != 0 { + let mut head = vec![0; 50]; + let len = file.read_at(&mut head, range.start).unwrap(); + head.truncate(len); + + for (i, &pos) in head.iter().enumerate() { + if pos == '\n' as u8 { + range.start += i as u64; + break; + } + } + } + + // tail alignment + { + let mut head = vec![0; 50]; + let len = file.read_at(&mut head, range.end).unwrap(); + head.truncate(len); + + for (i, &pos) in head.iter().enumerate() { + if pos == '\n' as u8 { + range.end += i as u64; + break; + } + } + } + } + + // Notify main about alignment + range_feedback.send(range.clone()).unwrap(); + // Ensure we remain within bounds of the designated file range + file.seek(SeekFrom::Start(range.start)).unwrap(); + + let limited = BufReader::with_capacity(10_usize.pow(5), file); + let mut buffered = limited.take(range.end - range.start); + citymap_naive(&mut buffered) + }) + .unwrap() +} + +fn citymap_naive(input: &mut impl BufRead) -> Citymap { + let mut map = Citymap::new(); + let mut buf = Vec::with_capacity(50); + loop { + let read = input.read_until(b'\n', &mut buf).unwrap(); + // Stream has finished + if read == 0 { + break; + } + + // Skip over just newline strings that get created by the alignment process + if buf == &[b'\n'] { + continue; + } + + let mut city = None; + let mut val = None; + for (i, &char) in buf.iter().enumerate() { + if char == b';' { + city = Some(&buf[0..i]); + val = Some(&buf[(i + 1)..(buf.len() - 1)]); + break; + } + } + #[cfg(not(feature = "unsafe"))] + let entry = map.lookup(std::str::from_utf8(city.unwrap()).unwrap()); + + #[cfg(feature = "unsafe")] + let entry = map.lookup(unsafe { std::str::from_utf8_unchecked(city.unwrap()) }); + + entry.add_new(val.unwrap()); + buf.clear(); + } + map +} + +fn print_results(map: Citymap) { + let mut res = map.into_key_values(); + res.sort_unstable_by(|(a, _), (b, _)| a.cmp(b)); + print!("{{"); + for (city, vals) in res { + let min = vals.min(); + let mean = vals.mean(); + let max = vals.max(); + print!("{city}={min:.1}/{mean:.1}/{max:.1}, ") + } + println!("}}"); +} diff --git a/src/main/rust/src/bin/phcs.rs b/src/main/rust/src/bin/phcs.rs new file mode 100644 index 0000000..5f54945 --- /dev/null +++ b/src/main/rust/src/bin/phcs.rs @@ -0,0 +1,201 @@ +use std::collections::HashMap; +use std::env; +use std::fmt; +use std::fs::File; +use std::io; +use std::io::prelude::*; +use std::thread::{self, Scope, ScopedJoinHandle}; + +use onebrc::mmap::Mmap; +use onebrc::mmap::MmapChunkIterator; + +// Defined in challenge spec +const MAX_STATIONS: usize = 10000; +const NUM_CONSUMERS: usize = 32; +const FIXED_POINT_DIVISOR: f64 = 10.0; + +struct StationData { + min_temp: i32, + max_temp: i32, + count: i32, + temp_sum: i32, +} + +impl fmt::Display for StationData { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{:.1}/{:.1}/{:.1}", + (self.min_temp as f64 / FIXED_POINT_DIVISOR), + self.get_mean(), + (self.max_temp as f64 / FIXED_POINT_DIVISOR) + ) + } +} + +/// Efficiently handles station statistics. Avoids using floating-point arithmetic to speed-up processing. +/// The mean is only calculated on demand, so we avoid calculating it as we read the file +impl StationData { + fn new(temp: i32) -> Self { + Self { + min_temp: temp, + max_temp: temp, + count: 1, + temp_sum: temp, + } + } + + fn to_bytes(&self) -> Vec { + format!( + "{:.1}/{:.1}/{:.1}", + (self.min_temp as f64 / FIXED_POINT_DIVISOR), + self.get_mean(), + (self.max_temp as f64 / FIXED_POINT_DIVISOR) + ) + .into_bytes() + } + + fn get_mean(&self) -> f64 { + (self.temp_sum as f64 / self.count as f64) / FIXED_POINT_DIVISOR + } + + fn update_from(&mut self, temp: i32) { + self.max_temp = self.max_temp.max(temp); + self.min_temp = self.min_temp.min(temp); + self.count += 1; + self.temp_sum += temp; + } + fn update_from_station(&mut self, src: &mut Self) { + self.max_temp = self.max_temp.max(src.max_temp); + self.min_temp = self.min_temp.min(src.min_temp); + self.temp_sum += src.temp_sum; + self.count += src.count; + } + + #[inline] + fn parse_temp(bytes: &[u8]) -> i32 { + let mut result: i32 = 0; + let mut negative: bool = false; + for &b in bytes { + match b { + b'0'..=b'9' => { + result = result * 10 + (b as i32 - b'0' as i32); + } + b'.' => {} + b'-' => { + negative = true; + } + _ => panic!("wrong format for temperature"), + } + } + if negative { + return -result; + } + result + } + + #[inline] + fn parse_data(line: &[u8]) -> (&[u8], i32) { + let semicolon_pos = line.iter().position(|&b| b == b';').unwrap(); + let name = &line[..semicolon_pos]; + let temp = &line[semicolon_pos + 1..]; + (name, Self::parse_temp(temp)) + } +} + +fn merge_hashmaps<'a>( + mut dest: HashMap<&'a [u8], StationData>, + src: HashMap<&'a [u8], StationData>, +) -> HashMap<&'a [u8], StationData> { + for (k, mut v) in src { + dest.entry(k) + .and_modify(|e| e.update_from_station(&mut v)) + .or_insert(v); + } + dest +} + +/// Parses a chunk of the input as StationData values. +fn process_chunk<'a>(current_chunk_slice: &'a [u8]) -> HashMap<&'a [u8], StationData> { + let mut station_map: HashMap<&[u8], StationData> = HashMap::with_capacity(MAX_STATIONS); + let mut start = 0; + while let Some(end) = current_chunk_slice[start..].iter().position(|&b| b == b'\n') { + let line = ¤t_chunk_slice[start..start + end]; + let (name, temp) = StationData::parse_data(line); + station_map + .entry(name) + .and_modify(|e| e.update_from(temp)) + .or_insert(StationData::new(temp)); + start += end + 1; // move to the start of the next line + } + // If we don't find a \n, process the remaining data + if start < current_chunk_slice.len() { + let line = ¤t_chunk_slice[start..]; + let (name, temp) = StationData::parse_data(line); + station_map + .entry(name) + .and_modify(|e| e.update_from(temp)) + .or_insert(StationData::new(temp)); + } + station_map +} + +fn process_mmap<'scope, 'env>( + mmap: Mmap<'env>, + s: &'scope Scope<'scope, 'env>, +) -> HashMap<&'env [u8], StationData> { + let mut handlers: Vec>> = Vec::new(); + + for chunk in MmapChunkIterator::new(mmap, NUM_CONSUMERS) { + let h = s.spawn(move || process_chunk(chunk)); + handlers.push(h); + } + + let mut station_map: HashMap<&[u8], StationData> = HashMap::with_capacity(MAX_STATIONS); + for h in handlers { + let inner_station = h.join().unwrap(); + station_map = merge_hashmaps(station_map, inner_station); + } + station_map +} + +fn write_output_to_stdout(station_map: HashMap<&[u8], StationData>) -> io::Result<()> { + let mut stdout = io::stdout().lock(); + let mut buffer = Vec::new(); + + buffer.extend_from_slice(b"{"); + + let mut sorted_key_value_vec: Vec<_> = station_map.iter().collect(); + sorted_key_value_vec.sort_by_key(|e| e.0); + + for (i, (name, data)) in sorted_key_value_vec.iter().enumerate() { + if i > 0 { + buffer.extend_from_slice(b", "); + } + buffer.extend_from_slice(name); + buffer.extend_from_slice(b"="); + buffer.extend(data.to_bytes()); + } + + buffer.extend_from_slice(b"}"); + + stdout.write_all(&buffer) +} + +fn main() -> io::Result<()> { + // won't accept non-utf-8 args + let args: Vec = env::args().collect(); + let file_name = match args.get(2).clone() { + Some(fname) => fname, + None => "../../../measurements.txt", + }; + let f = File::open(file_name)?; + let mmap = Mmap::from_file(f); + + thread::scope(|s| { + let station_map = process_mmap(mmap, s); + write_output_to_stdout(station_map).unwrap(); + }); + + Ok(()) +} diff --git a/src/main/rust/src/lib.rs b/src/main/rust/src/lib.rs index 588c78d..b65777f 100644 --- a/src/main/rust/src/lib.rs +++ b/src/main/rust/src/lib.rs @@ -1,5 +1,7 @@ #![feature(slice_as_chunks)] +pub mod mmap; + use std::fmt::Display; #[derive(Copy, Clone)] diff --git a/src/main/rust/src/mmap.rs b/src/main/rust/src/mmap.rs new file mode 100644 index 0000000..b10688d --- /dev/null +++ b/src/main/rust/src/mmap.rs @@ -0,0 +1,145 @@ +use std::ops::Deref; +use std::os::fd::AsRawFd; +use std::ptr::null_mut; +use std::{fs::File, os::raw::c_void}; + +use libc::{madvise, mmap, munmap, size_t, MADV_WILLNEED, MAP_FAILED, MAP_SHARED, PROT_READ}; + +/// Smart pointer type for a mmap. Handles munmap call. +pub struct Mmap<'a> { + mmap_slice: &'a [u8], +} + +/// To properly dispose of the mmap we have to manually call munmap. +/// So implementing drop for this smart-pointer type is necessary. +impl<'a> Drop for Mmap<'a> { + fn drop(&mut self) { + unsafe { + munmap( + self.mmap_slice.as_ptr() as *mut c_void, + self.mmap_slice.len(), + ); + } + } +} + +// anti-pattern for non-smart pointer types. +// ref: https://rust-unofficial.github.io/patterns/anti_patterns/deref.html +impl<'a> Deref for Mmap<'a> { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.mmap_slice + } +} + +impl<'a> Mmap<'a> { + fn new(data: &'a [u8]) -> Self { + Self { mmap_slice: data } + } + + pub fn from_file(f: File) -> Self { + let size = f.metadata().unwrap().len() as size_t; + let prot = PROT_READ; + let flags = MAP_SHARED; + unsafe { + let m = mmap(null_mut(), size, prot, flags, f.as_raw_fd(), 0); + if m == MAP_FAILED { + panic!("mmap failed"); + } + // We can advise the kernel on how we intend to use the mmap. + // But this did not improve my read performance in a meaningful way + madvise(m, size, MADV_WILLNEED); + return Self::new(std::slice::from_raw_parts(m as *const u8, size)); + } + } +} + +pub struct MmapChunkIterator<'a> { + data: Mmap<'a>, + chunk_size: usize, +} + +impl<'a> MmapChunkIterator<'a> { + fn with_consumers(mut self, consumers: usize) -> Self { + self.chunk_size = self.data.len() / consumers; + self + } + + pub fn new(data: Mmap<'a>, num_consumers: usize) -> Self { + Self { + data, + chunk_size: 1, + } + .with_consumers(num_consumers) + } +} + +impl<'a> IntoIterator for Mmap<'a> { + type IntoIter = MmapChunkIterator<'a>; + type Item = &'a [u8]; + + fn into_iter(self) -> Self::IntoIter { + MmapChunkIterator { + data: self, + chunk_size: 1, + } + } +} + +impl<'a> Iterator for MmapChunkIterator<'a> { + type Item = &'a [u8]; + + fn next(&mut self) -> Option { + if self.data.is_empty() { + return None; + } + + let chunk_end = self.chunk_size.min(self.data.len()); + let chunk = &self.data[..chunk_end]; + + // Find the last newline in the chunk + let split_at = chunk + .iter() + .rposition(|&x| x == b'\n') + .map(|i| i + 1) + .unwrap_or(chunk_end); + + let (result, rest) = self.data.mmap_slice.split_at(split_at); + self.data.mmap_slice = rest; + + Some(result) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use std::io::Write; + use std::path::Path; + + fn create_test_file(path: &Path, content: &[u8]) { + let mut file = File::create(path).unwrap(); + file.write_all(content).unwrap(); + } + + fn remove_test_file(path: &Path) { + if path.exists() { + fs::remove_file(path).unwrap(); + } + } + + #[test] + fn test_from_file() { + let test_file_path = Path::new("test_file.txt"); + let test_content = b"Hello, mmap!"; + create_test_file(test_file_path, test_content); + let file = File::open(test_file_path).unwrap(); + + let mmap = Mmap::from_file(file); + + assert_eq!(&*mmap, test_content); + remove_test_file(test_file_path); + } +}