5  Classifying wings

In this quick lesson, we try to classify wings using the tools seen in the last lessons.

using MetricSpaces
using Images
using DataFrames
using GLMakie
using Ripserer, PersistenceDiagrams
import Plots
using ProgressMeter

using Clustering
import StatsPlots

using MultivariateStats
using Chain
using ImageFiltering

We prepare a dataframe with the files and classes of each image

ds = DataFrame();

for (root, dir, files) in walkdir("wings/")

    for file in files
        dc = Dict(:Classe => root |> basename, :Caminho => file, :Caminho_completo => joinpath(root, file))
        
        push!(ds, dc, cols = :union)
    end
end

ds;
ds_split = groupby(ds, :Classe) |> collect;
function plot_mosaic(s)
    mosaicview(
    [imresize(load(f), (150, 300)) for f  s.Caminho_completo[1:min(end, 21)]]
    , ncol = 3
    ,fillvalue = RGB24(1)
    )
end;

f = ds.Caminho_completo[1]
"wings/Asilidae/Asilidae 11.png"

5.1 The dataset

The dataset consists of several images of 3 different species of insects:

5.1.1 Asilidae

plot_mosaic(ds_split[1])

5.1.2 Ceratopogonidae

plot_mosaic(ds_split[2])

5.1.3 Tipulidae

plot_mosaic(ds_split[3])

We load all images as matrices

images = [load(img) .|> Gray |> channelview for img  ds.Caminho_completo];

We can see that the image is indeed correct:

images[1] |> image

5.2 Matrix to \(\mathbb{R}^2\)

As before, we need to transform each image in points of the plane.

function img_to_points(img)
    img2 = imfilter(img, Kernel.gaussian(1)) .|> float
    ids = findall(x -> x <= 0.8, img2)
    pts = getindex.(ids, [1 2])

    [ [ p[1], p[2] ] for p in eachrow(pts)] |> EuclideanSpace
end;

We convert each image to points

pts = img_to_points.(images);

and normalize the coordinates, since each image has a different size:

function normalize!(pts)
    a, b = extrema(pts .|> last)

    pts ./ (b - a)
end

wings = normalize!.(pts);

We can plot a scatter to check that it is indeed ok:

scatter(wings[1])

In order to apply the Vietoris-Rips filtration, we need to reduce the amount of points in each wing. The farthest point sample come in our rescue again!

wings_short = @showprogress map(wings) do w
    ids = farthest_points_sample(w, 400)
    w[ids]
end;

Now we calculate each barcode using the Vietoris-Rips filtration:

pds = @showprogress map(wings_short) do w
    ripserer(w, cutoff = 0.008)
end
24-element Vector{Vector{PersistenceDiagram}}:
 [400-element 0-dimensional PersistenceDiagram, 38-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 28-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 39-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 30-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 37-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 35-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 30-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 29-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 10-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 11-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 14-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 18-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 23-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 14-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 25-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 19-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 18-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 19-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 25-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 27-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 22-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 22-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 26-element 1-dimensional PersistenceDiagram]
 [400-element 0-dimensional PersistenceDiagram, 23-element 1-dimensional PersistenceDiagram]

We can now see the metric space

scatter(wings_short[1])

and the corresponding 1-dimensional persistente diagram

Plots.plot(pds[1][2])

Now we calculate the pairwise 1-dimensional bottleneck distance between each wing:

function barcode_to_distance(pds)
    n = length(pds)
    DB = zeros(n, n)

    @showprogress for i  1:n
        for j  i:n
            if i == j
                DB[i, j] = 0 
                continue 
            end

            DB[i, j] = Bottleneck()(pds[i][2], pds[j][2])
            DB[j, i] = DB[i, j]
        end
    end

    DB
