Feature Extraction and Analysis on Fish in MegaLab CameraΒΆ

Happy New Year!

In a previous notebook, I was playing around with Gstreamer and spectrogram generation on the Megalab camera feed. I also played with the Community Fish Detector model to evaluate fish detections on the stream. The results were really interesting, but computationally slow (just running on my laptop).

I did manage to collect around four and half hours of footage from the Megalab camera during the evening of December 12th, 2025, ranging from around 13:30 to 17:50 (fortunately the camera feed also contains the timestamps). But running the inference on the entire footage on my Nvidia RTX 4070 laptop GPU was going to take some time. It was effectively taking 1.5 seconds per frame, and at 30 FPS for 4.5 hours, that would take around a week of continuous processing!

So I got to thinking - what would I do with all these fish detections once I had them? Well...

  • I don't really know what fish are detected
  • I don't know how many unique types of fish there are
  • I don't know what other types of creatures or objects might be detected
  • I don't have any tracks of fish that were detected multiple times.

What I do have is a large set of bounding boxes and confidence scores for each frame. And at the time of writing this, I was still waiting for the inference to finish running on the entire footage!

InΒ [1]:
%load_ext autoreload
%autoreload 2

import os
import polars as pl
import warnings

warnings.filterwarnings("ignore")

df = pl.read_delta("../../../megalab_recordings/data/inference")


df = df.with_columns(
    pl.col("video_file").map_elements(lambda f: int(os.path.splitext(f)[0].split('_')[-1]), ).alias("video_id"),
)

# sort by video_id and frame_id
df = df.sort(["video_id", "frame_id"])

# assuming each video to 10 seconds long, and 300 frames
# adding timestamp column from the frame_id column
df = df.with_columns(
    pl.struct(["video_id", "frame_id"]).map_elements(lambda x: (x['video_id'] * 10) + (x['frame_id'] * (10 / 300))).alias("timestamp")
)

df
Out[1]:
shape: (2_068_649, 6)
video_fileframe_idboxscorevideo_idtimestamp
stri64list[f64]f64i64f64
"megalab_0000.mp4"0[49.625584, 308.454529, … 25.061401]0.44787500.0
"megalab_0000.mp4"1[111.323929, 160.203247, … 22.092972]0.44081200.033333
"megalab_0000.mp4"2[111.131042, 156.693939, … 23.211853]0.50572900.066667
"megalab_0000.mp4"2[45.196251, 304.073059, … 26.24408]0.44083500.066667
"megalab_0000.mp4"3[112.592758, 154.698364, … 22.159836]0.55995200.1
………………
"megalab_1588.mp4"297[333.339081, 598.729156, … 170.905487]0.722037158815889.9
"megalab_1588.mp4"298[330.434692, 592.485535, … 169.059097]0.750856158815889.933333
"megalab_1588.mp4"298[324.720764, 244.884262, … 93.155594]0.853287158815889.933333
"megalab_1588.mp4"299[328.688354, 585.071289, … 168.657837]0.626334158815889.966667
"megalab_1588.mp4"299[322.843872, 242.395752, … 97.742188]0.849616158815889.966667

So I thought - if I was a biologist studying fish in this area, what sort of analysis would I want to do on this data? Well I'd certainly like to know how many unique fish are in this footage, the density of the species, and perhaps notable behaviors like schooling or feeding.

To start to answer some of these questions we will have to extract some features from the fish detections and do some analysis. These features will help us cluster the fish into groups, and perhaps identify unique individuals.

To start, let's just peek at the data and see what we have been detecting so far.

InΒ [Β ]:
import cv2
import matplotlib.pyplot as plt

