Notebook to illustrate fish identification with contrastive learning#
or download this notebook
identifying_fish_with_contrastive.ipynb
We will demonstrate that an embedding CNN (ResNet18) can be trained to differentiate fish images from various animals. For this, we will use miniCARP, a simplified version of a fish image library developed by the lab as part of the idtracker.ai 2019 paper.
[1]:
import random
from pathlib import Path
from random import randint
import gdown
import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.optimize import linear_sum_assignment
from sklearn.cluster import MiniBatchKMeans
from sklearn.manifold import TSNE
from sklearn.metrics import confusion_matrix
from torch import Tensor
from torch.nn.functional import pairwise_distance, relu
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torchvision.models.resnet import BasicBlock, ResNet
from tqdm.notebook import trange
if torch.cuda.is_available():
print("GPU is enabled.")
DEVICE = torch.device("cuda")
else:
print("WARNING: GPU is not enabled in this environment.")
DEVICE = torch.device("cpu")
GPU is enabled.
Let’s start by downloading the miniCARP dataset.
[2]:
CARP_PATH = Path("./miniCARP/")
def download_carp() -> None:
if not CARP_PATH.is_dir():
CARP_PATH.mkdir()
gdown.download(
id="19u3X339wNDOYgTr4AZp-L8O_7F5I82Wp",
output=str(CARP_PATH / "miniCARP_images.npy"),
)
gdown.download(
id="1NeULdkj6HrPw8inKHJ4Sqhnts_6KDcik",
output=str(CARP_PATH / "miniCARP_labels.npy"),
)
def load_carp(fraction: float) -> tuple[Tensor, Tensor]:
images: np.ndarray = np.load(CARP_PATH / "miniCARP_images.npy")
labels: np.ndarray = np.load(CARP_PATH / "miniCARP_labels.npy")
num_images = int(np.floor(images.shape[0] * fraction))
return torch.from_numpy(images[:num_images]), torch.from_numpy(labels[:num_images])
download_carp()
images, labels = load_carp(fraction=0.1)
images = images.unsqueeze(1) / 255
Let’s get some information about the dataset.
[3]:
print(f"Images shape: {images.shape}, dtype: {images.dtype}")
print(f"Labels shape: {labels.shape}, dtype: {labels.dtype}")
print(f"Label options: {torch.unique(labels).tolist()}")
NUM_LABELS = len(torch.unique(labels))
Images shape: torch.Size([7500, 1, 52, 52]), dtype: torch.float32
Labels shape: torch.Size([7500]), dtype: torch.uint8
Label options: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
We loaded 7500 images of size 52x52 corresponding to 15 different individuals, numbered 0-14. Here, we print some examples, together with the labels. Feel free to change the range of images and labels to be shown.
[4]:
def show_images(images, labels=None):
fig, axes = plt.subplots(3, len(images) // 3)
for image, ax in zip(images, axes.flatten()):
ax.imshow(image[0].numpy(), cmap="binary_r", vmax=1)
ax.set(xticks=(), yticks=())
if labels is not None:
for label, ax in zip(labels, axes.flatten()):
ax.set(title=label.item())
plt.show()
# Change range here to print other images
index = randint(0, len(images) - 15)
show_images(images[index : index + 15], labels=labels[index : index + 15])
We use a custom Dataset to generate one positive and one negative pair of images with each __getitem__ call. Positive pairs consist of images sampled from the same class, while negative pairs are drawn from different classes.
[5]:
class PairsDataset(Dataset):
def __init__(self, images: Tensor, labels: Tensor):
self.images = images
self.labels = labels
self.max_label = int(labels.max())
self.images_per_class = {
label: images[labels == label] for label in range(self.max_label + 1)
}
def __len__(self):
return len(self.images)
def __getitem__(self, index):
positive_label = random.randint(0, self.max_label)
negative_label_a, negative_label_b = random.sample(range(self.max_label + 1), 2)
images = random.choices(self.images_per_class[positive_label], k=2) + [
random.choice(self.images_per_class[negative_label_a]),
random.choice(self.images_per_class[negative_label_b]),
]
return images
Let’s set up the DataLoaders:
train_dataloader: This will supply positive and negative pairs of images for training the network.val_dataloader: This will provide the same images as the training set but in a naive, flat structure. It will be used to compute the Silhouette score during training.
[6]:
train_dataloader = DataLoader(
PairsDataset(images, labels), batch_size=400, num_workers=1, persistent_workers=True
)
val_dataloader = DataLoader(
TensorDataset(images, labels),
batch_size=400,
num_workers=1,
persistent_workers=True,
)
We use ResNet18 as our embedding network, incorporating two modifications to the standard PyTorch implementation:
The first layer is adjusted to accept single-channel tensors for grayscale images.
The final layer is configured to have 8 neurons by setting
num_classes=8, creating an 8-dimensional representation space.
[7]:
# from idtrackerai.base.network import ResNet18 # the network is originally defined in idtrackerai
class ResNet18(ResNet):
def __init__(self, n_channels_in: int = 1, n_dimensions_out: int = 8) -> None:
super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=n_dimensions_out)
if n_channels_in != 3:
# adapt first conv layer to our single channel images (not RGB)
self.conv1 = torch.nn.Conv2d(
n_channels_in, 64, kernel_size=7, stride=2, padding=3, bias=False
)
resnet = ResNet18().to(DEVICE)
optimizer = torch.optim.Adam(resnet.parameters(), lr=0.001)
print(resnet)
print(sum(p.numel() for p in resnet.parameters()) / 1.0e6, "M parameters")
ResNet18(
(conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=8, bias=True)
)
11.174344 M parameters
The Silhouette score is a measure of how close a datapoint is to its own cluster (cohesion) compared to other clusters (separation). This score increases with the definition of the clusters in our representation space.
[8]:
# from idtrackerai.base.tracker.contrastive import silhouette_scores # the silhouette_scores function is originally defined in idtrackerai
def silhouette_scores(X: Tensor, labels: Tensor) -> Tensor:
"""Silhouette score implemented in PyTorch for GPU acceleration
.. seealso::
`Wikipedia's entry on Silhouette score <https://en.wikipedia.org/wiki/Silhouette_(clustering)>`_
Parameters
----------
X : Tensor
Data points of shape (n_samples, n_features)
labels : Tensor
Predicted label for each sample, shape (n_samples)
Returns
-------
Tensor
Silhouette score for each sample, shape (n_samples). Same device and dtype as X
"""
unique_labels = torch.unique(labels)
intra_dist = torch.zeros(labels.size(), dtype=X.dtype, device=X.device)
for label in unique_labels:
where = labels == label
X_where = X[where]
distances = torch.cdist(X_where, X_where)
intra_dist[where] = distances.sum(dim=1) / (len(distances) - 1)
inter_dist = torch.full(labels.size(), torch.inf, dtype=X.dtype, device=X.device)
for label_a, label_b in torch.combinations(unique_labels, 2):
where_a = labels == label_a
where_b = labels == label_b
dist = torch.cdist(X[where_a], X[where_b])
dist_a = dist.mean(dim=1)
dist_b = dist.mean(dim=0)
inter_dist[where_a] = torch.minimum(dist_a, inter_dist[where_a])
inter_dist[where_b] = torch.minimum(dist_b, inter_dist[where_b])
sil_samples = (inter_dist - intra_dist) / torch.maximum(intra_dist, inter_dist)
return sil_samples.nan_to_num()
[9]:
def contrastive_loss_function(
pos_embs_A: Tensor, pos_embs_B: Tensor, neg_embs_A: Tensor, neg_embs_B: Tensor
) -> Tensor:
positive_distance = pairwise_distance(pos_embs_A, pos_embs_B)
positive_loss = relu(positive_distance - 1).square().sum()
negative_distance = pairwise_distance(neg_embs_A, neg_embs_B)
negative_loss = relu(10 - negative_distance).square().sum()
return positive_loss + negative_loss
Lets train!
[10]:
train_loss = []
silhouettes = []
best_silhouette = 0
for _ in trange(40, desc="Training"):
running_loss = 0.0
resnet.train()
# Train on a batch of images
for pos_imgs_A, pos_imgs_B, neg_imgs_A, neg_imgs_B in train_dataloader:
optimizer.zero_grad(set_to_none=True)
pos_imgs_A = pos_imgs_A.to(DEVICE, non_blocking=True)
pos_imgs_B = pos_imgs_B.to(DEVICE, non_blocking=True)
neg_imgs_A = neg_imgs_A.to(DEVICE, non_blocking=True)
neg_imgs_B = neg_imgs_B.to(DEVICE, non_blocking=True)
pos_embs_A = resnet(pos_imgs_A)
pos_embs_B = resnet(pos_imgs_B)
neg_embs_A = resnet(neg_imgs_A)
neg_embs_B = resnet(neg_imgs_B)
total_loss = contrastive_loss_function(
pos_embs_A, pos_embs_B, neg_embs_A, neg_embs_B
)
total_loss.backward()
optimizer.step()
running_loss += total_loss.item()
train_loss.append(running_loss / len(train_dataloader))
resnet.eval()
with torch.inference_mode():
embeddings = torch.concatenate(
[resnet(images.to(DEVICE)) for (images, _labels) in val_dataloader]
).detach()
kmeans = MiniBatchKMeans(n_clusters=NUM_LABELS)
labels = torch.from_numpy(kmeans.fit_predict(embeddings.numpy(force=True)))
silhouette = silhouette_scores(embeddings, labels.to(DEVICE)).mean().item()
if silhouette > best_silhouette:
best_silhouette = silhouette
torch.save(resnet.state_dict(), "checkpoint.pt")
print(f"Best silhouette: {best_silhouette:.4f}")
silhouettes.append(silhouette)
resnet.load_state_dict(torch.load("checkpoint.pt", weights_only=True))
Path("checkpoint.pt").unlink()
print(f"Loaded best model with silhouette: {best_silhouette:.4f}")
Best silhouette: 0.2732
Best silhouette: 0.2986
Best silhouette: 0.3018
Best silhouette: 0.3153
Best silhouette: 0.4331
Best silhouette: 0.4374
Best silhouette: 0.5050
Best silhouette: 0.5091
Best silhouette: 0.5779
Best silhouette: 0.5998
Best silhouette: 0.6979
Best silhouette: 0.7969
Best silhouette: 0.8114
Best silhouette: 0.8731
Best silhouette: 0.8766
Best silhouette: 0.8860
Best silhouette: 0.9216
Best silhouette: 0.9311
Best silhouette: 0.9400
Loaded best model with silhouette: 0.9400
Observe how the loss decreases and the Silhouette score increases during training. Due to the instability in cluster shapes (as indicated by the frequent drops in the Silhouette score), we use checkpoints to reload the network with the highest Silhouette score.
[11]:
fig, (ax_loss, ax_silhouettes) = plt.subplots(1, 2, figsize=(8, 4))
ax_loss.plot(train_loss)
ax_silhouettes.plot(silhouettes)
ax_loss.set(ylabel="Loss", xlabel="Epoch")
ax_silhouettes.set(ylabel="Silhouette", xlabel="Epoch")
fig.tight_layout()
Let’s examine the clustering results for the training dataset. Since the cluster identities may not align with the ground truth, we will permute them to ensure they match.
[12]:
plt.style.use("dark_background")
def evaluate_predictions(
gt_labels: np.ndarray, predicted_labels: np.ndarray, tsne: np.ndarray
):
"""Evaluate the predictions and plot the confusion matrix and t-SNE embeddings"""
confusion = confusion_matrix(gt_labels, predicted_labels)
ids, matches = linear_sum_assignment(-confusion.T)
confusion = confusion[matches]
fig, (prediction_errors_ax, ground_truth_ax, confusion_ax) = plt.subplots(
1, 3, figsize=(13, 5)
)
confusion_ax.set_title("Confusion matrix")
ground_truth_ax.set_title("Ground truth")
prediction_errors_ax.set_title("Error")
confusion_ax.imshow(confusion, interpolation="none")
confusion_ax.set(xlabel="Predicted", ylabel="Ground truth")
confusion_ax.set(xticks=range(NUM_LABELS), yticks=range(NUM_LABELS))
print(f"Accuracy: {np.trace(confusion) / np.sum(confusion):.2%}")
matched_labels = np.asarray([matches[label] for label in predicted_labels])
for confusion_ax in (ground_truth_ax, prediction_errors_ax):
confusion_ax.set_aspect("equal")
confusion_ax.set(xticks=(), yticks=())
ground_truth_ax.scatter(
*tsne.T, c=(gt_labels / np.max(gt_labels) + 1), cmap="hsv", s=3, lw=0
)
prediction_errors_ax.scatter(
*tsne.T,
c=(gt_labels == matched_labels),
cmap="RdYlGn",
s=3,
lw=0,
vmin=0,
vmax=1,
)
fig.tight_layout()
[13]:
with torch.inference_mode():
embeddings = []
labels = []
for images, _labels in val_dataloader:
labels.append(_labels)
embeddings.append(resnet(images.to(DEVICE)).numpy(force=True))
gt_labels = np.concatenate(labels)
embeddings = np.concatenate(embeddings)
kmeans = MiniBatchKMeans(n_clusters=NUM_LABELS)
predicted_labels = kmeans.fit_predict(embeddings)
tsne = TSNE(n_jobs=4).fit_transform(embeddings)
evaluate_predictions(gt_labels, predicted_labels, tsne)
Accuracy: 99.97%
This way we prove that, by providing positive and negative pairs of images, we can train a CNN to learn a representation space where images from the same animal are clustered together, while images from different animals are separated.
In a real tracking scenario, we would obtain the positive and negative pairs from the fragments structure of the video, where positive pairs are images from the same fragment and negative pairs are images from different coexisting fragments.
or download this notebook
identifying_fish_with_contrastive.ipynb