end
barcode_to_distance (generic function with 1 method)
DB = barcode_to_distance(pds)
24×24 Matrix{Float64}:
 0.0         0.0158869  0.0135061  …  0.0462039   0.0469204   0.0283849
 0.0158869   0.0        0.025868      0.0585658   0.0592823   0.0407468
 0.0135061   0.025868   0.0           0.0326978   0.0334143   0.020432
 0.00953908  0.0139792  0.0218124     0.0545102   0.0552267   0.0366912
 0.0235501   0.035912   0.0142476     0.0226538   0.0233703   0.0116482
 0.0230513   0.0182236  0.021989   …  0.0481824   0.0488989   0.0337146
 0.0487406   0.0611025  0.0352345     0.0121581   0.0141773   0.0203557
 0.014005    0.0263669  0.0123086     0.0321989   0.0329154   0.0143799
 0.0245204   0.0388475  0.0299852     0.0494071   0.0501236   0.0338243
 0.0341926   0.0285978  0.032716      0.0563412   0.05317     0.0411305
 0.0189494   0.0260279  0.0276124  …  0.0603102   0.0610267   0.0424912
 0.0276814   0.0249402  0.0411875     0.0738853   0.0746018   0.0560664
 0.0305476   0.0283709  0.0314908     0.0641885   0.064905    0.0463696
 0.02062     0.030228   0.023998      0.0552553   0.0559719   0.0374364
 0.0250578   0.0200642  0.0235812     0.0472065   0.0440353   0.0343454
 0.0683335   0.0559717  0.0818397  …  0.114537    0.115254    0.0967185
 0.0334725   0.0458344  0.0204434     0.0127314   0.0134479   0.0178701
 0.0211525   0.0335143  0.0116835     0.0284245   0.0312107   0.0190844
 0.0288523   0.0412142  0.0153462     0.0173516   0.0180681   0.0145953
 0.0518992   0.0642611  0.0383931     0.0139367   0.0114259   0.0235143
 0.0395003   0.0518622  0.0259942  …  0.0153055   0.0159943   0.0122653
 0.0462039   0.0585658  0.0326978     0.0         0.00844394  0.0213066
 0.0469204   0.0592823  0.0334143     0.00844394  0.0         0.0185355
 0.0283849   0.0407468  0.020432      0.0213066   0.0185355   0.0

and see if the classes are well separated:

function mds_plot(D)
    M = fit(MDS, D; distances = true, maxoutdim = 2)
    Y = predict(M)

    ds.Row = 1:nrow(ds)

    dfs = @chain ds begin
        groupby(:Classe)
        collect
    end

    fig = Figure();
    ax = Makie.Axis(fig[1,1])

    colors = cgrad(:tableau_10, 8, categorical = true)

    for (i, df)  enumerate(dfs)    
        scatter!(
            ax, Y[:, df.Row]
            , label = df.Classe[1], markersize = 15
            , color = colors[i]
            )
    end

    axislegend();
    fig

    fig
end;
mds_plot(DB)

5.3 Slicing it sideways

As we did with the hand-written digits dataset, we can do some sideways slicing on the wings.

set_value(x, value) = x < 0.99 ? value : x

function side_filtration(img, axis = 1, invert = false)

    img2 = imresize(img, (100, 200))
    m = imfilter(img2, Kernel.gaussian(0.4))
    # m = img .|> float
    m = set_value.(m, 0)
    # m |> image
    # m = img .|> float

    pts = img_to_points(m)

    a, b = if axis == 1 
        extrema(pts .|> first)
        else
        extrema(pts .|> last)
    end

    for i  a:b

        v = (b - i) / (b - a)

        if invert == true
            v = 1.0 - v
        end

        if axis == 1
            m[i, :] = set_value.(m[i, :], v)
        else 
            m[:, i] = set_value.(m[:, i], v)
        end

    end

    m .|> float
end;

We can visualize the filtrations as follows:

img = images[5]
img2 = side_filtration(img, 1)
heatmap(img2)

img2 = side_filtration(img, 2)
heatmap(img2)

img2 = side_filtration(img, 1, true)
heatmap(img2)

img2 = side_filtration(img, 2, true)
heatmap(img2)

And calculate each barcode:

pds_x = @showprogress map(images) do img
    img2 = side_filtration(img)
    bc = ripserer(Cubical(img2), cutoff = 0.1)
end

pds_y = @showprogress map(images) do img
    img2 = side_filtration(img, 2)
    ripserer(Cubical(img2), cutoff = 0.1)
end

pds_x2 = @showprogress map(images) do img
    img2 = side_filtration(img, 1, true)
    ripserer(Cubical(img2), cutoff = 0.1)
end

pds_y2 = @showprogress map(images) do img
    img2 = side_filtration(img, 2, true)
    ripserer(Cubical(img2), cutoff = 0.1)
