Identifying fish with contrastive#

Download this notebook identifying_fish_with_contrastive.ipynb

We will demonstrate that an embedding CNN (ResNet18) can be trained to differentiate fish images from various animals. For this, we will use miniCARP, a simplified version of a fish image library developed by the lab as part of the idtracker.ai 2019 paper.

[1]:
import random
from pathlib import Path
from random import randint

import gdown
import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.optimize import linear_sum_assignment
from sklearn.cluster import MiniBatchKMeans
from sklearn.manifold import TSNE
from sklearn.metrics import confusion_matrix
from torch import Tensor
from torch.nn.functional import pairwise_distance, relu
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torchvision.models.resnet import BasicBlock, ResNet
from tqdm.notebook import trange

if torch.cuda.is_available():
    print("GPU is enabled.")
    DEVICE = torch.device("cuda")
else:
    print("WARNING: GPU is not enabled in this environment.")
    DEVICE = torch.device("cpu")
GPU is enabled.

Let’s start by downloading the miniCARP dataset.

[2]:
CARP_PATH = Path("./miniCARP/")


def download_carp() -> None:
    if not CARP_PATH.is_dir():
        CARP_PATH.mkdir()
        gdown.download(
            id="19u3X339wNDOYgTr4AZp-L8O_7F5I82Wp",
            output=str(CARP_PATH / "miniCARP_images.npy"),
        )
        gdown.download(
            id="1NeULdkj6HrPw8inKHJ4Sqhnts_6KDcik",
            output=str(CARP_PATH / "miniCARP_labels.npy"),
        )


def load_carp(fraction: float) -> tuple[Tensor, Tensor]:
    images: np.ndarray = np.load(CARP_PATH / "miniCARP_images.npy")
    labels: np.ndarray = np.load(CARP_PATH / "miniCARP_labels.npy")
    num_images = int(np.floor(images.shape[0] * fraction))
    return torch.from_numpy(images[:num_images]), torch.from_numpy(labels[:num_images])


download_carp()
all_images, all_labels = load_carp(fraction=0.1)
all_images = all_images.unsqueeze(1) / 255

Let’s get some information about the dataset.

[3]:
print(f"Images shape: {all_images.shape}, dtype: {all_images.dtype}")
print(f"Labels shape: {all_labels.shape}, dtype: {all_labels.dtype}")
print(f"Label options: {torch.unique(all_labels).tolist()}")
NUM_LABELS = len(torch.unique(all_labels))
Images shape: torch.Size([7500, 1, 52, 52]), dtype: torch.float32
Labels shape: torch.Size([7500]), dtype: torch.uint8
Label options: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

We loaded 7500 images of size 52x52 corresponding to 15 different individuals, numbered 0-14. Here, we print some examples, together with the labels. Feel free to change the range of images and labels to be shown.

