Skip to content

Commit

Permalink
Changing sample CLI design; Refactor devices; Trying out quantization (
Browse files Browse the repository at this point in the history
…#70)

* MidiDataset can initialize with an iterator and only expand when necessary.

* reduce some memory overhead (we are starting to have >100k MidiDict and may get more in the future)

* classmethod+property is better...

* remove functools import

* use separate workers to build dataset instead of process pool

* add jsonl.zst support; unit test; fix bug

* receive context length via commandline. It's more convenient than digging into the config file every time.

* fix a minor output format mismatch when grad_checkpoint is true

* remove hardcoded cuda() as well as autocast for cpu inferencing; int8 quantization works; fix a bug on gradient_checkpointing along with use_cache

* fixing device

* fixing device

* bitsandbytes unnecessary now.

* Add a warning to force CPU when quantization is used in aria.run sample

* formatting; Also add black formatter to Makefile
  • Loading branch information
honglu2875 authored Nov 30, 2023
1 parent 9099103 commit 55f5068
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 21 deletions.
3 changes: 1 addition & 2 deletions aria/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def forward(self, src: torch.Tensor, use_cache=False, past_kv=None):
# remove torch.compile from the train script as this is not currently
# supported.
# Implements gradient checkpoints on Encoder Layers.
if self.model_config.grad_checkpoint is True:
if self.model_config.grad_checkpoint is True and not use_cache:
for layer in self.encode_layers:

def create_custom_forward(module):
Expand All @@ -326,7 +326,6 @@ def custom_forward(*args):
preserve_rng_state=True,
use_reentrant=True,
)

