Source code for now_2023.datasets._mri

import torch
import pandas as pd
from pathlib import Path

from torch.utils.data import Dataset, DataLoader, sampler
from typing import Optional, Callable
from dataclasses import dataclass


[docs] @dataclass class Sample: """Class reprsenting a sample from an MRIDataset.""" image: torch.Tensor label: str participant_id: str session_id: str hemi: str
[docs] class MRIDataset(Dataset):
[docs] def __init__( self, img_dir: str, data_df: pd.DataFrame, transform: Optional[Callable] = None, ): """ Parameters ---------- img_dir : str Path to the CAPS directory containing preprocessed images. data_df : DataFrame Metadata of the population. Columns include `participant_id`, `session_id`, and `diagnosis`. transform : Callable, optional List of transforms applied on-the-fly, chained with torchvision.transforms.Compose. """ self.img_dir: Path = Path(img_dir) self.transform = transform self.data_df = data_df self.label_code = {"AD": 1, "CN": 0} self.size = self[0]["image"].shape
def __len__(self) -> int: return len(self.data_df) def __getitem__(self, idx: int) -> dict: return self._get_sample(idx).__dict__ def _get_sample(self, idx: int) -> Sample: diagnosis = self.data_df.loc[idx, "diagnosis"] label = self.label_code[diagnosis] participant_id = self.data_df.loc[idx, "participant_id"] session_id = self.data_df.loc[idx, "session_id"] hemi = self.data_df.loc[idx, "hemi"] image_filename = f"{participant_id}_{session_id}_T1w_segm-graymatter_space-Ixi549Space_modulated-off_probability_{hemi}.pt" image_folder = ( self.img_dir / "subjects" / participant_id / session_id / "deeplearning_prepare_data" / "image_based" / "custom" ) image = torch.load(image_folder / image_filename) if self.transform: image = self.transform(image) return Sample(image, label, participant_id, session_id, hemi) def train(self): if self.transform: self.transform.train() def eval(self): if self.transform: self.transform.eval()