Skip to content

Commit

Permalink
Add timing utilities to AbsTokenizer (#103)
Browse files Browse the repository at this point in the history
* add timing util

* fix functions
  • Loading branch information
loubbrad authored Feb 22, 2024
1 parent b7c31de commit a8989a5
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
25 changes: 24 additions & 1 deletion aria/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,13 @@ def greedy_sample(
prompts: List[list],
max_new_tokens: int,
device: torch.device | None = None,
cfg_gamma: float | None = 1.4,
cfg_gamma: float | None = 1.05,
cfg_mode: str | None = None,
neg_prompts: List[list] | None = None,
neg_prompt_len: int | None = None,
alpha: float | None = None,
force_end=False,
dim_tok_pos: list[int] | None = None,
temperature: float = 0.95,
top_p: float = 0.95,
):
Expand Down Expand Up @@ -143,10 +144,16 @@ def greedy_sample(
device = device or torch.device("cuda")
model.eval()

if dim_tok_pos:
assert len(dim_tok_pos) == len(prompts), "Lengths don't match"

pad_id = tokenizer.pad_id
pad_tok = tokenizer.pad_tok
eos_id = tokenizer.tok_to_id[tokenizer.eos_tok]

if cfg_gamma == 1.0:
cfg_gamma = None

padded_combined_prompts = _process_prompts(
prompts,
pad_tok,
Expand Down Expand Up @@ -251,6 +258,22 @@ def greedy_sample(
next_token[_idx].item()
][0] not in ("dur", "onset"):
next_token[_idx] = tokenizer.tok_to_id[tokenizer.dim_tok]
elif dim_tok_pos is not None:
for _idx in range(tokens.size(0)):
if (
cur_pos < dim_tok_pos[_idx]
or dim_tok_inserted[_idx] is True
):
pass
elif tokenizer.id_to_tok[next_token[_idx].item()][0] not in (
"dur",
"onset",
):
# This only triggers if:
# - dim_tok hasn't already been inserted
# - the current position >= dim_tok_pos
# - The dim tok will not interfere with a note
next_token[_idx] = tokenizer.tok_to_id[tokenizer.dim_tok]

# Update dim_tok_inserted
for _idx in range(tokens.size(0)):
Expand Down
46 changes: 46 additions & 0 deletions aria/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import functools
import itertools
import random
import copy

from collections import defaultdict
from typing import Callable
Expand Down Expand Up @@ -359,6 +360,51 @@ def _format(self, prefix: list, unformatted_seq: list):

return res

def calc_length_ms(self, seq: list, onset: bool = False):
"""Calculates time (ms) end of sequence to the end of the last note. If
onset=True, then it will return the onset time of the last note instead
"""
assert type(seq) == list, "Must provide list of decoded toks"
assert type(seq[0]) != int, "Must provide list of decoded toks"

# Find the index of the last onset or dur token
seq = copy.deepcopy(seq)
for _idx in range(len(seq) - 1, -1, -1):
tok = seq[_idx]
if type(tok) is tuple and tok[0] in {"onset", "dur"}:
break
else:
seq.pop()

time_offset_ms = seq.count(self.time_tok) * self.abs_time_step
idx = len(seq) - 1
for tok in seq[::-1]:
if type(tok) is tuple and tok[0] == "dur":
assert seq[idx][0] == "dur", "Error with function"
assert seq[idx - 1][0] == "onset", "Error with function"

if onset is False:
return time_offset_ms + seq[idx - 1][1] + seq[idx][1]
elif onset is True:
return time_offset_ms + seq[idx - 1][1] # Ignore dur

idx -= 1

# If it gets to this point, an error has occurred
raise Exception

def truncate_by_time(self, tokenized_seq: list, trunc_time_ms: int):
"""This function truncates notes with onset_ms > trunc_tim_ms."""
time_offset_ms = 0
for idx, tok in enumerate(tokenized_seq):
if tok == self.time_tok:
time_offset_ms += self.abs_time_step
elif type(tok) is tuple and tok[0] == "onset":
if time_offset_ms + tok[1] > trunc_time_ms:
return tokenized_seq[: idx - 1]

return tokenized_seq

def _tokenize_midi_dict(self, midi_dict: MidiDict):
ticks_per_beat = midi_dict.ticks_per_beat
midi_dict.remove_instruments(self.config["ignore_instruments"])
Expand Down

0 comments on commit a8989a5

Please sign in to comment.