Create surface-based mask using KNN and geodesic distance around MNI coordinates

Hi! Inspired by the insightful discussion from this previous post, I’ve adapted a script to generate a binary mask within nilearn’s GM mask employing scikit-learn’s k-nearest neighbors (KNN). In the KNN algorithm, the cluster_size is defined around a specified number of seeds represented by MNI coordinates. The script snippets are provided below.

Now, I’m seeking assistance regarding this challenge: How can I create a binary mask, ensuring that KNN creates clusters around each seed by at least a centimeter in geodesic space? Since geodesic distance is required, I assume the mask should be created in surface space and I am already struggling in mapping MNI coordinates onto the surface space.

Any insights and guidance are highly appreciated!

# Binary mask delimited to GM created via scikit-learn nearest-neighbors
################################################################################

## Function: Get in-mask coordinates ###########################################
def _get_mask_coords(mask_img):
    '''Get in-mask voxel coordinates'''
    
    mask, affine = nl.masking._load_mask_img(mask_img)
    
    # get data indices for all '1' voxels inside mask data array
    mask_data_indices = np.asarray(np.nonzero(mask)).T.tolist()
    print('Mask data indices :', mask_data_indices)
    
    # return coordinates for those '1' voxels
    mask_coords = np.asarray(list(zip(*mask_data_indices)))
    
    mask_coords = nl.image.coord_transform(mask_coords[0],
                                     mask_coords[1],
                                     mask_coords[2],
                                     affine)
    
    mask_coords = np.asarray(mask_coords).T
        
    return mask_coords

mask_coords = _get_mask_coords(gm_mask)

## Loop for nearest-neighbors ##################################################
for cluster_size in cluster_sizes:
    cluster_arrays = []
    cluster_arrays_summed = None
    
    for seed in seeds:
        clf = skl.neighbors.NearestNeighbors(n_neighbors=cluster_size)
        nearest_neighbors = clf.fit(mask_coords).kneighbors_graph([seed]).A

        # Sum the adjacency matrices for each seed
        if cluster_arrays_summed is None:
            cluster_arrays_summed = nearest_neighbors
        else:
            cluster_arrays_summed += nearest_neighbors

    # inverse-transform cluster arrays
    masker = nl.maskers.NiftiMasker(mask_img=gm_mask)
    masker.fit()
    clusters_mask_img = masker.inverse_transform(cluster_arrays_summed.ravel())
    
    # Show results
    nl.plotting.plot_roi(clusters_mask_img, title=f'Mask with clusters of {cluster_size} voxels')
    print(f'Cluster {cluster_size} voxels, non-zero voxels', np.count_nonzero(clusters_mask_img.get_fdata()))
    print(f'Cluster {cluster_size} voxels, non-zero voxels', np.unique(clusters_mask_img.get_fdata()))

    # Binarize image
    clusters_mask_img_bin = nl.image.binarize_img(clusters_mask_img)    

    # Save volume
    img2save = nib.Nifti1Image(clusters_mask_img_bin.get_fdata(), template.affine, template.header)
    nib.save(img2save, os.path.join(basedir, f'mask_clusters{cluster_size}vox.nii.gz'))
    
    # Delete variables created in each iteration
    del (clf, nearest_neighbors, masker, clusters_mask_img, clusters_mask_img_bin, img2save)