diff --git a/paddlemix/models/vits-svc/chkpt/.gitkeep b/paddlemix/models/vits-svc/chkpt/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/paddlemix/models/vits-svc/configs/base.yaml b/paddlemix/models/vits-svc/configs/base.yaml new file mode 100644 index 000000000..dbf59f59e --- /dev/null +++ b/paddlemix/models/vits-svc/configs/base.yaml @@ -0,0 +1,72 @@ +train: + model: "sovits" + seed: 1234 + epochs: 10000 + learning_rate: 5e-5 + betas: [0.8, 0.99] + lr_decay: 0.999875 + eps: 1e-9 + batch_size: 8 + accum_step: 2 + c_stft: 9 + c_mel: 1. + c_kl: 0.2 + port: 8001 + pretrain: "./vits_pretrain/sovits5.0.pretrain.pth" +############################# +data: + training_files: "files/train.txt" + validation_files: "files/valid.txt" + segment_size: 8000 # WARNING: base on hop_length + max_wav_value: 32768.0 + sampling_rate: 32000 + filter_length: 1024 + hop_length: 320 + win_length: 1024 + mel_channels: 100 + mel_fmin: 50.0 + mel_fmax: 16000.0 +############################# +vits: + ppg_dim: 1280 + vec_dim: 256 + spk_dim: 256 + gin_channels: 256 + inter_channels: 192 + hidden_channels: 192 + filter_channels: 640 +############################# +gen: + upsample_input: 192 + upsample_rates: [5,4,4,2,2] + upsample_kernel_sizes: [15,8,8,4,4] + upsample_initial_channel: 320 + resblock_kernel_sizes: [3,7,11] + resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]] +############################# +mpd: + periods: [2,3,5,7,11] + kernel_size: 5 + stride: 3 + use_spectral_norm: False + lReLU_slope: 0.2 +############################# +mrd: + resolutions: "[(1024, 120, 600), (2048, 240, 1200), (4096, 480, 2400), (512, 50, 240)]" # (filter_length, hop_length, win_length) + use_spectral_norm: False + lReLU_slope: 0.2 +############################# +log: + info_interval: 100 + eval_interval: 1 + save_interval: 5 + num_audio: 6 + pth_dir: 'chkpt' + log_dir: 'logs' + keep_ckpts: 0 +############################# +dist_config: + dist_backend: "nccl" + dist_url: "tcp://localhost:54321" + world_size: 1 + diff --git a/paddlemix/models/vits-svc/crepe/__init__.py b/paddlemix/models/vits-svc/crepe/__init__.py new file mode 100644 index 000000000..0fa38295d --- /dev/null +++ b/paddlemix/models/vits-svc/crepe/__init__.py @@ -0,0 +1,8 @@ +from . import decode +from .core import * +from .model import Crepe +from . import convert +from . import filter +from . import load +# from . import loudness +# from . import threshold diff --git a/paddlemix/models/vits-svc/crepe/convert.py b/paddlemix/models/vits-svc/crepe/convert.py new file mode 100644 index 000000000..9cf8e916c --- /dev/null +++ b/paddlemix/models/vits-svc/crepe/convert.py @@ -0,0 +1,58 @@ +import scipy +import paddle +import math +import crepe + + +############################################################################### +# Pitch unit conversions +############################################################################### + + +def bins_to_cents(bins): + """Converts pitch bins to cents""" + cents = crepe.CENTS_PER_BIN * bins + 1997.3794084376191 + + # Trade quantization error for noise + return dither(cents) + + +def bins_to_frequency(bins): + """Converts pitch bins to frequency in Hz""" + return cents_to_frequency(bins_to_cents(bins)) + + +def cents_to_bins(cents, quantize_fn=math.floor): + """Converts cents to pitch bins""" + bins = (cents - 1997.3794084376191) / crepe.CENTS_PER_BIN + return quantize_fn(bins) + + +def cents_to_frequency(cents): + """Converts cents to frequency in Hz""" + return 10 * 2 ** (cents / 1200) + + +def frequency_to_bins(frequency, quantize_fn=math.floor): + """Convert frequency in Hz to pitch bins""" + return cents_to_bins(frequency_to_cents(frequency), quantize_fn) + + +def frequency_to_cents(frequency): + """Convert frequency in Hz to cents""" + return 1200 * math.log2(frequency / 10.) + + +# ############################################################################### +# # Utilities +# ############################################################################### + + +def dither(cents): + """Dither the predicted pitch in cents to remove quantization error""" + noise = scipy.stats.triang.rvs(c=0.5, + loc=-crepe.CENTS_PER_BIN, + scale=2 * crepe.CENTS_PER_BIN, + size=cents.shape) + # return cents + cents.new_tensor(noise) + return cents + paddle.to_tensor(noise, dtype=cents.dtype, stop_gradient=cents.stop_gradient) diff --git a/paddlemix/models/vits-svc/crepe/core.py b/paddlemix/models/vits-svc/crepe/core.py new file mode 100644 index 000000000..7c2c55267 --- /dev/null +++ b/paddlemix/models/vits-svc/crepe/core.py @@ -0,0 +1,742 @@ +import warnings +import math +import numpy as np +import resampy +import paddle +import tqdm + +import crepe + + +__all__ = ['CENTS_PER_BIN', + 'MAX_FMAX', + 'PITCH_BINS', + 'SAMPLE_RATE', + 'WINDOW_SIZE', + 'UNVOICED', + # 'embed', + # 'embed_from_file', + # 'embed_from_file_to_file', + # 'embed_from_files_to_files', + 'infer', + 'predict', + 'predict_from_file', + # 'predict_from_file_to_file', + # 'predict_from_files_to_files', + 'preprocess', + 'postprocess', + # 'resample' + ] + + +############################################################################### +# Constants +############################################################################### + + +CENTS_PER_BIN = 20 # cents +MAX_FMAX = 2006. # hz +PITCH_BINS = 360 +SAMPLE_RATE = 16000 # hz +WINDOW_SIZE = 1024 # samples +UNVOICED = np.nan + + +############################################################################### +# Crepe pitch prediction +############################################################################### + + +def predict(audio, + sample_rate, + hop_length=None, + fmin=50., + fmax=MAX_FMAX, + model='full', + decoder=crepe.decode.viterbi, + return_harmonicity=False, + return_periodicity=False, + batch_size=None, + device='cpu', + pad=True): + """Performs pitch estimation + + Arguments + audio (torch.tensor [shape=(1, time)]) + The audio signal + sample_rate (int) + The sampling rate in Hz + hop_length (int) + The hop_length in samples + fmin (float) + The minimum allowable frequency in Hz + fmax (float) + The maximum allowable frequency in Hz + model (string) + The model capacity. One of 'full' or 'tiny'. + decoder (function) + The decoder to use. See decode.py for decoders. + return_harmonicity (bool) [DEPRECATED] + Whether to also return the network confidence + return_periodicity (bool) + Whether to also return the network confidence + batch_size (int) + The number of frames per batch + device (string) + The device used to run inference + pad (bool) + Whether to zero-pad the audio + + Returns + pitch (torch.tensor [shape=(1, 1 + int(time // hop_length))]) + (Optional) periodicity (torch.tensor + [shape=(1, 1 + int(time // hop_length))]) + """ + # Deprecate return_harmonicity + if return_harmonicity: + message = ( + 'The crepe return_harmonicity argument is deprecated and ' + 'will be removed in a future release. Please use ' + 'return_periodicity. Rationale: if network confidence measured ' + 'harmonics, the value would be low for non-harmonic, periodic ' + 'sounds (e.g., sine waves). But this is not observed.') + warnings.warn(message, DeprecationWarning) + return_periodicity = return_harmonicity + + results = [] + + # Postprocessing breaks gradients, so just don't compute them + with paddle.no_grad(): + + # Preprocess audio + generator = preprocess(audio, + sample_rate, + hop_length, + batch_size, + device, + pad) + for frames in generator: + + # Infer independent probabilities for each pitch bin + probabilities = infer(frames, model) + + # shape=(batch, 360, time / hop_length) + probabilities = probabilities.reshape( + [audio.shape[0], -1, PITCH_BINS]).transpose([0, 2, 1]) + + # Convert probabilities to F0 and periodicity + result = postprocess(probabilities, + fmin, + fmax, + decoder, + return_harmonicity, + return_periodicity) + + # # Place on same device as audio to allow very long inputs + # if isinstance(result, tuple): + # result = (result[0].to(audio.device), + # result[1].to(audio.device)) + # else: + # result = result.to(audio.device) + + results.append(result) + + # Split pitch and periodicity + if return_periodicity: + pitch, periodicity = zip(*results) + return paddle.concat(pitch, 1), paddle.concat(periodicity, 1) + + # Concatenate + return paddle.concat(results, 1) + + +def predict_from_file(audio_file, + hop_length=None, + fmin=50., + fmax=MAX_FMAX, + model='full', + decoder=crepe.decode.viterbi, + return_harmonicity=False, + return_periodicity=False, + batch_size=None, + device='cpu', + pad=True): + """Performs pitch estimation from file on disk + + Arguments + audio_file (string) + The file to perform pitch tracking on + hop_length (int) + The hop_length in samples + fmin (float) + The minimum allowable frequency in Hz + fmax (float) + The maximum allowable frequency in Hz + model (string) + The model capacity. One of 'full' or 'tiny'. + decoder (function) + The decoder to use. See decode.py for decoders. + return_harmonicity (bool) [DEPRECATED] + Whether to also return the network confidence + return_periodicity (bool) + Whether to also return the network confidence + batch_size (int) + The number of frames per batch + device (string) + The device used to run inference + pad (bool) + Whether to zero-pad the audio + + Returns + pitch (torch.tensor [shape=(1, 1 + int(time // hop_length))]) + (Optional) periodicity (torch.tensor + [shape=(1, 1 + int(time // hop_length))]) + """ + # Load audio + audio, sample_rate = crepe.load.audio(audio_file) + + # Predict + return predict(audio, + sample_rate, + hop_length, + fmin, + fmax, + model, + decoder, + return_harmonicity, + return_periodicity, + batch_size, + device, + pad) + + +# def predict_from_file_to_file(audio_file, +# output_pitch_file, +# output_harmonicity_file=None, +# output_periodicity_file=None, +# hop_length=None, +# fmin=50., +# fmax=MAX_FMAX, +# model='full', +# decoder=crepe.decode.viterbi, +# batch_size=None, +# device='cpu', +# pad=True): +# """Performs pitch estimation from file on disk + +# Arguments +# audio_file (string) +# The file to perform pitch tracking on +# output_pitch_file (string) +# The file to save predicted pitch +# output_harmonicity_file (string or None) [DEPRECATED] +# The file to save predicted harmonicity +# output_periodicity_file (string or None) +# The file to save predicted periodicity +# hop_length (int) +# The hop_length in samples +# fmin (float) +# The minimum allowable frequency in Hz +# fmax (float) +# The maximum allowable frequency in Hz +# model (string) +# The model capacity. One of 'full' or 'tiny'. +# decoder (function) +# The decoder to use. See decode.py for decoders. +# batch_size (int) +# The number of frames per batch +# device (string) +# The device used to run inference +# pad (bool) +# Whether to zero-pad the audio +# """ +# # Deprecate output_harmonicity_file +# if output_harmonicity_file is not None: +# message = ( +# 'The crepe output_harmonicity_file argument is deprecated and ' +# 'will be removed in a future release. Please use ' +# 'output_periodicity_file. Rationale: if network confidence measured ' +# 'harmonic content, the value would be low for non-harmonic, periodic ' +# 'sounds (e.g., sine waves). But this is not observed.') +# warnings.warn(message, DeprecationWarning) +# output_periodicity_file = output_harmonicity_file + +# # Predict from file +# prediction = predict_from_file(audio_file, +# hop_length, +# fmin, +# fmax, +# model, +# decoder, +# False, +# output_periodicity_file is not None, +# batch_size, +# device, +# pad) + +# # Save to disk +# if output_periodicity_file is not None: +# torch.save(prediction[0].detach(), output_pitch_file) +# torch.save(prediction[1].detach(), output_periodicity_file) +# else: +# torch.save(prediction.detach(), output_pitch_file) + + +def predict_from_files_to_files(audio_files, + output_pitch_files, + output_harmonicity_files=None, + output_periodicity_files=None, + hop_length=None, + fmin=50., + fmax=MAX_FMAX, + model='full', + decoder=crepe.decode.viterbi, + batch_size=None, + device='cpu', + pad=True): + """Performs pitch estimation from files on disk without reloading model + + Arguments + audio_files (list[string]) + The files to perform pitch tracking on + output_pitch_files (list[string]) + The files to save predicted pitch + output_harmonicity_files (list[string] or None) [DEPRECATED] + The files to save predicted harmonicity + output_periodicity_files (list[string] or None) + The files to save predicted periodicity + hop_length (int) + The hop_length in samples + fmin (float) + The minimum allowable frequency in Hz + fmax (float) + The maximum allowable frequency in Hz + model (string) + The model capacity. One of 'full' or 'tiny'. + decoder (function) + The decoder to use. See decode.py for decoders. + batch_size (int) + The number of frames per batch + device (string) + The device used to run inference + pad (bool) + Whether to zero-pad the audio + """ + # Deprecate output_harmonicity_files + if output_harmonicity_files is not None: + message = ( + 'The crepe output_harmonicity_files argument is deprecated and ' + 'will be removed in a future release. Please use ' + 'output_periodicity_files. Rationale: if network confidence measured ' + 'harmonic content, the value would be low for non-harmonic, periodic ' + 'sounds (e.g., sine waves). But this is not observed.') + warnings.warn(message, DeprecationWarning) + output_periodicity_files = output_harmonicity_files + + if output_periodicity_files is None: + output_periodicity_files = len(audio_files) * [None] + + # Setup iterator + iterator = zip(audio_files, output_pitch_files, output_periodicity_files) + iterator = tqdm.tqdm(iterator, desc='crepe', dynamic_ncols=True) + for audio_file, output_pitch_file, output_periodicity_file in iterator: + + # Predict a file + predict_from_file_to_file(audio_file, + output_pitch_file, + None, + output_periodicity_file, + hop_length, + fmin, + fmax, + model, + decoder, + batch_size, + device, + pad) + +############################################################################### +# Crepe pitch embedding +############################################################################### + + +# def embed(audio, +# sample_rate, +# hop_length=None, +# model='full', +# batch_size=None, +# device='cpu', +# pad=True): +# """Embeds audio to the output of CREPE's fifth maxpool layer + +# Arguments +# audio (torch.tensor [shape=(1, time)]) +# The audio signals +# sample_rate (int) +# The sampling rate in Hz +# hop_length (int) +# The hop_length in samples +# model (string) +# The model capacity. One of 'full' or 'tiny'. +# batch_size (int) +# The number of frames per batch +# device (string) +# The device to run inference on +# pad (bool) +# Whether to zero-pad the audio + +# Returns +# embedding (torch.tensor [shape=(1, +# 1 + int(time // hop_length), 32, -1)]) +# """ +# results = [] + +# # Preprocess audio +# generator = preprocess(audio, +# sample_rate, +# hop_length, +# batch_size, +# device, +# pad) +# for frames in generator: + +# # Infer pitch embeddings +# embedding = infer(frames, model, embed=True) + +# # shape=(batch, time / hop_length, 32, embedding_size) +# result = embedding.reshape(audio.size(0), frames.size(0), 32, -1) + +# # Place on same device as audio. This allows for large inputs. +# results.append(result.to(audio.device)) + +# # Concatenate +# return torch.cat(results, 1) + + +def embed_from_file(audio_file, + hop_length=None, + model='full', + batch_size=None, + device='cpu', + pad=True): + """Embeds audio from disk to the output of CREPE's fifth maxpool layer + + Arguments + audio_file (string) + The wav file containing the audio to embed + hop_length (int) + The hop_length in samples + model (string) + The model capacity. One of 'full' or 'tiny'. + batch_size (int) + The number of frames per batch + device (string) + The device to run inference on + pad (bool) + Whether to zero-pad the audio + + Returns + embedding (torch.tensor [shape=(1, + 1 + int(time // hop_length), 32, -1)]) + """ + # Load audio + audio, sample_rate = crepe.load.audio(audio_file) + + # Embed + return embed(audio, + sample_rate, + hop_length, + model, + batch_size, + device, + pad) + + +# def embed_from_file_to_file(audio_file, +# output_file, +# hop_length=None, +# model='full', +# batch_size=None, +# device='cpu', +# pad=True): +# """Embeds audio from disk and saves to disk + +# Arguments +# audio_file (string) +# The wav file containing the audio to embed +# hop_length (int) +# The hop_length in samples +# output_file (string) +# The file to save the embedding +# model (string) +# The model capacity. One of 'full' or 'tiny'. +# batch_size (int) +# The number of frames per batch +# device (string) +# The device to run inference on +# pad (bool) +# Whether to zero-pad the audio +# """ +# # No use computing gradients if we're just saving to file +# with torch.no_grad(): + +# # Embed +# embedding = embed_from_file(audio_file, +# hop_length, +# model, +# batch_size, +# device, +# pad) + +# # Save to disk +# torch.save(embedding.detach(), output_file) + + +# def embed_from_files_to_files(audio_files, +# output_files, +# hop_length=None, +# model='full', +# batch_size=None, +# device='cpu', +# pad=True): +# """Embeds audio from disk and saves to disk without reloading model + +# Arguments +# audio_files (list[string]) +# The wav files containing the audio to embed +# output_files (list[string]) +# The files to save the embeddings +# hop_length (int) +# The hop_length in samples +# model (string) +# The model capacity. One of 'full' or 'tiny'. +# batch_size (int) +# The number of frames per batch +# device (string) +# The device to run inference on +# pad (bool) +# Whether to zero-pad the audio +# """ +# # Setup iterator +# iterator = zip(audio_files, output_files) +# iterator = tqdm.tqdm(iterator, desc='crepe', dynamic_ncols=True) +# for audio_file, output_file in iterator: + +# # Embed a file +# embed_from_file_to_file(audio_file, +# output_file, +# hop_length, +# model, +# batch_size, +# device, +# pad) + + +############################################################################### +# Components for step-by-step prediction +############################################################################### + + +def infer(frames, model='full', embed=False): + """Forward pass through the model + + Arguments + frames (torch.tensor [shape=(time / hop_length, 1024)]) + The network input + model (string) + The model capacity. One of 'full' or 'tiny'. + embed (bool) + Whether to stop inference at the intermediate embedding layer + + Returns + logits (torch.tensor [shape=(1 + int(time // hop_length), 360)]) OR + embedding (torch.tensor [shape=(1 + int(time // hop_length), + embedding_size)]) + """ + # Load the model if necessary + if not hasattr(infer, 'model') or not hasattr(infer, 'capacity') or \ + (hasattr(infer, 'capacity') and infer.capacity != model): + crepe.load.model(frames.place, model) + + # Move model to correct device (no-op if devices are the same) + # infer.model = infer.model.to(frames.place) + + # Apply model + return infer.model(frames, embed=embed) + + +def postprocess(probabilities, + fmin=0., + fmax=MAX_FMAX, + decoder=crepe.decode.viterbi, + return_harmonicity=False, + return_periodicity=False): + """Convert model output to F0 and periodicity + + Arguments + probabilities (torch.tensor [shape=(1, 360, time / hop_length)]) + The probabilities for each pitch bin inferred by the network + fmin (float) + The minimum allowable frequency in Hz + fmax (float) + The maximum allowable frequency in Hz + viterbi (bool) + Whether to use viterbi decoding + return_harmonicity (bool) [DEPRECATED] + Whether to also return the network confidence + return_periodicity (bool) + Whether to also return the network confidence + + Returns + pitch (torch.tensor [shape=(1, 1 + int(time // hop_length))]) + periodicity (torch.tensor [shape=(1, 1 + int(time // hop_length))]) + """ + # Sampling is non-differentiable, so remove from graph + probabilities = probabilities.detach() + + # Convert frequency range to pitch bin range + minidx = crepe.convert.frequency_to_bins(fmin) + maxidx = crepe.convert.frequency_to_bins(fmax, math.ceil) + + # Remove frequencies outside of allowable range + probabilities[:, :minidx] = -float('inf') + probabilities[:, maxidx:] = -float('inf') + + # Perform argmax or viterbi sampling + bins, pitch = decoder(probabilities) + + # Deprecate return_harmonicity + if return_harmonicity: + message = ( + 'The crepe return_harmonicity argument is deprecated and ' + 'will be removed in a future release. Please use ' + 'return_periodicity. Rationale: if network confidence measured ' + 'harmonics, the value would be low for non-harmonic, periodic ' + 'sounds (e.g., sine waves). But this is not observed.') + warnings.warn(message, DeprecationWarning) + return_periodicity = return_harmonicity + + if not return_periodicity: + return pitch + + # Compute periodicity from probabilities and decoded pitch bins + return pitch, periodicity(probabilities, bins) + + +def preprocess(audio, + sample_rate, + hop_length=None, + batch_size=None, + device='cpu', + pad=True): + """Convert audio to model input + + Arguments + audio (torch.tensor [shape=(1, time)]) + The audio signals + sample_rate (int) + The sampling rate in Hz + hop_length (int) + The hop_length in samples + batch_size (int) + The number of frames per batch + device (string) + The device to run inference on + pad (bool) + Whether to zero-pad the audio + + Returns + frames (torch.tensor [shape=(1 + int(time // hop_length), 1024)]) + """ + # Default hop length of 10 ms + hop_length = sample_rate // 100 if hop_length is None else hop_length + + # Resample + if sample_rate != SAMPLE_RATE: + audio = resample(audio, sample_rate) + hop_length = int(hop_length * SAMPLE_RATE / sample_rate) + + # Get total number of frames + + # Maybe pad + if pad: + total_frames = 1 + int(audio.shape[1] // hop_length) + audio = paddle.nn.functional.pad( + audio[None], + (WINDOW_SIZE // 2, WINDOW_SIZE // 2), + data_format="NCL").squeeze(0) + else: + total_frames = 1 + int((audio.size(1) - WINDOW_SIZE) // hop_length) + + # Default to running all frames in a single batch + batch_size = total_frames if batch_size is None else batch_size + + # Generate batches + for i in range(0, total_frames, batch_size): + + # Batch indices + start = max(0, i * hop_length) + end = min(audio.shape[1], + (i + batch_size - 1) * hop_length + WINDOW_SIZE) + + # Chunk + frames = paddle.nn.functional.unfold( + audio[:, None, None, start:end], + kernel_sizes=(1, WINDOW_SIZE), + strides=(1, hop_length)) + + # shape=(1 + int(time / hop_length, 1024) + frames = frames.transpose([0, 2, 1]).reshape([-1, WINDOW_SIZE]) + + # Place on device + # frames = frames.to(device) + + # Mean-center + frames -= frames.mean(axis=1, keepdim=True) + + # Scale + # Note: during silent frames, this produces very large values. But + # this seems to be what the network expects. + frames /= paddle.maximum(paddle.to_tensor(1e-10), + frames.std(axis=1, keepdim=True)) + + yield frames + + +############################################################################### +# Utilities +############################################################################### + + +def periodicity(probabilities, bins): + """Computes the periodicity from the network output and pitch bins""" + # shape=(batch * time / hop_length, 360) + # probs_stacked = probabilities.transpose(1, 2).reshape(-1, PITCH_BINS) + probs_stacked = probabilities.transpose([0, 2, 1]).reshape([-1, PITCH_BINS]) + + # shape=(batch * time / hop_length, 1) + # bins_stacked = bins.reshape(-1, 1).to(torch.int64) + bins_stacked = bins.reshape([-1, 1]) + + # Use maximum logit over pitch bins as periodicity + # periodicity = probs_stacked.gather(1, bins_stacked) + periodicity = probs_stacked.take_along_axis(bins_stacked, axis=1) + + # shape=(batch, time / hop_length) + return periodicity.reshape([probabilities.shape[0], probabilities.shape[2]]) + + +# def resample(audio, sample_rate): +# """Resample audio""" +# # Store device for later placement +# device = audio.device + +# # Convert to numpy +# audio = audio.detach().cpu().numpy().squeeze(0) + +# # Resample +# # We have to use resampy if we want numbers to match Crepe +# audio = resampy.resample(audio, sample_rate, SAMPLE_RATE) + +# # Convert to pytorch +# return torch.tensor(audio, device=device).unsqueeze(0) diff --git a/paddlemix/models/vits-svc/crepe/decode.py b/paddlemix/models/vits-svc/crepe/decode.py new file mode 100644 index 000000000..63420208d --- /dev/null +++ b/paddlemix/models/vits-svc/crepe/decode.py @@ -0,0 +1,81 @@ +import librosa +import numpy as np +import paddle + +import crepe + + +############################################################################### +# Probability sequence decoding methods +############################################################################### + + +def argmax(logits): + """Sample observations by taking the argmax""" + bins = logits.argmax(dim=1) + + # Convert to frequency in Hz + return bins, crepe.convert.bins_to_frequency(bins) + + +def weighted_argmax(logits): + """Sample observations using weighted sum near the argmax""" + # Find center of analysis window + bins = logits.argmax(dim=1) + + # Find bounds of analysis window + start = torch.max(torch.tensor(0, device=logits.device), bins - 4) + end = torch.min(torch.tensor(logits.size(1), device=logits.device), bins + 5) + + # Mask out everything outside of window + for batch in range(logits.size(0)): + for time in range(logits.size(2)): + logits[batch, :start[batch, time], time] = -float('inf') + logits[batch, end[batch, time]:, time] = -float('inf') + + # Construct weights + if not hasattr(weighted_argmax, 'weights'): + weights = crepe.convert.bins_to_cents(torch.arange(360)) + weighted_argmax.weights = weights[None, :, None] + + # Ensure devices are the same (no-op if they are) + weighted_argmax.weights = weighted_argmax.weights.to(logits.device) + + # Convert to probabilities + with torch.no_grad(): + probs = torch.sigmoid(logits) + + # Apply weights + cents = (weighted_argmax.weights * probs).sum(dim=1) / probs.sum(dim=1) + + # Convert to frequency in Hz + return bins, crepe.convert.cents_to_frequency(cents) + + +def viterbi(logits): + """Sample observations using viterbi decoding""" + # Create viterbi transition matrix + if not hasattr(viterbi, 'transition'): + xx, yy = np.meshgrid(range(360), range(360)) + transition = np.maximum(12 - abs(xx - yy), 0) + transition = transition / transition.sum(axis=1, keepdims=True) + viterbi.transition = transition + + # Normalize logits + with paddle.no_grad(): + # probs = torch.nn.functional.softmax(logits, dim=1) + probs = paddle.nn.functional.softmax(logits, axis=1) + + # Convert to numpy + sequences = probs.cpu().numpy() + + # Perform viterbi decoding + bins = np.array([ + librosa.sequence.viterbi(sequence, viterbi.transition).astype(np.int64) + for sequence in sequences]) + + # Convert to pytorch + bins = paddle.to_tensor(bins) + + # Convert to frequency in Hz + return bins, crepe.convert.bins_to_frequency(bins) diff --git a/paddlemix/models/vits-svc/crepe/filter.py b/paddlemix/models/vits-svc/crepe/filter.py new file mode 100644 index 000000000..b8f040b15 --- /dev/null +++ b/paddlemix/models/vits-svc/crepe/filter.py @@ -0,0 +1,336 @@ +import numpy as np +import paddle +from paddle.nn import functional as F + +############################################################################### +# Sequence filters +############################################################################### + +def mean(signals, win_length=9): + """Averave filtering for signals containing nan values + + Arguments + signals (torch.tensor (shape=(batch, time))) + The signals to filter + win_length + The size of the analysis window + + Returns + filtered (torch.tensor (shape=(batch, time))) + """ + + assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)" + signals = signals.unsqueeze(1) + + # Apply the mask by setting masked elements to zero, or make NaNs zero + mask = ~paddle.isnan(signals) + masked_x = paddle.where(condition=mask, x=signals, y=paddle.zeros_like( + x=signals)) + + # Create a ones kernel with the same number of channels as the input tensor + ones_kernel = paddle.ones([signals.shape[1], 1, win_length]) + + # Perform sum pooling + sum_pooled = paddle.nn.functional.conv1d(x=masked_x, weight=ones_kernel, + stride=1, padding=win_length // 2) + + # Count the non-masked (valid) elements in each pooling window + valid_count = paddle.nn.functional.conv1d(x=mask.astype(dtype='float32' + ), weight=ones_kernel, stride=1, padding=win_length // 2) + + valid_count = valid_count.clip(min=1) # Avoid division by zero + + # Perform masked average pooling + avg_pooled = sum_pooled / valid_count + + # Fill zero values with NaNs + avg_pooled[avg_pooled == 0] = float("nan") + + return avg_pooled.squeeze(1) + + +def mean_torch(signals, win_length=9): + """Averave filtering for signals containing nan values + + Arguments + signals (torch.tensor (shape=(batch, time))) + The signals to filter + win_length + The size of the analysis window + + Returns + filtered (torch.tensor (shape=(batch, time))) + """ + import torch + import torch.nn.functional as F + + assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)" + signals = signals.unsqueeze(1) + + # Apply the mask by setting masked elements to zero, or make NaNs zero + mask = ~torch.isnan(signals) + masked_x = torch.where(mask, signals, torch.zeros_like(signals)) + + # Create a ones kernel with the same number of channels as the input tensor + ones_kernel = torch.ones(signals.size(1), 1, win_length, device=signals.device) + + # Perform sum pooling + sum_pooled = F.conv1d( + masked_x, + ones_kernel, + stride=1, + padding=win_length // 2, + ) + + # Count the non-masked (valid) elements in each pooling window + valid_count = F.conv1d( + mask.float(), + ones_kernel, + stride=1, + padding=win_length // 2, + ) + valid_count = valid_count.clamp(min=1) # Avoid division by zero + + # Perform masked average pooling + avg_pooled = sum_pooled / valid_count + + # Fill zero values with NaNs + avg_pooled[avg_pooled == 0] = float("nan") + + return avg_pooled.squeeze(1) + + +if __name__ == "__main__": + + import torch + import numpy as np + + signals = np.random.randn(1, 3591).astype("float32")*1000 + + signals_paddle = paddle.to_tensor(signals) + signals_torch = torch.from_numpy(signals) + win_length = 5 + + y_paddle = mean(signals_paddle, win_length) + y_torch = mean_torch(signals_torch, win_length) + + print( y_paddle.mean().item() - y_torch.mean().item() ) + print( y_paddle.std().item() - y_torch.std().item() ) + + + +def median(signals, win_length): + """Median filtering for signals containing nan values + + Arguments + signals (torch.tensor (shape=(batch, time))) + The signals to filter + win_length + The size of the analysis window + + Returns + filtered (torch.tensor (shape=(batch, time))) + """ + + assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)" + signals = signals.unsqueeze(axis=1) + + mask = ~paddle.isnan(x=signals) + masked_x = paddle.where(condition=mask, x=signals, y=paddle.zeros_like( + x=signals)) + + padding = win_length // 2 + + x = F.pad(masked_x, (padding, padding), mode="reflect", data_format="NCL") + mask = F.pad(mask.astype(dtype='float32'), (padding, padding), mode="constant", value=0, data_format="NCL") + + x = x.unfold(axis=2, size=win_length, step=1) + mask = mask.unfold(axis=2, size=win_length, step=1) + + x = x.reshape(tuple(x.shape)[:3] + (-1,)) + mask = mask.reshape(tuple(mask.shape)[:3] + (-1,)) + + # Combine the mask with the input tensor + # x_masked = torch.where(mask.bool(), x.double(), float("inf")).to(x) + x_masked = paddle.where(condition=mask.astype(dtype='bool'), x=x.astype + (dtype='float64'), y=float('inf')) + + # Sort the masked tensor along the last dimension + # x_sorted, _ = torch.sort(x_masked, dim=-1) + x_sorted = paddle.sort(x=x_masked, axis=-1) + + # Compute the count of non-masked (valid) values + valid_count = mask.sum(axis=-1) + + # Calculate the index of the median value for each pooling window + median_idx = ((valid_count - 1) // 2).clip(min=0) + + # Gather the median values using the calculated indices + # median_pooled = x_sorted.gather(-1, median_idx.unsqueeze(-1).long()).squeeze(-1) + median_pooled = x_sorted.take_along_axis(axis=-1, indices=median_idx. + unsqueeze(axis=-1).astype(dtype='int64')).squeeze(axis=-1) + + # Fill infinite values with NaNs + # median_pooled[torch.isinf(median_pooled)] = float("nan") + median_pooled[paddle.isinf(x=median_pooled)] = float('nan') + + return median_pooled.squeeze(axis=1) + + +def median_torch(signals, win_length): + """Median filtering for signals containing nan values + + Arguments + signals (torch.tensor (shape=(batch, time))) + The signals to filter + win_length + The size of the analysis window + + Returns + filtered (torch.tensor (shape=(batch, time))) + """ + import torch + import torch.nn.functional as F + + assert signals.dim() == 2, "Input tensor must have 2 dimensions (batch_size, width)" + signals = signals.unsqueeze(1) + + mask = ~torch.isnan(signals) + masked_x = torch.where(mask, signals, torch.zeros_like(signals)) + padding = win_length // 2 + + x = F.pad(masked_x, (padding, padding), mode="reflect") + mask = F.pad(mask.float(), (padding, padding), mode="constant", value=0) + + x = x.unfold(2, win_length, 1) + mask = mask.unfold(2, win_length, 1) + + x = x.contiguous().view(x.size()[:3] + (-1,)) + mask = mask.contiguous().view(mask.size()[:3] + (-1,)) + + # Combine the mask with the input tensor + x_masked = torch.where(mask.bool(), x.double(), float("inf")).to(x) + + # Sort the masked tensor along the last dimension + x_sorted, _ = torch.sort(x_masked, dim=-1) + + # Compute the count of non-masked (valid) values + valid_count = mask.sum(dim=-1) + + # Calculate the index of the median value for each pooling window + median_idx = ((valid_count - 1) // 2).clamp(min=0) + + # Gather the median values using the calculated indices + median_pooled = x_sorted.gather(-1, median_idx.unsqueeze(-1).long()).squeeze(-1) + + # Fill infinite values with NaNs + median_pooled[torch.isinf(median_pooled)] = float("nan") + + return median_pooled.squeeze(1) + + +if __name__ == "__main__": + + import torch + import numpy as np + + signals = np.random.randn(1, 3591)*1000 + + signals_paddle = paddle.to_tensor(signals) + signals_torch = torch.from_numpy(signals) + win_length = 7 + + y_paddle = median(signals_paddle, win_length) + y_torch = median_torch(signals_torch, win_length) + + print( y_paddle.mean().item() - y_torch.mean().item() ) + print( y_paddle.std().item() - y_torch.std().item() ) + + +############################################################################### +# Utilities +############################################################################### + + +def nanfilter(signals, win_length, filter_fn): + """Filters a sequence, ignoring nan values + + Arguments + signals (torch.tensor (shape=(batch, time))) + The signals to filter + win_length + The size of the analysis window + filter_fn (function) + The function to use for filtering + + Returns + filtered (torch.tensor (shape=(batch, time))) + """ + # Output buffer + filtered = torch.empty_like(signals) + + # Loop over frames + for i in range(signals.size(1)): + + # Get analysis window bounds + start = max(0, i - win_length // 2) + end = min(signals.size(1), i + win_length // 2 + 1) + + # Apply filter to window + filtered[:, i] = filter_fn(signals[:, start:end]) + + return filtered + + +def nanmean(signals): + """Computes the mean, ignoring nans + + Arguments + signals (torch.tensor [shape=(batch, time)]) + The signals to filter + + Returns + filtered (torch.tensor [shape=(batch, time)]) + """ + signals = signals.clone() + + # Find nans + nans = torch.isnan(signals) + + # Set nans to 0. + signals[nans] = 0. + + # Compute average + return signals.sum(dim=1) / (~nans).float().sum(dim=1) + + +def nanmedian(signals): + """Computes the median, ignoring nans + + Arguments + signals (torch.tensor [shape=(batch, time)]) + The signals to filter + + Returns + filtered (torch.tensor [shape=(batch, time)]) + """ + # Find nans + nans = torch.isnan(signals) + + # Compute median for each slice + medians = [nanmedian1d(signal[~nan]) for signal, nan in zip(signals, nans)] + + # Stack results + return torch.tensor(medians, dtype=signals.dtype, device=signals.device) + + +def nanmedian1d(signal): + """Computes the median. If signal is empty, returns torch.nan + + Arguments + signal (torch.tensor [shape=(time,)]) + + Returns + median (torch.tensor [shape=(1,)]) + """ + return torch.median(signal) if signal.numel() else np.nan diff --git a/paddlemix/models/vits-svc/crepe/load.py b/paddlemix/models/vits-svc/crepe/load.py new file mode 100644 index 000000000..e68c0bc0e --- /dev/null +++ b/paddlemix/models/vits-svc/crepe/load.py @@ -0,0 +1,41 @@ +import os + +import numpy as np +import paddle +import crepe +from scipy.io import wavfile + + +def audio(filename): + """Load audio from disk""" + sample_rate, audio = wavfile.read(filename) + + # Convert to float32 + if audio.dtype == np.int16: + audio = audio.astype(np.float32) / np.iinfo(np.int16).max + + # PyTorch is not compatible with non-writeable arrays, so we make a copy + return torch.tensor(np.copy(audio))[None], sample_rate + + +def model(device, capacity='full'): + """Preloads model from disk""" + # Bind model and capacity + crepe.infer.capacity = capacity + crepe.infer.model = crepe.Crepe(capacity) + + # Load weights + file = os.path.join(os.path.dirname(__file__), 'assets', f'{capacity}.pdparam') + # crepe.infer.model.load_state_dict( + # torch.load(file, map_location=device)) + + crepe.infer.model.load_dict( + paddle.load(file)) + + # Place on device + # crepe.infer.model = crepe.infer.model.to(device) + + # Eval mode + crepe.infer.model.eval() + + diff --git a/paddlemix/models/vits-svc/hubert/LICENSE.txt b/paddlemix/models/vits-svc/hubert/LICENSE.txt new file mode 100644 index 000000000..6eb2af050 --- /dev/null +++ b/paddlemix/models/vits-svc/hubert/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Benjamin van Niekerk + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/paddlemix/models/vits-svc/hubert/__init__.py b/paddlemix/models/vits-svc/hubert/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/paddlemix/models/vits-svc/hubert/hubert_model.py b/paddlemix/models/vits-svc/hubert/hubert_model.py new file mode 100644 index 000000000..800869299 --- /dev/null +++ b/paddlemix/models/vits-svc/hubert/hubert_model.py @@ -0,0 +1,1114 @@ +import copy +import random +from typing import Optional, Tuple + +import os +import re +import numpy as np + +import paddle + +import torch +import torch.nn as nn + + + +# # 这部分用于对齐 torch 和 paddle 的 MultiheadAttention +# if __name__ == "__main__": + +# np.random.seed(1107) + +# x = np.random.randn(32, 255, 768) * 10000 +# x = x.astype("float32") + +# x_paddle = paddle.to_tensor(x) +# x_torch = torch.from_numpy(x).cuda() + +# n_heads = 12 +# d_model = 768 +# dropout = 0 + +# # -------- torch -------- +# self_attn_torch = torch.nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True).cuda() +# # -------- paddle -------- +# self_attn_paddle = paddle.nn.MultiHeadAttention(d_model, n_heads, dropout=dropout) + + +# # ----------------------------- +# attn_torch_state_dict = self_attn_torch.state_dict() +# attn_paddle_state_dict = self_attn_paddle.state_dict() + +# en_param_torch = set(attn_torch_state_dict.keys()) +# en_param_paddle = set(attn_paddle_state_dict.keys()) + +# # torch 参数传给 Paddle +# for key_torch, value_torch in attn_torch_state_dict.items(): + +# # print(key_torch) + +# if 'out_proj.weight' == key_torch: +# attn_paddle_state_dict[key_torch] = paddle.to_tensor(value_torch.cpu().numpy()).T +# continue + +# if key_torch in en_param_paddle: +# assert attn_paddle_state_dict[key_torch].shape == list(value_torch.shape) +# attn_paddle_state_dict[key_torch] = paddle.to_tensor(value_torch.cpu().numpy()) +# continue + +# print(f"{key_torch} -> sth happened") + + +# # 参数矩阵 +# q, k, v = torch.chunk(attn_torch_state_dict['in_proj_weight'], 3) +# some_deal = lambda x: paddle.to_tensor(x.cpu().numpy()).T # <-------- 转置? +# q, k, v = some_deal(q), some_deal(k), some_deal(v) + +# attn_paddle_state_dict['q_proj.weight'] = q +# attn_paddle_state_dict['k_proj.weight'] = k +# attn_paddle_state_dict['v_proj.weight'] = v + +# # 偏置 +# q, k, v = torch.chunk(attn_torch_state_dict['in_proj_bias'], 3) +# some_deal = lambda x: paddle.to_tensor(x.cpu().numpy()) +# q, k, v = some_deal(q), some_deal(k), some_deal(v) + +# attn_paddle_state_dict['q_proj.bias'] = q +# attn_paddle_state_dict['k_proj.bias'] = k +# attn_paddle_state_dict['v_proj.bias'] = v + + +# # 加载参数 +# self_attn_paddle.load_dict(attn_paddle_state_dict) + +# # torch +# tgt2_torch = self_attn_torch(x_torch, x_torch, x_torch)[0] + +# # paddle +# tgt2_paddle = self_attn_paddle(x_paddle, x_paddle, x_paddle) + +# print( +# tgt2_torch.mean().item() - tgt2_paddle.mean().item(), +# tgt2_torch.std().item() - tgt2_paddle.std().item(), +# ) + + +def MultiheadAttention_torch2paddle(pd_model, tc_model, prefix=""): + + self_attn_torch, self_attn_paddle = tc_model, pd_model + + # 多头注意力机制,torch 模型转 paddle 模型 + attn_torch_state_dict = self_attn_torch.state_dict() + attn_paddle_state_dict = self_attn_paddle.state_dict() + + en_param_torch = set(attn_torch_state_dict.keys()) + en_param_paddle = set(attn_paddle_state_dict.keys()) + + # torch 参数传给 Paddle + for key_torch, value_torch in attn_torch_state_dict.items(): + + # print(key_torch) + + if prefix + 'out_proj.weight' == key_torch: + # Linear 层参数要转置 + attn_paddle_state_dict[key_torch] = paddle.to_tensor(value_torch.cpu().numpy()).T + continue + + if key_torch in en_param_paddle: + assert attn_paddle_state_dict[key_torch].shape == list(value_torch.shape) + attn_paddle_state_dict[key_torch] = paddle.to_tensor(value_torch.cpu().numpy()) + continue + + print(f"{key_torch} -> sth happened") + + + # 参数矩阵 + q, k, v = torch.chunk(attn_torch_state_dict[prefix + 'in_proj_weight'], 3) + some_deal = lambda x: paddle.to_tensor(x.cpu().numpy()).T # <-------- 转置? + q, k, v = some_deal(q), some_deal(k), some_deal(v) + + attn_paddle_state_dict[prefix + 'q_proj.weight'] = q + attn_paddle_state_dict[prefix + 'k_proj.weight'] = k + attn_paddle_state_dict[prefix + 'v_proj.weight'] = v + + # 偏置 + q, k, v = torch.chunk(attn_torch_state_dict[prefix + 'in_proj_bias'], 3) + some_deal = lambda x: paddle.to_tensor(x.cpu().numpy()) + q, k, v = some_deal(q), some_deal(k), some_deal(v) + + attn_paddle_state_dict[prefix + 'q_proj.bias'] = q + attn_paddle_state_dict[prefix + 'k_proj.bias'] = k + attn_paddle_state_dict[prefix + 'v_proj.bias'] = v + + + # 加载参数 + self_attn_paddle.load_dict(attn_paddle_state_dict) + + return self_attn_paddle + + +# # 测试 MultiheadAttention_torch2paddle 函数 +# if __name__ == "__main__": + +# np.random.seed(1107) + +# x = np.random.randn(32, 255, 768) * 3 +# x = x.astype("float32") + +# x_paddle = paddle.to_tensor(x) +# x_torch = torch.from_numpy(x).cuda() + +# n_heads = 12 +# d_model = 768 +# dropout = 0 + +# # -------- torch -------- +# self_attn_torch = torch.nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True).cuda() # <--- 注意这里是 batch first +# # -------- paddle -------- +# self_attn_paddle = paddle.nn.MultiHeadAttention(d_model, n_heads, dropout=dropout) + +# # torch 的参数给 paddle 模型 +# self_attn_paddle = MultiheadAttention_torch2paddle(self_attn_paddle, self_attn_torch) + + +# # torch +# tgt2_torch = self_attn_torch(x_torch, x_torch, x_torch)[0] + +# # paddle +# tgt2_paddle = self_attn_paddle(x_paddle, x_paddle, x_paddle) + + +# print( +# tgt2_torch.mean().item() - tgt2_paddle.mean().item(), +# tgt2_torch.std().item() - tgt2_paddle.std().item(), +# ) + + + +class FeatureExtractor(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.conv0 = paddle.nn.Conv1D(1, 512, 10, 5, bias_attr=False) + self.norm0 = paddle.nn.GroupNorm(512, 512) + self.conv1 = paddle.nn.Conv1D(512, 512, 3, 2, bias_attr=False) + self.conv2 = paddle.nn.Conv1D(512, 512, 3, 2, bias_attr=False) + self.conv3 = paddle.nn.Conv1D(512, 512, 3, 2, bias_attr=False) + self.conv4 = paddle.nn.Conv1D(512, 512, 3, 2, bias_attr=False) + self.conv5 = paddle.nn.Conv1D(512, 512, 2, 2, bias_attr=False) + self.conv6 = paddle.nn.Conv1D(512, 512, 2, 2, bias_attr=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = paddle.nn.functional.gelu(self.norm0(self.conv0(x))) + x = paddle.nn.functional.gelu(self.conv1(x)) + x = paddle.nn.functional.gelu(self.conv2(x)) + x = paddle.nn.functional.gelu(self.conv3(x)) + x = paddle.nn.functional.gelu(self.conv4(x)) + x = paddle.nn.functional.gelu(self.conv5(x)) + x = paddle.nn.functional.gelu(self.conv6(x)) + return x + + + +class FeatureExtractor_torch(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv0 = torch.nn.Conv1d(1, 512, 10, 5, bias=False) + self.norm0 = torch.nn.GroupNorm(512, 512) + self.conv1 = torch.nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv2 = torch.nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv3 = torch.nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv4 = torch.nn.Conv1d(512, 512, 3, 2, bias=False) + self.conv5 = torch.nn.Conv1d(512, 512, 2, 2, bias=False) + self.conv6 = torch.nn.Conv1d(512, 512, 2, 2, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = torch.nn.functional.gelu(self.norm0(self.conv0(x))) + x = torch.nn.functional.gelu(self.conv1(x)) + x = torch.nn.functional.gelu(self.conv2(x)) + x = torch.nn.functional.gelu(self.conv3(x)) + x = torch.nn.functional.gelu(self.conv4(x)) + x = torch.nn.functional.gelu(self.conv5(x)) + x = torch.nn.functional.gelu(self.conv6(x)) + return x + + + +# if __name__ == "__main__": +# # 目标, 将 torch 参数传递给 paddle 模型 +# torch_fe = FeatureExtractor_torch().cuda() +# paddle_fe = FeatureExtractor() + + +# np.random.seed(1107) +# inputs = np.random.rand(1, 1, 574480).astype("float32") + +# tc_inp = torch.from_numpy(inputs).cuda() +# pd_inp = paddle.to_tensor(inputs) + + +# paddle_fe_state_dict = paddle_fe.state_dict() +# # 目测参数都一样, 直接转换就可以 +# if set(torch_fe.state_dict().keys()) == set(paddle_fe.state_dict().keys()): + +# for torch_key, torch_value in torch_fe.state_dict().items(): + +# paddle_fe_state_dict[torch_key] = paddle.to_tensor( torch_value.detach().cpu().numpy() ) + +# else: +# print("WTF!?") + +# paddle_fe.load_dict(paddle_fe_state_dict) + +# # 运行模型 +# y_tc = torch_fe(tc_inp) +# y_pd = paddle_fe(pd_inp) + +# print( +# abs((y_tc.cpu().detach().numpy() +# - +# y_pd.numpy())).max().item(), +# ) + + + +class FeatureProjection(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.norm = paddle.nn.LayerNorm(512) + self.projection = paddle.nn.Linear(512, 768) + self.dropout = paddle.nn.Dropout(0.1) + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + x = self.norm(x) + x = self.projection(x) + x = self.dropout(x) + return x + + +class FeatureProjection_torch(torch.nn.Module): + def __init__(self): + super().__init__() + self.norm = torch.nn.LayerNorm(512) + self.projection = torch.nn.Linear(512, 768) + self.dropout = torch.nn.Dropout(0.1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + x = self.projection(x) + x = self.dropout(x) + return x + + + +# if __name__ == "__main__": +# # 目标, 将 torch 参数传递给 paddle 模型 + +# torch_fp = FeatureProjection_torch().cuda() +# paddle_fp = FeatureProjection() + +# # 开 eval , 有 dropout +# torch_fp.eval() +# paddle_fp.eval() + + +# np.random.seed(1107) +# inputs = np.random.rand(1, 1795, 512).astype("float32") + +# tc_inp = torch.from_numpy(inputs).cuda() +# pd_inp = paddle.to_tensor(inputs) + + +# paddle_fp_state_dict = paddle_fp.state_dict() +# # 目测参数都一样, 直接转换就可以 +# if set(torch_fp.state_dict().keys()) == set(paddle_fp.state_dict().keys()): + +# for torch_key, torch_value in torch_fp.state_dict().items(): + +# if "projection.weight" == torch_key: +# paddle_fp_state_dict[torch_key] = paddle.to_tensor( torch_value.detach().cpu().numpy() ).T +# else: +# assert paddle_fp_state_dict[torch_key].shape == list(torch_value.shape) +# paddle_fp_state_dict[torch_key] = paddle.to_tensor( torch_value.detach().cpu().numpy() ) + +# else: +# print("WTF!?") + + +# paddle_fp.load_dict(paddle_fp_state_dict) + +# # 运行模型 +# y_tc = torch_fp(tc_inp) +# y_pd = paddle_fp(pd_inp) + +# print( +# abs((y_tc.cpu().detach().numpy() +# - +# y_pd.numpy())).max().item(), +# ) + + + +class PositionalConvEmbedding(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.conv = paddle.nn.Conv1D( + 768, + 768, + kernel_size=128, + padding=128 // 2, + groups=16, + ) + self.conv = paddle.nn.utils.weight_norm(self.conv, name="weight", dim=2) + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + x = self.conv(x.transpose([0, 2, 1])) + x = paddle.nn.functional.gelu(x[:, :, :-1]) + return x.transpose([0, 2, 1]) + + +class PositionalConvEmbedding_torch(nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d( + 768, + 768, + kernel_size=128, + padding=128 // 2, + groups=16, + ) + self.conv = torch.nn.utils.weight_norm(self.conv, name="weight", dim=2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x.transpose(1, 2)) + x = torch.nn.functional.gelu(x[:, :, :-1]) + return x.transpose(1, 2) + + + +# if __name__ == "__main__": +# # 目标, 将 torch 参数传递给 paddle 模型 + +# torch_fp = PositionalConvEmbedding_torch().cuda() +# paddle_fp = PositionalConvEmbedding() + +# # 开 eval , 有 dropout +# torch_fp.eval() +# paddle_fp.eval() + + +# np.random.seed(1107) +# inputs = np.random.rand(16, 1795, 768).astype("float32") + +# tc_inp = torch.from_numpy(inputs).cuda() +# pd_inp = paddle.to_tensor(inputs) + +# paddle_fp_state_dict = paddle_fp.state_dict() +# # 目测参数都一样, 直接转换就可以 +# if set(torch_fp.state_dict().keys()) == set(paddle_fp.state_dict().keys()): + +# for torch_key, torch_value in torch_fp.state_dict().items(): + +# # assert paddle_fp_state_dict[torch_key].shape == list(torch_value.shape) +# assert paddle_fp_state_dict[torch_key].size == torch_value.numel() + +# origin_shape = paddle_fp_state_dict[torch_key].shape + +# paddle_fp_state_dict[torch_key] = paddle.to_tensor( torch_value.detach().cpu().numpy() ).reshape(origin_shape) + +# else: +# print("WTF!?") + + +# paddle_fp.load_dict(paddle_fp_state_dict) + +# # 运行模型 +# y_tc = torch_fp(tc_inp) +# y_pd = paddle_fp(pd_inp) + +# print( +# "PositionalConvEmbedding", +# abs((y_tc.cpu().detach().numpy() +# - +# y_pd.numpy())).max().item(), +# ) + + + +class TransformerEncoder(paddle.nn.Layer): + def __init__( + self, encoder_layer: paddle.nn.TransformerEncoderLayer, num_layers: int + ) -> None: + super(TransformerEncoder, self).__init__() + self.layers = paddle.nn.LayerList( + [copy.deepcopy(encoder_layer) for _ in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: paddle.Tensor, + mask: paddle.Tensor = None, + src_key_padding_mask: paddle.Tensor = None, + output_layer: Optional[int] = None, + ) -> paddle.Tensor: + output = src + for layer in self.layers[:output_layer]: + output = layer( + output, src_mask=mask, + # src_key_padding_mask=src_key_padding_mask # <-------- paddle 没有这个参数 + ) + return output + + +class TransformerEncoder_torch(torch.nn.Module): + def __init__( + self, encoder_layer: torch.nn.TransformerEncoderLayer, num_layers: int + ) -> None: + super(TransformerEncoder_torch, self).__init__() + self.layers = torch.nn.ModuleList( + [copy.deepcopy(encoder_layer) for _ in range(num_layers)] + ) + self.num_layers = num_layers + + def forward( + self, + src: torch.Tensor, + mask: torch.Tensor = None, + src_key_padding_mask: torch.Tensor = None, + output_layer: Optional[int] = None, + ) -> torch.Tensor: + output = src + for layer in self.layers[:output_layer]: + output = layer( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask + ) + return output + + +# if __name__ == "__main__": + +# torch_encoder = TransformerEncoder_torch( +# torch.nn.TransformerEncoderLayer( +# 768, 12, 3072, activation="gelu", batch_first=True +# ), +# 12, +# ).cuda() + +# paddle_encoder = TransformerEncoder( +# paddle.nn.TransformerEncoderLayer( +# 768, 12, 3072, activation="gelu" +# ), +# 12, +# ) + +# torch_encoder.eval() +# paddle_encoder.eval() + +# np.random.seed(1107) +# inputs = np.random.rand(2, 1795, 768).astype("float32") # <--- + +# tc_inp = torch.from_numpy(inputs).cuda() +# pd_inp = paddle.to_tensor(inputs) + +# paddle_fp_state_dict = paddle_encoder.state_dict() + +# # torch 参数 +# param_torch = set(torch_encoder.state_dict().keys()) +# param_paddle = set(paddle_encoder.state_dict().keys()) + +# # 所有 torch 有, 但 paddle 没有的参数 +# # 这个主要是 multihead attention 那块儿参数的问题 +# param_torch - param_paddle + +# pattern = r"layers\.\d+\.linear\d+\.weight" + +# # 先解决共有参数的问题 +# for torch_key in param_torch & param_paddle: + +# torch_value = torch_encoder.state_dict()[torch_key] +# assert paddle_fp_state_dict[torch_key].size == torch_value.numel() + +# match = re.match(pattern, torch_key) +# if match: +# # 匹配到了, 需要转置 +# # print(torch_key) +# paddle_fp_state_dict[torch_key] = paddle.to_tensor( torch_value.detach().cpu().numpy().T ) +# else: +# # 没有匹配到了 +# assert paddle_fp_state_dict[torch_key].shape == list(torch_value.shape) +# paddle_fp_state_dict[torch_key] = paddle.to_tensor( torch_value.detach().cpu().numpy() ) + +# assert len(param_paddle - param_torch) == len(param_torch - param_paddle) * 3 + +# # 参数加载一遍 +# paddle_encoder.load_dict(paddle_fp_state_dict) + +# # 接下来就是 attention 部分的参数转换 +# for idx in range( len(paddle_encoder.layers) ): + +# pd_model = paddle_encoder.layers[idx].self_attn +# tc_model = torch_encoder.layers[idx].self_attn + +# prefix = f"layers.{idx}.self_attn." + +# print("BEFORE ", pd_model.q_proj.weight.data.mean().item()) +# paddle_encoder.layers[idx].self_attn = MultiheadAttention_torch2paddle(pd_model, tc_model, prefix="") +# print("AFTER ", pd_model.q_proj.weight.data.mean().item()) + +# # param_torch = set(torch_encoder.state_dict().keys()) +# # param_paddle = set(paddle_encoder.state_dict().keys()) + +# # 运行模型 +# y_tc = torch_encoder(tc_inp, output_layer=None) +# y_pd = paddle_encoder(pd_inp, output_layer=None) + +# print( +# "TransformerEncoder", +# abs((y_tc.cpu().detach().numpy() +# - +# y_pd.numpy())).max().item(), +# ) + + +def TransformerEncoder_torch2paddle(torch_encoder, paddle_encoder): + + paddle_fp_state_dict = paddle_encoder.state_dict() + + # torch 参数 + param_torch = set(torch_encoder.state_dict().keys()) + param_paddle = set(paddle_encoder.state_dict().keys()) + + # 所有 torch 有, 但 paddle 没有的参数 + # 这个主要是 multihead attention 那块儿参数的问题 + param_torch - param_paddle + + pattern = r"layers\.\d+\.linear\d+\.weight" + + # 先解决共有参数的问题 + for torch_key in param_torch & param_paddle: + + torch_value = torch_encoder.state_dict()[torch_key] + assert paddle_fp_state_dict[torch_key].size == torch_value.numel() + + match = re.match(pattern, torch_key) + if match: + # 匹配到了, 需要转置 + # print(torch_key) + paddle_fp_state_dict[torch_key] = paddle.to_tensor( torch_value.detach().cpu().numpy().T ) + else: + # 没有匹配到了 + assert paddle_fp_state_dict[torch_key].shape == list(torch_value.shape) + paddle_fp_state_dict[torch_key] = paddle.to_tensor( torch_value.detach().cpu().numpy() ) + + assert len(param_paddle - param_torch) == len(param_torch - param_paddle) * 3 + + # 参数加载一遍 + paddle_encoder.load_dict(paddle_fp_state_dict) + + # 接下来就是 attention 部分的参数转换 + for idx in range( len(paddle_encoder.layers) ): + + pd_model = paddle_encoder.layers[idx].self_attn + tc_model = torch_encoder.layers[idx].self_attn + + prefix = f"layers.{idx}.self_attn." + + print("BEFORE ", pd_model.q_proj.weight.data.mean().item()) + paddle_encoder.layers[idx].self_attn = MultiheadAttention_torch2paddle(pd_model, tc_model, prefix="") + print("AFTER ", pd_model.q_proj.weight.data.mean().item()) + + # param_torch = set(torch_encoder.state_dict().keys()) + # param_paddle = set(paddle_encoder.state_dict().keys()) + + return paddle_encoder + + + +# # 测试 TransformerEncoder_torch2paddle +# if __name__ == "__main__": + +# np.random.seed(1107) +# inputs = np.random.rand(2, 1795, 768).astype("float32") # <--- + +# tc_inp = torch.from_numpy(inputs).to("cuda:1") +# pd_inp = paddle.to_tensor(inputs) + +# torch_encoder = TransformerEncoder_torch( +# torch.nn.TransformerEncoderLayer( +# 768, 12, 3072, activation="gelu", batch_first=True +# ), +# 12, +# ).to("cuda:1") + +# paddle_encoder = TransformerEncoder( +# paddle.nn.TransformerEncoderLayer( +# 768, 12, 3072, activation="gelu" +# ), +# 12, +# ) + +# torch_encoder.eval() +# paddle_encoder.eval() + +# paddle_encoder = TransformerEncoder_torch2paddle(torch_encoder, paddle_encoder) + + +# # 运行模型 +# y_tc = torch_encoder(tc_inp, output_layer=None) +# y_pd = paddle_encoder(pd_inp, output_layer=None) + +# print( +# "TransformerEncoder", +# abs((y_tc.cpu().detach().numpy() +# - +# y_pd.numpy())).max().item(), +# ) + + + +def _compute_mask_torch( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + device: torch.device, + min_masks: int = 0, +) -> torch.Tensor: + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`" + ) + + # compute number of masked spans in batch + num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random()) + num_masked_spans = max(num_masked_spans, min_masks) + + # make sure num masked indices <= sequence_length + if num_masked_spans * mask_length > sequence_length: + num_masked_spans = sequence_length // mask_length + + # SpecAugment mask to fill + mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool) + + # uniform distribution to sample from, make sure that offset samples are < sequence_length + uniform_dist = torch.ones( + (batch_size, sequence_length - (mask_length - 1)), device=device + ) + + # get random indices to mask + mask_indices = torch.multinomial(uniform_dist, num_masked_spans) + + # expand masked indices to masked spans + mask_indices = ( + mask_indices.unsqueeze(dim=-1) + .expand((batch_size, num_masked_spans, mask_length)) + .reshape(batch_size, num_masked_spans * mask_length) + ) + offsets = ( + torch.arange(mask_length, device=device)[None, None, :] + .expand((batch_size, num_masked_spans, mask_length)) + .reshape(batch_size, num_masked_spans * mask_length) + ) + mask_idxs = mask_indices + offsets + + # scatter indices to mask + mask = mask.scatter(1, mask_idxs, True) + + return mask + +# 推理不用 +def _compute_mask(shape: Tuple[int, int], mask_prob: float, mask_length: + int, min_masks: int=0) ->paddle.Tensor: + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError('`mask_length` has to be bigger than 0.') + + if mask_length > sequence_length: + raise ValueError( + f'`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`' + ) + + # compute number of masked spans in batch + num_masked_spans = int(mask_prob * sequence_length / mask_length + + random.random()) + num_masked_spans = max(num_masked_spans, min_masks) + + # make sure num masked indices <= sequence_length + if num_masked_spans * mask_length > sequence_length: + num_masked_spans = sequence_length // mask_length + + # SpecAugment mask to fill + mask = paddle.zeros(shape=(batch_size, sequence_length), dtype='bool') + + # uniform distribution to sample from, make sure that offset samples are < sequence_length + uniform_dist = paddle.ones(shape=(batch_size, sequence_length - ( + mask_length - 1))) + + # get random indices to mask + mask_indices = paddle.multinomial(x=uniform_dist, num_samples= + num_masked_spans) + + # expand masked indices to masked spans + mask_indices = mask_indices.unsqueeze(axis=-1).expand(shape=(batch_size, + num_masked_spans, mask_length)).reshape([batch_size, + num_masked_spans * mask_length]) + offsets = paddle.arange(end=mask_length)[None, None, :].expand(( + batch_size, num_masked_spans, mask_length)).reshape([batch_size, + num_masked_spans * mask_length]) + mask_idxs = mask_indices + offsets + + # scatter indices to mask + mask = mask.astype(int).put_along_axis(axis=1, indices=mask_idxs, values=1, broadcast=False).astype(bool) + + return mask + + +# if __name__ == "__main__": + +# torch_param = ((1, 11), 0.8, 10, 'cuda:0', 2) +# paddle_param = ((1, 11), 0.8, 10, 2) + +# torch_out = _compute_mask_torch(*torch_param) +# paddle_out = _compute_mask(*paddle_param) + +# # --------- multinomial 这里是随机的 --------- +# # print( +# # (torch_out.cpu().numpy() == paddle_out.numpy()).all() +# # ) + + + + +class Hubert(paddle.nn.Layer): + def __init__(self, num_label_embeddings: int = 100, mask: bool = True): + super().__init__() + self._mask = mask + self.feature_extractor = FeatureExtractor() + self.feature_projection = FeatureProjection() + self.positional_embedding = PositionalConvEmbedding() + self.norm = paddle.nn.LayerNorm(768) + self.dropout = paddle.nn.Dropout(0.1) + self.encoder = TransformerEncoder( + paddle.nn.TransformerEncoderLayer( + 768, 12, 3072, activation="gelu", + # batch_first=True # <-------- paddle 默认 batch first + ), + 12, + ) + self.proj = paddle.nn.Linear(768, 256) + self.masked_spec_embed = paddle.create_parameter( + shape=[768], + dtype='float32', + default_initializer=paddle.nn.initializer.Assign( + paddle.empty(shape=[768], dtype='float32').uniform_())) + + self.label_embedding = paddle.nn.Embedding(num_label_embeddings, 256) + + def mask(self, x: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]: + mask = None + if self.training and self._mask: + mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2) + x[mask] = self.masked_spec_embed.to(x.dtype) + return x, mask + + def encode( + self, x: paddle.Tensor, layer: Optional[int] = None + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + x = self.feature_extractor(x) + x = self.feature_projection(x.transpose([0, 2, 1])) + x, mask = self.mask(x) + x = x + self.positional_embedding(x) + x = self.dropout(self.norm(x)) + x = self.encoder(x, output_layer=layer) + return x, mask + + # def logits(self, x: paddle.Tensor) -> paddle.Tensor: + # logits = paddle.nn.functional.cosine_similarity( + # x.unsqueeze(2), + # self.label_embedding.weight.unsqueeze(0).unsqueeze(0), + # axis=-1, + # ) + # return logits / 0.1 + + # def forward(self, x: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]: + # x, mask = self.encode(x) + # x = self.proj(x) + # logits = self.logits(x) + # return logits, mask + + +class Hubert_torch(nn.Module): + def __init__(self, num_label_embeddings: int = 100, mask: bool = True): + super().__init__() + self._mask = mask + self.feature_extractor = FeatureExtractor_torch() + self.feature_projection = FeatureProjection_torch() + self.positional_embedding = PositionalConvEmbedding_torch() + self.norm = torch.nn.LayerNorm(768) + self.dropout = torch.nn.Dropout(0.1) + self.encoder = TransformerEncoder_torch( + torch.nn.TransformerEncoderLayer( + 768, 12, 3072, activation="gelu", batch_first=True + ), + 12, + ) + self.proj = torch.nn.Linear(768, 256) + + self.masked_spec_embed = torch.nn.Parameter(torch.FloatTensor(768).uniform_()) + self.label_embedding = torch.nn.Embedding(num_label_embeddings, 256) + + def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + mask = None + if self.training and self._mask: + mask = _compute_mask_torch((x.size(0), x.size(1)), 0.8, 10, x.device, 2) + x[mask] = self.masked_spec_embed.to(x.dtype) + return x, mask + + def encode( + self, x: torch.Tensor, layer: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = self.feature_extractor(x) + x = self.feature_projection(x.transpose(1, 2)) + x, mask = self.mask(x) + x = x + self.positional_embedding(x) + x = self.dropout(self.norm(x)) + x = self.encoder(x, output_layer=layer) + return x, mask + + # def logits(self, x: torch.Tensor) -> torch.Tensor: + # logits = torch.cosine_similarity( + # x.unsqueeze(2), + # self.label_embedding.weight.unsqueeze(0).unsqueeze(0), + # dim=-1, + # ) + # return logits / 0.1 + + # def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # x, mask = self.encode(x) + # x = self.proj(x) + # logits = self.logits(x) + # return logits, mask + + + +class HubertSoft_torch(Hubert_torch): + def __init__(self): + super().__init__() + + @torch.inference_mode() + def units(self, wav: torch.Tensor) -> torch.Tensor: + wav = torch.nn.functional.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) + x, _ = self.encode(wav) + return self.proj(x) + + +class HubertSoft(Hubert): + def __init__(self): + super().__init__() + + @paddle.no_grad() + def units(self, wav: paddle.Tensor) -> paddle.Tensor: + wav = paddle.nn.functional.pad(wav, ((400 - 320) // 2, (400 - 320) // 2), data_format="NCL") + x, _ = self.encode(wav) + return self.proj(x) + + +# # -------------------------------- 测试 HubertSoft -------------------------------- +# if __name__ == "__main__": + +# x = np.random.rand(1, 1, 574400).astype("float32") +# x_tc = torch.from_numpy(x).cuda() +# x_pd = paddle.to_tensor(x) + +# m_tc = HubertSoft_torch().cuda() +# m_pd = HubertSoft() + +# m_tc.eval() # for dropout +# m_pd.eval() + + +# # ------------ 参数集合 ------------ +# m_tc_params = set(m_tc.state_dict().keys()) +# m_pd_params = set(m_pd.state_dict().keys()) + +# pattern = r"encoder.layers\.\d+\.linear\d+\.weight" + +# m_pd_state_dict = m_pd.state_dict() + +# # 先解决共有参数的问题 +# for torch_key in m_tc_params & m_pd_params: + +# # print(torch_key) + +# torch_value = m_tc.state_dict()[torch_key] +# assert m_pd.state_dict()[torch_key].size == torch_value.numel() + +# match = re.match(pattern, torch_key) +# if match or torch_key in ["feature_projection.projection.weight", "proj.weight"]: +# # 匹配到了, 需要转置 +# # print(torch_key) +# m_pd_state_dict[torch_key] = paddle.to_tensor( torch_value.detach().cpu().numpy().T ) +# else: + +# if torch_key.endswith('.weight_g'): +# m_pd_state_dict[torch_key] = paddle.to_tensor( torch_value.detach().cpu().numpy() ).reshape( +# m_pd_state_dict[torch_key].shape +# ) + +# else: +# # 没有匹配到了 +# assert m_pd_state_dict[torch_key].shape == list(torch_value.shape) +# m_pd_state_dict[torch_key] = paddle.to_tensor( torch_value.detach().cpu().numpy() ) + +# # 参数加载一遍 +# m_pd.load_dict(m_pd_state_dict) + +# assert len(m_pd_params - m_tc_params) == len(m_tc_params - m_pd_params) * 3 + +# m_pd.encoder = TransformerEncoder_torch2paddle(m_tc.encoder, m_pd.encoder) + +# y_tc = m_tc.units(x_tc) +# y_pd = m_pd.units(x_pd) + +# print( +# abs( +# y_tc.detach().cpu().numpy() +# - +# y_pd.numpy() +# ).max().item() +# ) + + + +def hubert_soft_torch2paddle(m_tc, m_pd): + + m_tc.eval() # for dropout + m_pd.eval() + + # ------------ 参数集合 ------------ + m_tc_params = set(m_tc.state_dict().keys()) + m_pd_params = set(m_pd.state_dict().keys()) + + pattern = r"encoder.layers\.\d+\.linear\d+\.weight" + + m_pd_state_dict = m_pd.state_dict() + + # 先解决共有参数的问题 + for torch_key in m_tc_params & m_pd_params: + + # print(torch_key) + + torch_value = m_tc.state_dict()[torch_key] + assert m_pd.state_dict()[torch_key].size == torch_value.numel() + + match = re.match(pattern, torch_key) + if match or torch_key in ["feature_projection.projection.weight", "proj.weight"]: + # 匹配到了, 需要转置 + # print(torch_key) + m_pd_state_dict[torch_key] = paddle.to_tensor( torch_value.detach().cpu().numpy().T ) + else: + + if torch_key.endswith('.weight_g'): + m_pd_state_dict[torch_key] = paddle.to_tensor( torch_value.detach().cpu().numpy() ).reshape( + m_pd_state_dict[torch_key].shape + ) + + else: + # 没有匹配到了 + assert m_pd_state_dict[torch_key].shape == list(torch_value.shape) + m_pd_state_dict[torch_key] = paddle.to_tensor( torch_value.detach().cpu().numpy() ) + + # 参数加载一遍 + m_pd.load_dict(m_pd_state_dict) + + assert len(m_pd_params - m_tc_params) == len(m_tc_params - m_pd_params) * 3 + + m_pd.encoder = TransformerEncoder_torch2paddle(m_tc.encoder, m_pd.encoder) + + return m_pd + + +# # -------------------------------- 测试 hubert_soft_torch2paddle -------------------------------- +# if __name__ == "__main__": + +# x = np.random.rand(1, 1, 574400).astype("float32") +# x_tc = torch.from_numpy(x).cuda() +# x_pd = paddle.to_tensor(x) + +# m_tc = HubertSoft_torch().cuda() +# m_pd = HubertSoft() + +# m_tc.eval() # for dropout +# m_pd.eval() + +# m_pd = hubert_soft_torch2paddle(m_tc, m_pd) + +# y_tc = m_tc.units(x_tc) +# y_pd = m_pd.units(x_pd) + +# print( +# "hubert_soft_torch2paddle", +# abs( +# y_tc.detach().cpu().numpy() +# - +# y_pd.numpy() +# ).max().item() +# ) + + + + +def hubert_soft_torch( + path: str, +) -> HubertSoft_torch: + r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. + Args: + path (str): path of a pretrained model + """ + hubert = HubertSoft_torch() + checkpoint = torch.load(path) + + from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present + + consume_prefix_in_state_dict_if_present(checkpoint, "module.") + hubert.load_state_dict(checkpoint) + hubert.eval() + return hubert + + + +def hubert_soft( + path: str, +) -> HubertSoft: + r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. + Args: + path (str): path of a pretrained model + """ + hubert = HubertSoft() + checkpoint = paddle.load(path) + # consume_prefix_in_state_dict_if_present(checkpoint, "module.") + hubert.load_dict(checkpoint) + hubert.eval() + return hubert + + +# if __name__ == "__main__": + +# model_path = "~/Desktop/whisper-vits-svc/hubert_pretrain/hubert-soft-0d54a1f4.pt" +# model_path = os.path.expanduser(model_path) +# tc_model = hubert_soft_torch(model_path) + +# pd_model = HubertSoft() +# pd_model = hubert_soft_torch2paddle(tc_model, pd_model) + +# paddle.save( +# pd_model.state_dict(), +# "hubert_pretrain/hubert-soft.pdparam" +# ) + +# model_path = "hubert_pretrain/hubert-soft.pdparam" +# hubert_soft(model_path) \ No newline at end of file diff --git a/paddlemix/models/vits-svc/prepare/preprocess_a.py b/paddlemix/models/vits-svc/prepare/preprocess_a.py new file mode 100644 index 000000000..87d03b5ba --- /dev/null +++ b/paddlemix/models/vits-svc/prepare/preprocess_a.py @@ -0,0 +1,58 @@ +import os +import librosa +import argparse +import numpy as np +from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor, as_completed +from scipy.io import wavfile + + +def resample_wave(wav_in, wav_out, sample_rate): + wav, _ = librosa.load(wav_in, sr=sample_rate) + wav = wav / np.abs(wav).max() * 0.6 + wav = wav / max(0.01, np.max(np.abs(wav))) * 32767 * 0.6 + wavfile.write(wav_out, sample_rate, wav.astype(np.int16)) + + +def process_file(file, wavPath, spks, outPath, sr): + if file.endswith(".wav"): + file = file[:-4] + resample_wave(f"{wavPath}/{spks}/{file}.wav", f"{outPath}/{spks}/{file}.wav", sr) + + +def process_files_with_thread_pool(wavPath, spks, outPath, sr, thread_num=None): + files = [f for f in os.listdir(f"./{wavPath}/{spks}") if f.endswith(".wav")] + + with ThreadPoolExecutor(max_workers=thread_num) as executor: + futures = {executor.submit(process_file, file, wavPath, spks, outPath, sr): file for file in files} + + for future in tqdm(as_completed(futures), total=len(futures), desc=f'Processing {sr} {spks}'): + future.result() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True) + parser.add_argument("-o", "--out", help="out", dest="out", required=True) + parser.add_argument("-s", "--sr", help="sample rate", dest="sr", type=int, required=True) + parser.add_argument("-t", "--thread_count", help="thread count to process, set 0 to use all cpu cores", dest="thread_count", type=int, default=1) + + args = parser.parse_args() + print(args.wav) + print(args.out) + print(args.sr) + + os.makedirs(args.out, exist_ok=True) + wavPath = args.wav + outPath = args.out + + assert args.sr == 16000 or args.sr == 32000 + + for spks in os.listdir(wavPath): + if os.path.isdir(f"./{wavPath}/{spks}"): + os.makedirs(f"./{outPath}/{spks}", exist_ok=True) + if args.thread_count == 0: + process_num = os.cpu_count() // 2 + 1 + else: + process_num = args.thread_count + process_files_with_thread_pool(wavPath, spks, outPath, args.sr, process_num) diff --git a/paddlemix/models/vits-svc/prepare/preprocess_crepe.py b/paddlemix/models/vits-svc/prepare/preprocess_crepe.py new file mode 100644 index 000000000..4a2669607 --- /dev/null +++ b/paddlemix/models/vits-svc/prepare/preprocess_crepe.py @@ -0,0 +1,94 @@ +import sys,os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import numpy as np +import librosa +import paddle +import crepe +import argparse +from tqdm import tqdm + +def paddle_randn_like(t): + return paddle.randn(shape=t.shape, dtype=t.dtype) + +def compute_f0(filename, save, device): + + audio, sr = librosa.load(filename, sr=16000) + assert sr == 16000 + # Load audio + audio = paddle.to_tensor(audio)[None] + audio = audio + paddle_randn_like(audio) * 0.001 + + # Here we'll use a 10 millisecond hop length + hop_length = 160 + # Provide a sensible frequency range for your domain (upper limit is 2006 Hz) + # This would be a reasonable range for speech + fmin = 50 + fmax = 1000 + # Select a model capacity--one of "tiny" or "full" + model = "full" + # Pick a batch size that doesn't cause memory errors on your gpu + batch_size = 512 + # Compute pitch using first gpu + pitch, periodicity = crepe.predict( + audio, + sr, + hop_length, + fmin, + fmax, + model, + batch_size=batch_size, + device=device, + return_periodicity=True, + ) + # CREPE was not trained on silent audio. some error on silent need filter.pitPath + periodicity = crepe.filter.median(periodicity, 7) + pitch = crepe.filter.mean(pitch, 5) + pitch[periodicity < 0.5] = 0 + pitch = pitch.squeeze(0) + np.save(save, pitch, allow_pickle=False) + + +if __name__ == "__main__": + + # torch npy 路径 + tc_npy = "~/Desktop/whisper-vits-svc/data_svc/pitch/421_all/000002.pit.npy" + tc_npy = os.path.expanduser(tc_npy) + + # paddle npy 路径 + pd_npy = "~/Desktop/PaddleMIX/paddlemix/models/vits-svc/data_svc/pitch/421_all/000002.pit.npy" + pd_npy = os.path.expanduser(pd_npy) + + tc_arr = np.load(tc_npy) + pd_arr = np.load(pd_npy) + + print( + abs(tc_arr - pd_arr).max(), + abs(tc_arr - pd_arr).mean(), + tc_arr.std() - pd_arr.std(), + tc_arr.mean() - pd_arr.mean(), + ) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("-w", "--wav", help="wav", dest="wav", default="data_svc/waves-16k") + parser.add_argument("-p", "--pit", help="pit", dest="pit", default="data_svc/pitch") + + args = parser.parse_args() + print(args.wav) + print(args.pit) + + os.makedirs(args.pit, exist_ok=True) + wavPath = args.wav + pitPath = args.pit + + device = None + + for spks in os.listdir(wavPath): + if os.path.isdir(f"./{wavPath}/{spks}"): + os.makedirs(f"./{pitPath}/{spks}", exist_ok=True) + + files = [f for f in os.listdir(f"./{wavPath}/{spks}") if f.endswith(".wav")] + for file in tqdm(files, desc=f'Processing crepe {spks}'): + file = file[:-4] + compute_f0(f"{wavPath}/{spks}/{file}.wav", f"{pitPath}/{spks}/{file}.pit", device) diff --git a/paddlemix/models/vits-svc/prepare/preprocess_hubert.py b/paddlemix/models/vits-svc/prepare/preprocess_hubert.py new file mode 100644 index 000000000..bbf0c1b55 --- /dev/null +++ b/paddlemix/models/vits-svc/prepare/preprocess_hubert.py @@ -0,0 +1,79 @@ +import sys,os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import numpy as np +import argparse +import paddle +import librosa + +from tqdm import tqdm +from hubert import hubert_model + + +def load_audio(file: str, sr: int = 16000): + x, sr = librosa.load(file, sr=sr) + return x + + +def load_model(path, device=None): + model = hubert_model.hubert_soft(path) + model.eval() + # model.half() + # model.to(device) + return model + + +def pred_vec(model, wavPath, vecPath): + feats = load_audio(wavPath) + feats = paddle.to_tensor(feats) + # feats = feats[None, None, :].half() + feats = feats[None, None, :] + with paddle.no_grad(): + with paddle.amp.auto_cast(): + vec = model.units(feats).squeeze().numpy() + # print(vec.shape) # [length, dim=256] hop=320 + np.save(vecPath, vec, allow_pickle=False) + + +if __name__ == "__main__": + + # # torch npy 路径 + # tc_npy = "~/Desktop/whisper-vits-svc/data_svc/hubert/421_all/000002.vec.npy" + # tc_npy = os.path.expanduser(tc_npy) + + # # paddle npy 路径 + # pd_npy = "~/Desktop/PaddleMIX/paddlemix/models/vits-svc/data_svc/hubert/421_all/000002.vec.npy" + # pd_npy = os.path.expanduser(pd_npy) + + # tc_arr = np.load(tc_npy) + # pd_arr = np.load(pd_npy) + + # print( + # abs(tc_arr - pd_arr).max(), + # abs(tc_arr - pd_arr).mean(), + # tc_arr.std() - pd_arr.std(), + # tc_arr.mean() - pd_arr.mean(), + # ) + + + parser = argparse.ArgumentParser() + parser.add_argument("-w", "--wav", help="wav", dest="wav", default="data_svc/waves-16k/") + parser.add_argument("-v", "--vec", help="vec", dest="vec", default="data_svc/hubert") + + args = parser.parse_args() + print(args.wav) + print(args.vec) + os.makedirs(args.vec, exist_ok=True) + + wavPath = args.wav + vecPath = args.vec + + hubert = load_model( os.path.join("hubert_pretrain", "hubert-soft.pdparam") ) + + for spks in os.listdir(wavPath): + if os.path.isdir(f"./{wavPath}/{spks}"): + os.makedirs(f"./{vecPath}/{spks}", exist_ok=True) + + files = [f for f in os.listdir(f"./{wavPath}/{spks}") if f.endswith(".wav")] + for file in tqdm(files, desc=f'Processing vec {spks}'): + file = file[:-4] + pred_vec(hubert, f"{wavPath}/{spks}/{file}.wav", f"{vecPath}/{spks}/{file}.vec") diff --git a/paddlemix/models/vits-svc/prepare/preprocess_ppg.py b/paddlemix/models/vits-svc/prepare/preprocess_ppg.py new file mode 100644 index 000000000..e58f853e9 --- /dev/null +++ b/paddlemix/models/vits-svc/prepare/preprocess_ppg.py @@ -0,0 +1,183 @@ +import sys,os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import numpy as np +import argparse +import paddle +import torch +import random +from tqdm import tqdm +from whisper.model import Whisper, ModelDimensions, Whisper_torch, AudioEncoder_torch2paddle +from whisper.audio import load_audio, pad_or_trim, log_mel_spectrogram_torch, log_mel_spectrogram + + +checkpoint_dims = { + 'n_mels': 80, + 'n_vocab': 51865, + 'n_audio_ctx': 1500, + 'n_audio_state': 1280, + 'n_audio_head': 20, + 'n_audio_layer': 32, + 'n_text_ctx': 448, + 'n_text_state': 1280, + 'n_text_head': 20, + 'n_text_layer': 32, +} + + +def load_model_torch(path) -> Whisper_torch: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(path, map_location="cpu") + dims = ModelDimensions(**checkpoint_dims) + print(dims) + model = Whisper_torch(dims) + # del model.decoder + cut = len(model.encoder.blocks) // 4 + cut = -1 * cut + del model.encoder.blocks[cut:] + model.load_state_dict(checkpoint["model_state_dict"], strict=False) + model.eval() + # model.half() + model.to(device) + return model + +def load_model(path) -> Whisper: + # device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = paddle.load(path) + + # dims = ModelDimensions(**checkpoint["dims"]) + dims = ModelDimensions(**checkpoint_dims) + + print(dims) + model = Whisper(dims) + # del model.decoder + cut = len(model.encoder.blocks) // 4 + cut = -1 * cut + del model.encoder.blocks[cut:] + model.set_state_dict(checkpoint) + model.eval() + # model.half() + # model.to(device) + return model + + +# -------------------------------------------------------------- + +def pred_ppg(whisper: Whisper, wavPath, ppgPath): + # wavPath = "~/Desktop/whisper-vits-svc/data_svc/waves-16k/435_all/000000.wav" + wavPath = os.path.expanduser(wavPath) + audio = load_audio(wavPath) + audln = audio.shape[0] + ppgln = audln // 320 + audio = pad_or_trim(audio) + mel = log_mel_spectrogram(audio) + # return mel.cpu().numpy() + with paddle.no_grad(): + ppg = whisper.encoder(mel.unsqueeze(0)).squeeze().data.cpu().numpy() + ppg = ppg[:ppgln,] # [length, dim=1280] + # print(ppg.shape) + np.save(ppgPath, ppg, allow_pickle=False) + + return ppg + + +# def pred_ppg_torch(whisper: Whisper_torch, wavPath, ppgPath): +# audio = load_audio(wavPath) +# audln = audio.shape[0] +# ppgln = audln // 320 +# audio = pad_or_trim(audio) +# mel = log_mel_spectrogram_torch(audio).cuda() +# # return mel.cpu().numpy() +# with torch.no_grad(): +# ppg = whisper.encoder(mel.unsqueeze(0)).squeeze().data.cpu().float().numpy() +# ppg = ppg[:ppgln,] # [length, dim=1280] +# # print(ppg.shape) +# # np.save(ppgPath, ppg, allow_pickle=False) + +# return ppg + +# -------------------------------------------------------------- + +if __name__ == "__main__": + + # path_tc = "~/Desktop/whisper-vits-svc/data_svc/whisper/434_all/000002.ppg.npy" + # path_tc = os.path.expanduser(path_tc) + # y_tc = np.load(path_tc) + + # path_pd = "~/Desktop/PaddleMIX/paddlemix/models/vits-svc/data_svc/whisper/434_all/000002.ppg.npy" + # path_pd = os.path.expanduser(path_pd) + # y_pd = np.load(path_pd) + + # print( + # "MAX:", abs(y_tc - y_pd).max(), + # "MEAN", abs(y_tc - y_pd).mean(), + # ) + + # ---------------------------------------------------------- + parser = argparse.ArgumentParser() + parser.add_argument("-w", "--wav", help="wav", dest="wav", default="data_svc/waves-16k") + parser.add_argument("-p", "--ppg", help="ppg", dest="ppg", default="data_svc/whisper") + args = parser.parse_args() + print(args.wav) + print(args.ppg) + + os.makedirs(args.ppg, exist_ok=True) + wavPath = args.wav + ppgPath = args.ppg + + whisper = load_model(os.path.join("whisper_pretrain", "large-v2.pdparam")) + + # ------------ torch 旧模型转化 ------------ + # whisper_torch = load_model_torch(os.path.join("whisper_pretrain", "large-v2.pt")) + # AudioEncoder_torch2paddle(whisper_torch.encoder, whisper.encoder) + + # paddle.save( + # whisper.state_dict(), "whisper_pretrain/large-v2.pdparam" + # ) + + # ------------ 测试 paddle 和 torch 模型 ------------ + # import numpy as np + # x = np.random.rand(1, 80, 520).astype("float32") + # x_tc = torch.from_numpy(x).cuda() + # x_pd = paddle.to_tensor(x) + + # --------------------------- + # dims = ModelDimensions(**checkpoint_dims) + + # whisper_torch = Whisper_torch(dims).cuda() + # whisper = Whisper(dims) + # --------------------------- + + # 转化模型参数 + # AudioEncoder_torch2paddle(whisper_torch.encoder, whisper.encoder) + + # y_tc = whisper_torch.encoder( x_tc ).detach().cpu().numpy() + # y_pd = whisper.encoder( x_pd ).detach().cpu().numpy() + + # print( + # abs(y_tc - y_pd).max() + # ) + # ---------------------------------------------------- + + spkPaths = os.listdir(wavPath) + # random.shuffle(spkPaths) # why shuffle + spkPaths = sorted(spkPaths) + + for spks in spkPaths: + if os.path.isdir(f"./{wavPath}/{spks}"): + os.makedirs(f"./{ppgPath}/{spks}", exist_ok=True) + + files = [f for f in os.listdir(f"./{wavPath}/{spks}") if f.endswith(".wav")] + for file in tqdm(files, desc=f'Processing ppg {spks}'): + if file.endswith(".wav"): + # print(file) + file = file[:-4] + path_wav = f"{wavPath}/{spks}/{file}.wav" + path_ppg = f"{ppgPath}/{spks}/{file}.ppg" + # if os.path.isfile(f"{path_ppg}.npy"): + # continue + # y_tc = pred_ppg_torch(whisper_torch, path_wav, path_ppg) + y_pd = pred_ppg(whisper, path_wav, path_ppg) + + # print( + # abs(y_tc - y_pd).max() + # ) diff --git a/paddlemix/models/vits-svc/prepare/preprocess_speaker.py b/paddlemix/models/vits-svc/prepare/preprocess_speaker.py new file mode 100644 index 000000000..9167ab486 --- /dev/null +++ b/paddlemix/models/vits-svc/prepare/preprocess_speaker.py @@ -0,0 +1,120 @@ +import sys,os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import paddle +import numpy as np +import argparse + +from tqdm import tqdm +from functools import partial +from argparse import RawTextHelpFormatter +from multiprocessing.pool import ThreadPool + +from speaker.models.lstm import LSTMSpeakerEncoder +from speaker.config import SpeakerEncoderConfig +from speaker.utils.audio import AudioProcessor +from speaker.infer import read_json + + +def get_spk_wavs(dataset_path, output_path): + wav_files = [] + os.makedirs(f"./{output_path}", exist_ok=True) + for spks in os.listdir(dataset_path): + if os.path.isdir(f"./{dataset_path}/{spks}"): + os.makedirs(f"./{output_path}/{spks}", exist_ok=True) + for file in os.listdir(f"./{dataset_path}/{spks}"): + if file.endswith(".wav"): + wav_files.append(f"./{dataset_path}/{spks}/{file}") + elif spks.endswith(".wav"): + wav_files.append(f"./{dataset_path}/{spks}") + return wav_files + + +def process_wav(wav_file, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder): + waveform = speaker_encoder_ap.load_wav( + wav_file, sr=speaker_encoder_ap.sample_rate + ) + spec = speaker_encoder_ap.melspectrogram(waveform) + # spec = torch.from_numpy(spec.T) + # if args.use_cuda: + # spec = spec.cuda() + spec = paddle.to_tensor(spec.T) + spec = spec.unsqueeze(0) + embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy() + embed = embed.squeeze() + embed_path = wav_file.replace(dataset_path, output_path) + embed_path = embed_path.replace(".wav", ".spk") + np.save(embed_path, embed, allow_pickle=False) + + +def extract_speaker_embeddings(wav_files, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder, concurrency): + bound_process_wav = partial(process_wav, dataset_path=dataset_path, output_path=output_path, args=args, speaker_encoder_ap=speaker_encoder_ap, speaker_encoder=speaker_encoder) + + with ThreadPool(concurrency) as pool: + list(tqdm(pool.imap(bound_process_wav, wav_files), total=len(wav_files))) + + +if __name__ == "__main__": + + pd_file = "~/Desktop/PaddleMIX/paddlemix/models/vits-svc/data_svc/speaker/421_all/000003.spk.npy" + pd_file = os.path.expanduser(pd_file) + + tc_file = "~/Desktop/whisper-vits-svc/data_svc/speaker/421_all/000003.spk.npy" + tc_file = os.path.expanduser(pd_file) + + pd_npy = np.load(pd_file) + tc_npy = np.load(tc_file) + + print( + abs(pd_npy - tc_npy).max() + ) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="""Compute embedding vectors for each wav file in a dataset.""", + formatter_class=RawTextHelpFormatter, + ) + parser.add_argument("dataset_path", type=str, help="Path to dataset waves.") + parser.add_argument( + "output_path", type=str, help="path for output speaker/speaker_wavs.npy." + ) + parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) + parser.add_argument("-t", "--thread_count", help="thread count to process, set 0 to use all cpu cores", dest="thread_count", type=int, default=1) + args = parser.parse_args() + dataset_path = args.dataset_path + output_path = args.output_path + thread_count = args.thread_count + # model + args.model_path = os.path.join("speaker_pretrain", "best_model.pdparam") + args.config_path = os.path.join("speaker_pretrain", "config.json") + # config + config_dict = read_json(args.config_path) + + # model + config = SpeakerEncoderConfig(config_dict) + config.from_dict(config_dict) + + speaker_encoder = LSTMSpeakerEncoder( + config.model_params["input_dim"], + config.model_params["proj_dim"], + config.model_params["lstm_dim"], + config.model_params["num_lstm_layers"], + ) + + speaker_encoder.load_checkpoint(args.model_path, eval=True) + + # preprocess + speaker_encoder_ap = AudioProcessor(**config.audio) + # normalize the input audio level and trim silences + speaker_encoder_ap.do_sound_norm = True + speaker_encoder_ap.do_trim_silence = True + + wav_files = get_spk_wavs(dataset_path, output_path) + + if thread_count == 0: + process_num = os.cpu_count() + else: + process_num = thread_count + + extract_speaker_embeddings(wav_files, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder, process_num) \ No newline at end of file diff --git a/paddlemix/models/vits-svc/prepare/preprocess_speaker_ave.py b/paddlemix/models/vits-svc/prepare/preprocess_speaker_ave.py new file mode 100644 index 000000000..de464c3d4 --- /dev/null +++ b/paddlemix/models/vits-svc/prepare/preprocess_speaker_ave.py @@ -0,0 +1,70 @@ +import os +import paddle +import argparse +import numpy as np +from tqdm import tqdm + +if __name__ == '__main__': + + pd_file = "~/Desktop/PaddleMIX/paddlemix/models/vits-svc/data_svc/singer/422_all.spk.npy" + pd_file = os.path.expanduser(pd_file) + + tc_file = "~/Desktop/whisper-vits-svc/data_svc/singer/422_all.spk.npy" + tc_file = os.path.expanduser(pd_file) + + pd_npy = np.load(pd_file) + tc_npy = np.load(tc_file) + + print( + abs(pd_npy - tc_npy).max() + ) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--dataset_speaker", type=str, default="data_svc/speaker/") + parser.add_argument("--dataset_singer", type=str, default="data_svc/singer") + + data_speaker = parser.parse_args().dataset_speaker + data_singer = parser.parse_args().dataset_singer + + os.makedirs(data_singer, exist_ok=True) + + for speaker in os.listdir(data_speaker): + subfile_num = 0 + speaker_ave = 0 + + for file in tqdm(os.listdir(os.path.join(data_speaker, speaker)), desc=f"average {speaker}"): + if not file.endswith(".npy"): + continue + source_embed = np.load(os.path.join(data_speaker, speaker, file)) + source_embed = source_embed.astype(np.float32) + speaker_ave = speaker_ave + source_embed + subfile_num = subfile_num + 1 + if subfile_num == 0: + continue + speaker_ave = speaker_ave / subfile_num + + np.save(os.path.join(data_singer, f"{speaker}.spk.npy"), + speaker_ave, allow_pickle=False) + + # rewrite timbre code by average, if similarity is larger than cmp_val + rewrite_timbre_code = False + if not rewrite_timbre_code: + continue + cmp_src = paddle.to_tensor(speaker_ave, dtype="float32") + cmp_num = 0 + cmp_val = 0.85 + for file in tqdm(os.listdir(os.path.join(data_speaker, speaker)), desc=f"rewrite {speaker}"): + if not file.endswith(".npy"): + continue + cmp_tmp = np.load(os.path.join(data_speaker, speaker, file)) + cmp_tmp = cmp_tmp.astype(np.float32) + cmp_tmp = paddle.to_tensor(cmp_tmp, dtype="float32") + cmp_cos = paddle.nn.functional.cosine_similarity(cmp_src, cmp_tmp, axis=0) + if (cmp_cos > cmp_val): + cmp_num += 1 + np.save(os.path.join(data_speaker, speaker, file), + speaker_ave, allow_pickle=False) + print(f"rewrite timbre for {speaker} with :", cmp_num) diff --git a/paddlemix/models/vits-svc/prepare/preprocess_spec.py b/paddlemix/models/vits-svc/prepare/preprocess_spec.py new file mode 100644 index 000000000..51f8ec38a --- /dev/null +++ b/paddlemix/models/vits-svc/prepare/preprocess_spec.py @@ -0,0 +1,82 @@ +import sys,os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import torch +import paddle +import argparse +import multiprocessing +from concurrent.futures import ThreadPoolExecutor +from tqdm import tqdm +from vits import spectrogram +from vits import utils +from omegaconf import OmegaConf + + +def compute_spec_paddle(hps, filename, specname): + audio, sampling_rate = utils.load_wav_to_paddle(filename) + assert sampling_rate == hps.sampling_rate, f"{sampling_rate} is not {hps.sampling_rate}" + audio_norm = audio / hps.max_wav_value + audio_norm = audio_norm.unsqueeze(0) + n_fft = hps.filter_length + sampling_rate = hps.sampling_rate + hop_size = hps.hop_length + win_size = hps.win_length + spec = spectrogram.spectrogram_paddle( + audio_norm, n_fft, sampling_rate, hop_size, win_size, center=False) + spec = paddle.squeeze(spec, 0) + + # print(f"paddle {spec.mean().item()} {spec.std().item()}") + paddle.save(spec, specname) + +# def compute_spec_torch(hps, filename, specname): +# audio, sampling_rate = utils.load_wav_to_torch(filename) +# assert sampling_rate == hps.sampling_rate, f"{sampling_rate} is not {hps.sampling_rate}" +# audio_norm = audio / hps.max_wav_value +# audio_norm = audio_norm.unsqueeze(0) +# n_fft = hps.filter_length +# sampling_rate = hps.sampling_rate +# hop_size = hps.hop_length +# win_size = hps.win_length +# spec = spectrogram.spectrogram_torch( +# audio_norm, n_fft, sampling_rate, hop_size, win_size, center=False) +# spec = torch.squeeze(spec, 0) + +# # print(f"torch {spec.mean().item()} {spec.std().item()}") +# torch.save(spec, specname) + +def process_file(file): + if file.endswith(".wav"): + file = file[:-4] + compute_spec_paddle(hps.data, f"{wavPath}/{spks}/{file}.wav", f"{spePath}/{spks}/{file}.pt") + # compute_spec_torch(hps.data, f"{wavPath}/{spks}/{file}.wav", f"{spePath}/{spks}/{file}.pt") + + + +def process_files_with_thread_pool(wavPath, spks, thread_num): + files = os.listdir(f"./{wavPath}/{spks}") + with ThreadPoolExecutor(max_workers=thread_num) as executor: + list(tqdm(executor.map(process_file, files), total=len(files), desc=f'Processing spec {spks}')) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-w", "--wav", help="wav", dest="wav", required=True) + parser.add_argument("-s", "--spe", help="spe", dest="spe", required=True) + parser.add_argument("-t", "--thread_count", help="thread count to process, set 0 to use all cpu cores", dest="thread_count", type=int, default=1) + + args = parser.parse_args() + print(args.wav) + print(args.spe) + + os.makedirs(args.spe, exist_ok=True) + wavPath = args.wav + spePath = args.spe + hps = OmegaConf.load("./configs/base.yaml") + + for spks in os.listdir(wavPath): + if os.path.isdir(f"./{wavPath}/{spks}"): + os.makedirs(f"./{spePath}/{spks}", exist_ok=True) + if args.thread_count == 0: + process_num = os.cpu_count() // 2 + 1 + else: + process_num = args.thread_count + process_files_with_thread_pool(wavPath, spks, process_num) diff --git a/paddlemix/models/vits-svc/prepare/preprocess_train.py b/paddlemix/models/vits-svc/prepare/preprocess_train.py new file mode 100644 index 000000000..985738ec3 --- /dev/null +++ b/paddlemix/models/vits-svc/prepare/preprocess_train.py @@ -0,0 +1,68 @@ +import os +import random + + +def print_error(info): + print(f"\033[31m File isn't existed: {info}\033[0m") + + +IndexBySinger = False +if __name__ == "__main__": + os.makedirs("./files/", exist_ok=True) + + rootPath = "./data_svc/waves-32k/" + all_items = [] + for spks in os.listdir(f"./{rootPath}"): + if not os.path.isdir(f"./{rootPath}/{spks}"): + continue + print(f"./{rootPath}/{spks}") + for file in os.listdir(f"./{rootPath}/{spks}"): + if file.endswith(".wav"): + file = file[:-4] + + if (IndexBySinger == False): + path_spk = f"./data_svc/speaker/{spks}/{file}.spk.npy" + else: + path_spk = f"./data_svc/singer/{spks}.spk.npy" + + path_wave = f"./data_svc/waves-32k/{spks}/{file}.wav" + path_spec = f"./data_svc/specs/{spks}/{file}.pt" + path_pitch = f"./data_svc/pitch/{spks}/{file}.pit.npy" + path_hubert = f"./data_svc/hubert/{spks}/{file}.vec.npy" + path_whisper = f"./data_svc/whisper/{spks}/{file}.ppg.npy" + has_error = 0 + if not os.path.isfile(path_spk): + print_error(path_spk) + has_error = 1 + if not os.path.isfile(path_wave): + print_error(path_wave) + has_error = 1 + if not os.path.isfile(path_spec): + print_error(path_spec) + has_error = 1 + if not os.path.isfile(path_pitch): + print_error(path_pitch) + has_error = 1 + if not os.path.isfile(path_hubert): + print_error(path_hubert) + has_error = 1 + if not os.path.isfile(path_whisper): + print_error(path_whisper) + has_error = 1 + if has_error == 0: + all_items.append( + f"{path_wave}|{path_spec}|{path_pitch}|{path_hubert}|{path_whisper}|{path_spk}") + + random.shuffle(all_items) + valids = all_items[:10] + valids.sort() + trains = all_items[10:] + # trains.sort() + fw = open("./files/valid.txt", "w", encoding="utf-8") + for strs in valids: + print(strs, file=fw) + fw.close() + fw = open("./files/train.txt", "w", encoding="utf-8") + for strs in trains: + print(strs, file=fw) + fw.close() diff --git a/paddlemix/models/vits-svc/speaker/README.md b/paddlemix/models/vits-svc/speaker/README.md new file mode 100644 index 000000000..b6f541f88 --- /dev/null +++ b/paddlemix/models/vits-svc/speaker/README.md @@ -0,0 +1,18 @@ +### Speaker Encoder + +This is an implementation of https://arxiv.org/abs/1710.10467. This model can be used for voice and speaker embedding. + +With the code here you can generate d-vectors for both multi-speaker and single-speaker TTS datasets, then visualise and explore them along with the associated audio files in an interactive chart. + +Below is an example showing embedding results of various speakers. You can generate the same plot with the provided notebook as demonstrated in [this video](https://youtu.be/KW3oO7JVa7Q). + +![](umap.png) + +Download a pretrained model from [Released Models](https://github.com/mozilla/TTS/wiki/Released-Models) page. + +To run the code, you need to follow the same flow as in TTS. + +- Define 'config.json' for your needs. Note that, audio parameters should match your TTS model. +- Example training call ```python speaker_encoder/train.py --config_path speaker_encoder/config.json --data_path ~/Data/Libri-TTS/train-clean-360``` +- Generate embedding vectors ```python speaker_encoder/compute_embeddings.py --use_cuda true /model/path/best_model.pth.tar model/config/path/config.json dataset/path/ output_path``` . This code parses all .wav files at the given dataset path and generates the same folder structure under the output path with the generated embedding files. +- Watch training on Tensorboard as in TTS diff --git a/paddlemix/models/vits-svc/speaker/__init__.py b/paddlemix/models/vits-svc/speaker/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/paddlemix/models/vits-svc/speaker/config.py b/paddlemix/models/vits-svc/speaker/config.py new file mode 100644 index 000000000..0ef27b28e --- /dev/null +++ b/paddlemix/models/vits-svc/speaker/config.py @@ -0,0 +1,64 @@ +from dataclasses import asdict, dataclass, field +from typing import Dict, List + +from .utils.coqpit import MISSING +from .utils.shared_configs import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig +# from .utils.shared_configs import BaseTrainingConfig + +@dataclass +class SpeakerEncoderConfig(BaseTrainingConfig): + """Defines parameters for Speaker Encoder model.""" + + model: str = "speaker_encoder" + audio: BaseAudioConfig = field(default_factory=BaseAudioConfig) + datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()]) + # model params + model_params: Dict = field( + default_factory=lambda: { + "model_name": "lstm", + "input_dim": 80, + "proj_dim": 256, + "lstm_dim": 768, + "num_lstm_layers": 3, + "use_lstm_with_projection": True, + } + ) + + audio_augmentation: Dict = field(default_factory=lambda: {}) + + storage: Dict = field( + default_factory=lambda: { + "sample_from_storage_p": 0.66, # the probability with which we'll sample from the DataSet in-memory storage + "storage_size": 15, # the size of the in-memory storage with respect to a single batch + } + ) + + # training params + max_train_step: int = 1000000 # end training when number of training steps reaches this value. + loss: str = "angleproto" + grad_clip: float = 3.0 + lr: float = 0.0001 + lr_decay: bool = False + warmup_steps: int = 4000 + wd: float = 1e-6 + + # logging params + tb_model_param_stats: bool = False + steps_plot_stats: int = 10 + checkpoint: bool = True + save_step: int = 1000 + print_step: int = 20 + + # data loader + num_speakers_in_batch: int = MISSING + num_utters_per_speaker: int = MISSING + num_loader_workers: int = MISSING + skip_speakers: bool = False + voice_len: float = 1.6 + + def check_values(self): + super().check_values() + c = asdict(self) + assert ( + c["model_params"]["input_dim"] == self.audio.num_mels + ), " [!] model input dimendion must be equal to melspectrogram dimension." diff --git a/paddlemix/models/vits-svc/speaker/infer.py b/paddlemix/models/vits-svc/speaker/infer.py new file mode 100644 index 000000000..77a049577 --- /dev/null +++ b/paddlemix/models/vits-svc/speaker/infer.py @@ -0,0 +1,108 @@ +import re +import json +import fsspec +# # import torch +# import numpy as np +# import argparse + +# from argparse import RawTextHelpFormatter +# from .models.lstm import LSTMSpeakerEncoder +# from .config import SpeakerEncoderConfig +# from .utils.audio import AudioProcessor + + +def read_json(json_path): + config_dict = {} + try: + with fsspec.open(json_path, "r", encoding="utf-8") as f: + data = json.load(f) + except json.decoder.JSONDecodeError: + # backwards compat. + data = read_json_with_comments(json_path) + config_dict.update(data) + return config_dict + + +def read_json_with_comments(json_path): + """for backward compat.""" + # fallback to json + with fsspec.open(json_path, "r", encoding="utf-8") as f: + input_str = f.read() + # handle comments + input_str = re.sub(r"\\\n", "", input_str) + input_str = re.sub(r"//.*\n", "\n", input_str) + data = json.loads(input_str) + return data + + +# if __name__ == "__main__": + +# parser = argparse.ArgumentParser( +# description="""Compute embedding vectors for each wav file in a dataset.""", +# formatter_class=RawTextHelpFormatter, +# ) +# parser.add_argument("model_path", type=str, help="Path to model checkpoint file.") +# parser.add_argument( +# "config_path", +# type=str, +# help="Path to model config file.", +# ) + +# parser.add_argument("-s", "--source", help="input wave", dest="source") +# parser.add_argument( +# "-t", "--target", help="output 256d speaker embeddimg", dest="target" +# ) + +# parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) +# parser.add_argument("--eval", type=bool, help="compute eval.", default=True) + +# args = parser.parse_args() +# source_file = args.source +# target_file = args.target + +# # config +# config_dict = read_json(args.config_path) +# # print(config_dict) + +# # model +# config = SpeakerEncoderConfig(config_dict) +# config.from_dict(config_dict) + +# speaker_encoder = LSTMSpeakerEncoder( +# config.model_params["input_dim"], +# config.model_params["proj_dim"], +# config.model_params["lstm_dim"], +# config.model_params["num_lstm_layers"], +# ) + +# speaker_encoder.load_checkpoint(args.model_path, eval=True, use_cuda=args.use_cuda) + +# # preprocess +# speaker_encoder_ap = AudioProcessor(**config.audio) +# # normalize the input audio level and trim silences +# speaker_encoder_ap.do_sound_norm = True +# speaker_encoder_ap.do_trim_silence = True + +# # compute speaker embeddings + +# # extract the embedding +# waveform = speaker_encoder_ap.load_wav( +# source_file, sr=speaker_encoder_ap.sample_rate +# ) +# spec = speaker_encoder_ap.melspectrogram(waveform) +# spec = torch.from_numpy(spec.T) +# if args.use_cuda: +# spec = spec.cuda() +# spec = spec.unsqueeze(0) +# embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy() +# embed = embed.squeeze() +# # print(embed) +# # print(embed.size) +# np.save(target_file, embed, allow_pickle=False) + + +# if hasattr(speaker_encoder, 'module'): +# state_dict = speaker_encoder.module.state_dict() +# else: +# state_dict = speaker_encoder.state_dict() +# torch.save({'model': state_dict}, "model_small.pth") diff --git a/paddlemix/models/vits-svc/speaker/models/__init__.py b/paddlemix/models/vits-svc/speaker/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/paddlemix/models/vits-svc/speaker/models/lstm.py b/paddlemix/models/vits-svc/speaker/models/lstm.py new file mode 100644 index 000000000..34607f67e --- /dev/null +++ b/paddlemix/models/vits-svc/speaker/models/lstm.py @@ -0,0 +1,443 @@ +import numpy as np +import torch +import paddle + +import sys; sys.path.append("~/Desktop/PaddleMIX/paddlemix/models/vits-svc") +from speaker.utils.io import load_fsspec_torch + + + +class LSTMWithProjection_torch(torch.nn.Module): + def __init__(self, input_size, hidden_size, proj_size): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.proj_size = proj_size + self.lstm = torch.nn.LSTM(input_size, hidden_size, batch_first=True) + self.linear = torch.nn.Linear(hidden_size, proj_size, bias=False) + + def forward(self, x): + # self.lstm.flatten_parameters() + o, (_, _) = self.lstm(x) + return self.linear(o) + + + +class LSTMWithProjection(paddle.nn.Layer): + def __init__(self, input_size, hidden_size, proj_size): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.proj_size = proj_size + self.lstm = paddle.nn.LSTM(input_size, hidden_size) # batch_first=True + self.linear = paddle.nn.Linear(hidden_size, proj_size, bias_attr=False) + + def forward(self, x): + # self.lstm.flatten_parameters() + o, (_, _) = self.lstm(x) + return self.linear(o) + + + +def LSTMWithProjection_torch2paddle(lstm_paddle, lstm_torch): + + + # pd_model_state_dict = lstm_paddle.state_dict() + pd_model_state_dict = {} + tc_model_state_dict = lstm_torch.state_dict() + + # print( + # pd_model_state_dict['lstm.weight_ih_l0'] is pd_model_state_dict['lstm.0.cell.weight_ih'], + # pd_model_state_dict['lstm.weight_hh_l0'] is pd_model_state_dict['lstm.0.cell.weight_hh'], + # pd_model_state_dict['lstm.bias_ih_l0'] is pd_model_state_dict['lstm.0.cell.bias_ih'], + # pd_model_state_dict['lstm.bias_hh_l0'] is pd_model_state_dict['lstm.0.cell.bias_hh'] + # ) + + pd_model_state_dict['lstm.weight_ih_l0'] = paddle.to_tensor( + tc_model_state_dict['lstm.weight_ih_l0'].detach().cpu().numpy() + ) + pd_model_state_dict['lstm.weight_hh_l0'] = paddle.to_tensor( + tc_model_state_dict['lstm.weight_hh_l0'].detach().cpu().numpy() + ) + pd_model_state_dict['lstm.bias_ih_l0'] = paddle.to_tensor( + tc_model_state_dict['lstm.bias_ih_l0'].detach().cpu().numpy() + ) + pd_model_state_dict['lstm.bias_hh_l0'] = paddle.to_tensor( + tc_model_state_dict['lstm.bias_hh_l0'].detach().cpu().numpy() + ) + + # # ------------------------------------------- + pd_model_state_dict['lstm.0.cell.weight_ih'] = paddle.to_tensor( + tc_model_state_dict['lstm.weight_ih_l0'].detach().cpu().numpy() + ) + pd_model_state_dict['lstm.0.cell.weight_hh'] = paddle.to_tensor( + tc_model_state_dict['lstm.weight_hh_l0'].detach().cpu().numpy() + ) + pd_model_state_dict['lstm.0.cell.bias_ih'] = paddle.to_tensor( + tc_model_state_dict['lstm.bias_ih_l0'].detach().cpu().numpy() + ) + pd_model_state_dict['lstm.0.cell.bias_hh'] = paddle.to_tensor( + tc_model_state_dict['lstm.bias_hh_l0'].detach().cpu().numpy() + ) + + lstm_paddle.load_dict(pd_model_state_dict) + + lstm_paddle.linear.weight.set_value( + paddle.to_tensor( lstm_torch.linear.weight.data.cpu().numpy().T ) + ) + + return lstm_paddle + + +# if __name__ == "__main__": + +# # ---------- 测试结果 ---------- +# input_size, hidden_size, proj_size = 80, 768, 256 + +# # lstm 模型 +# lstm_paddle = LSTMWithProjection(input_size, hidden_size, proj_size) +# lstm_torch = LSTMWithProjection_torch(input_size, hidden_size, proj_size).cuda() + +# # lstm 参数传递 +# lstm_paddle = LSTMWithProjection_torch2paddle(lstm_paddle, lstm_torch) + +# # 输入参数 +# x = np.random.rand(10, 250, 80).astype("float32") +# x_tc = torch.from_numpy(x).cuda() +# x_pd = paddle.to_tensor(x) + +# lstm_paddle.lstm.could_use_cudnn = False + +# y_pd, (_, _) = lstm_paddle.lstm(x_pd) +# y_tc, (_, _) = lstm_torch.lstm(x_tc) + +# y_pd = y_pd.numpy() +# y_tc = y_tc.detach().cpu().numpy() + +# print( +# abs( +# y_pd - y_tc +# ).max() +# ) + +# y_pd = lstm_paddle(x_pd) +# y_tc = lstm_torch(x_tc) + +# y_pd = y_pd.numpy() +# y_tc = y_tc.detach().cpu().numpy() + +# print( +# abs( +# y_pd - y_tc +# ).max(), +# f"mean: {y_pd.mean() - y_tc.mean()}", +# f"std : {y_pd.std() - y_tc.std()}", +# ) + + + + +# class LSTMWithoutProjection(torch.nn.Module): +# def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers): +# super().__init__() +# self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True) +# self.linear = nn.Linear(lstm_dim, proj_dim, bias=True) +# self.relu = nn.ReLU() + +# def forward(self, x): +# _, (hidden, _) = self.lstm(x) +# return self.relu(self.linear(hidden[-1])) + + +class LSTMSpeakerEncoder_torch(torch.nn.Module): + def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True): + super().__init__() + self.use_lstm_with_projection = use_lstm_with_projection + layers = [] + # choise LSTM layer + if use_lstm_with_projection: + layers.append(LSTMWithProjection_torch(input_dim, lstm_dim, proj_dim)) + for _ in range(num_lstm_layers - 1): + layers.append(LSTMWithProjection_torch(proj_dim, lstm_dim, proj_dim)) + self.layers = torch.nn.Sequential(*layers) + else: + # self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers) + pass + + self._init_layers() + + def _init_layers(self): + for name, param in self.layers.named_parameters(): + if "bias" in name: + torch.nn.init.constant_(param, 0.0) + elif "weight" in name: + torch.nn.init.xavier_normal_(param) + + # def forward(self, x): + # # TODO: implement state passing for lstms + # d = self.layers(x) + # if self.use_lstm_with_projection: + # d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1) + # else: + # d = torch.nn.functional.normalize(d, p=2, dim=1) + # return d + + @torch.no_grad() + def inference(self, x): # + print("torch", x.mean().item(), x.std().item()) + d = self.layers.forward(x) + if self.use_lstm_with_projection: + d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1) + else: + d = torch.nn.functional.normalize(d, p=2, dim=1) + return d + + def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True): # + """ + Generate embeddings for a batch of utterances + x: 1xTxD + """ + max_len = x.shape[1] + + if max_len < num_frames: + num_frames = max_len + + offsets = np.linspace(0, max_len - num_frames, num=num_eval) + + frames_batch = [] + for offset in offsets: + offset = int(offset) + end_offset = int(offset + num_frames) + frames = x[:, offset:end_offset] + frames_batch.append(frames) + + frames_batch = torch.cat(frames_batch, dim=0) + embeddings = self.inference(frames_batch) + + if return_mean: + embeddings = torch.mean(embeddings, dim=0, keepdim=True) + + return embeddings + + # def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5): + # """ + # Generate embeddings for a batch of utterances + # x: BxTxD + # """ + # num_overlap = num_frames * overlap + # max_len = x.shape[1] + # embed = None + # num_iters = seq_lens / (num_frames - num_overlap) + # cur_iter = 0 + # for offset in range(0, max_len, num_frames - num_overlap): + # cur_iter += 1 + # end_offset = min(x.shape[1], offset + num_frames) + # frames = x[:, offset:end_offset] + # if embed is None: + # embed = self.inference(frames) + # else: + # embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :]) + # return embed / num_iters + + # pylint: disable=unused-argument, redefined-builtin + def load_checkpoint(self, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): + state = load_fsspec_torch(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if use_cuda: + self.cuda() + if eval: + self.eval() + assert not self.training + + +class LSTMSpeakerEncoder(paddle.nn.Layer): + def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True): + super().__init__() + self.use_lstm_with_projection = use_lstm_with_projection + layers = [] + # choise LSTM layer + if use_lstm_with_projection: + layers.append(LSTMWithProjection(input_dim, lstm_dim, proj_dim)) + for _ in range(num_lstm_layers - 1): + layers.append(LSTMWithProjection(proj_dim, lstm_dim, proj_dim)) + self.layers = paddle.nn.Sequential(*layers) + else: + # self.layers = LSTMWithoutProjection(input_dim, lstm_dim, proj_dim, num_lstm_layers) + raise NotImplementedError() + + self._init_layers() + + def _init_layers(self): + for name, param in self.layers.named_parameters(): + if 'bias' in name: + init_Constant = paddle.nn.initializer.Constant(value=0.0) + init_Constant(param) + elif 'weight' in name: + init_XavierNormal = paddle.nn.initializer.XavierNormal() + init_XavierNormal(param) + + + + # def forward(self, x): + # # TODO: implement state passing for lstms + # d = self.layers(x) + # if self.use_lstm_with_projection: + # d = torch.nn.functional.normalize(d[:, -1], p=2, dim=1) + # else: + # d = torch.nn.functional.normalize(d, p=2, dim=1) + # return d + + @paddle.no_grad() + def inference(self, x): # + print("paddle", x.mean().item(), x.std().item()) + d = self.layers.forward(x) + if self.use_lstm_with_projection: + d = paddle.nn.functional.normalize(d[:, -1], p=2, axis=1) + else: + d = paddle.nn.functional.normalize(d, p=2, axis=1) + return d + + def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True): # + """ + Generate embeddings for a batch of utterances + x: 1xTxD + """ + max_len = x.shape[1] + + if max_len < num_frames: + num_frames = max_len + + offsets = np.linspace(0, max_len - num_frames, num=num_eval) + + frames_batch = [] + for offset in offsets: + offset = int(offset) + end_offset = int(offset + num_frames) + frames = x[:, offset:end_offset] + frames_batch.append(frames) + + # frames_batch = torch.cat(frames_batch, dim=0) + frames_batch = paddle.concat(frames_batch, axis=0) + embeddings = self.inference(frames_batch) + + if return_mean: + # embeddings = torch.mean(embeddings, dim=0, keepdim=True) + embeddings = paddle.mean(embeddings, axis=0, keepdim=True) + + return embeddings + + # def batch_compute_embedding(self, x, seq_lens, num_frames=160, overlap=0.5): + # """ + # Generate embeddings for a batch of utterances + # x: BxTxD + # """ + # num_overlap = num_frames * overlap + # max_len = x.shape[1] + # embed = None + # num_iters = seq_lens / (num_frames - num_overlap) + # cur_iter = 0 + # for offset in range(0, max_len, num_frames - num_overlap): + # cur_iter += 1 + # end_offset = min(x.shape[1], offset + num_frames) + # frames = x[:, offset:end_offset] + # if embed is None: + # embed = self.inference(frames) + # else: + # embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :]) + # return embed / num_iters + + # pylint: disable=unused-argument, redefined-builtin + + def load_checkpoint(self, checkpoint_path: str, eval: bool = False): + # state = load_fsspec(checkpoint_path) + ckpt = paddle.load( checkpoint_path ) + self.set_state_dict( ckpt ) + if eval: + self.eval() + assert not self.training + + # TODO: https://github.com/PaddlePaddle/Paddle/issues/64989 + for pd_layer in self.layers: + pd_layer.lstm.could_use_cudnn = False + + +# if __name__ == "__main__": + +# tc_model = LSTMSpeakerEncoder_torch(80, 256, 768, 3).cuda() +# pd_model = LSTMSpeakerEncoder(80, 256, 768, 3) + + +# model_path = "speaker_pretrain/best_model.pth.tar" +# tc_model.load_checkpoint(model_path, eval=True, use_cuda=True) + +# # model_path = "speaker_pretrain/best_model.pdparam" +# # pd_model.load_checkpoint(model_path, eval=True) + +# x = np.random.randn(1, 212, 80).astype("float32") +# x_tc = torch.from_numpy(x).cuda() +# x_pd = paddle.to_tensor(x) + + +# for pd_layer, tc_layer in zip(pd_model.layers, tc_model.layers): +# pd_layer.lstm.could_use_cudnn = False +# LSTMWithProjection_torch2paddle(pd_layer, tc_layer) + + +# y_tc = tc_model.compute_embedding(x_tc).detach().cpu().numpy() +# y_pd = pd_model.compute_embedding(x_pd).detach().cpu().numpy() + +# print( +# abs(y_tc - y_pd).max(), +# f"{y_tc.mean().item()} {y_pd.mean().item()}", +# f"{y_tc.std().item()} {y_pd.std().item()}", +# ) + +# paddle.save( +# pd_model.state_dict(), +# "speaker_pretrain/best_model.pdparam" +# ) + + +if __name__ == "__main__": + + tc_model = LSTMSpeakerEncoder_torch(80, 256, 768, 3).cuda() + pd_model = LSTMSpeakerEncoder(80, 256, 768, 3) + + + model_path = "speaker_pretrain/best_model.pth.tar" + tc_model.load_checkpoint(model_path, eval=True, use_cuda=True) + + model_path = "speaker_pretrain/best_model.pdparam" + pd_model.load_checkpoint(model_path, eval=True) + + x = np.random.randn(1, 212, 80).astype("float32") + x_tc = torch.from_numpy(x).cuda() + x_pd = paddle.to_tensor(x) + + + # for pd_layer, tc_layer in zip(pd_model.layers, tc_model.layers): + # # TODO: https://github.com/PaddlePaddle/Paddle/issues/64989 + # pd_layer.lstm.could_use_cudnn = False + # # LSTMWithProjection_torch2paddle(pd_layer, tc_layer) + + + y_tc = tc_model.compute_embedding(x_tc).detach().cpu().numpy() + y_pd = pd_model.compute_embedding(x_pd).detach().cpu().numpy() + + print( + abs(y_tc - y_pd).max(), + f"\nmean: {y_tc.mean().item()} {y_pd.mean().item()}", + f"\nstd: {y_tc.std().item()} {y_pd.std().item()}", + ) + + # paddle.save( + # pd_model.state_dict(), + # "speaker_pretrain/best_model.pdparam" + # ) + + print( + pd_model_state_dict['lstm.weight_ih_l0'] is pd_model_state_dict['lstm.0.cell.weight_ih'], + pd_model_state_dict['lstm.weight_hh_l0'] is pd_model_state_dict['lstm.0.cell.weight_hh'], + pd_model_state_dict['lstm.bias_ih_l0'] is pd_model_state_dict['lstm.0.cell.bias_ih'], + pd_model_state_dict['lstm.bias_hh_l0'] is pd_model_state_dict['lstm.0.cell.bias_hh'] + ) \ No newline at end of file diff --git a/paddlemix/models/vits-svc/speaker/models/resnet.py b/paddlemix/models/vits-svc/speaker/models/resnet.py new file mode 100644 index 000000000..fcc850d7b --- /dev/null +++ b/paddlemix/models/vits-svc/speaker/models/resnet.py @@ -0,0 +1,212 @@ +import numpy as np +import torch +from torch import nn + +from TTS.utils.io import load_fsspec + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=8): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel), + nn.Sigmoid(), + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +class SEBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): + super(SEBasicBlock, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.se = SELayer(planes, reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.relu(out) + out = self.bn1(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + return out + + +class ResNetSpeakerEncoder(nn.Module): + """Implementation of the model H/ASP without batch normalization in speaker embedding. This model was proposed in: https://arxiv.org/abs/2009.14153 + Adapted from: https://github.com/clovaai/voxceleb_trainer + """ + + # pylint: disable=W0102 + def __init__( + self, + input_dim=64, + proj_dim=512, + layers=[3, 4, 6, 3], + num_filters=[32, 64, 128, 256], + encoder_type="ASP", + log_input=False, + ): + super(ResNetSpeakerEncoder, self).__init__() + + self.encoder_type = encoder_type + self.input_dim = input_dim + self.log_input = log_input + self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) + self.relu = nn.ReLU(inplace=True) + self.bn1 = nn.BatchNorm2d(num_filters[0]) + + self.inplanes = num_filters[0] + self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0]) + self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2)) + self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2)) + self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2)) + + self.instancenorm = nn.InstanceNorm1d(input_dim) + + outmap_size = int(self.input_dim / 8) + + self.attention = nn.Sequential( + nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), + nn.ReLU(), + nn.BatchNorm1d(128), + nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), + nn.Softmax(dim=2), + ) + + if self.encoder_type == "SAP": + out_dim = num_filters[3] * outmap_size + elif self.encoder_type == "ASP": + out_dim = num_filters[3] * outmap_size * 2 + else: + raise ValueError("Undefined encoder") + + self.fc = nn.Linear(out_dim, proj_dim) + + self._init_layers() + + def _init_layers(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def create_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + # pylint: disable=R0201 + def new_parameter(self, *size): + out = nn.Parameter(torch.FloatTensor(*size)) + nn.init.xavier_normal_(out) + return out + + def forward(self, x, l2_norm=False): + x = x.transpose(1, 2) + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + if self.log_input: + x = (x + 1e-6).log() + x = self.instancenorm(x).unsqueeze(1) + + x = self.conv1(x) + x = self.relu(x) + x = self.bn1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = x.reshape(x.size()[0], -1, x.size()[-1]) + + w = self.attention(x) + + if self.encoder_type == "SAP": + x = torch.sum(x * w, dim=2) + elif self.encoder_type == "ASP": + mu = torch.sum(x * w, dim=2) + sg = torch.sqrt((torch.sum((x ** 2) * w, dim=2) - mu ** 2).clamp(min=1e-5)) + x = torch.cat((mu, sg), 1) + + x = x.view(x.size()[0], -1) + x = self.fc(x) + + if l2_norm: + x = torch.nn.functional.normalize(x, p=2, dim=1) + return x + + @torch.no_grad() + def compute_embedding(self, x, num_frames=250, num_eval=10, return_mean=True): + """ + Generate embeddings for a batch of utterances + x: 1xTxD + """ + max_len = x.shape[1] + + if max_len < num_frames: + num_frames = max_len + + offsets = np.linspace(0, max_len - num_frames, num=num_eval) + + frames_batch = [] + for offset in offsets: + offset = int(offset) + end_offset = int(offset + num_frames) + frames = x[:, offset:end_offset] + frames_batch.append(frames) + + frames_batch = torch.cat(frames_batch, dim=0) + embeddings = self.forward(frames_batch, l2_norm=True) + + if return_mean: + embeddings = torch.mean(embeddings, dim=0, keepdim=True) + + return embeddings + + def load_checkpoint(self, config: dict, checkpoint_path: str, eval: bool = False, use_cuda: bool = False): + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + if use_cuda: + self.cuda() + if eval: + self.eval() + assert not self.training diff --git a/paddlemix/models/vits-svc/speaker/utils/__init__.py b/paddlemix/models/vits-svc/speaker/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/paddlemix/models/vits-svc/speaker/utils/audio.py b/paddlemix/models/vits-svc/speaker/utils/audio.py new file mode 100644 index 000000000..244583072 --- /dev/null +++ b/paddlemix/models/vits-svc/speaker/utils/audio.py @@ -0,0 +1,822 @@ +from typing import Dict, Tuple + +import librosa +import numpy as np +import pyworld as pw +import scipy.io.wavfile +import scipy.signal +import soundfile as sf +import torch +from torch import nn + +class StandardScaler: + """StandardScaler for mean-scale normalization with the given mean and scale values.""" + + def __init__(self, mean: np.ndarray = None, scale: np.ndarray = None) -> None: + self.mean_ = mean + self.scale_ = scale + + def set_stats(self, mean, scale): + self.mean_ = mean + self.scale_ = scale + + def reset_stats(self): + delattr(self, "mean_") + delattr(self, "scale_") + + def transform(self, X): + X = np.asarray(X) + X -= self.mean_ + X /= self.scale_ + return X + + def inverse_transform(self, X): + X = np.asarray(X) + X *= self.scale_ + X += self.mean_ + return X + +# class TorchSTFT(nn.Module): # pylint: disable=abstract-method +# """Some of the audio processing funtions using Torch for faster batch processing. + +# TODO: Merge this with audio.py +# """ + +# def __init__( +# self, +# n_fft, +# hop_length, +# win_length, +# pad_wav=False, +# window="hann_window", +# sample_rate=None, +# mel_fmin=0, +# mel_fmax=None, +# n_mels=80, +# use_mel=False, +# do_amp_to_db=False, +# spec_gain=1.0, +# ): +# super().__init__() +# self.n_fft = n_fft +# self.hop_length = hop_length +# self.win_length = win_length +# self.pad_wav = pad_wav +# self.sample_rate = sample_rate +# self.mel_fmin = mel_fmin +# self.mel_fmax = mel_fmax +# self.n_mels = n_mels +# self.use_mel = use_mel +# self.do_amp_to_db = do_amp_to_db +# self.spec_gain = spec_gain +# self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) +# self.mel_basis = None +# if use_mel: +# self._build_mel_basis() + +# def __call__(self, x): +# """Compute spectrogram frames by torch based stft. + +# Args: +# x (Tensor): input waveform + +# Returns: +# Tensor: spectrogram frames. + +# Shapes: +# x: [B x T] or [:math:`[B, 1, T]`] +# """ +# if x.ndim == 2: +# x = x.unsqueeze(1) +# if self.pad_wav: +# padding = int((self.n_fft - self.hop_length) / 2) +# x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") +# # B x D x T x 2 +# o = torch.stft( +# x.squeeze(1), +# self.n_fft, +# self.hop_length, +# self.win_length, +# self.window, +# center=True, +# pad_mode="reflect", # compatible with audio.py +# normalized=False, +# onesided=True, +# return_complex=False, +# ) +# M = o[:, :, :, 0] +# P = o[:, :, :, 1] +# S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8)) +# if self.use_mel: +# S = torch.matmul(self.mel_basis.to(x), S) +# if self.do_amp_to_db: +# S = self._amp_to_db(S, spec_gain=self.spec_gain) +# return S + +# def _build_mel_basis(self): +# mel_basis = librosa.filters.mel( +# sr=self.sample_rate, n_fft=self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax +# ) +# self.mel_basis = torch.from_numpy(mel_basis).float() + +# @staticmethod +# def _amp_to_db(x, spec_gain=1.0): +# return torch.log(torch.clamp(x, min=1e-5) * spec_gain) + +# @staticmethod +# def _db_to_amp(x, spec_gain=1.0): +# return torch.exp(x) / spec_gain + + +# pylint: disable=too-many-public-methods +class AudioProcessor(object): + """Audio Processor for TTS used by all the data pipelines. + + Note: + All the class arguments are set to default values to enable a flexible initialization + of the class with the model config. They are not meaningful for all the arguments. + + Args: + sample_rate (int, optional): + target audio sampling rate. Defaults to None. + + resample (bool, optional): + enable/disable resampling of the audio clips when the target sampling rate does not match the original sampling rate. Defaults to False. + + num_mels (int, optional): + number of melspectrogram dimensions. Defaults to None. + + log_func (int, optional): + log exponent used for converting spectrogram aplitude to DB. + + min_level_db (int, optional): + minimum db threshold for the computed melspectrograms. Defaults to None. + + frame_shift_ms (int, optional): + milliseconds of frames between STFT columns. Defaults to None. + + frame_length_ms (int, optional): + milliseconds of STFT window length. Defaults to None. + + hop_length (int, optional): + number of frames between STFT columns. Used if ```frame_shift_ms``` is None. Defaults to None. + + win_length (int, optional): + STFT window length. Used if ```frame_length_ms``` is None. Defaults to None. + + ref_level_db (int, optional): + reference DB level to avoid background noise. In general <20DB corresponds to the air noise. Defaults to None. + + fft_size (int, optional): + FFT window size for STFT. Defaults to 1024. + + power (int, optional): + Exponent value applied to the spectrogram before GriffinLim. Defaults to None. + + preemphasis (float, optional): + Preemphasis coefficient. Preemphasis is disabled if == 0.0. Defaults to 0.0. + + signal_norm (bool, optional): + enable/disable signal normalization. Defaults to None. + + symmetric_norm (bool, optional): + enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else [0, k], Defaults to None. + + max_norm (float, optional): + ```k``` defining the normalization range. Defaults to None. + + mel_fmin (int, optional): + minimum filter frequency for computing melspectrograms. Defaults to None. + + mel_fmax (int, optional): + maximum filter frequency for computing melspectrograms.. Defaults to None. + + spec_gain (int, optional): + gain applied when converting amplitude to DB. Defaults to 20. + + stft_pad_mode (str, optional): + Padding mode for STFT. Defaults to 'reflect'. + + clip_norm (bool, optional): + enable/disable clipping the our of range values in the normalized audio signal. Defaults to True. + + griffin_lim_iters (int, optional): + Number of GriffinLim iterations. Defaults to None. + + do_trim_silence (bool, optional): + enable/disable silence trimming when loading the audio signal. Defaults to False. + + trim_db (int, optional): + DB threshold used for silence trimming. Defaults to 60. + + do_sound_norm (bool, optional): + enable/disable signal normalization. Defaults to False. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True. + + do_amp_to_db_mel (bool, optional): + enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + + stats_path (str, optional): + Path to the computed stats file. Defaults to None. + + verbose (bool, optional): + enable/disable logging. Defaults to True. + + """ + + def __init__( + self, + sample_rate=None, + resample=False, + num_mels=None, + log_func="np.log10", + min_level_db=None, + frame_shift_ms=None, + frame_length_ms=None, + hop_length=None, + win_length=None, + ref_level_db=None, + fft_size=1024, + power=None, + preemphasis=0.0, + signal_norm=None, + symmetric_norm=None, + max_norm=None, + mel_fmin=None, + mel_fmax=None, + spec_gain=20, + stft_pad_mode="reflect", + clip_norm=True, + griffin_lim_iters=None, + do_trim_silence=False, + trim_db=60, + do_sound_norm=False, + do_amp_to_db_linear=True, + do_amp_to_db_mel=True, + stats_path=None, + verbose=True, + **_, + ): + + # setup class attributed + self.sample_rate = sample_rate + self.resample = resample + self.num_mels = num_mels + self.log_func = log_func + self.min_level_db = min_level_db or 0 + self.frame_shift_ms = frame_shift_ms + self.frame_length_ms = frame_length_ms + self.ref_level_db = ref_level_db + self.fft_size = fft_size + self.power = power + self.preemphasis = preemphasis + self.griffin_lim_iters = griffin_lim_iters + self.signal_norm = signal_norm + self.symmetric_norm = symmetric_norm + self.mel_fmin = mel_fmin or 0 + self.mel_fmax = mel_fmax + self.spec_gain = float(spec_gain) + self.stft_pad_mode = stft_pad_mode + self.max_norm = 1.0 if max_norm is None else float(max_norm) + self.clip_norm = clip_norm + self.do_trim_silence = do_trim_silence + self.trim_db = trim_db + self.do_sound_norm = do_sound_norm + self.do_amp_to_db_linear = do_amp_to_db_linear + self.do_amp_to_db_mel = do_amp_to_db_mel + self.stats_path = stats_path + # setup exp_func for db to amp conversion + if log_func == "np.log": + self.base = np.e + elif log_func == "np.log10": + self.base = 10 + else: + raise ValueError(" [!] unknown `log_func` value.") + # setup stft parameters + if hop_length is None: + # compute stft parameters from given time values + self.hop_length, self.win_length = self._stft_parameters() + else: + # use stft parameters from config file + self.hop_length = hop_length + self.win_length = win_length + assert min_level_db != 0.0, " [!] min_level_db is 0" + assert self.win_length <= self.fft_size, " [!] win_length cannot be larger than fft_size" + members = vars(self) + if verbose: + print(" > Setting up Audio Processor...") + for key, value in members.items(): + print(" | > {}:{}".format(key, value)) + # create spectrogram utils + self.mel_basis = self._build_mel_basis() + self.inv_mel_basis = np.linalg.pinv(self._build_mel_basis()) + # setup scaler + if stats_path and signal_norm: + mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path) + self.setup_scaler(mel_mean, mel_std, linear_mean, linear_std) + self.signal_norm = True + self.max_norm = None + self.clip_norm = None + self.symmetric_norm = None + + ### setting up the parameters ### + def _build_mel_basis( + self, + ) -> np.ndarray: + """Build melspectrogram basis. + + Returns: + np.ndarray: melspectrogram basis. + """ + if self.mel_fmax is not None: + assert self.mel_fmax <= self.sample_rate // 2 + return librosa.filters.mel( + sr=self.sample_rate, n_fft=self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax + ) + + def _stft_parameters( + self, + ) -> Tuple[int, int]: + """Compute the real STFT parameters from the time values. + + Returns: + Tuple[int, int]: hop length and window length for STFT. + """ + factor = self.frame_length_ms / self.frame_shift_ms + assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms" + hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate) + win_length = int(hop_length * factor) + return hop_length, win_length + + ### normalization ### + def normalize(self, S: np.ndarray) -> np.ndarray: + """Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]` + + Args: + S (np.ndarray): Spectrogram to normalize. + + Raises: + RuntimeError: Mean and variance is computed from incompatible parameters. + + Returns: + np.ndarray: Normalized spectrogram. + """ + # pylint: disable=no-else-return + S = S.copy() + if self.signal_norm: + # mean-var scaling + if hasattr(self, "mel_scaler"): + if S.shape[0] == self.num_mels: + return self.mel_scaler.transform(S.T).T + elif S.shape[0] == self.fft_size / 2: + return self.linear_scaler.transform(S.T).T + else: + raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") + # range normalization + S -= self.ref_level_db # discard certain range of DB assuming it is air noise + S_norm = (S - self.min_level_db) / (-self.min_level_db) + if self.symmetric_norm: + S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm + if self.clip_norm: + S_norm = np.clip( + S_norm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type + ) + return S_norm + else: + S_norm = self.max_norm * S_norm + if self.clip_norm: + S_norm = np.clip(S_norm, 0, self.max_norm) + return S_norm + else: + return S + + def denormalize(self, S: np.ndarray) -> np.ndarray: + """Denormalize spectrogram values. + + Args: + S (np.ndarray): Spectrogram to denormalize. + + Raises: + RuntimeError: Mean and variance are incompatible. + + Returns: + np.ndarray: Denormalized spectrogram. + """ + # pylint: disable=no-else-return + S_denorm = S.copy() + if self.signal_norm: + # mean-var scaling + if hasattr(self, "mel_scaler"): + if S_denorm.shape[0] == self.num_mels: + return self.mel_scaler.inverse_transform(S_denorm.T).T + elif S_denorm.shape[0] == self.fft_size / 2: + return self.linear_scaler.inverse_transform(S_denorm.T).T + else: + raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") + if self.symmetric_norm: + if self.clip_norm: + S_denorm = np.clip( + S_denorm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type + ) + S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db + return S_denorm + self.ref_level_db + else: + if self.clip_norm: + S_denorm = np.clip(S_denorm, 0, self.max_norm) + S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db + return S_denorm + self.ref_level_db + else: + return S_denorm + + ### Mean-STD scaling ### + def load_stats(self, stats_path: str) -> Tuple[np.array, np.array, np.array, np.array, Dict]: + """Loading mean and variance statistics from a `npy` file. + + Args: + stats_path (str): Path to the `npy` file containing + + Returns: + Tuple[np.array, np.array, np.array, np.array, Dict]: loaded statistics and the config used to + compute them. + """ + stats = np.load(stats_path, allow_pickle=True).item() # pylint: disable=unexpected-keyword-arg + mel_mean = stats["mel_mean"] + mel_std = stats["mel_std"] + linear_mean = stats["linear_mean"] + linear_std = stats["linear_std"] + stats_config = stats["audio_config"] + # check all audio parameters used for computing stats + skip_parameters = ["griffin_lim_iters", "stats_path", "do_trim_silence", "ref_level_db", "power"] + for key in stats_config.keys(): + if key in skip_parameters: + continue + if key not in ["sample_rate", "trim_db"]: + assert ( + stats_config[key] == self.__dict__[key] + ), f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}" + return mel_mean, mel_std, linear_mean, linear_std, stats_config + + # pylint: disable=attribute-defined-outside-init + def setup_scaler( + self, mel_mean: np.ndarray, mel_std: np.ndarray, linear_mean: np.ndarray, linear_std: np.ndarray + ) -> None: + """Initialize scaler objects used in mean-std normalization. + + Args: + mel_mean (np.ndarray): Mean for melspectrograms. + mel_std (np.ndarray): STD for melspectrograms. + linear_mean (np.ndarray): Mean for full scale spectrograms. + linear_std (np.ndarray): STD for full scale spectrograms. + """ + self.mel_scaler = StandardScaler() + self.mel_scaler.set_stats(mel_mean, mel_std) + self.linear_scaler = StandardScaler() + self.linear_scaler.set_stats(linear_mean, linear_std) + + ### DB and AMP conversion ### + # pylint: disable=no-self-use + def _amp_to_db(self, x: np.ndarray) -> np.ndarray: + """Convert amplitude values to decibels. + + Args: + x (np.ndarray): Amplitude spectrogram. + + Returns: + np.ndarray: Decibels spectrogram. + """ + return self.spec_gain * _log(np.maximum(1e-5, x), self.base) + + # pylint: disable=no-self-use + def _db_to_amp(self, x: np.ndarray) -> np.ndarray: + """Convert decibels spectrogram to amplitude spectrogram. + + Args: + x (np.ndarray): Decibels spectrogram. + + Returns: + np.ndarray: Amplitude spectrogram. + """ + return _exp(x / self.spec_gain, self.base) + + ### Preemphasis ### + def apply_preemphasis(self, x: np.ndarray) -> np.ndarray: + """Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values. + + Args: + x (np.ndarray): Audio signal. + + Raises: + RuntimeError: Preemphasis coeff is set to 0. + + Returns: + np.ndarray: Decorrelated audio signal. + """ + if self.preemphasis == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1, -self.preemphasis], [1], x) + + def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray: + """Reverse pre-emphasis.""" + if self.preemphasis == 0: + raise RuntimeError(" [!] Preemphasis is set 0.0.") + return scipy.signal.lfilter([1], [1, -self.preemphasis], x) + + ### SPECTROGRAMs ### + def _linear_to_mel(self, spectrogram: np.ndarray) -> np.ndarray: + """Project a full scale spectrogram to a melspectrogram. + + Args: + spectrogram (np.ndarray): Full scale spectrogram. + + Returns: + np.ndarray: Melspectrogram + """ + return np.dot(self.mel_basis, spectrogram) + + def _mel_to_linear(self, mel_spec: np.ndarray) -> np.ndarray: + """Convert a melspectrogram to full scale spectrogram.""" + return np.maximum(1e-10, np.dot(self.inv_mel_basis, mel_spec)) + + def spectrogram(self, y: np.ndarray) -> np.ndarray: + """Compute a spectrogram from a waveform. + + Args: + y (np.ndarray): Waveform. + + Returns: + np.ndarray: Spectrogram. + """ + if self.preemphasis != 0: + D = self._stft(self.apply_preemphasis(y)) + else: + D = self._stft(y) + if self.do_amp_to_db_linear: + S = self._amp_to_db(np.abs(D)) + else: + S = np.abs(D) + return self.normalize(S).astype(np.float32) + + def melspectrogram(self, y: np.ndarray) -> np.ndarray: + """Compute a melspectrogram from a waveform.""" + if self.preemphasis != 0: + D = self._stft(self.apply_preemphasis(y)) + else: + D = self._stft(y) + if self.do_amp_to_db_mel: + S = self._amp_to_db(self._linear_to_mel(np.abs(D))) + else: + S = self._linear_to_mel(np.abs(D)) + return self.normalize(S).astype(np.float32) + + def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray: + """Convert a spectrogram to a waveform using Griffi-Lim vocoder.""" + S = self.denormalize(spectrogram) + S = self._db_to_amp(S) + # Reconstruct phase + if self.preemphasis != 0: + return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) + return self._griffin_lim(S ** self.power) + + def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray: + """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" + D = self.denormalize(mel_spectrogram) + S = self._db_to_amp(D) + S = self._mel_to_linear(S) # Convert back to linear + if self.preemphasis != 0: + return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) + return self._griffin_lim(S ** self.power) + + def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray: + """Convert a full scale linear spectrogram output of a network to a melspectrogram. + + Args: + linear_spec (np.ndarray): Normalized full scale linear spectrogram. + + Returns: + np.ndarray: Normalized melspectrogram. + """ + S = self.denormalize(linear_spec) + S = self._db_to_amp(S) + S = self._linear_to_mel(np.abs(S)) + S = self._amp_to_db(S) + mel = self.normalize(S) + return mel + + ### STFT and ISTFT ### + def _stft(self, y: np.ndarray) -> np.ndarray: + """Librosa STFT wrapper. + + Args: + y (np.ndarray): Audio signal. + + Returns: + np.ndarray: Complex number array. + """ + return librosa.stft( + y=y, + n_fft=self.fft_size, + hop_length=self.hop_length, + win_length=self.win_length, + pad_mode=self.stft_pad_mode, + window="hann", + center=True, + ) + + def _istft(self, y: np.ndarray) -> np.ndarray: + """Librosa iSTFT wrapper.""" + return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length) + + def _griffin_lim(self, S): + angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) + S_complex = np.abs(S).astype(np.complex) + y = self._istft(S_complex * angles) + if not np.isfinite(y).all(): + print(" [!] Waveform is not finite everywhere. Skipping the GL.") + return np.array([0.0]) + for _ in range(self.griffin_lim_iters): + angles = np.exp(1j * np.angle(self._stft(y))) + y = self._istft(S_complex * angles) + return y + + def compute_stft_paddings(self, x, pad_sides=1): + """Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding + (first and final frames)""" + assert pad_sides in (1, 2) + pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0] + if pad_sides == 1: + return 0, pad + return pad // 2, pad // 2 + pad % 2 + + def compute_f0(self, x: np.ndarray) -> np.ndarray: + """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. + + Args: + x (np.ndarray): Waveform. + + Returns: + np.ndarray: Pitch. + + Examples: + >>> WAV_FILE = filename = librosa.util.example_audio_file() + >>> from TTS.config import BaseAudioConfig + >>> from TTS.utils.audio import AudioProcessor + >>> conf = BaseAudioConfig(mel_fmax=8000) + >>> ap = AudioProcessor(**conf) + >>> wav = ap.load_wav(WAV_FILE, sr=22050)[:5 * 22050] + >>> pitch = ap.compute_f0(wav) + """ + f0, t = pw.dio( + x.astype(np.double), + fs=self.sample_rate, + f0_ceil=self.mel_fmax, + frame_period=1000 * self.hop_length / self.sample_rate, + ) + f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) + # pad = int((self.win_length / self.hop_length) / 2) + # f0 = [0.0] * pad + f0 + [0.0] * pad + # f0 = np.pad(f0, (pad, pad), mode="constant", constant_values=0) + # f0 = np.array(f0, dtype=np.float32) + + # f01, _, _ = librosa.pyin( + # x, + # fmin=65 if self.mel_fmin == 0 else self.mel_fmin, + # fmax=self.mel_fmax, + # frame_length=self.win_length, + # sr=self.sample_rate, + # fill_na=0.0, + # ) + + # spec = self.melspectrogram(x) + return f0 + + ### Audio Processing ### + def find_endpoint(self, wav: np.ndarray, threshold_db=-40, min_silence_sec=0.8) -> int: + """Find the last point without silence at the end of a audio signal. + + Args: + wav (np.ndarray): Audio signal. + threshold_db (int, optional): Silence threshold in decibels. Defaults to -40. + min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8. + + Returns: + int: Last point without silence. + """ + window_length = int(self.sample_rate * min_silence_sec) + hop_length = int(window_length / 4) + threshold = self._db_to_amp(threshold_db) + for x in range(hop_length, len(wav) - window_length, hop_length): + if np.max(wav[x : x + window_length]) < threshold: + return x + hop_length + return len(wav) + + def trim_silence(self, wav): + """Trim silent parts with a threshold and 0.01 sec margin""" + margin = int(self.sample_rate * 0.01) + wav = wav[margin:-margin] + return librosa.effects.trim(wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[ + 0 + ] + + @staticmethod + def sound_norm(x: np.ndarray) -> np.ndarray: + """Normalize the volume of an audio signal. + + Args: + x (np.ndarray): Raw waveform. + + Returns: + np.ndarray: Volume normalized waveform. + """ + return x / abs(x).max() * 0.95 + + ### save and load ### + def load_wav(self, filename: str, sr: int = None) -> np.ndarray: + """Read a wav file using Librosa and optionally resample, silence trim, volume normalize. + + Args: + filename (str): Path to the wav file. + sr (int, optional): Sampling rate for resampling. Defaults to None. + + Returns: + np.ndarray: Loaded waveform. + """ + if self.resample: + x, sr = librosa.load(filename, sr=self.sample_rate) + elif sr is None: + x, sr = sf.read(filename) + assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr) + else: + x, sr = librosa.load(filename, sr=sr) + if self.do_trim_silence: + try: + x = self.trim_silence(x) + except ValueError: + print(f" [!] File cannot be trimmed for silence - {filename}") + if self.do_sound_norm: + x = self.sound_norm(x) + return x + + def save_wav(self, wav: np.ndarray, path: str, sr: int = None) -> None: + """Save a waveform to a file using Scipy. + + Args: + wav (np.ndarray): Waveform to save. + path (str): Path to a output file. + sr (int, optional): Sampling rate used for saving to the file. Defaults to None. + """ + wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav)))) + scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm.astype(np.int16)) + + @staticmethod + def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray: + mu = 2 ** qc - 1 + # wav_abs = np.minimum(np.abs(wav), 1.0) + signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu) + # Quantize signal to the specified number of levels. + signal = (signal + 1) / 2 * mu + 0.5 + return np.floor( + signal, + ) + + @staticmethod + def mulaw_decode(wav, qc): + """Recovers waveform from quantized values.""" + mu = 2 ** qc - 1 + x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) + return x + + @staticmethod + def encode_16bits(x): + return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16) + + @staticmethod + def quantize(x: np.ndarray, bits: int) -> np.ndarray: + """Quantize a waveform to a given number of bits. + + Args: + x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`. + bits (int): Number of quantization bits. + + Returns: + np.ndarray: Quantized waveform. + """ + return (x + 1.0) * (2 ** bits - 1) / 2 + + @staticmethod + def dequantize(x, bits): + """Dequantize a waveform from the given number of bits.""" + return 2 * x / (2 ** bits - 1) - 1 + + +def _log(x, base): + if base == 10: + return np.log10(x) + return np.log(x) + + +def _exp(x, base): + if base == 10: + return np.power(10, x) + return np.exp(x) diff --git a/paddlemix/models/vits-svc/speaker/utils/coqpit.py b/paddlemix/models/vits-svc/speaker/utils/coqpit.py new file mode 100644 index 000000000..f4f781c57 --- /dev/null +++ b/paddlemix/models/vits-svc/speaker/utils/coqpit.py @@ -0,0 +1,954 @@ +import argparse +import functools +import json +import operator +import os +from collections.abc import MutableMapping +from dataclasses import MISSING as _MISSING +from dataclasses import Field, asdict, dataclass, fields, is_dataclass, replace +from pathlib import Path +from pprint import pprint +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union, get_type_hints + +T = TypeVar("T") +MISSING: Any = "???" + + +class _NoDefault(Generic[T]): + pass + + +NoDefaultVar = Union[_NoDefault[T], T] +no_default: NoDefaultVar = _NoDefault() + + +def is_primitive_type(arg_type: Any) -> bool: + """Check if the input type is one of `int, float, str, bool`. + + Args: + arg_type (typing.Any): input type to check. + + Returns: + bool: True if input type is one of `int, float, str, bool`. + """ + try: + return isinstance(arg_type(), (int, float, str, bool)) + except (AttributeError, TypeError): + return False + + +def is_list(arg_type: Any) -> bool: + """Check if the input type is `list` + + Args: + arg_type (typing.Any): input type. + + Returns: + bool: True if input type is `list` + """ + try: + return arg_type is list or arg_type is List or arg_type.__origin__ is list or arg_type.__origin__ is List + except AttributeError: + return False + + +def is_dict(arg_type: Any) -> bool: + """Check if the input type is `dict` + + Args: + arg_type (typing.Any): input type. + + Returns: + bool: True if input type is `dict` + """ + try: + return arg_type is dict or arg_type is Dict or arg_type.__origin__ is dict + except AttributeError: + return False + + +def is_union(arg_type: Any) -> bool: + """Check if the input type is `Union`. + + Args: + arg_type (typing.Any): input type. + + Returns: + bool: True if input type is `Union` + """ + try: + return safe_issubclass(arg_type.__origin__, Union) + except AttributeError: + return False + + +def safe_issubclass(cls, classinfo) -> bool: + """Check if the input type is a subclass of the given class. + + Args: + cls (type): input type. + classinfo (type): parent class. + + Returns: + bool: True if the input type is a subclass of the given class + """ + try: + r = issubclass(cls, classinfo) + except Exception: # pylint: disable=broad-except + return cls is classinfo + else: + return r + + +# def _coqpit_json_default(obj: Any) -> Any: +# if isinstance(obj, Path): +# return str(obj) +# raise TypeError(f"Can't encode object of type {type(obj).__name__}") + + +def _default_value(x: Field): + """Return the default value of the input Field. + + Args: + x (Field): input Field. + + Returns: + object: default value of the input Field. + """ + if x.default not in (MISSING, _MISSING): + return x.default + if x.default_factory not in (MISSING, _MISSING): + return x.default_factory() + return x.default + + +# def _is_optional_field(field) -> bool: +# """Check if the input field is optional. + +# Args: +# field (Field): input Field to check. + +# Returns: +# bool: True if the input field is optional. +# """ +# # return isinstance(field.type, _GenericAlias) and type(None) in getattr(field.type, "__args__") +# return type(None) in getattr(field.type, "__args__") + + +# def my_get_type_hints( +# cls, +# ): +# """Custom `get_type_hints` dealing with https://github.com/python/typing/issues/737 + +# Returns: +# [dataclass]: dataclass to get the type hints of its fields. +# """ +# r_dict = {} +# for base in cls.__class__.__bases__: +# if base == object: +# break +# r_dict.update(my_get_type_hints(base)) +# r_dict.update(get_type_hints(cls)) +# return r_dict + + +def _serialize(x): + """Pick the right serialization for the datatype of the given input. + + Args: + x (object): input object. + + Returns: + object: serialized object. + """ + if isinstance(x, Path): + return str(x) + if isinstance(x, dict): + return {k: _serialize(v) for k, v in x.items()} + if isinstance(x, list): + return [_serialize(xi) for xi in x] + if isinstance(x, Serializable) or issubclass(type(x), Serializable): + return x.serialize() + if isinstance(x, type) and issubclass(x, Serializable): + return x.serialize(x) + return x + + +def _deserialize_dict(x: Dict) -> Dict: + """Deserialize dict. + + Args: + x (Dict): value to deserialized. + + Returns: + Dict: deserialized dictionary. + """ + out_dict = {} + for k, v in x.items(): + if v is None: # if {'key':None} + out_dict[k] = None + else: + out_dict[k] = _deserialize(v, type(v)) + return out_dict + + +def _deserialize_list(x: List, field_type: Type) -> List: + """Deserialize values for List typed fields. + + Args: + x (List): value to be deserialized + field_type (Type): field type. + + Raises: + ValueError: Coqpit does not support multi type-hinted lists. + + Returns: + [List]: deserialized list. + """ + field_args = None + if hasattr(field_type, "__args__") and field_type.__args__: + field_args = field_type.__args__ + elif hasattr(field_type, "__parameters__") and field_type.__parameters__: + # bandaid for python 3.6 + field_args = field_type.__parameters__ + if field_args: + if len(field_args) > 1: + raise ValueError(" [!] Coqpit does not support multi-type hinted 'List'") + field_arg = field_args[0] + # if field type is TypeVar set the current type by the value's type. + if isinstance(field_arg, TypeVar): + field_arg = type(x) + return [_deserialize(xi, field_arg) for xi in x] + return x + + +def _deserialize_union(x: Any, field_type: Type) -> Any: + """Deserialize values for Union typed fields + + Args: + x (Any): value to be deserialized. + field_type (Type): field type. + + Returns: + [Any]: desrialized value. + """ + for arg in field_type.__args__: + # stop after first matching type in Union + try: + x = _deserialize(x, arg) + break + except ValueError: + pass + return x + + +def _deserialize_primitive_types(x: Union[int, float, str, bool], field_type: Type) -> Union[int, float, str, bool]: + """Deserialize python primitive types (float, int, str, bool). + It handles `inf` values exclusively and keeps them float against int fields since int does not support inf values. + + Args: + x (Union[int, float, str, bool]): value to be deserialized. + field_type (Type): field type. + + Returns: + Union[int, float, str, bool]: deserialized value. + """ + + if isinstance(x, (str, bool)): + return x + if isinstance(x, (int, float)): + if x == float("inf") or x == float("-inf"): + # if value type is inf return regardless. + return x + x = field_type(x) + return x + # TODO: Raise an error when x does not match the types. + return None + + +def _deserialize(x: Any, field_type: Any) -> Any: + """Pick the right desrialization for the given object and the corresponding field type. + + Args: + x (object): object to be deserialized. + field_type (type): expected type after deserialization. + + Returns: + object: deserialized object + + """ + # pylint: disable=too-many-return-statements + if is_dict(field_type): + return _deserialize_dict(x) + if is_list(field_type): + return _deserialize_list(x, field_type) + if is_union(field_type): + return _deserialize_union(x, field_type) + if issubclass(field_type, Serializable): + return field_type.deserialize_immutable(x) + if is_primitive_type(field_type): + return _deserialize_primitive_types(x, field_type) + raise ValueError(f" [!] '{type(x)}' value type of '{x}' does not match '{field_type}' field type.") + + +# # Recursive setattr (supports dotted attr names) +# def rsetattr(obj, attr, val): +# def _setitem(obj, attr, val): +# return operator.setitem(obj, int(attr), val) + +# pre, _, post = attr.rpartition(".") +# setfunc = _setitem if post.isnumeric() else setattr + +# return setfunc(rgetattr(obj, pre) if pre else obj, post, val) + + +# # Recursive getattr (supports dotted attr names) +# def rgetattr(obj, attr, *args): +# def _getitem(obj, attr): +# return operator.getitem(obj, int(attr), *args) + +# def _getattr(obj, attr): +# getfunc = _getitem if attr.isnumeric() else getattr +# return getfunc(obj, attr, *args) + +# return functools.reduce(_getattr, [obj] + attr.split(".")) + + +# # Recursive setitem (supports dotted attr names) +# def rsetitem(obj, attr, val): +# pre, _, post = attr.rpartition(".") +# return operator.setitem(rgetitem(obj, pre) if pre else obj, post, val) + + +# # Recursive getitem (supports dotted attr names) +# def rgetitem(obj, attr, *args): +# def _getitem(obj, attr): +# return operator.getitem(obj, int(attr) if attr.isnumeric() else attr, *args) + +# return functools.reduce(_getitem, [obj] + attr.split(".")) + + +@dataclass +class Serializable: + """Gives serialization ability to any inheriting dataclass.""" + + def __post_init__(self): + self._validate_contracts() + for key, value in self.__dict__.items(): + if value is no_default: + raise TypeError(f"__init__ missing 1 required argument: '{key}'") + + def _validate_contracts(self): + dataclass_fields = fields(self) + + for field in dataclass_fields: + + value = getattr(self, field.name) + + if value is None: + if not _is_optional_field(field): + raise TypeError(f"{field.name} is not optional") + + contract = field.metadata.get("contract", None) + + if contract is not None: + if value is not None and not contract(value): + raise ValueError(f"break the contract for {field.name}, {self.__class__.__name__}") + + def validate(self): + """validate if object can serialize / deserialize correctly.""" + self._validate_contracts() + if self != self.__class__.deserialize( # pylint: disable=no-value-for-parameter + json.loads(json.dumps(self.serialize())) + ): + raise ValueError("could not be deserialized with same value") + + def to_dict(self) -> dict: + """Transform serializable object to dict.""" + cls_fields = fields(self) + o = {} + for cls_field in cls_fields: + o[cls_field.name] = getattr(self, cls_field.name) + return o + + def serialize(self) -> dict: + """Serialize object to be json serializable representation.""" + if not is_dataclass(self): + raise TypeError("need to be decorated as dataclass") + + dataclass_fields = fields(self) + + o = {} + + for field in dataclass_fields: + value = getattr(self, field.name) + value = _serialize(value) + o[field.name] = value + return o + + def deserialize(self, data: dict) -> "Serializable": + """Parse input dictionary and desrialize its fields to a dataclass. + + Returns: + self: deserialized `self`. + """ + if not isinstance(data, dict): + raise ValueError() + data = data.copy() + init_kwargs = {} + for field in fields(self): + # if field.name == 'dataset_config': + if field.name not in data: + if field.name in vars(self): + init_kwargs[field.name] = vars(self)[field.name] + continue + raise ValueError(f' [!] Missing required field "{field.name}"') + value = data.get(field.name, _default_value(field)) + if value is None: + init_kwargs[field.name] = value + continue + if value == MISSING: + raise ValueError(f"deserialized with unknown value for {field.name} in {self.__name__}") + value = _deserialize(value, field.type) + init_kwargs[field.name] = value + for k, v in init_kwargs.items(): + setattr(self, k, v) + return self + + @classmethod + def deserialize_immutable(cls, data: dict) -> "Serializable": + """Parse input dictionary and desrialize its fields to a dataclass. + + Returns: + Newly created deserialized object. + """ + if not isinstance(data, dict): + raise ValueError() + data = data.copy() + init_kwargs = {} + for field in fields(cls): + # if field.name == 'dataset_config': + if field.name not in data: + if field.name in vars(cls): + init_kwargs[field.name] = vars(cls)[field.name] + continue + # if not in cls and the default value is not Missing use it + default_value = _default_value(field) + if default_value not in (MISSING, _MISSING): + init_kwargs[field.name] = default_value + continue + raise ValueError(f' [!] Missing required field "{field.name}"') + value = data.get(field.name, _default_value(field)) + if value is None: + init_kwargs[field.name] = value + continue + if value == MISSING: + raise ValueError(f"Deserialized with unknown value for {field.name} in {cls.__name__}") + value = _deserialize(value, field.type) + init_kwargs[field.name] = value + return cls(**init_kwargs) + + +# # ---------------------------------------------------------------------------- # +# # Argument Parsing from `argparse` # +# # ---------------------------------------------------------------------------- # + + +# def _get_help(field): +# try: +# field_help = field.metadata["help"] +# except KeyError: +# field_help = "" +# return field_help + + +# def _init_argparse( +# parser, +# field_name, +# field_type, +# field_default, +# field_default_factory, +# field_help, +# arg_prefix="", +# help_prefix="", +# relaxed_parser=False, +# ): +# has_default = False +# default = None +# if field_default: +# has_default = True +# default = field_default +# elif field_default_factory not in (None, _MISSING): +# has_default = True +# default = field_default_factory() + +# if not has_default and not is_primitive_type(field_type) and not is_list(field_type): +# # aggregate types (fields with a Coqpit subclass as type) are not supported without None +# return parser +# arg_prefix = field_name if arg_prefix == "" else f"{arg_prefix}.{field_name}" +# help_prefix = field_help if help_prefix == "" else f"{help_prefix} - {field_help}" +# if is_dict(field_type): # pylint: disable=no-else-raise +# # NOTE: accept any string in json format as input to dict field. +# parser.add_argument( +# f"--{arg_prefix}", +# dest=arg_prefix, +# default=json.dumps(field_default) if field_default else None, +# type=json.loads, +# ) +# elif is_list(field_type): +# # TODO: We need a more clear help msg for lists. +# if hasattr(field_type, "__args__"): # if the list is hinted +# if len(field_type.__args__) > 1 and not relaxed_parser: +# raise ValueError(" [!] Coqpit does not support multi-type hinted 'List'") +# list_field_type = field_type.__args__[0] +# else: +# raise ValueError(" [!] Coqpit does not support un-hinted 'List'") + +# # TODO: handle list of lists +# if is_list(list_field_type) and relaxed_parser: +# return parser + +# if not has_default or field_default_factory is list: +# if not is_primitive_type(list_field_type) and not relaxed_parser: +# raise NotImplementedError(" [!] Empty list with non primitive inner type is currently not supported.") + +# # If the list's default value is None, the user can specify the entire list by passing multiple parameters +# parser.add_argument( +# f"--{arg_prefix}", +# nargs="*", +# type=list_field_type, +# help=f"Coqpit Field: {help_prefix}", +# ) +# else: +# # If a default value is defined, just enable editing the values from argparse +# # TODO: allow inserting a new value/obj to the end of the list. +# for idx, fv in enumerate(default): +# parser = _init_argparse( +# parser, +# str(idx), +# list_field_type, +# fv, +# field_default_factory, +# field_help="", +# help_prefix=f"{help_prefix} - ", +# arg_prefix=f"{arg_prefix}", +# relaxed_parser=relaxed_parser, +# ) +# elif is_union(field_type): +# # TODO: currently I don't know how to handle Union type on argparse +# if not relaxed_parser: +# raise NotImplementedError( +# " [!] Parsing `Union` field from argparse is not yet implemented. Please create an issue." +# ) +# elif issubclass(field_type, Serializable): +# return default.init_argparse( +# parser, arg_prefix=arg_prefix, help_prefix=help_prefix, relaxed_parser=relaxed_parser +# ) +# elif isinstance(field_type(), bool): + +# def parse_bool(x): +# if x not in ("true", "false"): +# raise ValueError(f' [!] Value for boolean field must be either "true" or "false". Got "{x}".') +# return x == "true" + +# parser.add_argument( +# f"--{arg_prefix}", +# type=parse_bool, +# default=field_default, +# help=f"Coqpit Field: {help_prefix}", +# metavar="true/false", +# ) +# elif is_primitive_type(field_type): +# parser.add_argument( +# f"--{arg_prefix}", +# default=field_default, +# type=field_type, +# help=f"Coqpit Field: {help_prefix}", +# ) +# else: +# if not relaxed_parser: +# raise NotImplementedError(f" [!] '{field_type}' is not supported by arg_parser. Please file a bug report.") +# return parser + + +# ---------------------------------------------------------------------------- # +# Main Coqpit Class # +# ---------------------------------------------------------------------------- # + + +@dataclass +class Coqpit(Serializable, MutableMapping): + """Coqpit base class to be inherited by any Coqpit dataclasses. + It overrides Python `dict` interface and provides `dict` compatible API. + It also enables serializing/deserializing a dataclass to/from a json file, plus some semi-dynamic type and value check. + Note that it does not support all datatypes and likely to fail in some cases. + """ + + _initialized = False + + def _is_initialized(self): + """Check if Coqpit is initialized. Useful to prevent running some aux functions + at the initialization when no attribute has been defined.""" + return "_initialized" in vars(self) and self._initialized + + def __post_init__(self): + self._initialized = True + try: + self.check_values() + except AttributeError: + pass + + ## `dict` API functions + + def __iter__(self): + return iter(asdict(self)) + + def __len__(self): + return len(fields(self)) + + def __setitem__(self, arg: str, value: Any): + setattr(self, arg, value) + + def __getitem__(self, arg: str): + """Access class attributes with ``[arg]``.""" + return self.__dict__[arg] + + def __delitem__(self, arg: str): + delattr(self, arg) + + def _keytransform(self, key): # pylint: disable=no-self-use + return key + + ## end `dict` API functions + + def __getattribute__(self, arg: str): # pylint: disable=no-self-use + """Check if the mandatory field is defined when accessing it.""" + value = super().__getattribute__(arg) + if isinstance(value, str) and value == "???": + raise AttributeError(f" [!] MISSING field {arg} must be defined.") + return value + + def __contains__(self, arg: str): + return arg in self.to_dict() + + def get(self, key: str, default: Any = None): + if self.has(key): + return asdict(self)[key] + return default + + def items(self): + return asdict(self).items() + + def merge(self, coqpits: Union["Coqpit", List["Coqpit"]]): + """Merge a coqpit instance or a list of coqpit instances to self. + Note that it does not pass the fields and overrides attributes with + the last Coqpit instance in the given List. + TODO: find a way to merge instances with all the class internals. + + Args: + coqpits (Union[Coqpit, List[Coqpit]]): coqpit instance or list of instances to be merged. + """ + + def _merge(coqpit): + self.__dict__.update(coqpit.__dict__) + self.__annotations__.update(coqpit.__annotations__) + self.__dataclass_fields__.update(coqpit.__dataclass_fields__) + + if isinstance(coqpits, list): + for coqpit in coqpits: + _merge(coqpit) + else: + _merge(coqpits) + + def check_values(self): + pass + + def has(self, arg: str) -> bool: + return arg in vars(self) + + def copy(self): + return replace(self) + + def update(self, new: dict, allow_new=False) -> None: + """Update Coqpit fields by the input ```dict```. + + Args: + new (dict): dictionary with new values. + allow_new (bool, optional): allow new fields to add. Defaults to False. + """ + for key, value in new.items(): + if allow_new: + setattr(self, key, value) + else: + if hasattr(self, key): + setattr(self, key, value) + else: + raise KeyError(f" [!] No key - {key}") + + def pprint(self) -> None: + """Print Coqpit fields in a format.""" + pprint(asdict(self)) + + def to_dict(self) -> dict: + # return asdict(self) + return self.serialize() + + def from_dict(self, data: dict) -> None: + self = self.deserialize(data) # pylint: disable=self-cls-assignment + + @classmethod + def new_from_dict(cls: Serializable, data: dict) -> "Coqpit": + return cls.deserialize_immutable(data) + + def to_json(self) -> str: + """Returns a JSON string representation.""" + return json.dumps(asdict(self), indent=4, default=_coqpit_json_default) + + def save_json(self, file_name: str) -> None: + """Save Coqpit to a json file. + + Args: + file_name (str): path to the output json file. + """ + with open(file_name, "w", encoding="utf8") as f: + json.dump(asdict(self), f, indent=4) + + def load_json(self, file_name: str) -> None: + """Load a json file and update matching config fields with type checking. + Non-matching parameters in the json file are ignored. + + Args: + file_name (str): path to the json file. + + Returns: + Coqpit: new Coqpit with updated config fields. + """ + with open(file_name, "r", encoding="utf8") as f: + input_str = f.read() + dump_dict = json.loads(input_str) + # TODO: this looks stupid 💆 + self = self.deserialize(dump_dict) # pylint: disable=self-cls-assignment + self.check_values() + + @classmethod + def init_from_argparse( + cls, args: Optional[Union[argparse.Namespace, List[str]]] = None, arg_prefix: str = "coqpit" + ) -> "Coqpit": + """Create a new Coqpit instance from argparse input. + + Args: + args (namespace or list of str, optional): parsed argparse.Namespace or list of command line parameters. If unspecified will use a newly created parser with ```init_argparse()```. + arg_prefix: prefix to add to CLI parameters. Gets forwarded to ```init_argparse``` when ```args``` is not passed. + """ + if not args: + # If args was not specified, parse from sys.argv + parser = cls.init_argparse(cls, arg_prefix=arg_prefix) + args = parser.parse_args() # pylint: disable=E1120, E1111 + if isinstance(args, list): + # If a list was passed in (eg. the second result of `parse_known_args`, run that through argparse first to get a parsed Namespace + parser = cls.init_argparse(cls, arg_prefix=arg_prefix) + args = parser.parse_args(args) # pylint: disable=E1120, E1111 + + # Handle list and object attributes with defaults, which can be modified + # directly (eg. --coqpit.list.0.val_a 1), by constructing real objects + # from defaults and passing those to `cls.__init__` + args_with_lists_processed = {} + class_fields = fields(cls) + for field in class_fields: + has_default = False + default = None + field_default = field.default if field.default is not _MISSING else None + field_default_factory = field.default_factory if field.default_factory is not _MISSING else None + if field_default: + has_default = True + default = field_default + elif field_default_factory: + has_default = True + default = field_default_factory() + + if has_default and (not is_primitive_type(field.type) or is_list(field.type)): + args_with_lists_processed[field.name] = default + + args_dict = vars(args) + for k, v in args_dict.items(): + # Remove argparse prefix (eg. "--coqpit." if present) + if k.startswith(f"{arg_prefix}."): + k = k[len(f"{arg_prefix}.") :] + + rsetitem(args_with_lists_processed, k, v) + + return cls(**args_with_lists_processed) + + def parse_args( + self, args: Optional[Union[argparse.Namespace, List[str]]] = None, arg_prefix: str = "coqpit" + ) -> None: + """Update config values from argparse arguments with some meta-programming ✨. + + Args: + args (namespace or list of str, optional): parsed argparse.Namespace or list of command line parameters. If unspecified will use a newly created parser with ```init_argparse()```. + arg_prefix: prefix to add to CLI parameters. Gets forwarded to ```init_argparse``` when ```args``` is not passed. + """ + if not args: + # If args was not specified, parse from sys.argv + parser = self.init_argparse(arg_prefix=arg_prefix) + args = parser.parse_args() + if isinstance(args, list): + # If a list was passed in (eg. the second result of `parse_known_args`, run that through argparse first to get a parsed Namespace + parser = self.init_argparse(arg_prefix=arg_prefix) + args = parser.parse_args(args) + + args_dict = vars(args) + + for k, v in args_dict.items(): + if k.startswith(f"{arg_prefix}."): + k = k[len(f"{arg_prefix}.") :] + try: + rgetattr(self, k) + except (TypeError, AttributeError) as e: + raise Exception(f" [!] '{k}' not exist to override from argparse.") from e + + rsetattr(self, k, v) + + self.check_values() + + def parse_known_args( + self, + args: Optional[Union[argparse.Namespace, List[str]]] = None, + arg_prefix: str = "coqpit", + relaxed_parser=False, + ) -> List[str]: + """Update config values from argparse arguments. Ignore unknown arguments. + This is analog to argparse.ArgumentParser.parse_known_args (vs parse_args). + + Args: + args (namespace or list of str, optional): parsed argparse.Namespace or list of command line parameters. If unspecified will use a newly created parser with ```init_argparse()```. + arg_prefix: prefix to add to CLI parameters. Gets forwarded to ```init_argparse``` when ```args``` is not passed. + relaxed_parser (bool, optional): If True, do not force all the fields to have compatible types with the argparser. Defaults to False. + + Returns: + List of unknown parameters. + """ + if not args: + # If args was not specified, parse from sys.argv + parser = self.init_argparse(arg_prefix=arg_prefix, relaxed_parser=relaxed_parser) + args, unknown = parser.parse_known_args() + if isinstance(args, list): + # If a list was passed in (eg. the second result of `parse_known_args`, run that through argparse first to get a parsed Namespace + parser = self.init_argparse(arg_prefix=arg_prefix, relaxed_parser=relaxed_parser) + args, unknown = parser.parse_known_args(args) + + self.parse_args(args) + return unknown + + def init_argparse( + self, + parser: Optional[argparse.ArgumentParser] = None, + arg_prefix="coqpit", + help_prefix="", + relaxed_parser=False, + ) -> argparse.ArgumentParser: + """Pass Coqpit fields as argparse arguments. This allows to edit values through command-line. + + Args: + parser (argparse.ArgumentParser, optional): argparse.ArgumentParser instance. If unspecified a new one will be created. + arg_prefix (str, optional): Prefix to be used for the argument name. Defaults to 'coqpit'. + help_prefix (str, optional): Prefix to be used for the argument description. Defaults to ''. + relaxed_parser (bool, optional): If True, do not force all the fields to have compatible types with the argparser. Defaults to False. + + Returns: + argparse.ArgumentParser: parser instance with the new arguments. + """ + if not parser: + parser = argparse.ArgumentParser() + class_fields = fields(self) + for field in class_fields: + if field.name in vars(self): + # use the current value of the field + # prevent dropping the current value + field_default = vars(self)[field.name] + else: + # use the default value of the field + field_default = field.default if field.default is not _MISSING else None + field_type = field.type + field_default_factory = field.default_factory + field_help = _get_help(field) + _init_argparse( + parser, + field.name, + field_type, + field_default, + field_default_factory, + field_help, + arg_prefix, + help_prefix, + relaxed_parser, + ) + return parser + + +def check_argument( + name, + c, + is_path: bool = False, + prerequest: str = None, + enum_list: list = None, + max_val: float = None, + min_val: float = None, + restricted: bool = False, + alternative: str = None, + allow_none: bool = True, +) -> None: + """Simple type and value checking for Coqpit. + It is intended to be used under ```__post_init__()``` of config dataclasses. + + Args: + name (str): name of the field to be checked. + c (dict): config dictionary. + is_path (bool, optional): if ```True``` check if the path is exist. Defaults to False. + prerequest (list or str, optional): a list of field name that are prerequestedby the target field name. + Defaults to ```[]```. + enum_list (list, optional): list of possible values for the target field. Defaults to None. + max_val (float, optional): maximum possible value for the target field. Defaults to None. + min_val (float, optional): minimum possible value for the target field. Defaults to None. + restricted (bool, optional): if ```True``` the target field has to be defined. Defaults to False. + alternative (str, optional): a field name superceding the target field. Defaults to None. + allow_none (bool, optional): if ```True``` allow the target field to be ```None```. Defaults to False. + + + Example: + >>> num_mels = 5 + >>> check_argument('num_mels', c, restricted=True, min_val=10, max_val=2056) + >>> fft_size = 128 + >>> check_argument('fft_size', c, restricted=True, min_val=128, max_val=4058) + """ + # check if None allowed + if allow_none and c[name] is None: + return + if not allow_none: + assert c[name] is not None, f" [!] None value is not allowed for {name}." + # check if restricted and it it is check if it exists + if isinstance(restricted, bool) and restricted: + assert name in c.keys(), f" [!] {name} not defined in config.json" + # check prerequest fields are defined + if isinstance(prerequest, list): + assert any( + f not in c.keys() for f in prerequest + ), f" [!] prequested fields {prerequest} for {name} are not defined." + else: + assert ( + prerequest is None or prerequest in c.keys() + ), f" [!] prequested fields {prerequest} for {name} are not defined." + # check if the path exists + if is_path: + assert os.path.exists(c[name]), f' [!] path for {name} ("{c[name]}") does not exist.' + # skip the rest if the alternative field is defined. + if alternative in c.keys() and c[alternative] is not None: + return + # check value constraints + if name in c.keys(): + if max_val is not None: + assert c[name] <= max_val, f" [!] {name} is larger than max value {max_val}" + if min_val is not None: + assert c[name] >= min_val, f" [!] {name} is smaller than min value {min_val}" + if enum_list is not None: + assert c[name].lower() in enum_list, f" [!] {name} is not a valid value" diff --git a/paddlemix/models/vits-svc/speaker/utils/io.py b/paddlemix/models/vits-svc/speaker/utils/io.py new file mode 100644 index 000000000..0297db10a --- /dev/null +++ b/paddlemix/models/vits-svc/speaker/utils/io.py @@ -0,0 +1,199 @@ +import fsspec +import datetime +import json +import os +import pickle as pickle_tts +import shutil +from typing import Any, Callable, Dict, Union + + +import torch +from .coqpit import Coqpit + + +class RenamingUnpickler(pickle_tts.Unpickler): + """Overload default pickler to solve module renaming problem""" + + def find_class(self, module, name): + return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name) + + +class AttrDict(dict): + """A custom dict which converts dict keys + to class attributes""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__dict__ = self + + +def copy_model_files(config: Coqpit, out_path, new_fields): + """Copy config.json and other model files to training folder and add + new fields. + + Args: + config (Coqpit): Coqpit config defining the training run. + out_path (str): output path to copy the file. + new_fields (dict): new fileds to be added or edited + in the config file. + """ + copy_config_path = os.path.join(out_path, "config.json") + # add extra information fields + config.update(new_fields, allow_new=True) + # TODO: Revert to config.save_json() once Coqpit supports arbitrary paths. + with fsspec.open(copy_config_path, "w", encoding="utf8") as f: + json.dump(config.to_dict(), f, indent=4) + + # copy model stats file if available + if config.audio.stats_path is not None: + copy_stats_path = os.path.join(out_path, "scale_stats.npy") + filesystem = fsspec.get_mapper(copy_stats_path).fs + if not filesystem.exists(copy_stats_path): + with fsspec.open(config.audio.stats_path, "rb") as source_file: + with fsspec.open(copy_stats_path, "wb") as target_file: + shutil.copyfileobj(source_file, target_file) + + +def load_fsspec_torch( + path: str, + map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, + **kwargs, +) -> Any: + """Like torch.load but can load from other locations (e.g. s3:// , gs://). + + Args: + path: Any path or url supported by fsspec. + map_location: torch.device or str. + **kwargs: Keyword arguments forwarded to torch.load. + + Returns: + Object stored in path. + """ + with fsspec.open(path, "rb") as f: + return torch.load(f, map_location=map_location, **kwargs) + + +def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin + try: + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu")) + except ModuleNotFoundError: + pickle_tts.Unpickler = RenamingUnpickler + state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts) + model.load_state_dict(state["model"]) + if use_cuda: + model.cuda() + if eval: + model.eval() + return model, state + + +def save_fsspec(state: Any, path: str, **kwargs): + """Like torch.save but can save to other locations (e.g. s3:// , gs://). + + Args: + state: State object to save + path: Any path or url supported by fsspec. + **kwargs: Keyword arguments forwarded to torch.save. + """ + with fsspec.open(path, "wb") as f: + torch.save(state, f, **kwargs) + + +def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs): + if hasattr(model, "module"): + model_state = model.module.state_dict() + else: + model_state = model.state_dict() + if isinstance(optimizer, list): + optimizer_state = [optim.state_dict() for optim in optimizer] + else: + optimizer_state = optimizer.state_dict() if optimizer is not None else None + + if isinstance(scaler, list): + scaler_state = [s.state_dict() for s in scaler] + else: + scaler_state = scaler.state_dict() if scaler is not None else None + + if isinstance(config, Coqpit): + config = config.to_dict() + + state = { + "config": config, + "model": model_state, + "optimizer": optimizer_state, + "scaler": scaler_state, + "step": current_step, + "epoch": epoch, + "date": datetime.date.today().strftime("%B %d, %Y"), + } + state.update(kwargs) + save_fsspec(state, output_path) + + +def save_checkpoint( + config, + model, + optimizer, + scaler, + current_step, + epoch, + output_folder, + **kwargs, +): + file_name = "checkpoint_{}.pth.tar".format(current_step) + checkpoint_path = os.path.join(output_folder, file_name) + print("\n > CHECKPOINT : {}".format(checkpoint_path)) + save_model( + config, + model, + optimizer, + scaler, + current_step, + epoch, + checkpoint_path, + **kwargs, + ) + + +def save_best_model( + current_loss, + best_loss, + config, + model, + optimizer, + scaler, + current_step, + epoch, + out_path, + keep_all_best=False, + keep_after=10000, + **kwargs, +): + if current_loss < best_loss: + best_model_name = f"best_model_{current_step}.pth.tar" + checkpoint_path = os.path.join(out_path, best_model_name) + print(" > BEST MODEL : {}".format(checkpoint_path)) + save_model( + config, + model, + optimizer, + scaler, + current_step, + epoch, + checkpoint_path, + model_loss=current_loss, + **kwargs, + ) + fs = fsspec.get_mapper(out_path).fs + # only delete previous if current is saved successfully + if not keep_all_best or (current_step < keep_after): + model_names = fs.glob(os.path.join(out_path, "best_model*.pth.tar")) + for model_name in model_names: + if os.path.basename(model_name) != best_model_name: + fs.rm(model_name) + # create a shortcut which always points to the currently best model + shortcut_name = "best_model.pth.tar" + shortcut_path = os.path.join(out_path, shortcut_name) + fs.copy(checkpoint_path, shortcut_path) + best_loss = current_loss + return best_loss diff --git a/paddlemix/models/vits-svc/speaker/utils/shared_configs.py b/paddlemix/models/vits-svc/speaker/utils/shared_configs.py new file mode 100644 index 000000000..a89d3a91c --- /dev/null +++ b/paddlemix/models/vits-svc/speaker/utils/shared_configs.py @@ -0,0 +1,342 @@ +from dataclasses import asdict, dataclass +from typing import List + +from .coqpit import Coqpit, check_argument + + +@dataclass +class BaseAudioConfig(Coqpit): + """Base config to definge audio processing parameters. It is used to initialize + ```TTS.utils.audio.AudioProcessor.``` + + Args: + fft_size (int): + Number of STFT frequency levels aka.size of the linear spectogram frame. Defaults to 1024. + + win_length (int): + Each frame of audio is windowed by window of length ```win_length``` and then padded with zeros to match + ```fft_size```. Defaults to 1024. + + hop_length (int): + Number of audio samples between adjacent STFT columns. Defaults to 1024. + + frame_shift_ms (int): + Set ```hop_length``` based on milliseconds and sampling rate. + + frame_length_ms (int): + Set ```win_length``` based on milliseconds and sampling rate. + + stft_pad_mode (str): + Padding method used in STFT. 'reflect' or 'center'. Defaults to 'reflect'. + + sample_rate (int): + Audio sampling rate. Defaults to 22050. + + resample (bool): + Enable / Disable resampling audio to ```sample_rate```. Defaults to ```False```. + + preemphasis (float): + Preemphasis coefficient. Defaults to 0.0. + + ref_level_db (int): 20 + Reference Db level to rebase the audio signal and ignore the level below. 20Db is assumed the sound of air. + Defaults to 20. + + do_sound_norm (bool): + Enable / Disable sound normalization to reconcile the volume differences among samples. Defaults to False. + + log_func (str): + Numpy log function used for amplitude to DB conversion. Defaults to 'np.log10'. + + do_trim_silence (bool): + Enable / Disable trimming silences at the beginning and the end of the audio clip. Defaults to ```True```. + + do_amp_to_db_linear (bool, optional): + enable/disable amplitude to dB conversion of linear spectrograms. Defaults to True. + + do_amp_to_db_mel (bool, optional): + enable/disable amplitude to dB conversion of mel spectrograms. Defaults to True. + + trim_db (int): + Silence threshold used for silence trimming. Defaults to 45. + + power (float): + Exponent used for expanding spectrogra levels before running Griffin Lim. It helps to reduce the + artifacts in the synthesized voice. Defaults to 1.5. + + griffin_lim_iters (int): + Number of Griffing Lim iterations. Defaults to 60. + + num_mels (int): + Number of mel-basis frames that defines the frame lengths of each mel-spectrogram frame. Defaults to 80. + + mel_fmin (float): Min frequency level used for the mel-basis filters. ~50 for male and ~95 for female voices. + It needs to be adjusted for a dataset. Defaults to 0. + + mel_fmax (float): + Max frequency level used for the mel-basis filters. It needs to be adjusted for a dataset. + + spec_gain (int): + Gain applied when converting amplitude to DB. Defaults to 20. + + signal_norm (bool): + enable/disable signal normalization. Defaults to True. + + min_level_db (int): + minimum db threshold for the computed melspectrograms. Defaults to -100. + + symmetric_norm (bool): + enable/disable symmetric normalization. If set True normalization is performed in the range [-k, k] else + [0, k], Defaults to True. + + max_norm (float): + ```k``` defining the normalization range. Defaults to 4.0. + + clip_norm (bool): + enable/disable clipping the our of range values in the normalized audio signal. Defaults to True. + + stats_path (str): + Path to the computed stats file. Defaults to None. + """ + + # stft parameters + fft_size: int = 1024 + win_length: int = 1024 + hop_length: int = 256 + frame_shift_ms: int = None + frame_length_ms: int = None + stft_pad_mode: str = "reflect" + # audio processing parameters + sample_rate: int = 22050 + resample: bool = False + preemphasis: float = 0.0 + ref_level_db: int = 20 + do_sound_norm: bool = False + log_func: str = "np.log10" + # silence trimming + do_trim_silence: bool = True + trim_db: int = 45 + # griffin-lim params + power: float = 1.5 + griffin_lim_iters: int = 60 + # mel-spec params + num_mels: int = 80 + mel_fmin: float = 0.0 + mel_fmax: float = None + spec_gain: int = 20 + do_amp_to_db_linear: bool = True + do_amp_to_db_mel: bool = True + # normalization params + signal_norm: bool = True + min_level_db: int = -100 + symmetric_norm: bool = True + max_norm: float = 4.0 + clip_norm: bool = True + stats_path: str = None + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + check_argument("num_mels", c, restricted=True, min_val=10, max_val=2056) + check_argument("fft_size", c, restricted=True, min_val=128, max_val=4058) + check_argument("sample_rate", c, restricted=True, min_val=512, max_val=100000) + check_argument( + "frame_length_ms", + c, + restricted=True, + min_val=10, + max_val=1000, + alternative="win_length", + ) + check_argument("frame_shift_ms", c, restricted=True, min_val=1, max_val=1000, alternative="hop_length") + check_argument("preemphasis", c, restricted=True, min_val=0, max_val=1) + check_argument("min_level_db", c, restricted=True, min_val=-1000, max_val=10) + check_argument("ref_level_db", c, restricted=True, min_val=0, max_val=1000) + check_argument("power", c, restricted=True, min_val=1, max_val=5) + check_argument("griffin_lim_iters", c, restricted=True, min_val=10, max_val=1000) + + # normalization parameters + check_argument("signal_norm", c, restricted=True) + check_argument("symmetric_norm", c, restricted=True) + check_argument("max_norm", c, restricted=True, min_val=0.1, max_val=1000) + check_argument("clip_norm", c, restricted=True) + check_argument("mel_fmin", c, restricted=True, min_val=0.0, max_val=1000) + check_argument("mel_fmax", c, restricted=True, min_val=500.0, allow_none=True) + check_argument("spec_gain", c, restricted=True, min_val=1, max_val=100) + check_argument("do_trim_silence", c, restricted=True) + check_argument("trim_db", c, restricted=True) + + +@dataclass +class BaseDatasetConfig(Coqpit): + """Base config for TTS datasets. + + Args: + name (str): + Dataset name that defines the preprocessor in use. Defaults to None. + + path (str): + Root path to the dataset files. Defaults to None. + + meta_file_train (str): + Name of the dataset meta file. Or a list of speakers to be ignored at training for multi-speaker datasets. + Defaults to None. + + unused_speakers (List): + List of speakers IDs that are not used at the training. Default None. + + meta_file_val (str): + Name of the dataset meta file that defines the instances used at validation. + + meta_file_attn_mask (str): + Path to the file that lists the attention mask files used with models that require attention masks to + train the duration predictor. + """ + + name: str = "" + path: str = "" + meta_file_train: str = "" + ununsed_speakers: List[str] = None + meta_file_val: str = "" + meta_file_attn_mask: str = "" + + def check_values( + self, + ): + """Check config fields""" + c = asdict(self) + check_argument("name", c, restricted=True) + check_argument("path", c, restricted=True) + check_argument("meta_file_train", c, restricted=True) + check_argument("meta_file_val", c, restricted=False) + check_argument("meta_file_attn_mask", c, restricted=False) + + +@dataclass +class BaseTrainingConfig(Coqpit): + """Base config to define the basic training parameters that are shared + among all the models. + + Args: + model (str): + Name of the model that is used in the training. + + run_name (str): + Name of the experiment. This prefixes the output folder name. Defaults to `coqui_tts`. + + run_description (str): + Short description of the experiment. + + epochs (int): + Number training epochs. Defaults to 10000. + + batch_size (int): + Training batch size. + + eval_batch_size (int): + Validation batch size. + + mixed_precision (bool): + Enable / Disable mixed precision training. It reduces the VRAM use and allows larger batch sizes, however + it may also cause numerical unstability in some cases. + + scheduler_after_epoch (bool): + If true, run the scheduler step after each epoch else run it after each model step. + + run_eval (bool): + Enable / Disable evaluation (validation) run. Defaults to True. + + test_delay_epochs (int): + Number of epochs before starting to use evaluation runs. Initially, models do not generate meaningful + results, hence waiting for a couple of epochs might save some time. + + print_eval (bool): + Enable / Disable console logging for evalutaion steps. If disabled then it only shows the final values at + the end of the evaluation. Default to ```False```. + + print_step (int): + Number of steps required to print the next training log. + + log_dashboard (str): "tensorboard" or "wandb" + Set the experiment tracking tool + + plot_step (int): + Number of steps required to log training on Tensorboard. + + model_param_stats (bool): + Enable / Disable logging internal model stats for model diagnostic. It might be useful for model debugging. + Defaults to ```False```. + + project_name (str): + Name of the project. Defaults to config.model + + wandb_entity (str): + Name of W&B entity/team. Enables collaboration across a team or org. + + log_model_step (int): + Number of steps required to log a checkpoint as W&B artifact + + save_step (int):ipt + Number of steps required to save the next checkpoint. + + checkpoint (bool): + Enable / Disable checkpointing. + + keep_all_best (bool): + Enable / Disable keeping all the saved best models instead of overwriting the previous one. Defaults + to ```False```. + + keep_after (int): + Number of steps to wait before saving all the best models. In use if ```keep_all_best == True```. Defaults + to 10000. + + num_loader_workers (int): + Number of workers for training time dataloader. + + num_eval_loader_workers (int): + Number of workers for evaluation time dataloader. + + output_path (str): + Path for training output folder, either a local file path or other + URLs supported by both fsspec and tensorboardX, e.g. GCS (gs://) or + S3 (s3://) paths. The nonexist part of the given path is created + automatically. All training artefacts are saved there. + """ + + model: str = None + run_name: str = "coqui_tts" + run_description: str = "" + # training params + epochs: int = 10000 + batch_size: int = None + eval_batch_size: int = None + mixed_precision: bool = False + scheduler_after_epoch: bool = False + # eval params + run_eval: bool = True + test_delay_epochs: int = 0 + print_eval: bool = False + # logging + dashboard_logger: str = "tensorboard" + print_step: int = 25 + plot_step: int = 100 + model_param_stats: bool = False + project_name: str = None + log_model_step: int = None + wandb_entity: str = None + # checkpointing + save_step: int = 10000 + checkpoint: bool = True + keep_all_best: bool = False + keep_after: int = 10000 + # dataloading + num_loader_workers: int = 0 + num_eval_loader_workers: int = 0 + use_noise_augment: bool = False + # paths + output_path: str = None + # distributed + distributed_backend: str = "nccl" + distributed_url: str = "tcp://localhost:54321" diff --git a/paddlemix/models/vits-svc/vits/__init__.py b/paddlemix/models/vits-svc/vits/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/paddlemix/models/vits-svc/vits/spectrogram.py b/paddlemix/models/vits-svc/vits/spectrogram.py new file mode 100644 index 000000000..7fe6f4769 --- /dev/null +++ b/paddlemix/models/vits-svc/vits/spectrogram.py @@ -0,0 +1,299 @@ +import torch +# import torch.utils.data +import paddle +import math +from librosa.filters import mel as librosa_mel_fn + +MAX_WAV_VALUE = 32768.0 + + +# def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): +# """ +# PARAMS +# ------ +# C: compression factor +# """ +# return torch.log(torch.clamp(x, min=clip_val) * C) + + +# def dynamic_range_decompression_torch(x, C=1): +# """ +# PARAMS +# ------ +# C: compression factor used to compress +# """ +# return torch.exp(x) / C + + +# def spectral_normalize_torch(magnitudes): +# output = dynamic_range_compression_torch(magnitudes) +# return output + + +# def spectral_de_normalize_torch(magnitudes): +# output = dynamic_range_decompression_torch(magnitudes) +# return output + + +mel_basis = {} +hann_window = {} + + + + +def custom_hann_window_torch(window_length, periodic=True, dtype=None, device=None): + if dtype is None: + dtype = torch.float32 + + if device is None: + device = torch.device('cpu') + + if periodic: + window_length += 1 + + n = torch.arange(window_length, dtype=dtype, device=device) + window = 0.5 - 0.5 * torch.cos(2 * math.pi * n / (window_length - 1)) + + if periodic: + window = window[:-1] + + return window + + +def custom_hann_window_paddle(window_length, periodic=True, dtype=None,): + if dtype is None: + dtype = 'float32' + if periodic: + window_length += 1 + n = paddle.arange(dtype=dtype, end=window_length) + window = 0.5 - 0.5 * paddle.cos(x=2 * math.pi * n / (window_length - 1)) + if periodic: + window = window[:-1] + return window + + +# if __name__ == "__main__": + +# win_size = 100 +# y_custom = custom_hann_window_torch(win_size) +# y = torch.hann_window(win_size) +# print(abs((y - y_custom).numpy()).max()) + + +# y_custom = custom_hann_window_torch(win_size, periodic=False) +# y = torch.hann_window(win_size, periodic=False) +# print(abs((y - y_custom).numpy()).max()) + + +# win_size = 997 +# y_paddle = custom_hann_window_paddle(win_size).numpy() +# y = torch.hann_window(win_size).numpy() +# print( abs(y - y_paddle).max() ) + + + + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( + dtype=y.dtype, device=y.device + ) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + + + + +def spectrogram_paddle(y, n_fft, sampling_rate, hop_size, win_size, center=False): + if paddle.min(x=y) < -1.0: + print('min value is ', paddle.min(x=y)) + if paddle.max(x=y) > 1.0: + print('max value is ', paddle.max(x=y)) + + global hann_window + dtype_device = str(y.dtype) + '_' + str(y.place) + wnsize_dtype_device = str(win_size) + '_' + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = custom_hann_window_paddle(win_size, dtype=y.dtype) + + y = paddle.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + data_format="NCL" + ) + y = y.squeeze(1) + + spec = paddle.signal.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode='reflect', + normalized=False, + onesided=True) + + # spec = spec.real() + spec = paddle.stack( [spec.real(), spec.imag()], axis=-1 ) + spec = paddle.sqrt(x=spec.pow(y=2).sum(axis=-1) + 1e-06) + + return spec + + + + +if __name__ == "__main__": + + + + import numpy as np + + y = np.random.normal(loc=0.1406, scale=0.2395, size=[1, 1112384]).astype("float32") + y_pd = paddle.to_tensor(y) + y_tc = torch.from_numpy(y).cpu() + + hop_size = 320 + n_fft = 1024 + hop_length = 320 + win_length = 1024 + + window_pd = custom_hann_window_paddle(win_length) + window_tc = torch.hann_window(win_length) + + center = False + pad_mode = "reflect" + normalized = False + onesided = True + return_complex = False + + spec_tc = torch.stft( + y_tc, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=window_tc, + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=False, + ).cpu().numpy() + + + spec_pd = paddle.signal.stft( + y_pd, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=window_pd, + center=center, + pad_mode='reflect', + normalized=False, + onesided=True) + # spec_pd = paddle.stack( [spec_pd.real(), spec_pd.imag()], axis=-1 ).cpu().numpy() + spec_pd = paddle.as_real(spec_pd).cpu().numpy() + + print( + abs(spec_tc - spec_pd).max().item() + ) + # print(spec_tc.mean().item(), spec_pd.mean().item()) + + # print( + # abs( spec_tc.mean().numpy() - spec_pd.mean().numpy() ) + # ) + + +# def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): +# global mel_basis +# dtype_device = str(spec.dtype) + "_" + str(spec.device) +# fmax_dtype_device = str(fmax) + "_" + dtype_device +# if fmax_dtype_device not in mel_basis: +# mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) +# mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( +# dtype=spec.dtype, device=spec.device +# ) +# spec = torch.matmul(mel_basis[fmax_dtype_device], spec) +# spec = spectral_normalize_torch(spec) +# return spec + + +# def mel_spectrogram_torch( +# y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False +# ): +# if torch.min(y) < -1.0: +# print("min value is ", torch.min(y)) +# if torch.max(y) > 1.0: +# print("max value is ", torch.max(y)) + +# global mel_basis, hann_window +# dtype_device = str(y.dtype) + "_" + str(y.device) +# fmax_dtype_device = str(fmax) + "_" + dtype_device +# wnsize_dtype_device = str(win_size) + "_" + dtype_device +# if fmax_dtype_device not in mel_basis: +# mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) +# mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( +# dtype=y.dtype, device=y.device +# ) +# if wnsize_dtype_device not in hann_window: +# hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( +# dtype=y.dtype, device=y.device +# ) + +# y = torch.nn.functional.pad( +# y.unsqueeze(1), +# (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), +# mode="reflect", +# ) +# y = y.squeeze(1) + +# spec = torch.stft( +# y, +# n_fft, +# hop_length=hop_size, +# win_length=win_size, +# window=hann_window[wnsize_dtype_device], +# center=center, +# pad_mode="reflect", +# normalized=False, +# onesided=True, +# return_complex=False, +# ) + +# spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + +# spec = torch.matmul(mel_basis[fmax_dtype_device], spec) +# spec = spectral_normalize_torch(spec) + +# return spec diff --git a/paddlemix/models/vits-svc/vits/utils.py b/paddlemix/models/vits-svc/vits/utils.py new file mode 100644 index 000000000..d276a964c --- /dev/null +++ b/paddlemix/models/vits-svc/vits/utils.py @@ -0,0 +1,37 @@ +import torch +import paddle +import numpy as np +from scipy.io.wavfile import read + +MATPLOTLIB_FLAG = False + + +def load_wav_to_torch(full_path): + sampling_rate, data = read(full_path) + return torch.FloatTensor(data.astype(np.float32)), sampling_rate + +def load_wav_to_paddle(full_path): + sampling_rate, data = read(full_path) + return paddle.to_tensor(data.astype(np.float32)), sampling_rate + +f0_bin = 256 +f0_max = 1100.0 +f0_min = 50.0 +f0_mel_min = 1127 * np.log(1 + f0_min / 700) +f0_mel_max = 1127 * np.log(1 + f0_max / 700) + + +def f0_to_coarse(f0): + is_torch = isinstance(f0, torch.Tensor) + f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * \ + np.log(1 + f0 / 700) + f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * \ + (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 + + f0_mel[f0_mel <= 1] = 1 + f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 + f0_coarse = ( + f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int) + assert f0_coarse.max() <= 255 and f0_coarse.min( + ) >= 1, (f0_coarse.max(), f0_coarse.min()) + return f0_coarse diff --git a/paddlemix/models/vits-svc/whisper/__init__.py b/paddlemix/models/vits-svc/whisper/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/paddlemix/models/vits-svc/whisper/audio.py b/paddlemix/models/vits-svc/whisper/audio.py new file mode 100644 index 000000000..79202917c --- /dev/null +++ b/paddlemix/models/vits-svc/whisper/audio.py @@ -0,0 +1,186 @@ +import os +import math +from functools import lru_cache +from typing import Union + +import librosa +import numpy as np +import paddle +import torch +import torch.nn.functional as F + +from .utils import exact_div +# def exact_div(x, y): +# assert x % y == 0 +# return x // y + +from librosa.filters import mel as librosa_mel_fn + +# hard-coded audio hyperparameters +SAMPLE_RATE = 16000 +N_FFT = 400 +N_MELS = 80 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000: number of samples in a chunk +N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000: number of frames in a mel spectrogram input + + +def load_audio(file: str, sr: int = SAMPLE_RATE): + x, sr = librosa.load(file, sr=sr) + return x + + +def pad_or_trim(array, length_max: int = N_SAMPLES, length_min: int = N_SAMPLES // 2, *, axis: int = -1): + """ + Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + """ + if torch.is_tensor(array): + if array.shape[axis] > length_max: + array = array.index_select(dim=axis, index=torch.arange(length_max, device=array.device)) + + if array.shape[axis] < length_min: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length_min - array.shape[axis]) + array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) + else: + if array.shape[axis] > length_max: + array = array.take(indices=range(length_max), axis=axis) + + if array.shape[axis] < length_min: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length_min - array.shape[axis]) + array = np.pad(array, pad_widths) + + return array + + +@lru_cache(maxsize=None) +def mel_filters(device, n_mels: int = N_MELS) -> torch.Tensor: + """ + load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + Allows decoupling librosa dependency; saved using: + + np.savez_compressed( + "mel_filters.npz", + mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + ) + """ + assert n_mels == 80, f"Unsupported n_mels: {n_mels}" + return torch.from_numpy(librosa_mel_fn(sr=SAMPLE_RATE,n_fft=N_FFT,n_mels=n_mels)).to(device) + + + +@lru_cache(maxsize=None) +def mel_filters_paddle(device, n_mels: int = N_MELS) -> paddle.Tensor: + """ + load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + Allows decoupling librosa dependency; saved using: + + np.savez_compressed( + "mel_filters.npz", + mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + ) + """ + assert n_mels == 80, f"Unsupported n_mels: {n_mels}" + return paddle.to_tensor(librosa_mel_fn(sr=SAMPLE_RATE,n_fft=N_FFT,n_mels=n_mels)) + + + +def custom_hann_window_paddle(window_length, periodic=True, dtype=None,): + if dtype is None: + dtype = 'float32' + if periodic: + window_length += 1 + n = paddle.arange(dtype=dtype, end=window_length) + window = 0.5 - 0.5 * paddle.cos(x=2 * math.pi * n / (window_length - 1)) + if periodic: + window = window[:-1] + return window + + +def log_mel_spectrogram_torch(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 is supported + + Returns + ------- + torch.Tensor, shape = (80, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not torch.is_tensor(audio): + if isinstance(audio, str): + audio = load_audio(audio) + audio = torch.from_numpy(audio) + + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) + magnitudes = stft[..., :-1].abs() ** 2 + + filters = mel_filters(audio.device, n_mels) + mel_spec = filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec + + +def log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 is supported + + Returns + ------- + torch.Tensor, shape = (80, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not paddle.is_tensor(audio): + if isinstance(audio, str): + audio = load_audio(audio) + audio = paddle.to_tensor(audio) + + window = custom_hann_window_paddle(N_FFT) + stft = paddle.signal.stft(audio, + N_FFT, + HOP_LENGTH, + window=window + ) + magnitudes = stft[..., :-1].abs() ** 2 + + filters = mel_filters_paddle(None, n_mels) + mel_spec = filters @ magnitudes + + log_spec = paddle.clip(x=mel_spec, min=1e-10).log10() + log_spec = paddle.maximum(x=log_spec, y=log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + + return log_spec + + +if __name__ == "__main__": + + x = np.random.rand(480000).astype("float32") + + y_pd = log_mel_spectrogram(x).detach().cpu().numpy() + y_tc = log_mel_spectrogram_torch(x).detach().cpu().numpy() + + print( + abs(y_pd - y_tc).max() + ) \ No newline at end of file diff --git a/paddlemix/models/vits-svc/whisper/model.py b/paddlemix/models/vits-svc/whisper/model.py new file mode 100644 index 000000000..4e17bc96a --- /dev/null +++ b/paddlemix/models/vits-svc/whisper/model.py @@ -0,0 +1,844 @@ +from dataclasses import dataclass +from typing import Dict +from typing import Iterable, Optional + +import numpy as np +import paddle +import torch +# import torch.nn.functional as F +# from torch import torch.Tensor +# from torch import nn + +# from .decoding import detect_language as detect_language_function, decode as decode_function + + +def init_weights(model): + for param in model.parameters(): + # print(param.shape) + torch.nn.init.uniform_(param) + +def LayerNorm_torch2paddle(model_torch, model_paddle): + + model_paddle.weight.set_value( + model_torch.weight.data.cpu().numpy() + ) + model_paddle.bias.set_value( + model_torch.bias.data.cpu().numpy() + ) + +# if __name__ == "__main__": + +# model_tc = torch.nn.LayerNorm(256).cuda() +# init_weights(model_tc) +# model_pd = paddle.nn.LayerNorm(256) + +# LayerNorm_torch2paddle(model_tc, model_pd) + +# x = np.random.randn(1, 1500, 256).astype("float32") +# x_tc = torch.from_numpy(x).cuda() +# x_pd = paddle.to_tensor(x) + +# y_tc = model_tc(x_tc) +# y_pd = model_pd(x_pd) + +# y_tc = y_tc.detach().cpu().numpy() +# y_pd = y_pd.detach().cpu().numpy() + +# print( +# abs(y_tc - y_pd).max(), +# ) + + + + +@dataclass +class ModelDimensions: + n_mels: int + n_audio_ctx: int + n_audio_state: int + n_audio_head: int + n_audio_layer: int + n_vocab: int + n_text_ctx: int + n_text_state: int + n_text_head: int + n_text_layer: int + + +# class LayerNorm(torch.nn.LayerNorm): +# def forward(self, x: torch.Tensor) -> torch.Tensor: +# # return super().forward(x.float()).type(x.dtype) sovits5.0 +# return super().forward(x).type(x.dtype) + + +# class Linear(torch.nn.Linear): +# def forward(self, x: torch.Tensor) -> torch.Tensor: +# return F.linear( +# x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype) +# ) + + +# class Conv1d(torch.nn.Conv1d): +# def _conv_forward(self, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: +# return super()._conv_forward( +# x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) +# ) + + +def sinusoids_torch(length, channels, max_timescale=10000): + """Returns sinusoids for positional embedding""" + assert channels % 2 == 0 + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + + +def sinusoids(length, channels, max_timescale=10000): + """Returns sinusoids for positional embedding""" + assert channels % 2 == 0 + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = paddle.exp(-log_timescale_increment * paddle.arange(channels // 2)) + scaled_time = paddle.arange(length)[:, np.newaxis].astype("float32") * inv_timescales[np.newaxis, :] + return paddle.concat([paddle.sin(scaled_time), paddle.cos(scaled_time)], axis=1) + + +# if __name__ == "__main__": + +# y_tc = sinusoids_torch(1500, 1280) +# y_pd = sinusoids(1500, 1280) + +# y_delta = y_pd.cpu().numpy() - y_tc.cpu().numpy() + +# print( +# abs(y_delta).max() +# ) + + + + + +class MultiHeadAttention_torch(torch.nn.Module): + def __init__(self, n_state: int, n_head: int): + super().__init__() + self.n_head = n_head + self.query = torch.nn.Linear(n_state, n_state) + self.key = torch.nn.Linear(n_state, n_state, bias=False) + self.value = torch.nn.Linear(n_state, n_state) + self.out = torch.nn.Linear(n_state, n_state) + + def forward( + self, + x: torch.Tensor, + xa: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + kv_cache: Optional[dict] = None, + ): + q = self.query(x) + + if kv_cache is None or xa is None or self.key not in kv_cache: + # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv torch.Tensors; + # otherwise, perform key/value projections for self- or cross-attention as usual. + k = self.key(x if xa is None else xa) + v = self.value(x if xa is None else xa) + else: + # for cross-attention, calculate keys and values once and reuse in subsequent calls. + k = kv_cache[self.key] + v = kv_cache[self.value] + + wv, qk = self.qkv_attention(q, k, v, mask) + return self.out(wv), qk + + def qkv_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None): + n_batch, n_ctx, n_state = q.shape + scale = (n_state // self.n_head) ** -0.25 + q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale + k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale + v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) + + qk = q @ k + if mask is not None: + qk = qk + mask[:n_ctx, :n_ctx] + qk = qk.float() + + w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype) + return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() + + +class MultiHeadAttention(paddle.nn.Layer): + def __init__(self, n_state: int, n_head: int): + super().__init__() + self.n_head = n_head + self.query = paddle.nn.Linear(n_state, n_state) + self.key = paddle.nn.Linear(n_state, n_state, bias_attr=False) + self.value = paddle.nn.Linear(n_state, n_state) + self.out = paddle.nn.Linear(n_state, n_state) + + def forward( + self, + x: paddle.Tensor, + xa: Optional[paddle.Tensor] = None, + mask: Optional[paddle.Tensor] = None, + kv_cache: Optional[dict] = None, + ): + q = self.query(x) + + if kv_cache is None or xa is None or self.key not in kv_cache: + # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv torch.Tensors; + # otherwise, perform key/value projections for self- or cross-attention as usual. + k = self.key(x if xa is None else xa) + v = self.value(x if xa is None else xa) + else: + # for cross-attention, calculate keys and values once and reuse in subsequent calls. + k = kv_cache[self.key] + v = kv_cache[self.value] + + wv, qk = self.qkv_attention(q, k, v, mask) + return self.out(wv), qk + + def qkv_attention(self, q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + mask: Optional[paddle.Tensor] = None): + n_batch, n_ctx, n_state = q.shape + scale = (n_state // self.n_head) ** -0.25 + q = q.reshape([*q.shape[:2], self.n_head, -1]).transpose([0, 2, 1, 3]) * scale + k = k.reshape([*k.shape[:2], self.n_head, -1]).transpose([0, 2, 3, 1]) * scale + v = v.reshape([*v.shape[:2], self.n_head, -1]).transpose([0, 2, 1, 3]) + + qk = q @ k + if mask is not None: + qk = qk + mask[:n_ctx, :n_ctx] + qk = qk.astype("float32") + + w = paddle.nn.functional.softmax(qk, axis=-1).astype(q.dtype) + return (w @ v).transpose([0, 2, 1, 3]).flatten(start_axis=2), qk.detach() + + +def MultiHeadAttention_torch2paddle(model_torch, model_paddle): + + model_paddle.query.weight.set_value( model_torch.query.weight.data.T.cpu().numpy() ) + model_paddle.query.bias.set_value( model_torch.query.bias.data.cpu().numpy() ) + + model_paddle.key.weight.set_value( model_torch.key.weight.data.T.cpu().numpy() ) + # model_paddle.key.bias.set_value( model_torch.query.weight.data ) + + model_paddle.value.weight.set_value( model_torch.value.weight.data.T.cpu().numpy() ) + model_paddle.value.bias.set_value( model_torch.value.bias.data.cpu().numpy() ) + model_paddle.out.weight.set_value( model_torch.out.weight.data.T.cpu().numpy() ) + model_paddle.out.bias.set_value( model_torch.out.bias.data.cpu().numpy() ) + + # attn_ln + + + +# if __name__ == "__main__": + +# model_tc = MultiHeadAttention_torch(1280, 20).cuda() +# model_pd = MultiHeadAttention(1280, 20) + +# x = np.random.randn(1, 1500, 1280).astype("float32") +# x_tc = torch.from_numpy(x).cuda() +# x_pd = paddle.to_tensor(x) + +# MultiHeadAttention_torch2paddle(model_tc, model_pd) + +# y_tc1, y_tc2 = model_tc(x_tc) +# y_pd1, y_pd2 = model_pd(x_pd) + +# y_tc1, y_tc2 = y_tc1.detach().cpu().numpy(), y_tc2.detach().cpu().numpy() +# y_pd1, y_pd2 = y_pd1.detach().cpu().numpy(), y_pd2.detach().cpu().numpy() + +# print( +# abs(y_tc1 - y_pd1).max(), + +# abs(y_tc2 - y_pd2).max(), +# ) + + + + + + + + + +class ResidualAttentionBlock_torch(torch.nn.Module): + def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): + super().__init__() + + self.attn = MultiHeadAttention_torch(n_state, n_head) + self.attn_ln = torch.nn.LayerNorm(n_state) + + self.cross_attn = MultiHeadAttention_torch(n_state, n_head) if cross_attention else None + self.cross_attn_ln = torch.nn.LayerNorm(n_state) if cross_attention else None + + n_mlp = n_state * 4 + self.mlp = torch.nn.Sequential( + torch.nn.Linear(n_state, n_mlp), + torch.nn.GELU(), + torch.nn.Linear(n_mlp, n_state) + ) + self.mlp_ln = torch.nn.LayerNorm(n_state) + + def forward( + self, + x: torch.Tensor, + xa: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + kv_cache: Optional[dict] = None, + ): + x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] + if self.cross_attn: + x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] + x = x + self.mlp(self.mlp_ln(x)) + return x + + + +class ResidualAttentionBlock(paddle.nn.Layer): + def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): + super().__init__() + + self.attn = MultiHeadAttention(n_state, n_head) + self.attn_ln = paddle.nn.LayerNorm(n_state) + + self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None + self.cross_attn_ln = paddle.nn.LayerNorm(n_state) if cross_attention else None + + n_mlp = n_state * 4 + self.mlp = paddle.nn.Sequential( + paddle.nn.Linear(n_state, n_mlp), + paddle.nn.GELU(), + paddle.nn.Linear(n_mlp, n_state) + ) + self.mlp_ln = paddle.nn.LayerNorm(n_state) + + def forward( + self, + x: paddle.Tensor, + xa: Optional[paddle.Tensor] = None, + mask: Optional[paddle.Tensor] = None, + kv_cache: Optional[dict] = None, + ): + x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] + if self.cross_attn: + x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] + x = x + self.mlp(self.mlp_ln(x)) + return x + + + +if __name__ == "__main__": + + x = np.random.rand(1, 1499, 1280).astype("float32") + x_tc = torch.from_numpy(x).cuda() + x_pd = paddle.to_tensor(x) + + model_tc = ResidualAttentionBlock_torch(1280, 20, False).cuda() + # init_weights(model_tc) + + model_pd = ResidualAttentionBlock(1280, 20, False) + + model_tc_state_dict = model_tc.state_dict() + model_pd_state_dict = model_pd.state_dict() + + print( + set( model_tc.state_dict().keys() ) == set( model_pd.state_dict().keys() ) + ) + + for torch_key, torch_value in model_pd.state_dict().items(): + if list(torch_value.shape) == model_pd_state_dict[torch_key].shape: + model_pd_state_dict[torch_key] = paddle.to_tensor( + torch_value.detach().cpu().numpy() + ) + else: + print(torch_key) + + model_pd.set_state_dict( model_pd_state_dict ) + + MultiHeadAttention_torch2paddle(model_tc.attn, model_pd.attn) + # MultiHeadAttention_torch2paddle(model_tc.cross_attn, model_pd.cross_attn) + + model_pd.mlp[0].weight.set_value( + paddle.to_tensor( + model_tc.mlp[0].weight.data.cpu().numpy().T + ) + ) + model_pd.mlp[0].bias.set_value( + paddle.to_tensor( + model_tc.mlp[0].bias.data.cpu().numpy() + ) + ) + + model_pd.mlp[2].weight.set_value( + paddle.to_tensor( + model_tc.mlp[2].weight.data.cpu().numpy().T + ) + ) + model_pd.mlp[2].bias.set_value( + paddle.to_tensor( + model_tc.mlp[2].bias.data.cpu().numpy() + ) + ) + + # ----------- 一些 LayerNorm ----------- + model_pd.mlp_ln.weight.set_value( + paddle.to_tensor( + model_tc.mlp_ln.weight.data.cpu().numpy() + ) + ) + model_pd.mlp_ln.bias.set_value( + paddle.to_tensor( + model_tc.mlp_ln.bias.data.cpu().numpy() + ) + ) + + model_pd.attn_ln.weight.set_value( + paddle.to_tensor( + model_tc.attn_ln.weight.data.cpu().numpy() + ) + ) + model_pd.attn_ln.bias.set_value( + paddle.to_tensor( + model_tc.attn_ln.bias.data.cpu().numpy() + ) + ) + + y_tc = model_tc(x_tc).detach().cpu().numpy() + y_pd = model_pd(x_pd).detach().cpu().numpy() + + print( + abs(y_tc - y_pd).max() + ) + + + +def ResidualAttentionBlock_torch2paddle(model_tc, model_pd): + + model_tc_state_dict = model_tc.state_dict() + model_pd_state_dict = model_pd.state_dict() + + # print( + # set( model_tc.state_dict().keys() ) == set( model_pd.state_dict().keys() ) + # ) + + for torch_key, torch_value in model_pd.state_dict().items(): + if list(torch_value.shape) == model_pd_state_dict[torch_key].shape: + model_pd_state_dict[torch_key] = paddle.to_tensor( + torch_value.detach().cpu().numpy() + ) + else: + print(torch_key) + + model_pd.set_state_dict( model_pd_state_dict ) + + MultiHeadAttention_torch2paddle(model_tc.attn, model_pd.attn) + # MultiHeadAttention_torch2paddle(model_tc.cross_attn, model_pd.cross_attn) + + model_pd.mlp[0].weight.set_value( + paddle.to_tensor( + model_tc.mlp[0].weight.data.cpu().numpy().T + ) + ) + model_pd.mlp[0].bias.set_value( + paddle.to_tensor( + model_tc.mlp[0].bias.data.cpu().numpy() + ) + ) + + model_pd.mlp[2].weight.set_value( + paddle.to_tensor( + model_tc.mlp[2].weight.data.cpu().numpy().T + ) + ) + model_pd.mlp[2].bias.set_value( + paddle.to_tensor( + model_tc.mlp[2].bias.data.cpu().numpy() + ) + ) + + LayerNorm_torch2paddle( model_tc.attn_ln, model_pd.attn_ln ) + LayerNorm_torch2paddle( model_tc.mlp_ln, model_pd.mlp_ln ) + + + +# if __name__ == "__main__": + +# x = np.random.rand(1, 1499, 1280).astype("float32") +# x_tc = torch.from_numpy(x).cuda() +# x_pd = paddle.to_tensor(x) + +# model_tc = ResidualAttentionBlock_torch(1280, 20, False).cuda() +# # init_weights(model_tc) + +# model_pd = ResidualAttentionBlock(1280, 20, False) + +# ResidualAttentionBlock_torch2paddle(model_tc, model_pd) + +# y_tc = model_tc(x_tc).detach().cpu().numpy() +# y_pd = model_pd(x_pd).detach().cpu().numpy() + +# print( +# abs(y_tc - y_pd).max() +# ) + + + + +class AudioEncoder_torch(torch.nn.Module): + def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): + super().__init__() + self.conv1 = torch.nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1) + self.conv2 = torch.nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) + self.register_buffer("positional_embedding", sinusoids_torch(n_ctx, n_state)) + + self.blocks: Iterable[ResidualAttentionBlock_torch] = torch.nn.ModuleList( + [ResidualAttentionBlock_torch(n_state, n_head) for _ in range(n_layer)] + ) + self.ln_post = torch.nn.LayerNorm(n_state) + + + + def forward(self, x: torch.Tensor): + """ + x : torch.torch.Tensor, shape = (batch_size, n_mels, n_ctx) + the mel spectrogram of the audio + """ + x = torch.nn.functional.gelu(self.conv1(x)) + x = torch.nn.functional.gelu(self.conv2(x)) + x = x.permute(0, 2, 1) + + len_x = x.shape[1] + len_e = self.positional_embedding.shape[0] + assert len_x <= len_e, "incorrect audio shape" + pos_e = self.positional_embedding[:len_x, :] + x = (x + pos_e).to(x.dtype) + + for block in self.blocks: + x = block(x) + + x = self.ln_post(x) + return x + + +class AudioEncoder(paddle.nn.Layer): + def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): + super().__init__() + + self.conv1 = paddle.nn.Conv1D(n_mels, n_state, kernel_size=3, padding=1) + self.conv2 = paddle.nn.Conv1D(n_state, n_state, kernel_size=3, stride=2, + padding=1) + self.register_buffer(name='positional_embedding', tensor= + sinusoids(n_ctx, n_state)) + self.blocks: Iterable[ResidualAttentionBlock] = paddle.nn.LayerList( + sublayers=[ResidualAttentionBlock(n_state, n_head) for _ in + range(n_layer)]) + self.ln_post = paddle.nn.LayerNorm(n_state) + + + def forward(self, x: torch.Tensor): + """ + x : torch.torch.Tensor, shape = (batch_size, n_mels, n_ctx) + the mel spectrogram of the audio + """ + x = paddle.nn.functional.gelu(self.conv1(x)) + x = paddle.nn.functional.gelu(self.conv2(x)) + x = x.transpose([0, 2, 1]) + + len_x = x.shape[1] + len_e = self.positional_embedding.shape[0] + assert len_x <= len_e, "incorrect audio shape" + pos_e = self.positional_embedding[:len_x, :] + x = (x + pos_e).astype(x.dtype) + + for block in self.blocks: + x = block(x) + + x = self.ln_post(x) + return x + + + +def AudioEncoder_torch2paddle(model_torch, model_paddle): + + model_paddle.conv1.weight.set_value( + paddle.to_tensor( + model_torch.conv1.weight.data.detach().cpu().numpy() + ) + ) + model_paddle.conv1.bias.set_value( + paddle.to_tensor( + model_torch.conv1.bias.data.detach().cpu().numpy() + ) + ) + + model_paddle.conv2.weight.set_value( + paddle.to_tensor( + model_torch.conv2.weight.data.detach().cpu().numpy() + ) + ) + model_paddle.conv2.bias.set_value( + paddle.to_tensor( + model_torch.conv2.bias.data.detach().cpu().numpy() + ) + ) + + model_paddle.ln_post.weight.set_value( + paddle.to_tensor( + model_torch.ln_post.weight.data.detach().cpu().numpy() + ) + ) + model_paddle.ln_post.bias.set_value( + paddle.to_tensor( + model_torch.ln_post.bias.data.detach().cpu().numpy() + ) + ) + + for i in range(len(model_paddle.blocks)): + ResidualAttentionBlock_torch2paddle( + model_torch.blocks[i], + model_paddle.blocks[i] + ) + + +# if __name__ == "__main__": + +# model_tc = AudioEncoder_torch(80, 1500, 1280, 20, 4).cuda() +# model_pd = AudioEncoder(80, 1500, 1280, 20, 4) + +# x = np.random.rand(1, 80, 3000).astype("float32") +# x_tc = torch.from_numpy(x).cuda() +# x_pd = paddle.to_tensor(x) + +# AudioEncoder_torch2paddle(model_tc, model_pd) + +# y_tc = model_tc( x_tc ).detach().cpu().numpy() +# y_pd = model_pd( x_pd ).detach().cpu().numpy() + +# print( +# abs(y_tc - y_pd).max() +# ) + + + + + + + + + + +# class TextDecoder(torch.nn.Module): +# def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): +# super().__init__() + +# self.token_embedding = nn.Embedding(n_vocab, n_state) +# self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) + +# self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( +# [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)] +# ) +# self.ln = LayerNorm(n_state) + +# mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) +# self.register_buffer("mask", mask, persistent=False) + +# def forward(self, x: torch.Tensor, xa: torch.Tensor, kv_cache: Optional[dict] = None): +# """ +# x : torch.Longtorch.Tensor, shape = (batch_size, <= n_ctx) +# the text tokens +# xa : torch.torch.Tensor, shape = (batch_size, n_mels, n_audio_ctx) +# the encoded audio features to be attended on +# """ +# offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 +# x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]] +# x = x.to(xa.dtype) + +# for block in self.blocks: +# x = block(x, xa, mask=self.mask, kv_cache=kv_cache) + +# x = self.ln(x) +# logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float() + +# return logits + + +class Whisper(paddle.nn.Layer): + def __init__(self, dims: ModelDimensions): + super().__init__() + self.dims = dims + self.encoder = AudioEncoder( + self.dims.n_mels, + self.dims.n_audio_ctx, + self.dims.n_audio_state, + self.dims.n_audio_head, + self.dims.n_audio_layer, + ) + # self.decoder = TextDecoder( + # self.dims.n_vocab, + # self.dims.n_text_ctx, + # self.dims.n_text_state, + # self.dims.n_text_head, + # self.dims.n_text_layer, + # ) + + # def embed_audio(self, mel: torch.torch.Tensor): + # return self.encoder(mel) + + # def logits(self, tokens: torch.torch.Tensor, audio_features: torch.torch.Tensor): + # return self.decoder(tokens, audio_features) + + # def forward(self, mel: torch.torch.Tensor, tokens: torch.torch.Tensor) -> Dict[str, torch.torch.Tensor]: + # return self.decoder(tokens, self.encoder(mel)) + + # @property + # def device(self): + # return next(self.parameters()).device + + # @property + # def is_multilingual(self): + # return self.dims.n_vocab == 51865 + + # def install_kv_cache_hooks(self, cache: Optional[dict] = None): + # """ + # The `MultiHeadAttention_torch` module optionally accepts `kv_cache` which stores the key and value + # torch.Tensors calculated for the previous positions. This method returns a dictionary that stores + # all caches, and the necessary hooks for the key and value projection modules that save the + # intermediate torch.Tensors to be reused during later calculations. + + # Returns + # ------- + # cache : Dict[nn.Module, torch.torch.Tensor] + # A dictionary object mapping the key/value projection modules to its cache + # hooks : List[RemovableHandle] + # List of PyTorch RemovableHandle objects to stop the hooks to be called + # """ + # cache = {**cache} if cache is not None else {} + # hooks = [] + + # def save_to_cache(module, _, output): + # if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]: + # cache[module] = output # save as-is, for the first token or cross attention + # else: + # cache[module] = torch.cat([cache[module], output], dim=1).detach() + # return cache[module] + + # def install_hooks(layer: nn.Module): + # if isinstance(layer, MultiHeadAttention_torch): + # hooks.append(layer.key.register_forward_hook(save_to_cache)) + # hooks.append(layer.value.register_forward_hook(save_to_cache)) + + # self.decoder.apply(install_hooks) + # return cache, hooks + + # detect_language = detect_language_function + # decode = decode_function + + + +class Whisper_torch(torch.nn.Module): + def __init__(self, dims: ModelDimensions): + super().__init__() + self.dims = dims + self.encoder = AudioEncoder_torch( + self.dims.n_mels, + self.dims.n_audio_ctx, + self.dims.n_audio_state, + self.dims.n_audio_head, + self.dims.n_audio_layer, + ) + # self.decoder = TextDecoder( + # self.dims.n_vocab, + # self.dims.n_text_ctx, + # self.dims.n_text_state, + # self.dims.n_text_head, + # self.dims.n_text_layer, + # ) + + # def embed_audio(self, mel: torch.torch.Tensor): + # return self.encoder(mel) + + # def logits(self, tokens: torch.torch.Tensor, audio_features: torch.torch.Tensor): + # return self.decoder(tokens, audio_features) + + # def forward(self, mel: torch.torch.Tensor, tokens: torch.torch.Tensor) -> Dict[str, torch.torch.Tensor]: + # return self.decoder(tokens, self.encoder(mel)) + + # @property + # def device(self): + # return next(self.parameters()).device + + # @property + # def is_multilingual(self): + # return self.dims.n_vocab == 51865 + + # def install_kv_cache_hooks(self, cache: Optional[dict] = None): + # """ + # The `MultiHeadAttention_torch` module optionally accepts `kv_cache` which stores the key and value + # torch.Tensors calculated for the previous positions. This method returns a dictionary that stores + # all caches, and the necessary hooks for the key and value projection modules that save the + # intermediate torch.Tensors to be reused during later calculations. + + # Returns + # ------- + # cache : Dict[nn.Module, torch.torch.Tensor] + # A dictionary object mapping the key/value projection modules to its cache + # hooks : List[RemovableHandle] + # List of PyTorch RemovableHandle objects to stop the hooks to be called + # """ + # cache = {**cache} if cache is not None else {} + # hooks = [] + + # def save_to_cache(module, _, output): + # if module not in cache or output.shape[1] > self.decoder.positional_embedding.shape[0]: + # cache[module] = output # save as-is, for the first token or cross attention + # else: + # cache[module] = torch.cat([cache[module], output], dim=1).detach() + # return cache[module] + + # def install_hooks(layer: nn.Module): + # if isinstance(layer, MultiHeadAttention_torch): + # hooks.append(layer.key.register_forward_hook(save_to_cache)) + # hooks.append(layer.value.register_forward_hook(save_to_cache)) + + # self.decoder.apply(install_hooks) + # return cache, hooks + + # detect_language = detect_language_function + # decode = decode_function + + +checkpoint_dims = { + 'n_mels': 80, + 'n_vocab': 51865, + 'n_audio_ctx': 1500, + 'n_audio_state': 1280, + 'n_audio_head': 20, + 'n_audio_layer': 32, + 'n_text_ctx': 448, + 'n_text_state': 1280, + 'n_text_head': 20, + 'n_text_layer': 32, +} + + +if __name__ == "__main__": + + dims = ModelDimensions(**checkpoint_dims) + + model_tc = Whisper_torch(dims).cuda() + model_pd = Whisper(dims) + + x = np.random.rand(1, 80, 1000).astype("float32") + x_tc = torch.from_numpy(x).cuda() + x_pd = paddle.to_tensor(x) + + AudioEncoder_torch2paddle(model_tc.encoder, model_pd.encoder) + + y_tc = model_tc.encoder( x_tc ).detach().cpu().numpy() + y_pd = model_pd.encoder( x_pd ).detach().cpu().numpy() + + print( + abs(y_tc - y_pd).max() + )