[4]:
def show_images(images, labels=None):
    fig, axes = plt.subplots(3, len(images) // 3)
    for image, ax in zip(images, axes.flatten()):
        ax.imshow(image[0].numpy(), cmap="binary_r", vmax=1)
        ax.set(xticks=(), yticks=())
    if labels is not None:
        for label, ax in zip(labels, axes.flatten()):
            ax.set(title=label.item())
    plt.show()


# Change range here to print other images
index = randint(0, len(all_images) - 15)
show_images(all_images[index : index + 15], labels=all_labels[index : index + 15])
../_images/good_videos_identifying_fish_with_contrastive_9_0.png

We divide the dataset into disjoint train/test datasets.

[5]:
test_size = 1000
labels = {"train": all_labels[:-test_size], "test": all_labels[-test_size:]}
images = {"train": all_images[:-test_size], "test": all_images[-test_size:]}

print(
    "Train and test are fractions {:.2%} and {:.2%}".format(
        len(images["train"]) / len(all_images), len(images["test"]) / len(all_images)
    )
)
Train and test are fractions 86.67% and 13.33%

We use a custom Dataset to generate one positive and one negative pair of images with each __getitem__ call. Positive pairs consist of images sampled from the same class, while negative pairs are drawn from different classes.

[6]:
class PairsDataset(Dataset):
    def __init__(self, images: Tensor, labels: Tensor):
        self.images = images
        self.labels = labels
        self.max_label = int(labels.max())
        self.images_per_class = {
            label: images[labels == label] for label in range(self.max_label + 1)
        }

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        positive_label = random.randint(0, self.max_label)
        negative_label_a, negative_label_b = random.sample(range(self.max_label + 1), 2)
        images = random.choices(self.images_per_class[positive_label], k=2) + [
            random.choice(self.images_per_class[negative_label_a]),
            random.choice(self.images_per_class[negative_label_b]),
        ]
        return images

Let’s set up the DataLoaders:

  • train_dataloader: This will supply positive and negative pairs of images for training the network.

  • val_dataloader: This will provide the same images as the training set but in a naive, flat structure. It will be used to compute the Silhouette score during training.

  • test_dataloader: This will provide test images in a flat structure, similar to the validation DataLoader, for evaluation.

[7]:
train_dataloader = DataLoader(
    PairsDataset(images["train"], labels["train"]),
    batch_size=400,
    num_workers=3,
    persistent_workers=True,
)
val_dataloader = DataLoader(
    TensorDataset(images["train"], labels["train"]),
    batch_size=400,
    num_workers=3,
    persistent_workers=True,
)
test_dataloader = DataLoader(
    TensorDataset(images["test"], labels["test"]),
    batch_size=400,
    num_workers=3,
    persistent_workers=True,
)

We use ResNet18 as our embedding network, incorporating two modifications to the standard PyTorch implementation:

  1. The first layer is adjusted to accept single-channel tensors for grayscale images.

  2. The final layer is configured to have 8 neurons by setting num_classes=8, creating an 8-dimensional representation space.

[8]:
# from idtrackerai.base.network import ResNet18 # the network is originally defined in idtrackerai


class ResNet18(ResNet):
    def __init__(self, n_channels_in: int = 1, n_dimensions_out: int = 8) -> None:
        super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=n_dimensions_out)
        if n_channels_in != 3:
            # adapt first conv layer to our single channel images (not RGB)
            self.conv1 = torch.nn.Conv2d(
                n_channels_in, 64, kernel_size=7, stride=2, padding=3, bias=False
            )


resnet = ResNet18().to(DEVICE)
optimizer = torch.optim.Adam(resnet.parameters(), lr=0.001)

print(resnet)
print(sum(p.numel() for p in resnet.parameters()) / 1.0e6, "M parameters")
ResNet18(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=8, bias=True)
)
11.174344 M parameters

The Silhouette score is a measure of how close a datapoint is to its own cluster (cohesion) compared to other clusters (separation). This score increases with the definition of the clusters in our representation space.

[9]:
# from idtrackerai.base.tracker.contrastive import silhouette_scores # the silhouette_scores function is originally defined in idtrackerai


def silhouette_scores(X: Tensor, labels: Tensor) -> Tensor:
    """Silhouette score implemented in PyTorch for GPU acceleration

    .. seealso::

        `Wikipedia's entry on Silhouette score <https://en.wikipedia.org/wiki/Silhouette_(clustering)>`_

    Parameters
    ----------
    X : Tensor
        Data points of shape (n_samples, n_features)
    labels : Tensor
        Predicted label for each sample, shape (n_samples)

    Returns
    -------
    Tensor
        Silhouette score for each sample, shape (n_samples). Same device and dtype as X
    """
    unique_labels = torch.unique(labels)

    intra_dist = torch.zeros(labels.size(), dtype=X.dtype, device=X.device)
    for label in unique_labels:
        where = labels == label
        X_where = X[where]
        distances = torch.cdist(X_where, X_where)
        intra_dist[where] = distances.sum(dim=1) / (len(distances) - 1)

    inter_dist = torch.full(labels.size(), torch.inf, dtype=X.dtype, device=X.device)

    for label_a, label_b in torch.combinations(unique_labels, 2):
        where_a = labels == label_a
        where_b = labels == label_b

        dist = torch.cdist(X[where_a], X[where_b])
        dist_a = dist.mean(dim=1)
        dist_b = dist.mean(dim=0)

        inter_dist[where_a] = torch.minimum(dist_a, inter_dist[where_a])
        inter_dist[where_b] = torch.minimum(dist_b, inter_dist[where_b])

    sil_samples = (inter_dist - intra_dist) / torch.maximum(intra_dist, inter_dist)

    return sil_samples.nan_to_num()
