I’m not an expert but recently went through this for another project. I first converted the images and labels into as set of TFRecords. The format is tricky to work with, but it’s (apparently) optimized for loading by tf.data.Datasets
. The script below shows an example of getting this to work.
Note that the network I was using worked with 3d images, and so the batches of data were of shape
(batch_size, x_dim, y_dim, z_dim, 1)
, rather than what is typical for 2d images: (batch_size, x_dim, y_dim, n_channels)
Also, preprocessing images with functions like tf.image.per_image_standardization
can slow things down substantially. In the example below, preprocessing is done in the storage of the tfrecords files, within the function serialize_example
.
Sorry to dump a pile of code. Do followup if you have any questions
import math
from numbers import Number
import numpy as np
import tensorflow as tf
import nibabel as nib
from nilearn import plotting
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def serialize_example(img: str, label: Number) -> tf.train.Example:
# convert to numpy arraw
nii0 = np.asanyarray(nib.load(img).dataobj)
# rescale to 0-1
# if you want to use tf.image.per_image_standardization(), this would be the place to do so
# instead of the rescaling
nii = (nii0 - nii0.min())/(nii0.max() - nii0.min()).astype(np.float32)
feature = {
'label': _float_feature(label),
'image_raw': _bytes_feature(tf.io.serialize_tensor(nii).numpy())
}
return tf.train.Example(features=tf.train.Features(feature=feature))
def write_records(niis, labels, n_per_record: int, outfile: str) -> None:
"""
store list of niftis (and associated label) into tfrecords for use as dataset
"""
n_niis = len(niis)
n_records = math.ceil(len(niis) / n_per_record)
for i, shard in enumerate(range(0, n_niis, n_per_record)):
print(f"writing record {i} of {n_records-1}")
with tf.io.TFRecordWriter(
f"{outfile}_{i:0>3}-of-{n_records-1:0>3}.tfrecords",
options= tf.io.TFRecordOptions(compression_type="GZIP")
) as writer:
for nii, label in zip(niis[shard:shard+n_per_record], labels[shard:shard+n_per_record]):
example = serialize_example(img=nii, label=label)
writer.write(example.SerializeToString())
def parse_1_example(example) -> tf.Tensor:
X = tf.io.parse_tensor(example['image_raw'], out_type=tf.float32)
return tf.expand_dims(X, 3), example['label']
def decode_example(record_bytes)-> dict:
example = tf.io.parse_example(
record_bytes,
features = {
"label": tf.io.FixedLenFeature([], dtype=tf.float32),
'image_raw': tf.io.FixedLenFeature([], dtype=tf.string)
}
)
return example
def get_batched_dataset(files, batch_size: int = 32, shuffle_size: int=1024) -> tf.data.Dataset:
dataset = (
tf.data.Dataset.list_files(files) # note shuffling is on by default
.flat_map(lambda x: tf.data.TFRecordDataset(x, compression_type="GZIP", num_parallel_reads=8))
.map(decode_example, num_parallel_calls=tf.data.AUTOTUNE)
.map(parse_1_example, num_parallel_calls=tf.data.AUTOTUNE)
.cache() # remove if all examples don't fit in memory (note interaction with shuffling of files, above)
.shuffle(shuffle_size)
.batch(batch_size, num_parallel_calls=tf.data.AUTOTUNE)
.prefetch(tf.data.AUTOTUNE)
)
return dataset
# store example set of image with labels 0-9
# (e.g., put nifti of label MNI152_T1_1mm_brain.nii.gz in the working directory)
mni_nii = ['MNI152_T1_1mm_brain.nii.gz']*10
# store examples in each tfrecord. number of examples per record is configurable.
# aim for as many examples as produces files of size > 100M
write_records(mni_nii, [x for x in range(10)], 10, "tmp")
# read the records back. this will be the list of files generated by write_records()
# a full dataset will have a list with many records
list_of_records=['tmp_000-of-000.tfrecords']
ds = get_batched_dataset(list_of_records, batch_size=2, shuffle_size=10)
# ds can now be passed to model.fit
# but first!!!!!
# the serialization is a lot, so it is a good idea to verify that the images
# look okay when loaded
(Xs, Ys) = next(ds.as_numpy_iterator())
# (batch_size, )
# order will depend on shuffle (turn off all shuffling to verify order)
Ys.shape
# (batch_size, x_dim, y_dim, z_dim, 1)
Xs.shape
# convert first element in batch to in-memory nibabel.nifti format for display
nii = nib.Nifti1Image(np.squeeze(Xs[0,]), affine=np.eye(4)*2)
# this should look like a typical brain
plotting.plot_anat(nii)