Source code for now_2023.utils._images

import random
from dataclasses import dataclass


@dataclass
class Point:
    x: int
    y: int
    z: int


class BaseCrop:
    def __init__(self, random_shift: int = 0):
        self.random_shift = random_shift
        self.train_mode = True
        self.bottom_left_point: Optional[Point] = None
        self.upper_right_point: Optional[Point] = None

    def _get_shift(self) -> int:
        """Return a random shift in train mode and no shift otherwise."""
        return (
            random.randint(-self.random_shift, self.random_shift)
            if self.train_mode
            else 0
        )

    def _get_slice(self, left_offset: int, right_offset: int):
        """Return a slice in a single dimension with possibly a random shift."""
        shift = self._get_shift()
        return slice(left_offset + shift, right_offset + shift)

    def _get_bbox(self) -> tuple:
        """Return the bounding box to use on the input image."""
        if self.bottom_left_point is None or self.upper_right_point is None:
            raise ValueError("No bounding box configured")
        return (
            Ellipsis,
            self._get_slice(self.bottom_left_point.x, self.upper_right_point.x),
            self._get_slice(self.bottom_left_point.y, self.upper_right_point.y),
            self._get_slice(self.bottom_left_point.z, self.upper_right_point.z),
        )

    def __call__(self, img):
        return img[self._get_bbox()].clone()

    def train(self):
        self.train_mode = True

    def eval(self):
        self.train_mode = False


[docs] class CropLeftHC(BaseCrop): """Crops the left hippocampus of a MRI non-linearly registered to MNI"""
[docs] def __init__(self, random_shift: int = 0): super().__init__(random_shift=random_shift) self.bottom_left_point = Point(25, 50, 27) self.upper_right_point = Point(55, 90, 57)
[docs] class CropRightHC(BaseCrop): """Crops the right hippocampus of a MRI non-linearly registered to MNI"""
[docs] def __init__(self, random_shift: int = 0): super().__init__(random_shift=random_shift) self.bottom_left_point = Point(65, 50, 27) self.upper_right_point = Point(95, 90, 57)