[10]:
def contrastive_loss_function(
    pos_embs_A: Tensor, pos_embs_B: Tensor, neg_embs_A: Tensor, neg_embs_B: Tensor
) -> Tensor:
    positive_distance = pairwise_distance(pos_embs_A, pos_embs_B)
    positive_loss = relu(positive_distance - 1).square().sum()

    negative_distance = pairwise_distance(neg_embs_A, neg_embs_B)
    negative_loss = relu(10 - negative_distance).square().sum()

    return positive_loss + negative_loss

Lets train!

[11]:
train_loss = []
silhouettes = []
best_silhouette = 0

for _ in trange(40, desc="Training"):
    running_loss = 0.0

    resnet.train()
    # Train on a batch of images
    for pos_imgs_A, pos_imgs_B, neg_imgs_A, neg_imgs_B in train_dataloader:
        optimizer.zero_grad(set_to_none=True)
        pos_imgs_A = pos_imgs_A.to(DEVICE, non_blocking=True)
        pos_imgs_B = pos_imgs_B.to(DEVICE, non_blocking=True)
        neg_imgs_A = neg_imgs_A.to(DEVICE, non_blocking=True)
        neg_imgs_B = neg_imgs_B.to(DEVICE, non_blocking=True)

        pos_embs_A = resnet(pos_imgs_A)
        pos_embs_B = resnet(pos_imgs_B)
        neg_embs_A = resnet(neg_imgs_A)
        neg_embs_B = resnet(neg_imgs_B)

        total_loss = contrastive_loss_function(
            pos_embs_A, pos_embs_B, neg_embs_A, neg_embs_B
        )
        total_loss.backward()
        optimizer.step()

        running_loss += total_loss.item()

    train_loss.append(running_loss / len(train_dataloader))

    resnet.eval()
    with torch.inference_mode():
        embeddings = torch.concatenate(
            [resnet(images.to(DEVICE)) for (images, _labels) in val_dataloader]
        ).detach()
        kmeans = MiniBatchKMeans(n_clusters=NUM_LABELS)
        labels = torch.from_numpy(kmeans.fit_predict(embeddings.numpy(force=True)))
        silhouette = silhouette_scores(embeddings, labels.to(DEVICE)).mean().item()
    if silhouette > best_silhouette:
        best_silhouette = silhouette
        torch.save(resnet.state_dict(), "checkpoint.pt")
        print(f"Best silhouette: {best_silhouette:.4f}")
    silhouettes.append(silhouette)
resnet.load_state_dict(torch.load("checkpoint.pt", weights_only=True))
Path("checkpoint.pt").unlink()
print(f"Loaded best model with silhouette: {best_silhouette:.4f}")
Best silhouette: 0.2467
Best silhouette: 0.2833
Best silhouette: 0.3005
Best silhouette: 0.3695
Best silhouette: 0.4218
Best silhouette: 0.4476
Best silhouette: 0.5794
Best silhouette: 0.6316
Best silhouette: 0.6359
Best silhouette: 0.7191
Best silhouette: 0.7539
Best silhouette: 0.7689
Best silhouette: 0.8047
Best silhouette: 0.8284
Best silhouette: 0.8490
Best silhouette: 0.9127
Best silhouette: 0.9210
Best silhouette: 0.9265
Loaded best model with silhouette: 0.9265

Observe how the loss decreases and the Silhouette score increases during training. Due to the instability in cluster shapes (as indicated by the frequent drops in the Silhouette score), we use checkpoints to reload the network with the highest Silhouette score.

