remove the need for mutex by using channels

This commit is contained in:
Fabian Schmidt 2024-07-23 13:23:26 +02:00
parent 327fe8564e
commit 393f802741

View File

@ -1,10 +1,10 @@
use std::{ use std::{
fs::File, fs::File,
io::{BufRead, BufReader}, io::{BufRead, BufReader},
sync::{Mutex},
thread, thread,
}; };
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::mpsc;
use std::time::Instant; use std::time::Instant;
use onebrc::format_nums; use onebrc::format_nums;
@ -12,22 +12,17 @@ const DEFAULT_HASHMAP_LENGTH: usize = 10000;
fn main() { fn main() {
print!("\x1b[2J"); print!("\x1b[2J");
let stations:Mutex<HashMap<String, onebrc::StationMeasurements>> =
Mutex::new(HashMap::with_capacity(DEFAULT_HASHMAP_LENGTH));
thread::scope(|s| { thread::scope(|s| {
// Doing this allows us to not "move" stations into the closure let mut stations: HashMap<String, onebrc::StationMeasurements> = HashMap::with_capacity(DEFAULT_HASHMAP_LENGTH);
// and remove the necessity for reference counting (Arc) let (tx, rx) = mpsc::channel();
// no performance improvement but less complex
let stations = &stations;
let now = Instant::now(); let now = Instant::now();
let cores: usize = thread::available_parallelism().unwrap().into(); let cores: usize = thread::available_parallelism().unwrap().into();
let chunk_length = 1_000_000_000 / cores; let chunk_length = 1_000_000_000 / cores;
for i in 0..cores { for i in 0..cores {
let file = File::open("../../../measurements.txt").expect("File measurements.txt not found"); let file = File::open("../../../measurements.txt").expect("File measurements.txt not found");
let reader = BufReader::new(file); let reader = BufReader::new(file);
let line_chunk = reader.lines().skip(chunk_length * i).take(chunk_length); let line_chunk = reader.lines().skip(chunk_length * i).take(chunk_length);
let tx = tx.clone();
s.spawn(move || { s.spawn(move || {
let mut t_stations: HashMap<String, onebrc::StationMeasurements> = let mut t_stations: HashMap<String, onebrc::StationMeasurements> =
HashMap::with_capacity(DEFAULT_HASHMAP_LENGTH); HashMap::with_capacity(DEFAULT_HASHMAP_LENGTH);
@ -61,18 +56,21 @@ fn main() {
} }
}); });
print!("\x1b[{print_line};60HTime reading lines in thread {i}={} ms", now_read_line.elapsed().as_millis()); print!("\x1b[{print_line};60HTime reading lines in thread {i}={} ms", now_read_line.elapsed().as_millis());
for (station, measurements) in t_stations.iter() { let _ = tx.send(t_stations);
let mut stations_guard = stations.lock().expect("Error while locking");
let joined_measurements_options = stations_guard.get_mut(station.as_str());
if let Some(joined_measurements) = joined_measurements_options {
joined_measurements.merge(measurements);
} else {
stations_guard.insert(station.to_owned(), *measurements);
}
}
}); });
} }
let mut stations: Vec<String> = stations.lock().unwrap().iter().map(|(station, measurements)| { drop(tx);
while let Ok(t_stations) = rx.recv() {
for (station, measurements) in t_stations.iter() {
let joined_measurements_options = stations.get_mut(station.as_str());
if let Some(joined_measurements) = joined_measurements_options {
joined_measurements.merge(measurements);
} else {
stations.insert(station.to_owned(), *measurements);
}
}
}
let mut stations: Vec<String> = stations.iter().map(|(station, measurements)| {
let measurements = measurements.to_string(); let measurements = measurements.to_string();
format!("{station}={measurements}") format!("{station}={measurements}")
}).collect(); }).collect();