You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I recently created an article on how you can fine-tune your own CLIP models after searching the web for resources and being unable to find any. In this discussion, I've decided to summarise my article so that hopefully, beginners or others looking to fine-tune CLIP models can do so with ease!
For the full code and a guided walk-through visit this article.
1. Load a Dataset
To perform fine-tuning, we will use a small image classification dataset. We’ll use the ceyda/fashion-products-small dataset which is a collection of fashion products.
fromdatasetsimportload_dataset# Load the datasetds=load_dataset('ceyda/fashion-products-small')
image=entry['image']
dataset=ds['train']
2. Load CLIP Model and Preprocessing
importclipimporttorch# OpenAI CLIP model and preprocessingmodel, preprocess=clip.load("ViT-B/32", jit=False)
device=torch.device("cuda"iftorch.cuda.is_available() else"cpu")
model.to(device)
Let's take a look at how well our base CLIP model performs image classification on this dataset.
importmatplotlib.pyplotasplt# Select indices for three example imagesindices= [0, 2, 10]
# Get the list of possible subcategories from the datasetsubcategories=list(set(example['subCategory'] forexampleindataset))
# Preprocess the text descriptions for each subcategorytext_inputs=torch.cat([clip.tokenize(f"a photo of {c}") forcinsubcategories]).to(device)
# Create a figure with subplotsfig, axes=plt.subplots(1, 3, figsize=(15, 5))
# Loop through the indices and process each imagefori, idxinenumerate(indices):
# Select an example image from the datasetexample=dataset[idx]
image=example['image']
subcategory=example['subCategory']
# Preprocess the imageimage_input=preprocess(image).unsqueeze(0).to(device)
# Calculate image and text featureswithtorch.no_grad():
image_features=model.encode_image(image_input)
text_features=model.encode_text(text_inputs)
# Normalize the featuresimage_features/=image_features.norm(dim=-1, keepdim=True)
text_features/=text_features.norm(dim=-1, keepdim=True)
# Calculate similarity between image and text featuressimilarity= (100.0*image_features @ text_features.T).softmax(dim=-1)
values, indices=similarity[0].topk(1)
# Display the image in the subplotaxes[i].imshow(image)
axes[i].set_title(f"Predicted: {subcategories[indices[0]]}, Actual: {subcategory}")
axes[i].axis('off')
# Show the plotplt.tight_layout()
plt.show()
3. Processing the Dataset
First, we must split our dataset into training and validation sets. This step is crucial because it allows us to evaluate the performance of our machine learning model on unseen data, ensuring that the model generalizes well to new, real-world data rather than just the data it was trained on.
We take 80% of the original dataset to train our model and the remaining 20% as the validation data.
fromtorch.utils.dataimportrandom_split# Split dataset into training and validation setstrain_size=int(0.8*len(dataset))
val_size=len(dataset) -train_sizetrain_dataset, val_dataset=random_split(dataset, [train_size, val_size])
fromtorch.utils.dataimportDataLoader# Create DataLoader for training and validation setstrain_loader=DataLoader(FashionDataset(train_dataset), batch_size=32, shuffle=True)
val_loader=DataLoader(FashionDataset(val_dataset), batch_size=32, shuffle=False)
Next, we modify the model for fine-tuning:
importtorch.nnasnn# Modify the model to include a classifier for subcategoriesclassCLIPFineTuner(nn.Module):
def__init__(self, model, num_classes):
super(CLIPFineTuner, self).__init__()
self.model=modelself.classifier=nn.Linear(model.visual.output_dim, num_classes)
defforward(self, x):
withtorch.no_grad():
features=self.model.encode_image(x).float() # Convert to float32returnself.classifier(features)
importtorch.optimasoptim# Define the loss function and optimizercriterion=nn.CrossEntropyLoss()
optimizer=optim.Adam(model_ft.classifier.parameters(), lr=1e-4)
5. Fine-Tuning CLIP Model
fromtqdmimporttqdm# Number of epochs for trainingnum_epochs=5# Training loopforepochinrange(num_epochs):
model_ft.train() # Set the model to training moderunning_loss=0.0# Initialize running loss for the current epochpbar=tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}, Loss: 0.0000") # Initialize progress barforimages, labelsinpbar:
images, labels=images.to(device), labels.to(device) # Move images and labels to the device (GPU or CPU)optimizer.zero_grad() # Clear the gradients of all optimized variablesoutputs=model_ft(images) # Forward pass: compute predicted outputs by passing inputs to the modelloss=criterion(outputs, labels) # Calculate the lossloss.backward() # Backward pass: compute gradient of the loss with respect to model parametersoptimizer.step() # Perform a single optimization step (parameter update)running_loss+=loss.item() # Update running losspbar.set_description(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}") # Update progress bar with current lossprint(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}') # Print average loss for the epoch# Validationmodel_ft.eval() # Set the model to evaluation modecorrect=0# Initialize correct predictions countertotal=0# Initialize total samples counterwithtorch.no_grad(): # Disable gradient calculation for validationforimages, labelsinval_loader:
images, labels=images.to(device), labels.to(device) # Move images and labels to the deviceoutputs=model_ft(images) # Forward pass: compute predicted outputs by passing inputs to the model_, predicted=torch.max(outputs.data, 1) # Get the class label with the highest probabilitytotal+=labels.size(0) # Update total samplescorrect+= (predicted==labels).sum().item() # Update correct predictionsprint(f'Validation Accuracy: {100*correct/total}%') # Print validation accuracy for the epoch# Save the fine-tuned modeltorch.save(model_ft.state_dict(), 'clip_finetuned.pth') # Save the model's state dictionary
6. Performance
Amazing! Let's now take a look at how our new model performs on the same images we tested earlier.
importmatplotlib.pyplotaspltimporttorchfromtorchvisionimporttransforms# Load the saved model weightsmodel_ft.load_state_dict(torch.load('clip_finetuned.pth'))
model_ft.eval() # Set the model to evaluation mode# Define the indices for the three imagesindices= [0, 2, 10]
# Preprocess the imagetransform=transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
# Create a figure with subplotsfig, axes=plt.subplots(1, 3, figsize=(15, 5))
# Loop through the indices and process each imagefori, idxinenumerate(indices):
# Get the image and label from the datasetitem=dataset[idx]
image=item['image']
true_label=item['subCategory']
# Transform the imageimage_tensor=transform(image).unsqueeze(0).to(device) # Add batch dimension and move to device# Perform inferencewithtorch.no_grad():
output=model_ft(image_tensor)
_, predicted_label_idx=torch.max(output, 1)
predicted_label=subcategories[predicted_label_idx.item()]
# Display the image in the subplotaxes[i].imshow(image)
axes[i].set_title(f'True label: {true_label}\nPredicted label: {predicted_label}')
axes[i].axis('off')
# Show the plotplt.tight_layout()
plt.show()
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Fine-Tuning CLIP Models - A Beginner's Guide
Introduction
I recently created an article on how you can fine-tune your own CLIP models after searching the web for resources and being unable to find any. In this discussion, I've decided to summarise my article so that hopefully, beginners or others looking to fine-tune CLIP models can do so with ease!
For the full code and a guided walk-through visit this article.
1. Load a Dataset
To perform fine-tuning, we will use a small image classification dataset. We’ll use the
ceyda/fashion-products-small dataset
which is a collection of fashion products.2. Load CLIP Model and Preprocessing
Let's take a look at how well our base CLIP model performs image classification on this dataset.
3. Processing the Dataset
First, we must split our dataset into training and validation sets. This step is crucial because it allows us to evaluate the performance of our machine learning model on unseen data, ensuring that the model generalizes well to new, real-world data rather than just the data it was trained on.
We take 80% of the original dataset to train our model and the remaining 20% as the validation data.
Next, we create a custom dataset class:
Next, we create DataLoaders:
Next, we modify the model for fine-tuning:
Finally, we instantiate the fine-tuning model:
4. Define Loss Function and Optimizer
5. Fine-Tuning CLIP Model
6. Performance
Amazing! Let's now take a look at how our new model performs on the same images we tested earlier.
Beta Was this translation helpful? Give feedback.
All reactions