[12]:
fig, (ax_loss, ax_silhouettes) = plt.subplots(1, 2, figsize=(8, 4))
ax_loss.plot(train_loss)
ax_silhouettes.plot(silhouettes)
ax_loss.set(ylabel="Loss", xlabel="Epoch")
ax_silhouettes.set(ylabel="Silhouette", xlabel="Epoch")
fig.tight_layout()
../_images/good_videos_identifying_fish_with_contrastive_24_0.png

Let’s examine the clustering results for the training dataset. Since the cluster identities may not align with the ground truth, we will permute them to ensure they match.

[13]:
plt.style.use("dark_background")


def evaluate_predictions(
    gt_labels: np.ndarray, predicted_labels: np.ndarray, tsne: np.ndarray
):
    """Evaluate the predictions and plot the confusion matrix and t-SNE embeddings"""
    confusion = confusion_matrix(gt_labels, predicted_labels)
    ids, matches = linear_sum_assignment(-confusion.T)

    fig, (confusion_ax, ground_truth_ax, prediction_errors_ax) = plt.subplots(
        1, 3, figsize=(13, 5)
    )
    confusion_ax.set_title("Confusion matrix")
    ground_truth_ax.set_title("Ground truth")
    prediction_errors_ax.set_title("Error")

    confusion_ax.imshow(confusion, interpolation="none")
    confusion_ax.plot(ids, matches, "ro", label="Identity matching")
    confusion_ax.legend()
    confusion_ax.set(xlabel="Predicted", ylabel="Ground truth")
    confusion_ax.set(xticks=range(NUM_LABELS), yticks=range(NUM_LABELS))

    print(f"Accuracy: {confusion[matches, ids].sum() / len(gt_labels):.2%}")
    matched_labels = np.asarray([matches[label] for label in predicted_labels])

    for confusion_ax in (ground_truth_ax, prediction_errors_ax):
        confusion_ax.set_aspect("equal")
        confusion_ax.set(xticks=(), yticks=())

    ground_truth_ax.scatter(
        *tsne.T, c=(gt_labels / np.max(gt_labels) + 1), cmap="hsv", s=3, lw=0
    )
    prediction_errors_ax.scatter(
        *tsne.T,
        c=(gt_labels == matched_labels),
        cmap="RdYlGn",
        s=3,
        lw=0,
        vmin=0,
        vmax=1,
    )
    fig.tight_layout()
[14]:
with torch.inference_mode():
    embeddings = []
    labels = []
    for images, _labels in val_dataloader:
        labels.append(_labels)
        embeddings.append(resnet(images.to(DEVICE)).numpy(force=True))
    gt_labels = np.concatenate(labels)
    embeddings = np.concatenate(embeddings)
    kmeans = MiniBatchKMeans(n_clusters=NUM_LABELS)
    predicted_labels = kmeans.fit_predict(embeddings)

tsne = TSNE(n_jobs=4).fit_transform(embeddings)
evaluate_predictions(gt_labels, predicted_labels, tsne)
Accuracy: 99.85%
../_images/good_videos_identifying_fish_with_contrastive_27_1.png

Now, let’s evaluate on the test dataset. These images have not been seen by the network during training. We will assign images to clusters by using the cluster centers obtained from the training step.

[15]:
with torch.inference_mode():
    embeddings = []
    labels = []
    for images, _labels in test_dataloader:
        labels.append(_labels)
        embeddings.append(resnet(images.to(DEVICE)).numpy(force=True))
    gt_labels = np.concatenate(labels)
    embeddings = np.concatenate(embeddings)
    # get predicted labels from previous cluster centers
    predicted_labels = kmeans.predict(embeddings)

tsne = TSNE(n_jobs=4).fit_transform(embeddings)
evaluate_predictions(gt_labels, predicted_labels, tsne)
Accuracy: 93.90%
../_images/good_videos_identifying_fish_with_contrastive_29_1.png

It’s important to note that when tracking a video, all images are used to train the embedding network, so there is no separate test dataset. Additionally, positive and negative pairs are not generated from labeled ground truth, as such labels do not exist in this scenario. Instead, positive pairs are formed from images within the same fragment, while negative pairs are created from pairs of coexisting fragments.

Download this notebook identifying_fish_with_contrastive.ipynb