total = 0
for video_file, gdf in df.group_by("video_file"):
    video_file = video_file[0]

    capture = cv2.VideoCapture(f"../../../megalab_recordings/recordings/{video_file}")

    for frame_id, ggdf in gdf.group_by("frame_id"):
        capture.set(cv2.CAP_PROP_POS_FRAMES, frame_id[0])
        _, frame = capture.read()

        for row in ggdf.iter_rows(named=True):
            bbox = row["box"]

            x, y, w, h = bbox

            # we inferenced the frame at size 1024x1024
            # we need to resize the bounding box to the original frame size
            orig_h, orig_w, _ = frame.shape
            x = int(x * orig_w / 1024)
            y = int(y * orig_h / 1024)
            w = int(w * orig_w / 1024)
            h = int(h * orig_h / 1024)

            frame = cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)

        plt.figure(figsize=(10, 5))
        plt.title(f"Video: {video_file}, Frame: {frame_id[0]}, Detections: {len(ggdf)}")
        plt.imshow(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        plt.axis("off")
        plt.show()

        break

    capture.release()
    
    total += 1
    if total >= 3:
        break
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Alright this is interesting! We can make a few observations from this initial exploration:

  • The camera is panning. This means that tracking fish across frames accurately will likely be more difficult, especially with ByteTracker which relies on the kalman filter and a constant forward velocity model.
  • There is a wiper that periodically wipes the lens. This will certainly cause some missed detections and complications with tracking fish across frames.
  • Fish are being detected! But there are a lot of missing fish and also some false positives like algae on the lens. But otherwise, we can clearly see there is multiple species of fish being detected.

Before we extract features, we know that we have over 2 million fish detections in this dataset. Since my laptop is not particularly powerful, we need to apply some initial filtering to reduce the dataset size to something more manageable for exploration.

Let's explore some of the statistics of the fish detections themselves. We can look at the distribution of bounding box sizes, aspect ratios, and confidence scores.

InΒ [6]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")

df = df.with_columns([
    pl.col("box").map_elements(lambda box: box[2] * box[3], return_dtype=pl.Float64).alias("area"),
    pl.col("box").map_elements(lambda box: box[2], return_dtype=pl.Float64).alias("width"),
    pl.col("box").map_elements(lambda box: box[3], return_dtype=pl.Float64).alias("height"),
])

median_conf = df.select(pl.col("score").median()).to_series()[0]
std_conf = df.select(pl.col("score").std()).to_series()[0]
median_area = df.select(pl.col("area").median()).to_series()[0]
std_area = df.select(pl.col("area").std()).to_series()[0]
med_width = df.select(pl.col("width").median()).to_series()[0]
std_width = df.select(pl.col("width").std()).to_series()[0]
med_height = df.select(pl.col("height").median()).to_series()[0]
std_height = df.select(pl.col("height").std()).to_series()[0]

plt.figure(figsize=(20, 8))
plt.subplot(2, 2, 1)

plt.hist(df.select("score").to_series(), bins=50, color='blue', alpha=0.7)
plt.axvline(median_conf, color='red', linestyle='dashed', linewidth=1, label='Median')
plt.axvline(median_conf + std_conf, color='orange', linestyle='dashed', linewidth=1)
plt.axvline(median_conf - std_conf, color='orange', linestyle='dashed', linewidth=1)
plt.legend()
plt.title(f"Histogram of Confidence. Median: {median_conf:.2f}, Std: {std_conf:.2f}")
plt.xlabel("Confidence")
plt.ylabel("Frequency")

plt.subplot(2, 2, 2)
plt.hist(df.select("area").to_series(), bins=50, color='green', alpha=0.7)
plt.axvline(median_area, color='red', linestyle='dashed', linewidth=1, label='Median')
plt.axvline(median_area + std_area, color='orange', linestyle='dashed', linewidth=1)
plt.axvline(median_area - std_area, color='orange', linestyle='dashed', linewidth=1)
plt.legend()
plt.title(f"Histogram of Box Areas. Median: {median_area:.2f}, Std: {std_area:.2f}")
plt.xlabel("Area (Width * Height)")
plt.ylabel("Frequency")

plt.subplot(2, 2, 3)
plt.hist(df.select("width").to_series(), bins=50, color='purple', alpha=0.7)
plt.axvline(med_width, color='red', linestyle='dashed', linewidth=1, label='Median')
plt.axvline(med_width + std_width, color='orange', linestyle='dashed', linewidth=1)
plt.axvline(med_width - std_width, color='orange', linestyle='dashed', linewidth=1)
plt.legend()
plt.title(f"Histogram of Box Widths. Median: {med_width:.2f}, Std: {std_width:.2f}")
plt.xlabel("Width")
plt.ylabel("Frequency")

plt.subplot(2, 2, 4)
plt.hist(df.select("height").to_series(), bins=50, color='brown', alpha=0.7)
plt.axvline(med_height, color='red', linestyle='dashed', linewidth=1, label='Median')
plt.axvline(med_height + std_height, color='orange', linestyle='dashed', linewidth=1)
plt.axvline(med_height - std_height, color='orange', linestyle='dashed', linewidth=1)
plt.legend()
plt.title(f"Histogram of Box Heights. Median: {med_height:.2f}, Std: {std_height:.2f}")
plt.xlabel("Height")
plt.ylabel("Frequency")

plt.tight_layout()
plt.show()
No description has been provided for this image

Clearly we have some left skew in our data with extremely long tails. We can safely filter out some of the extreme outliers in bounding box size to focus on the more common fish sizes. But we can also take a second a think about some static filters we can apply to remove some of the detections. We can see that there are some very small bounding boxes, that although are probably valid fish detections, are likely too small to extract meaningful features from.

To start, let's try to be purely statistical about this. Drawing on knowledge of distributions, we know that the width and height of the bounding box distributions are both left-skewed, and will likely be more normalized if we take the log of them. Then, we can safetly compute the median and standard deviation of the log-widths and log-heights, and filter out any bounding boxes that are above and below a certain standard deviation threshold.

Using some domain knowledge, we can also filter out any bounding boxes that are too small. Given the resolution, we don't expect to be able to feasibly detect species or individuals of fish that are less than 32 pixels or so in width or height. So we can set that as a hard threshold. We could be more strict about this, but for now let's just use this as a starting point.

InΒ [13]:
df = df.with_columns([
    pl.col("width").log().alias("log_width"),
    pl.col("height").log().alias("log_height"),
])

median_log_width = df.select(pl.col("log_width").median()).to_series()[0]
std_log_width = df.select(pl.col("log_width").std()).to_series()[0]
median_log_height = df.select(pl.col("log_height").median()).to_series()[0]
std_log_height = df.select(pl.col("log_height").std()).to_series()[0]

plt.figure(figsize=(20, 15))
plt.subplot(3, 2, 1)
plt.hist(df.select("log_width").to_series(), bins=50, color='purple', alpha=0.7)
plt.axvline(median_log_width, color='red', linestyle='dashed', linewidth=1, label='Median')
plt.axvline(median_log_width + std_log_width, color='orange', linestyle='dashed', linewidth=1)
plt.axvline(median_log_width - std_log_width, color='orange', linestyle='dashed', linewidth=1)
plt.legend()
plt.title("Histogram of Log-Transformed Box Widths")
plt.xlabel("Log(Width)")
plt.ylabel("Frequency")

plt.subplot(3, 2, 2)
plt.hist(df.select("log_height").to_series(), bins=50, color='brown', alpha=0.7)
plt.axvline(median_log_height, color='red', linestyle='dashed', linewidth=1, label='Median')
plt.axvline(median_log_height + std_log_height, color='orange', linestyle='dashed', linewidth=1)
plt.axvline(median_log_height - std_log_height, color='orange', linestyle='dashed', linewidth=1)
plt.legend()
plt.title("Histogram of Log-Transformed Box Heights")
plt.xlabel("Log(Height)")
plt.ylabel("Frequency")

# filter out boxs with plus or minus std from median
fdf = df.filter(
    (pl.col("log_width") >= (median_log_width - std_log_width*2)) &
    (pl.col("log_width") <= (median_log_width + std_log_width*2)) &
    (pl.col("log_height") >= (median_log_height - std_log_height*2)) &
    (pl.col("log_height") <= (median_log_height + std_log_height*2)) &
    (pl.col("width") >= 32) &
    (pl.col("height") >= 32)
)

print(f"Filtered out {df.height - fdf.height} / {((df.height - fdf.height) / df.height) * 100:.2f}% outliers")
print(f"Remaining detections: {fdf.height}")

# plot filtered distributions
plt.subplot(3, 2, 3)
plt.hist(fdf.select("width").to_series(), bins=50, color='purple', alpha=0.7)
plt.title("Filtered Histogram of Box Widths")
plt.xlabel("Width")
plt.ylabel("Frequency")

plt.subplot(3, 2, 4)
plt.hist(fdf.select("height").to_series(), bins=50, color='brown', alpha=0.7)
plt.title("Filtered Histogram of Box Heights")
plt.xlabel("Height")
plt.ylabel("Frequency")

# plot confidence and area distributions after filtering
plt.subplot(3, 2, 5)
plt.hist(fdf.select("score").to_series(), bins=50, color='blue', alpha=0.7)
plt.title("Filtered Histogram of Confidence")
plt.xlabel("Confidence")
plt.ylabel("Frequency")

plt.subplot(3, 2, 6)
plt.hist(fdf.select("area").to_series(), bins=50, color='green', alpha=0.7)
plt.title("Filtered Histogram of Box Areas")
plt.xlabel("Area")
plt.ylabel("Frequency")

plt.tight_layout()
plt.show()
Filtered out 1077047 / 52.07% outliers
Remaining detections: 991602
No description has been provided for this image
InΒ [14]:
fdf
Out[14]:
shape: (991_602, 11)
video_fileframe_idboxscorevideo_idtimestampareawidthheightlog_widthlog_height
stri64list[f64]f64i64f64f64f64f64f64f64
"megalab_0000.mp4"11[170.581802, 162.269897, … 37.856995]0.62856100.3666671227.52645232.42535437.8569953.4789413.633816
"megalab_0000.mp4"12[170.368607, 163.154022, … 37.43277]0.60806300.41241.06858633.15460237.432773.5011823.622547
"megalab_0000.mp4"13[169.404541, 163.256226, … 37.132385]0.59987900.4333331235.31585933.26788337.1323853.5045923.61449
"megalab_0000.mp4"14[169.029236, 163.517548, … 36.884888]0.62138700.4666671243.67403533.71771236.8848883.5180233.607802
"megalab_0000.mp4"15[168.896088, 163.629883, … 36.281189]0.63115200.51216.05758633.51757836.2811893.512073.591299
……………………………
"megalab_1588.mp4"295[670.556519, 686.834961, … 61.450317]0.4827158815889.8333333838.91204562.47180261.4503174.1347154.118229
"megalab_1588.mp4"296[329.527618, 246.254089, … 97.992889]0.855988158815889.86666710126.249896103.33657897.9928894.6379914.584895
"megalab_1588.mp4"297[327.502502, 245.356262, … 94.24176]0.850686158815889.99715.351083103.08966194.241764.6355994.545863
"megalab_1588.mp4"298[324.720764, 244.884262, … 93.155594]0.853287158815889.9333339772.12073104.90106293.1555944.6530184.534271
"megalab_1588.mp4"299[322.843872, 242.395752, … 97.742188]0.849616158815889.96666710164.605843103.99404997.7421884.6443344.582333

Incredible! With those filters, an incredible amount of the outliers were removed. The distributions haven't changed from being less left skewed, but their tails are certainly smaller.

Looking at the actual distributions, a huge amount of detections are less than 64 pixels in size. Looking at the actual images, we see these small black schooling fish in the background which might be the culprit. This might be interesting when we fun tracking over these detections, but probably not very useful for static image feature extraction. We will live with it and try to see what happens regardless.

Still, however, we have some 1 million fish detections even after filtering. This is STILL a lot of data to process. To reduce this down even further, we will run some stratified sampling on the data. We will stratify based on the video name to try and perserve detections acrossed the duration of the footage.

InΒ [16]:
fdf =fdf.group_by("video_file", maintain_order=True).map_groups(
    lambda group: group.sample(fraction=0.1, seed=42)
)

fdf.write_parquet("megalab_filtered_detections.parquet")

fdf
Out[16]:
shape: (98_454, 11)
video_fileframe_idboxscorevideo_idtimestampareawidthheightlog_widthlog_height
stri64list[f64]f64i64f64f64f64f64f64f64
"megalab_0000.mp4"150[3.175149, 90.095612, … 48.685226]0.79992605.02462.76734650.58551748.6852263.9236653.885376
"megalab_0000.mp4"205[80.909744, 77.6297, … 50.73703]0.78725406.8333332819.44861455.56983950.737034.0176413.926656
"megalab_0000.mp4"180[235.98291, 355.891418, … 89.744873]0.86281806.08421.98447593.84362889.7448734.541634.496971
"megalab_0000.mp4"172[202.524048, 375.623779, … 117.721046]0.88246605.73333311551.90219798.129456117.7210464.5862884.768318
"megalab_0000.mp4"144[91.562485, 49.712978, … 38.932774]0.6179904.81418.69254336.43954538.9327743.5956553.661836
……………………………
"megalab_1588.mp4"177[772.430511, 826.083618, … 54.022339]0.507881158815885.93987.64258973.81469754.0223394.3015583.989398
"megalab_1588.mp4"176[770.521362, 823.45816, … 54.20639]0.414012158815885.8666673835.25761970.75286954.206394.2591933.992799
"megalab_1588.mp4"64[416.594849, 797.126343, … 51.356079]0.79985158815882.1333333023.83521758.87979151.3560794.0754983.938783
"megalab_1588.mp4"166[776.010132, 614.337189, … 66.82724]0.842968158815885.5333336107.93794891.39892666.827244.5152344.202111
"megalab_1588.mp4"174[830.628113, 615.067276, … 69.218124]0.85362158815885.86620.98567195.65393169.2181244.5607374.237263

100,000 is a lot more manageable. Now we can start thinking about feature extraction!

We can start with something simple like a pre-trained IMAGENET MobileNetV3 model, and remove the final classification layer. This model will be fast and efficient, and should be able to extract some useful features from the fish images. What would have actually been interesting is to utilize the features from the object detection model itself. However, since it took me a week to process the footage, we will experiment with other networks for now.

A pretrained CNN is a good starting point, but there are problems we could run into that CNN's are not great at for this problem.

  • Since these are bounding box crops, the background is still present and the CNN could learn to associate background features instead of fish features.
  • An IMAGENET pretrained model was trained specifically on IMAGENET classes, which do not include the data we can trying to extract features for. If the goal is too extract features where a distance metric (like euclidean or cosine similarity) from one feature descriptor to another is closer if they are the same type of fish, then this model might not be the best choice.
  • CNN's do not understand global context or capture relationships between different parts of the image. They simply apply different convolutions to the input image and pass that through a series of layers to produce a final output. Fish have specific shapes and colors that might be better captured with a different architecture like a Vision Transformer (ViT) or a model specifically trained on fish images.
  • Single image feature descriptors may not actually be enough to uniquely cluster fish species. Again, peering into our dataset, the same species of fish appear in wildly different poses, lighting conditions, occlusions, distances from the camera, and backgrounds. It may be beneficial to extract temporal features from a sequence of frames using an object tracker to get a better idea of the fish's movement and behavior. While we have enough data to accomplish this, that may be a project for another day!

To enhance feature extraction, we could do several things:

  • Center crop the fish bounding boxes to remove as much background as possible.
  • Use a model specifically trained on fish images, or fine-tune a pretrained model on a fish dataset.
  • Experiment with different architectures like Vision Transformers or models that capture global context better.
  • Add background removal or segmentation to isolate the fish from the background.

We will make a single pass to save time and extract features using MobileNetV3 Large, VIT, and XCIT models all from the timm library to compare against later. I've compiled a module to handle the feature extraction in parallel using cuda streams.

InΒ [18]:
import polars as pl

fdf = pl.read_parquet("megalab_filtered_detections.parquet")
InΒ [Β ]:
import megalab_clustering

table_name = megalab_clustering.extract_features(fdf)

! ls -lh {table_name}
2026-01-19 07:31:53.935716: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-19 07:31:54.204261: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-01-19 07:31:55.373481: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 98454/98454 [2:14:41<00:00, 12.18it/s]  
total 674M
-rw-rw-r-- 1 jack jack  69M Jan 19 07:44 0-f4b09add-e608-4fe3-992e-c8916c89c112-0.parquet
-rw-rw-r-- 1 jack jack  69M Jan 19 07:58 1-e2fcee58-3ffc-47b5-94f2-38f48da98a09-0.parquet
-rw-rw-r-- 1 jack jack  69M Jan 19 08:12 2-f60eab9a-c21d-4254-970b-6e8d9342fd67-0.parquet
-rw-rw-r-- 1 jack jack  69M Jan 19 08:26 3-162e7b4b-2104-4ba7-a3d2-97809bf40bee-0.parquet
-rw-rw-r-- 1 jack jack  69M Jan 19 08:40 4-81f59260-9fea-49d2-8331-006609c28714-0.parquet
-rw-rw-r-- 1 jack jack  69M Jan 19 08:54 5-1d922747-0c50-4c92-8455-4171a3d764b5-0.parquet
-rw-rw-r-- 1 jack jack  69M Jan 19 09:08 6-01256b21-2ea0-4eb7-b6d8-55447cc5e1c6-0.parquet
-rw-rw-r-- 1 jack jack  69M Jan 19 09:21 7-5c1c68ff-0e2d-44f2-9af1-57602ad38de8-0.parquet
-rw-rw-r-- 1 jack jack  69M Jan 19 09:35 8-5ac2f5a1-2408-4513-83b5-f082c9e01919-0.parquet
-rw-rw-r-- 1 jack jack  56M Jan 19 09:46 9-c93d8a24-b37b-4c9c-84cc-b5c456520755-0.parquet
drwxrwxr-x 2 jack jack 4.0K Jan 19 09:46 _delta_log

That is quite a lot of data already! And this is just a small sample of the entire dataset that I eventually will process. As we can see, when we process the entire dataset, we will have move to utilize distributed computing, like Ray, Dask, or Spark, but also distributed storage like Delta Lake or some numerous Parquet files. Especially if we utilize a Vision Transformer!

But for now, we can work with this sample set and see what kind of insights we can get from it.

The first thing we should do is see how compressible these features are. If we can reduce the dimensionality of the features while retaining most of the information, that would be great for visualization and clustering. Let's try PCA for this.

InΒ [19]:
rdf = pl.read_delta("megalab_features")
rdf
Out[19]:
shape: (98_454, 5)
video_fileframe_idxcit_featurevit_featuremnetv3_feature
stri64list[f64]list[f64]list[f64]
"megalab_0873.mp4"254[-1.370655, 0.251851, … -1.125934][-1.973908, -6.005847, … -3.995991][2.291016, -0.32666, … -0.324219]
"megalab_0873.mp4"177[0.916819, -2.164785, … -0.959855][-2.870665, -6.252252, … -0.819254][1.082031, -0.297852, … -0.297119]
"megalab_0873.mp4"96[-0.771788, -0.493543, … -0.633564][-0.463634, -6.275813, … -4.233441][2.027344, -0.27124, … -0.262207]
"megalab_0873.mp4"79[-0.446059, -2.936343, … -1.072671][-1.530079, -5.404996, … -0.751862][1.464844, -0.219238, … -0.270264]
"megalab_0873.mp4"4[-0.145698, -2.020963, … -1.313443][-3.347079, -6.476562, … 2.167943][1.782227, -0.375, … -0.356445]
……………
"megalab_1294.mp4"92[0.137672, 0.819715, … -1.240862][-1.632193, -4.162421, … 1.525466][0.706055, -0.323242, … 0.129028]
"megalab_1294.mp4"170[0.196154, -1.0209, … 0.481639][-1.646796, -2.136511, … 0.242005][1.306641, -0.005623, … -0.134155]
"megalab_1294.mp4"177[-0.425422, 0.356714, … -1.511721][-1.420421, -6.523434, … 2.322006][1.390625, -0.37085, … 0.140137]
"megalab_1294.mp4"42[-0.126437, -1.722595, … -0.599533][-1.770346, -2.047084, … -3.004429][3.3125, -0.375, … -0.259277]
"megalab_1294.mp4"278[0.436486, -1.1367, … 0.239529][-2.913258, -1.103637, … 2.995327][1.012695, 0.250244, … 0.093201]

We first have to make sure that the features are scaled properly, as common dimensionality reduction techniques like PCA are sensitive to the scale of the features. We can use StandardScaler from sklearn to standardize the features before applying PCA.

InΒ [20]:
from sklearn.preprocessing import StandardScaler
import numpy as np

for feature_col in ["mnetv3_feature", "vit_feature", "xcit_feature"]:
    features = np.vstack(rdf.select(feature_col).to_series().to_list())

    scaler = StandardScaler()
    features = scaler.fit_transform(features)

    rdf = rdf.with_columns([
        pl.Series(feature_col, list(features))
    ])
InΒ [21]:
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")

for feature_col in ["mnetv3_feature", "vit_feature", "xcit_feature"]:
    features = np.vstack(rdf.select(feature_col).to_series().to_list())

    pca = PCA(n_components=0.95)
    pca_features = pca.fit_transform(features)

    rdf = rdf.with_columns([
        pl.Series(f"pca_{feature_col}", list(pca_features))
    ])

    plt.figure(figsize=(5, 3))
    sns.lineplot(
        x=np.arange(1, len(pca.explained_variance_ratio_)+1),
        y=np.cumsum(pca.explained_variance_ratio_)
    )
    plt.title(f"PCA for {feature_col} | 95% variance | {pca.n_components_} components")
    plt.xlabel("Number of Components")
    plt.ylabel("Cumulative Explained Variance")
    plt.show()

    percent_reduction = (1 - pca.n_components_ / features.shape[1]) * 100
    print(f"Feature: {feature_col}")
    print(f"Original feature dimension: {features.shape[1]}")
    print(f"Reduced feature dimension: {pca.n_components_}")
    print(f"Percent reduction in feature dimension: {percent_reduction:.2f}%")
    print()
No description has been provided for this image
Feature: mnetv3_feature
Original feature dimension: 1280
Reduced feature dimension: 613
Percent reduction in feature dimension: 52.11%

No description has been provided for this image
Feature: vit_feature
Original feature dimension: 384
Reduced feature dimension: 247
Percent reduction in feature dimension: 35.68%

No description has been provided for this image
Feature: xcit_feature
Original feature dimension: 384
Reduced feature dimension: 220
Percent reduction in feature dimension: 42.71%

Wow! That's a significant reduction in feature space. For all the different model flavors it supports evidence that there might exist a simplier feature extraction model that can capture the essence of the fish detections.

Now that we have a suggested dimensionality, we can try some clustering algorithms on this data. KMeans is a good starting point, as it is very popular and efficient. We can try the elbow method to determine the optimal number of clusters. We will also calculate the silhouette score to evaluate the quality of the clusters. But, the silhouette score can be computationally expensive for large datasets, so we will, yet again, sample down to 10,000 points for this calculation.

InΒ [23]:
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

def find_optimal_k(data, max_k):
    intertias = []
    scores = []
    for k in range(2, max_k + 1):
        kmeans = KMeans(n_clusters=k, random_state=42)
        kmeans = kmeans.fit(data)
        intertias.append(kmeans.inertia_)
        scores.append(silhouette_score(data, kmeans.labefls_))
    return intertias, scores

for feature_col in ["pca_mnetv3_feature", "pca_vit_feature", "pca_xcit_feature"]:

    pca_features = np.vstack(rdf.select(feature_col).sample(10_000).to_series().to_list())
    intertias, scores = find_optimal_k(pca_features, max_k=40)

    plt.figure(figsize=(10, 5))
    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()
    ax1.plot(range(2, 41), intertias, 'b-', label='Inertia')
    ax2.plot(range(2, 41), scores, 'r-', label='Silhouette Score')
    ax1.set_xlabel('Number of Clusters (k)')
    ax1.set_ylabel('Inertia')
    ax2.set_ylabel('Silhouette Score')

    plt.title(f"KMeans Elbow Method on '{feature_col}' Features")
    plt.tight_layout()
    plt.show()
<Figure size 1000x500 with 0 Axes>
No description has been provided for this image
<Figure size 1000x500 with 0 Axes>
No description has been provided for this image
<Figure size 1000x500 with 0 Axes>
No description has been provided for this image

Hmmph. Interesting results. We don't necessarily see a clear elbow in the plots, and the silhouette scores are also quite low. The silhouette scores being low indicates that the clusters are not well defined and there is a lot of overlap between them. From this clustering, it's difficult to say what might be a good number of clusters to choose.

This is one of the challenges with KMeans as determining the optimal number of clusters is not straightforward and often requires domain knowledge or additional metrics. But this is also an early indication that the feature descriptors may not be distinct enough to separate the fish into clear clusters. We could try another clustering algorithm, like HDBSCAN, to see if we can get better results without having to specify the number of clusters. But in practice, I have found that if KMeans is struggling to find clear clusters, other algorithms often struggle as well.

What we can do though is try to visualize the clusters and their distributions in the feature space using a pairplot. We can also utilize a different dimensionality reduction technique like UMAP to visualize the clusters in 2D or 3D space. UMAP actually preserves non-linear relationships in the data so it might give us a better idea of how well separated the clusters are.

InΒ [28]:
from umap import UMAP

for feature_col in ["mnetv3_feature", "vit_feature", "xcit_feature"]:
    features = np.vstack(rdf.select(feature_col).to_series().to_list())

    umap = UMAP(n_neighbors=15, n_components=4, metric='euclidean')
    umap_features = umap.fit_transform(features)

    rdf = rdf.with_columns([
        pl.Series(f"umap_{feature_col}", umap_features.tolist())
    ])

We could have used only 2 components. But by utilizing more, we retain more information and can visualize more dimensionality and separation of clusters.

To visualize a pairplot, we also want to assign cluster labels to each point. We will utilize KMeans with 4 clusters, since we saw some slight elbowing, and higher silhouette scores around that region.

InΒ [23]:
for feature_col in ["pca_mnetv3_feature", "pca_vit_feature", "pca_xcit_feature"]:
    features = np.vstack(rdf.select(feature_col).to_series().to_list())

    kmeans = KMeans(n_clusters=4, random_state=42)
    cluster_labels = kmeans.fit_predict(features)
    
    rdf = rdf.with_columns([
        pl.Series(f"kmeans_{feature_col}", cluster_labels.tolist())
    ])
InΒ [30]:
import pandas as pd

for feature_col in ["mnetv3_feature", "vit_feature", "xcit_feature"]:

    plot_samples = rdf\
        .select([f'umap_{feature_col}', f"kmeans_pca_{feature_col}"])\
        .to_pandas()

    plot_samples = plot_samples\
        .join(plot_samples[f'umap_{feature_col}'].apply(pd.Series).add_prefix(f"{feature_col}_"))\
        .drop(columns=[f'umap_{feature_col}'])

    sns.pairplot(
        plot_samples,
        hue=f"kmeans_pca_{feature_col}",
        palette="viridis",
        kind="hist",
        diag_kind="kde",
    )
    plt.suptitle(f"UMAP Components Pairplot of {feature_col} Features", y=1.02)
    plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Well this looks quite different then our elbow and silhouette analysis suggested. Visually, we can see a lot more seperation between clusters in the pairplots. The distributions on the diagonal also show more separation between clusters.

There is clear separation between some clusters here. But for all clusters, we don't see that. Based on all the analysis so far:

  • We didn't find a clear elbow in the KMeans inertia plot.
  • The silhouette score is very low, regardless if there is a peak or not. It should be closer to 1 to indicate high separation between clusters.
  • The pairplots of the UMAP features show significant overlap between clusters.

And finally, the most important of all, 4 clusters is clearly not enough to capture the diversity of fish species that are truly present in this dataset.

Before we conclude this, we have yet to even visualize what the clusters actually look like! From this, we can get a better idea of how Kmeans is grouping the detections, and what kind of features the models are actually extracting.

InΒ [Β ]:
import cv2

def get_fish_crop(video_file, frame_id, box):
    capture = cv2.VideoCapture(f"../../../megalab_recordings/recordings/{video_file}")
    capture.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
    _, frame = capture.read()
    capture.release()
    x, y, w, h = box
    orig_h, orig_w, _ = frame.shape
    x = int(x * orig_w / 1024)
    y = int(y * orig_h / 1024)
    w = int(w * orig_w / 1024)
    h = int(h * orig_h / 1024)
    crop = frame[y:y+h, x:x+w]
    crop = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)

    # resize to max dimension 128
    max_dim = max(crop.shape[0], crop.shape[1])
    scale = 128 / max_dim
    new_w = int(crop.shape[1] * scale)
    new_h = int(crop.shape[0] * scale)
    crop = cv2.resize(crop, (new_w, new_h))

    return crop