else:
new_past_kv = []
past_kv = (
Expand Down
154 changes: 141 additions & 13 deletions aria/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,93 @@

import argparse
import os
import re
import sys
import pathlib
import warnings


def _parse_sample_args():
argp = argparse.ArgumentParser(prog="aria sample")
argp.add_argument("model", help="name of model config file")
argp.add_argument("ckpt_path", help="path to model checkpoint")
argp.add_argument("midi_path", help="path to midi file")
argp.add_argument("-m", help="name of model config file")
argp.add_argument("-c", help="path to model checkpoint")
argp.add_argument("-p", help="path to midi file")
argp.add_argument(
"-var", help="number of variations", type=int, required=True
"-var",
help="number of variations",
type=int,
default=1,
)
argp.add_argument(
"-trunc", help="length to truncated prompt", type=int, required=True
"-trunc",
help="length to truncated prompt",
type=int,
default=200,
)
argp.add_argument("-e", action="store_true", help="enable force end")
argp.add_argument("-l", type=int, help="generation length")
argp.add_argument("-l", type=int, help="generation length", default=1024)
argp.add_argument("-q", action="store_true", help="quantize the model")

return argp.parse_args(sys.argv[2:])


def _get_model_name(name: str | None, state: dict):
if name is not None:
return name

print("Model name is not provided. Trying to infer from checkpoint...")
_defaults = {
16: "small",
32: "medium",
64: "large",
96: "xlarge",
}
try:
pattern = re.compile(r"encode_layers\.(\d+)\.")
layer_keys = [pattern.search(k) for k in state.keys()]
layer_keys = set(p.group(1) for p in layer_keys if p is not None)
for i in range(len(layer_keys)):
assert str(i) in layer_keys

if len(layer_keys) in _defaults:
print(f"Selecting model name: {_defaults[len(layer_keys)]}")
return _defaults[len(layer_keys)]
assert False
except:
raise ValueError("Model name is not provided and cannot be inferred.")


def _show_popup(prompt: str, files: list) -> str:
for i in range(len(files)):
print(f" [{i}] {files[i]}")

for tries in range(3): # 3 tries in case of fat fingers
try:
res = int(input(prompt + f" [0-{len(files) - 1}]: "))
assert 0 <= res < len(files)
return files[res]
except:
print("Invalid input. Try again...")

raise ValueError("Invalid input.")


def _get_ckpt_path(ckpt_path: str | None) -> str:
if ckpt_path is None:
ckpts = list(pathlib.Path(".").glob("*.bin"))
ckpt_path = _show_popup("Choose a checkpoint", ckpts)
return ckpt_path


def _get_midi_path(midi_path: str | None) -> str:
if midi_path is None:
midis = list(pathlib.Path(".").glob("*.mid")) + list(
pathlib.Path(".").glob("*.midi")
)
midi_path = _show_popup("Choose a midi-file", midis)
return midi_path


def sample(args):
"""Entrypoint for sampling"""

Expand All @@ -34,20 +101,80 @@ def sample(args):
from aria.data.midi import MidiDict
from aria.utils import midi_to_audio

assert cuda_is_available() is True, "CUDA device not available"
if not cuda_is_available():
print("CUDA device is not available. Using CPU instead.")
else:
greedy_sample = torch.autocast(device_type="cuda", dtype=torch.float16)(
greedy_sample
)
device = (
torch.device("cuda") if cuda_is_available() else torch.device("cpu")
)

ckpt_path = _get_ckpt_path(args.c) # let user input path if not provided
model_state = torch.load(ckpt_path, map_location=device)
model_name = _get_model_name(
args.m, model_state
) # infer model name if not provided

model_name = args.model
ckpt_path = args.ckpt_path
midi_path = args.midi_path
num_variations = args.var
truncate_len = args.trunc
force_end = args.e

tokenizer = TokenizerLazy(return_tensors=True)
model_config = ModelConfig(**load_model_config(model_name))
model_config.set_vocab_size(tokenizer.vocab_size)
model = TransformerLM(model_config).cuda()
model.load_state_dict(torch.load(ckpt_path))
model = TransformerLM(model_config).to(device)
model.load_state_dict(model_state)
if args.q:
if device.type != "cpu":
warnings.warn(
"Quantization is not supported on CUDA devices. Using CPU instead."
)
device = torch.device("cpu")

from torch.ao.quantization import get_default_qconfig_mapping
from torch.quantization.quantize_fx import prepare_fx, convert_fx

qconfig_mapping = get_default_qconfig_mapping()

def _quantize(module, key, input_shape):
inp = torch.randn(input_shape, dtype=torch.float, device=device)
m = prepare_fx(
getattr(module, key), qconfig_mapping, example_inputs=inp
)
m = convert_fx(m)
setattr(module, key, m)

for i in range(len(model.model.encode_layers)):
_quantize(
model.model.encode_layers[i],
"mixed_qkv",
input_shape=(1, 2048, model_config.n_heads),
)
_quantize(
model.model.encode_layers[i],
"att_proj_linear",
input_shape=(1, 2048, model_config.n_heads),
)
_quantize(
model.model.encode_layers[i],
"ff_linear_1",
input_shape=(1, 2048, model_config.n_heads),
)
_quantize(
model.model.encode_layers[i],
"ff_linear_2",
input_shape=(
1,
2048,
model_config.n_heads * model_config.ff_mult,
),
)

midi_path = _get_midi_path(
args.p
) # let user input midi path if not provided

if args.l and 0 < args.l < model.max_seq_len:
max_gen_len = args.l
Expand All @@ -70,6 +197,7 @@ def sample(args):
model,
tokenizer,
prompts,
device=device,
force_end=force_end,
max_seq_len=model_config.max_seq_len,
max_gen_len=max_gen_len,
Expand Down Expand Up @@ -124,7 +252,7 @@ def _parse_tokenized_dataset_args():
argp.add_argument("load_path", help="path midi_dict dataset")
argp.add_argument("save_path", help="path to save dataset")
argp.add_argument("-s", help="also produce shuffled", action="store_true")
argp.add_argument("-l", help="max sequence length", type=int)
argp.add_argument("-l", help="max sequence length", type=int, default=2048)

return argp.parse_args(sys.argv[2:])

Expand Down
18 changes: 12 additions & 6 deletions aria/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ def _get_cfg_coeff(cfg_gamma, cfg_mode, cur_pos, start_pos, total_len):
# temp=0.85, top_p=0.9, cfg_gamma=1.4


@torch.autocast(device_type="cuda", dtype=torch.float16)
def greedy_sample(
model: TransformerLM,
tokenizer: Tokenizer,
prompts: List[list],
max_seq_len: int,
max_gen_len: int,
device: torch.device | None = None,
cfg_gamma: float | None = 1.4,
cfg_mode: str | None = None,
neg_prompts: List[list] | None = None,
Expand All @@ -67,6 +67,7 @@ def greedy_sample(
prompts (List[list]): A list of prompts to sample as a batch.
max_seq_len (int): Maximum sequence length supported by the model.
max_gen_len (int): Maximum desired sequence length of the samples.
device (torch.device, optional): Device to use. Defaults to None.
cfg_gamma (float, optional): CFG gamma parameter. Defaults to 1.2.
This parameter *determines* whether parameters related to CFG are used.
None: No CFG or interpolation. `cfg_mode, neg_prompts, neg_prompt_len, alpha` are ignored.
Expand All @@ -88,6 +89,7 @@ def greedy_sample(
List[list]: The list of samples, decoded by the tokenizer.
"""
assert tokenizer.return_tensors is True, "tokenizer must return tensors."
device = device or torch.device("cuda")
model.eval()

pad_id = tokenizer.pad_id
Expand Down Expand Up @@ -121,24 +123,28 @@ def greedy_sample(
[
torch.concat(
[
torch.full((neg_max_len - len(neg_seq),), pad_id),
tokenizer.encode(neg_seq),
torch.full(
(neg_max_len - len(neg_seq),), pad_id, device=device
),
tokenizer.encode(neg_seq).to(device),
]
)
for neg_seq in neg_prompts
],
axis=0,
).cuda()
)
neg_len = (
neg_min_len
if neg_prompt_len is None
else min(neg_min_len, neg_prompt_len)
)
neg_tokens = neg_prompt_tensors[:, :neg_len]

tokens = torch.full((bsz, total_len), pad_id).cuda()
tokens = torch.full((bsz, total_len), pad_id, device=device)
for idx, unencoded_seq in enumerate(prompts):
tokens[idx, : len(unencoded_seq)] = tokenizer.encode(unencoded_seq)
tokens[idx, : len(unencoded_seq)] = tokenizer.encode(unencoded_seq).to(
device
)

dim_tok_inserted = [False for _ in range(bsz)]
input_text_mask = tokens != pad_id
Expand Down

0 comments on commit 55f5068

Please sign in to comment.