Skip to content
This repository has been archived by the owner on Mar 29, 2023. It is now read-only.

Add early stop to autoencoder #2

Merged
8 commits merged into from
Jan 27, 2023
30 changes: 29 additions & 1 deletion dfencoder/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def __init__(self,
progress_bar=True,
n_megabatches=1,
scaler='standard',
patience=5,
preset_cats=None,
*args,
**kwargs):
Expand Down Expand Up @@ -226,6 +227,8 @@ def __init__(self,
self.project_embeddings = project_embeddings

self.scaler = scaler

self.patience = patience

self.n_megabatches = n_megabatches

Expand Down Expand Up @@ -656,6 +659,9 @@ def fit(self, df, epochs=1, val=None):
n_updates = len(df) // self.batch_size
if len(df) % self.batch_size > 0:
n_updates += 1
last_loss = 5000

count_es = 0
for i in range(epochs):
self.train()
if self.verbose:
Expand Down Expand Up @@ -694,6 +700,28 @@ def fit(self, df, epochs=1, val=None):
_, _, _, net_loss = self.compute_loss(num, bin, cat, slc_out, _id=True)
id_loss.append(net_loss)

# Early stopping
current_net_loss = net_loss
if self.verbose:
print('The Current Net Loss:', current_net_loss)

if current_net_loss > last_loss:
count_es += 1
if self.verbose:
print('Early stop count:', count_es)

if count_es >= self.patience:
if self.verbose:
print('Early stopping: early stop count({}) >= patience({})'.format(count_es, self.patience))
break

else:
if self.verbose:
print('Set count for earlystop: 0')
count_es = 0

last_loss = current_net_loss

self.logger.end_epoch()
# if self.project_embeddings:
# self.logger.show_embeddings(self.categorical_fts)
Expand Down Expand Up @@ -1050,4 +1078,4 @@ def get_results(self, df, return_abs=False):
pdf['max_abs_z'] = combined_loss.max(dim=1)[0].cpu().numpy()
pdf['mean_abs_z'] = combined_loss.mean(dim=1).cpu().numpy()

return pdf
return pdf