Skip to content

Commit

Permalink
Stabilize sample sample command (#641)
Browse files Browse the repository at this point in the history
  • Loading branch information
nwagner84 authored Jul 4, 2023
1 parent 1f0e81d commit ccfb6d6
Show file tree
Hide file tree
Showing 33 changed files with 111 additions and 505 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

* #637 Stabilize `print` command
* #641 Stabilize `sample` command

### Removed

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ $ cargo install --git https://github.com/deutsche-nationalbibliothek/pica-rs --t
| [invalid](#invalid) | stable | write input lines, which can't be decoded as normalized PICA+ |
| [partition](#partition) | stable | partition a list of records based on subfield values |
| [print](#print) | stable | print records in human readable format |
| [sample](#sample) | beta | selects a random permutation of records |
| [sample](#sample) | stable | selects a random permutation of records |
| [select](#select) | beta | select subfield values from records |
| [slice](#slice) | stable | return records withing a range (half-open interval) |
| [split](#split) | stable | split a list of records into chunks |
Expand Down
4 changes: 4 additions & 0 deletions pica-record/src/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,10 @@ impl<'a> ByteRecord<'a> {
let result = hasher.finalize();
result.to_vec()
}

pub fn into_inner(self) -> RecordRef<'a> {
self.record
}
}

impl<'a> Deref for ByteRecord<'a> {
Expand Down
79 changes: 45 additions & 34 deletions src/bin/pica/commands/sample.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use std::ffi::OsString;
use std::io::{self, Read};

use clap::{value_parser, Parser};
use pica::{
ByteRecord, PicaWriter, Reader, ReaderBuilder, WriterBuilder,
};
use pica_record::io::{ReaderBuilder, RecordsIterator, WriterBuilder};
use pica_record::ByteRecord;
use rand::rngs::StdRng;
use rand::{thread_rng, Rng, SeedableRng};
use serde::{Deserialize, Serialize};
Expand All @@ -20,9 +18,11 @@ pub(crate) struct SampleConfig {
pub(crate) gzip: Option<bool>,
}

/// Selects a random permutation of records of the given sample size
/// using reservoir sampling.
#[derive(Parser, Debug)]
pub(crate) struct Sample {
/// Skip invalid records that can't be decoded
/// Skip invalid records that can't be decoded as normalized PICA+
#[arg(short, long)]
skip_invalid: bool,

Expand All @@ -34,14 +34,18 @@ pub(crate) struct Sample {
#[arg(short, long, value_name = "filename")]
output: Option<OsString>,

/// RNG seed
/// Initialize the RNG with a seed value to get deterministic
/// random records.
#[arg(long, value_name = "number")]
seed: Option<u64>,

/// Number of random records
#[arg(value_parser = value_parser!(u16).range(1..))]
sample_size: u16,

/// Read one or more files in normalized PICA+ format.
/// Read one or more files in normalized PICA+ format. If no
/// filenames where given or a filename is "-", data is read from
/// standard input (stdin)
#[arg(default_value = "-", hide_default_value = true)]
filenames: Vec<OsString>,
}
Expand All @@ -55,49 +59,56 @@ impl Sample {
config.global
);

let mut writer: Box<dyn PicaWriter> = WriterBuilder::new()
let mut writer = WriterBuilder::new()
.gzip(gzip_compression)
.from_path_or_stdout(self.output)?;

let sample_size = self.sample_size as usize;
let mut reservoir: Vec<ByteRecord> =
Vec::with_capacity(sample_size);

let mut rng: StdRng = match self.seed {
None => StdRng::from_rng(thread_rng()).unwrap(),
Some(seed) => StdRng::seed_from_u64(seed),
};

let sample_size = self.sample_size as usize;
let mut reservoir: Vec<Vec<u8>> =
Vec::with_capacity(sample_size);

let mut i = 0;

for filename in self.filenames {
let builder =
ReaderBuilder::new().skip_invalid(skip_invalid);
let mut reader: Reader<Box<dyn Read>> = match filename
.to_str()
{
Some("-") => builder.from_reader(Box::new(io::stdin())),
_ => builder.from_path(filename)?,
};

for result in reader.byte_records() {
let record = result?;

if i < sample_size {
reservoir.push(record);
} else {
let j = rng.gen_range(0..i);
if j < sample_size {
reservoir[j] = record;
let mut reader =
ReaderBuilder::new().from_path(filename)?;

while let Some(result) = reader.next() {
match result {
Err(e) => {
if e.is_invalid_record() && skip_invalid {
continue;
} else {
return Err(e.into());
}
}
Ok(record) => {
let mut data = Vec::<u8>::new();
record.write_to(&mut data)?;

if i < sample_size {
reservoir.push(data);
} else {
let j = rng.gen_range(0..i);
if j < sample_size {
reservoir[j] = data;
}
}

i += 1;
}
}

i += 1;
}
}

for record in &reservoir {
writer.write_byte_record(record)?;
for data in &reservoir {
let record = ByteRecord::from_bytes(data).unwrap();
writer.write_byte_record(&record)?;
}

writer.finish()?;
Expand Down
2 changes: 0 additions & 2 deletions src/bin/pica/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ enum Commands {
Invalid(Invalid),
Partition(Partition),
Print(Print),

/// Selects a random permutation of records
Sample(Sample),

/// Select subfield values from records
Expand Down
43 changes: 0 additions & 43 deletions tests/common/mod.rs

This file was deleted.

4 changes: 0 additions & 4 deletions tests/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,2 @@
pub use self::common::TestResult;

mod common;
mod path;
mod pica;
mod snapshot;
26 changes: 0 additions & 26 deletions tests/pica/mod.rs

This file was deleted.

Loading

0 comments on commit ccfb6d6

Please sign in to comment.