end
24-element Vector{Vector{PersistenceDiagram}}:
 [6-element 0-dimensional PersistenceDiagram, 17-element 1-dimensional PersistenceDiagram]
 [3-element 0-dimensional PersistenceDiagram, 21-element 1-dimensional PersistenceDiagram]
 [3-element 0-dimensional PersistenceDiagram, 13-element 1-dimensional PersistenceDiagram]
 [2-element 0-dimensional PersistenceDiagram, 20-element 1-dimensional PersistenceDiagram]
 [2-element 0-dimensional PersistenceDiagram, 16-element 1-dimensional PersistenceDiagram]
 [4-element 0-dimensional PersistenceDiagram, 18-element 1-dimensional PersistenceDiagram]
 [3-element 0-dimensional PersistenceDiagram, 21-element 1-dimensional PersistenceDiagram]
 [5-element 0-dimensional PersistenceDiagram, 21-element 1-dimensional PersistenceDiagram]
 [10-element 0-dimensional PersistenceDiagram, 40-element 1-dimensional PersistenceDiagram]
 [4-element 0-dimensional PersistenceDiagram, 17-element 1-dimensional PersistenceDiagram]
 [21-element 0-dimensional PersistenceDiagram, 18-element 1-dimensional PersistenceDiagram]
 [7-element 0-dimensional PersistenceDiagram, 15-element 1-dimensional PersistenceDiagram]
 [22-element 0-dimensional PersistenceDiagram, 21-element 1-dimensional PersistenceDiagram]
 [39-element 0-dimensional PersistenceDiagram, 12-element 1-dimensional PersistenceDiagram]
 [8-element 0-dimensional PersistenceDiagram, 15-element 1-dimensional PersistenceDiagram]
 [19-element 0-dimensional PersistenceDiagram, 13-element 1-dimensional PersistenceDiagram]
 [26-element 0-dimensional PersistenceDiagram, 10-element 1-dimensional PersistenceDiagram]
 [36-element 0-dimensional PersistenceDiagram, 11-element 1-dimensional PersistenceDiagram]
 [35-element 0-dimensional PersistenceDiagram, 6-element 1-dimensional PersistenceDiagram]
 [23-element 0-dimensional PersistenceDiagram, 18-element 1-dimensional PersistenceDiagram]
 [38-element 0-dimensional PersistenceDiagram, 14-element 1-dimensional PersistenceDiagram]
 [29-element 0-dimensional PersistenceDiagram, 9-element 1-dimensional PersistenceDiagram]
 [26-element 0-dimensional PersistenceDiagram, 5-element 1-dimensional PersistenceDiagram]
 [17-element 0-dimensional PersistenceDiagram, 16-element 1-dimensional PersistenceDiagram]
barcode(pds_x[5])
barcode(pds_y[5])
barcode(pds_x2[5])
barcode(pds_y2[5])

Let’s see way some figures have so many generators in dimension 1:

img = images[9]
img2 = side_filtration(img, 2, true)
heatmap(img2)

barcode(pds_y2[9])

The respective distance matrices are obtained with

DB_x = barcode_to_distance(pds_x)
DB_y = barcode_to_distance(pds_y)
DB_x2 = barcode_to_distance(pds_x2)
DB_y2 = barcode_to_distance(pds_y2)
24×24 Matrix{Float64}:
 0.0       0.325713  0.298343  0.164898  …  0.502762  0.612723  0.289092
 0.325713  0.0       0.380486  0.30512      0.599516  0.850059  0.316092
 0.298343  0.380486  0.0       0.227273     0.472488  0.635195  0.289641
 0.164898  0.30512   0.227273  0.0          0.465906  0.58974   0.278277
 0.269424  0.324795  0.28      0.248831     0.523943  0.791427  0.282226
 0.236258  0.381064  0.149272  0.133401  …  0.490336  0.653043  0.376954
 0.337017  0.247476  0.491706  0.293634     0.552486  0.813468  0.364641
 0.357143  0.214901  0.541667  0.326569     0.592159  0.773393  0.357143
 0.630172  0.580911  0.675477  0.602126     0.734369  0.768504  0.640938
 0.216896  0.497411  0.243819  0.196304     0.441745  0.735056  0.460388
 0.30455   0.252874  0.535287  0.324666  …  0.612982  0.82798   0.340557
 0.18034   0.333329  0.398941  0.252595     0.60662   0.704994  0.186046
 0.386805  0.312664  0.62529   0.451256     0.800758  0.85196   0.436765
 0.337017  0.482759  0.102273  0.232538     0.505535  0.610378  0.313947
 0.243924  0.333329  0.468204  0.269936     0.699314  0.762496  0.188937
 0.310581  0.362069  0.340765  0.289988  …  0.594937  0.835443  0.284884
 0.414365  0.573856  0.448857  0.363636     0.463877  0.645959  0.589754
 0.370159  0.482759  0.357955  0.329545     0.756637  0.873414  0.456036
 0.558401  0.835113  0.573447  0.545455     0.484143  0.65033   0.79324
 0.387342  0.317418  0.679487  0.491404     0.772375  0.820506  0.410405
 0.318275  0.344828  0.25791   0.297683  …  0.633987  0.841502  0.238372
 0.502762  0.599516  0.472488  0.465906     0.0       0.561036  0.744185
 0.612723  0.850059  0.635195  0.58974      0.561036  0.0       0.804219
 0.289092  0.316092  0.289641  0.278277     0.744185  0.804219  0.0

And we can see that none of the tools we used before can separate well the classes:

mds_plot(DB)

mds_plot(DB_x)

mds_plot(DB_y)

mds_plot(DB_x2)

mds_plot(DB_y2)

Even if we sum all these distances, we still can’t cluster correctly any class:

DB_final = zero(DB)

for d in [DB, DB_x, DB_y, DB_x2, DB_y2]
    DB_final = DB_final + (d ./ maximum(d))
end
mds_plot(DB_final)