for cluster_id in range(4):

    rows = 5
    cols = 10
    image = np.zeros((rows * 128, cols * 128, 3), dtype=np.uint8)

    cluster_df = rdf.filter(pl.col("kmeans_pca_vit_feature") == cluster_id)
    sample_rows = cluster_df.sample(n=min(rows*cols, cluster_df.height)).to_dicts()

    for i, row in enumerate(sample_rows):
        video_file = row["video_file"]
        frame_id = row["frame_id"]
        box = fdf.filter(
            (pl.col("video_file") == video_file) & (pl.col("frame_id") == frame_id)
        )\
        .select("box")\
        .to_series()\
        .to_list()[0]
        
        crop = get_fish_crop(video_file, frame_id, box)

        image[
            (i // cols) * 128 : (i // cols) * 128 + crop.shape[0],
            (i % cols) * 128 : (i % cols) * 128 + crop.shape[1],
        ] = crop
    
    plt.figure(figsize=(20, 10))
    plt.imshow(image)
    plt.axis("off")
    plt.title(f"Cluster {cluster_id} Sample Fish Crops")
    plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Awha! Look at that! Cluster 0 seems to have been picking up on a false positive: some strange algae growth on the lens. While there are some misclassifications, it's actually rather consistent.

Cluster 1 is very clear. Of course, there are some misclassifications, but a majority seem to be identifying a Yellow Tang fish. This is awesome! Looking closer, we can also see that a lot of the backgrounds are of the coral reef. This might also be a sign that the model is picking up on background features instead of fish features.

Cluster 2 is just a mess. From this, I can see many different species. There is no clear pattern here.

Finally, Cluster 3 looks like it is picking up more on background features again. Nearly all the images seem to be of the blue ocean background. Otherwise, there are several different species of fish present.

So it seems clear that our feature extraction and clustering pipeline did serve some purpose. But it didn't quite cluster on the type of features we had hoped for.

Conclusion and Next StepsΒΆ

I spent a long time on this notebook, nearly one month in my off time from work, wrangling this data together and trying to extract meaningful features. While I didn't achieve my goal, I'm far from complete. There are several steps I will be taking in the future to improve this analysis:

  • Experiment with different feature extraction models, especially Vision Transformers or models specifically trained on fish images. The model I am using to gather the detections, the Community Fish Detector, is already trained on a massive dataset of features. Will those features be useful here? Or will there cosine similarity be too similar (as this model is a binary classifier after all)?

  • Experiment with different clustering algorithms like HDBSCAN or DBSCAN that do not require specifying the number of clusters upfront. These algorithms can also handle noise and outliers better than KMeans, but, they are quite slow on large datasets. There are also other deep learning methods like DeepDPM that automatically learn feature representations and cluster assignments simultaneously.

  • Experiment with temporal feature extraction using an object tracker to gather sequences of fish images. This could help capture more information about the fish's movement and behavior, potentially even enhancing resolution of the fish images by using super resolution techniques that work across multiple frames.

  • Perform background removal or automatic segmentation of the detections to retrieve only the fish skeleton. All my detections include background, which appears to be causing the clustering to group based on background features instead of fish features. Can I solve this with some image processing techniques? Or do I need to train a segmentation model?

  • Perform few-shot learning to do semi-supervised classification. Unfortunately, this isn't totally unsupervised and would require some labeled data, but could help improve the clustering results significantly. Fortunately, I know a few marine biologists who might be willing to help label some data!

So many avenues to explore, so little time to do it. But! I hope you enjoyed this little exploration into the underwater world.