diff --git a/dfencoder/autoencoder.py b/dfencoder/autoencoder.py index 191723c..9f951d8 100644 --- a/dfencoder/autoencoder.py +++ b/dfencoder/autoencoder.py @@ -164,7 +164,8 @@ def __init__(self, run=None, progress_bar=True, n_megabatches=1, - scaler='standard', # scaler for the numerical features + scaler='standard', + patience=5, preset_cats=None, loss_scaler='standard', # scaler for the losses (z score) *args, @@ -225,8 +226,9 @@ def __init__(self, self.logdir = logdir self.run = run self.project_embeddings = project_embeddings - self.scaler = scaler + self.patience = patience + # scaler class used to scale losses and collect loss stats self.loss_scaler_str = loss_scaler self.loss_scaler = self.get_scaler(loss_scaler) @@ -688,6 +690,9 @@ def fit( n_updates = len(df) // self.batch_size if len(df) % self.batch_size > 0: n_updates += 1 + last_loss = 5000 + + count_es = 0 for i in range(epochs): self.train() if self.verbose: @@ -726,6 +731,28 @@ def fit( _, _, _, net_loss = self.compute_loss(num, bin, cat, slc_out, _id=True) id_loss.append(net_loss) + # Early stopping + current_net_loss = net_loss + if self.verbose: + print('The Current Net Loss:', current_net_loss) + + if current_net_loss > last_loss: + count_es += 1 + if self.verbose: + print('Early stop count:', count_es) + + if count_es >= self.patience: + if self.verbose: + print('Early stopping: early stop count({}) >= patience({})'.format(count_es, self.patience)) + break + + else: + if self.verbose: + print('Set count for earlystop: 0') + count_es = 0 + + last_loss = current_net_loss + self.logger.end_epoch() # if self.project_embeddings: # self.logger.show_embeddings(self.categorical_fts)