Skip to content

Commit

Permalink
updating aligned to include case and adding testing
Browse files Browse the repository at this point in the history
  • Loading branch information
pique0822 committed Oct 9, 2023
1 parent 966a4a0 commit 85ae0dd
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 10 deletions.
24 changes: 16 additions & 8 deletions src/fstalign.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ wer_alignment Fstalign(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine

vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> hyp_ctm_rows = {},
vector<RawNlpRecord> hyp_nlp_rows = {},
vector<string> one_best_tokens = {}) {
vector<string> one_best_tokens = {},
bool use_case = false) {
auto logger = logger::GetOrCreateLogger("fstalign");

// Go through top alignment and create stitches
Expand Down Expand Up @@ -287,7 +288,11 @@ vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> h

part.hyp_orig = ctmPart.word;
// sanity check
std::string ctmCopy = UnicodeLowercase(ctmPart.word);
std::string ctmCopy = std::string(ctmPart.word);
if (!use_case) {
ctmCopy = UnicodeLowercase(ctmPart.word);
}

if (hyp_tk != ctmCopy) {
logger->warn(
"hum, looks like the ctm and the alignment got out of sync? [{}] vs "
Expand Down Expand Up @@ -326,7 +331,10 @@ vector<Stitching> make_stitches(wer_alignment &alignment, vector<RawCtmRecord> h
part.hyp_orig = token;

// sanity check
std::string token_copy = UnicodeLowercase(token);
std::string token_copy = std::string(token);
if (!use_case) {
token_copy = UnicodeLowercase(token);
}
if (hyp_tk != token_copy) {
logger->warn(
"hum, looks like the text and the alignment got out of sync? [{}] vs "
Expand Down Expand Up @@ -633,7 +641,7 @@ void write_stitches_to_nlp(vector<Stitching>& stitches, ofstream &output_nlp_fil
}

void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine, const string& output_sbs, const string& output_nlp,
AlignerOptions alignerOptions, bool add_inserts_nlp) {
AlignerOptions alignerOptions, bool add_inserts_nlp, bool use_case) {
// int speaker_switch_context_size, int numBests, int pr_threshold, string symbols_filename,
// string composition_approach, bool record_case_stats) {
auto logger = logger::GetOrCreateLogger("fstalign");
Expand All @@ -648,19 +656,19 @@ void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine
NlpFstLoader *nlp_hyp_loader = dynamic_cast<NlpFstLoader *>(&hypLoader);
OneBestFstLoader *best_loader = dynamic_cast<OneBestFstLoader *>(&hypLoader);
if (ctm_hyp_loader) {
stitches = make_stitches(topAlignment, ctm_hyp_loader->mCtmRows, {});
stitches = make_stitches(topAlignment, ctm_hyp_loader->mCtmRows, {}, {}, use_case);
} else if (nlp_hyp_loader) {
stitches = make_stitches(topAlignment, {}, nlp_hyp_loader->mNlpRows);
stitches = make_stitches(topAlignment, {}, nlp_hyp_loader->mNlpRows, {}, use_case);
} else if (best_loader) {
vector<string> tokens;
tokens.reserve(best_loader->TokensSize());
for (int i = 0; i < best_loader->TokensSize(); i++) {
string token = best_loader->getToken(i);
tokens.push_back(token);
}
stitches = make_stitches(topAlignment, {}, {}, tokens);
stitches = make_stitches(topAlignment, {}, {}, tokens, use_case);
} else {
stitches = make_stitches(topAlignment);
stitches = make_stitches(topAlignment, {}, {}, {}, use_case);
}

NlpFstLoader *nlp_ref_loader = dynamic_cast<NlpFstLoader *>(&refLoader);
Expand Down
2 changes: 1 addition & 1 deletion src/fstalign.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct AlignerOptions {
// int numBests, string symbols_filename, string composition_approach);

void HandleWer(FstLoader& refLoader, FstLoader& hypLoader, SynonymEngine &engine, const string& output_sbs, const string& output_nlp,
AlignerOptions alignerOptions, bool add_inserts_nlp = false);
AlignerOptions alignerOptions, bool add_inserts_nlp = false, bool use_case = false);
void HandleAlign(NlpFstLoader &refLoader, CtmFstLoader &hypLoader, SynonymEngine &engine, ofstream &output_nlp_file,
AlignerOptions alignerOptions);

Expand Down
2 changes: 1 addition & 1 deletion src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ int main(int argc, char **argv) {
}

if (command == "wer") {
HandleWer(*ref, *hyp, engine, output_sbs, output_nlp, alignerOptions, add_inserts_nlp);
HandleWer(*ref, *hyp, engine, output_sbs, output_nlp, alignerOptions, add_inserts_nlp, use_case);
} else if (command == "align") {
if (output_nlp.empty()) {
console->error("the output nlp file must be specified");
Expand Down
33 changes: 33 additions & 0 deletions test/data/short.aligned.case.nlp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
token|speaker|ts|endTs|punctuation|prepunctuation|case|tags|wer_tags|oldTs|oldEndTs|ali_comment|confidence
<crosstalk>|2|0.0000|0.0000|||LC|[]|[]||||
Yeah|1|0.0000|0.0000|,||UC|[]|[]||||
yeah|1|||,||LC|[]|[]|||del|
right|1|0.0000|0.0000|.||LC|[]|[]||||
Yeah|1|||,||UC|[]|[]|||del|
all|1|||||LC|[]|[]|||del|
right|1|0.0000|0.0000|,||LC|[]|[]|||sub(I'll)|
probably|1|0.0000|0.0000|||LC|[]|[]|||sub(do)|
just|1|0.0000|0.0000|||LC|[]|[]||||
that|1|0.0000|0.0000|.||LC|[]|[]||||
Are|3|0.0000|0.0000|||UC|[]|[]||||
there|3|0.0000|0.0000|||LC|[]|[]||||
any|3|0.0000|0.0000|||LC|[]|[]||||
visuals|3|0.0000|0.0000|||LC|[]|[]||||
that|3|0.0000|0.0000|||LC|[]|[]||||
come|3|0.0000|0.0000|||LC|[]|[]||||
to|3|0.0000|0.0000|||LC|[]|[]||||
mind|3|0.0000|0.0000|||LC|[]|[]||||
or|3|0.0000|0.0000|||LC|[]|[]||||
Yeah|1|0.0000|0.0000|,||UC|[]|[]||||
sure|1|0.0000|0.0000|.||LC|[]|[]||||
When|1|0.0000|0.0000|||UC|[]|[]||||
I|1|0.0000|0.0000|||CA|[]|[]||||
hear|1|0.0000|0.0000|||LC|[]|[]||||
Foobar|1|0.0000|0.0000|,||UC|[]|[]||||
I|1|0.0000|0.0000|||CA|[]|[]||||
think|1|0.0000|0.0000|||LC|[]|[]||||
about|1|0.0000|0.0000|||LC|[]|[]||||
just|1|0.0000|0.0000|||LC|[]|[]||||
that|1|0.0000|0.0000|:||LC|[]|[]||||
foo|1|0.0000|0.0000|||LC|[]|[]|||sub(Foobar)|
a|1|0.0000|0.0000|||LC|[]|[]||||
43 changes: 43 additions & 0 deletions test/data/short.aligned.punc_case.nlp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
token|speaker|ts|endTs|punctuation|prepunctuation|case|tags|wer_tags|oldTs|oldEndTs|ali_comment|confidence
<crosstalk>|2|0.0000|0.0000|||LC|[]|[]||||
Yeah|1|0.0000|0.0000|,||UC|[]|[]||||
,|1|0.0000|0.0000|||UC|[]|[]||||
yeah|1|||,||LC|[]|[]|||del|
,|1|||||LC|[]|[]|||del|
right|1|0.0000|0.0000|.||LC|[]|[]||||
.|1|||||LC|[]|[]|||del|
Yeah|1|||,||UC|[]|[]|||del|
,|1|||||UC|[]|[]|||del|
all|1|||||LC|[]|[]|||del|
right|1|||,||LC|[]|[]|||del|
,|1|0.0000|0.0000|||LC|[]|[]|||sub(I'll)|
probably|1|0.0000|0.0000|||LC|[]|[]|||sub(do)|
just|1|0.0000|0.0000|||LC|[]|[]||||
that|1|0.0000|0.0000|.||LC|[]|[]||||
.|1|0.0000|0.0000|||LC|[]|[]|||sub(?)|
Are|3|0.0000|0.0000|||UC|[]|[]||||
there|3|0.0000|0.0000|||LC|[]|[]||||
any|3|0.0000|0.0000|||LC|[]|[]||||
visuals|3|0.0000|0.0000|||LC|[]|[]||||
that|3|0.0000|0.0000|||LC|[]|[]||||
come|3|0.0000|0.0000|||LC|[]|[]||||
to|3|0.0000|0.0000|||LC|[]|[]||||
mind|3|0.0000|0.0000|||LC|[]|[]||||
or|3|0.0000|0.0000|||LC|[]|[]||||
Yeah|1|0.0000|0.0000|,||UC|[]|[]||||
,|1|0.0000|0.0000|||UC|[]|[]||||
sure|1|0.0000|0.0000|.||LC|[]|[]||||
.|1|0.0000|0.0000|||LC|[]|[]||||
When|1|0.0000|0.0000|||UC|[]|[]||||
I|1|0.0000|0.0000|||CA|[]|[]||||
hear|1|0.0000|0.0000|||LC|[]|[]||||
Foobar|1|0.0000|0.0000|,||UC|[]|[]||||
,|1|0.0000|0.0000|||UC|[]|[]||||
I|1|0.0000|0.0000|||CA|[]|[]||||
think|1|0.0000|0.0000|||LC|[]|[]||||
about|1|0.0000|0.0000|||LC|[]|[]||||
just|1|0.0000|0.0000|||LC|[]|[]||||
that|1|0.0000|0.0000|:||LC|[]|[]||||
:|1|0.0000|0.0000|||LC|[]|[]||||
foo|1|0.0000|0.0000|||LC|[]|[]|||sub(,)|
a|1|0.0000|0.0000|||LC|[]|[]||||
20 changes: 20 additions & 0 deletions test/fstalign_Test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,26 @@ TEST_CASE_METHOD(UniqueTestsFixture, "main-adapted-composition()") {
REQUIRE_THAT(result, Contains("WER: INS:2 DEL:7 SUB:4"));
}

SECTION("wer with case(nlp output)") {
const auto result =
exec(command("wer", approach, "short_punc.ref.nlp", "short_punc.hyp.nlp", sbs_output, nlp_output, TEST_SYNONYMS)+" --use-case");
const auto testFile = std::string{TEST_DATA} + "short.aligned.case.nlp";

REQUIRE(compareFiles(nlp_output.c_str(), testFile.c_str()));
REQUIRE_THAT(result, Contains("WER: 6/32 = 0.1875"));
REQUIRE_THAT(result, Contains("WER: INS:0 DEL:3 SUB:3"));
}

SECTION("wer with case and punctuation(nlp output)") {
const auto result =
exec(command("wer", approach, "short_punc.ref.nlp", "short_punc.hyp.nlp", sbs_output, nlp_output, TEST_SYNONYMS)+" --use-punctuation --use-case");
const auto testFile = std::string{TEST_DATA} + "short.aligned.punc_case.nlp";

REQUIRE(compareFiles(nlp_output.c_str(), testFile.c_str()));
REQUIRE_THAT(result, Contains("WER: 13/42 = 0.3095"));
REQUIRE_THAT(result, Contains("WER: INS:2 DEL:7 SUB:4"));
}

// alignment tests

SECTION("align_1") {
Expand Down

0 comments on commit 85ae0dd

Please sign in to comment.