from operator import index
from os import times
from time import time
from typing import *
from . import ImageIO
import hashlib
import pathlib
import numpy as np
import random

[docs] class BaseDataSet(): def __init__(self, images : List[ImageIO.ImzMLReader], buffer_type : str) -> None: self.images = images self.elements = [] self.buffer_type = buffer_type if 'disk' in self.buffer_type: # buffer images on disk; check path exists self.buffer_path = pathlib.Path(self.buffer_type.split(':')[1]) if not self.buffer_path.exists(): raise Exception(f"The path {str(self.buffer_path)} does not exist!") def __len__(self) -> int: pass def __getitem__(self, index): pass
[docs] def getitems(self, indexes: List[int]): pass
[docs] class SpectrumDataset(BaseDataSet): """Dataset for accession individual spectra and class labels (optional) of multiple images (m2aia.ImzMLReader objects). The aim of the SpectrumDataset is to provide convenient access to spectra of single or multiple ImzMLReaders. Two access strategy exist: 1) Spectral approach: a single spectrum is returned. 2) Spatio-spectral: a central spectrum and corresponding neighbors are returned. To use multiple images a spectra depth of equal size for each image is required. A label mask can be provided and is used to return labels for each accessed element. To use the spatio-spectral approach, a shape element is required. The Dataset will then return the spectrum embedded in neighboring spectra, i.e. if the shape tuple is shape=(5,5) the shape of a data entry is [B,C,H,W], with batchsize as B = 1, spectrum depth as C = len(spectrum), width as W=5 and height as H=5 of the patch. If no shape element was provided, the Dataset will return a single spectrum with shape [B=1,C]. If multiple elements of the Dataset should be queried at one, the SpectrumDataset.getitems(list_of_indices) returns a batch like object containing the elements. i.e. without a shape definition returned elements will have the shape [B=len(list_of_indices), C] and with shape=(5,5) the shape [B=len(list_of_indices),C,H=5,W=5]. This is used in m2aia.BatchGenerator. Complete processing examples with focus on deep learning can be found on Example usage:: import m2aia as m2 I = m2.ImzMLReader("path/to/imzMl/file.imzML") I.SetNormalization(m2.m2NormalizationTIC) I.SetIntensityTransformation(m2.m2IntensityTransformationSquareRoot) I.Execute() dataset = m2.SpectrumDataset([I], shuffle=True) for X,Y in dataset(): print("Spectrum", X.shape, "Class Labels", Y.shape) do_something(X,Y) """
[docs] def find_nearest_indices(self, centroids: np.array, xaxis: np.array): return np.array([np.argmin(np.abs(xaxis - mz)) for mz in centroids])
[docs] def find_subrange_indices(self, xs, center_index, tolerance, is_ppm): # Calculate the lower and upper bounds based on the index and tolerance center_value = xs[center_index] if is_ppm: tol = xs[center_index] * tolerance * 10e-6 lower_bound = center_value - tol upper_bound = center_value + tol # Initialize search pointers left_index = center_index right_index = center_index # Move the left pointer to find the lower bound while left_index > 0 and xs[left_index - 1] >= lower_bound: left_index -= 1 # Move the right pointer to find the upper bound while right_index < len(xs) - 1 and xs[right_index + 1] <= upper_bound: right_index += 1 # print(left_index,center_index, right_index + 1, "=>", abs(left_index - (right_index+1))) # Extract the values within the range return left_index, right_index + 1
def __init__(self,images:List[ImageIO.ImzMLReader], labeled_images:List[np.array] = None, sampling_masks:List[np.array] = None, spectrum_mask_indices:List[int] = None, tolerance:np.float32=None, is_tolerance_in_ppm:bool=True, label_map: Dict = None, shape:Tuple = None, transform_data = None, transform_labels = None, buffer_type:str='memory', reduce_function=np.mean, shuffle=False, quiet_init=True)-> None: """_summary_ Args: images (List[ImageIO.ImzMLReader]): A list of ImageIO.ImzMLReader objects labeled_images (List[np.array], optional): A list of labeled masks. If non the ImzMLReader.GetMaskArray is used for each image. Defaults to None. exclude_labels (List[np.int32]): A list of labels which are excluded. spectrum_mask_indices (np.array, optional): A list of indices along the x axis (indices of m/z values). If None, the whole spectra with all m/z values is loaded. Defaults to None. shape (Tuple, optional): The shape can be used to query a neighborhood around a given spectrum/pixel. For example if shape is set to (-1,5,5), a 5x5 neighborhood is sampled around a queried pixel position. If shape is set to (-1,) the shape size is set to either the number of indices given in the spectrum_mask_indices or is set to hew whole spectrum depth. Defaults to (-1,). tolerance (int, optional): if spectrum_mask_indices is used, a tolerance can be set to apply a reduce function around the indices. Defaults to 20. reduce_function (function, optional): Reduce function if tolerance is set. Defaults to numpy.mean. transforms (Function, optional): A transformation can pe applied to a given spectrum using e.g. the transforms of torchvision. Defaults to None. buffer_type (str, optional): During querying the images it is possible to buffer queried spectra in memory to provide a fast access to upcoming queries e.g. in the next epoch. Defaults to 'memory'. Disable by setting it to None. """ super().__init__(images, buffer_type) # track member variables self.shape = shape self.spectrum_mask_indices = spectrum_mask_indices self.tolerance = tolerance self.is_ppm = is_tolerance_in_ppm self.x_hws = 0 self.y_hws = 0 self.xs = self.images[0].GetXAxis() self.reduce_function=reduce_function self.ranges = None self.transform_data = transform_data self.transform_labels = transform_labels self.label_map = label_map self.labels = set() if self.tolerance and self.spectrum_mask_indices is not None: self.ranges = [self.find_subrange_indices(self.xs, index, self.tolerance, self.is_ppm) for index in self.spectrum_mask_indices] else: self.tolerance = None self.ranges = None # make sure all images have identically x axis for imageID, handle in enumerate(self.images): assert(np.all(self.xs == handle.GetXAxis())) # check for equal x axis size if self.spectrum_mask_indices is None: # complete spectrum data self.spectrum_depth = self.images[0].GetXAxisDepth() else: # mask a spectrum using a list of indices self.spectrum_depth = self.spectrum_mask_indices.shape[0] if self.shape and len(self.shape) >= 2: if self.shape[-2]%2 == 0 or self.shape[-1]%2 == 0: raise Exception(f"We only support odd neighborhood sizes!") # half window size self.x_hws = self.shape[-1]//2 self.y_hws = self.shape[-2]//2 if "memory" == self.buffer_type: self.buffer = [] for k in range(len(self.images)): buffer_spectrum_label = np.zeros((self.images[k].GetNumberOfSpectra(),), dtype=np.bool) buffer_spectrum_data = np.zeros((self.images[k].GetNumberOfSpectra(), self.spectrum_depth), dtype = np.float32) self.buffer.append((buffer_spectrum_label, buffer_spectrum_data)) # self.neighborhood_size = neighborhood_size self.index_images = [] self.hit_counter = [0]*len(self.images) # for each image for imageID, handle in enumerate(self.images): imageElements = [] index_image = -np.ones(handle.GetShape()[::-1], dtype=np.int32) if labeled_images is None: mask_image = handle.GetMaskArray() else: mask_image = labeled_images[imageID] if sampling_masks is None: sampling_mask = handle.GetMaskArray() else: sampling_mask = sampling_masks[imageID] for spectrumID in range(handle.GetNumberOfSpectra()): (x,y,z) = handle.GetSpectrumPosition(spectrumID) if sampling_mask[z,y,x] <= 0 or mask_image[z,y,x] < 0: continue index_image[z,y,x] = spectrumID label = mask_image[z,y,x] imageElements.append((imageID, spectrumID, (x,y,z), label)) self.labels.add(label) # for each imageID insert elements and an index image self.elements.extend(imageElements) self.index_images.append(index_image) if shuffle: random.shuffle(self.elements) def __len__(self) -> int: """ Returns the number of accessible spectra. If multiple images are used, len is the sum of accessible elements of all images. Returns: int: Number of accessible elements of this Dataset. """ return len(self.elements) def __getitem__(self, index): return self.getitems([index])
[docs] def getitems(self, dataset_query_indices: List[int]): image_query_indices = {imageID:[] for imageID, _ in enumerate(self.images)} image_query_labels = {imageID:[] for imageID, _ in enumerate(self.images)} for index in dataset_query_indices: # check origin imageID, spectrumID, (x,y,z), label = self.elements[index] image_shape = self.images[imageID].GetShape() # save label image_query_labels[imageID].append(label) if self.shape: left = np.clip(x-self.x_hws, 0, image_shape[0]) right = np.clip(x+self.x_hws+1, 0, image_shape[0]) bottom = np.clip(y-self.y_hws,0, image_shape[1]) top = np.clip(y+self.y_hws+1,0, image_shape[1]) indices = self.index_images[imageID][0,bottom:top,left:right].flatten() if any(indices < 0): choices = np.random.choice(indices[indices>=0], len(indices[indices < 0])) indices[indices < 0] = choices expected_indices = (2*self.x_hws+1) * (2*self.y_hws+1) if len(indices) < expected_indices: missing_indices = expected_indices - len(indices) indices = np.concatenate([indices,np.random.choice(indices, missing_indices)]) image_query_indices[imageID].append(indices) else: image_query_indices[imageID].extend([spectrumID]) BUFFER_QUERY = 0 BUFFER_DATA = 1 result_data = None result_labels = None for imageID, image in enumerate(self.images): # get all indices for the image indices = np.array(image_query_indices[imageID], dtype=np.int32) labels = np.array(image_query_labels[imageID], dtype=np.int32) if self.buffer_type == "memory": # buffering is used image_buffer_data = self.buffer[imageID][BUFFER_DATA] image_buffer_query = self.buffer[imageID][BUFFER_QUERY] # mask indices which require to be loaded from ImzML image (buffer status False indicates "not buffered") query_mask = image_buffer_query[indices] == False miss_indices = indices[query_mask] # hit_indices = image_indices[~query_mask] # load from imzML and store in buffer if np.any(query_mask): # print(len(miss_indices)/len(indices)) # get data for those indices from the image spectra = image.GetSpectra(miss_indices).astype(np.float32) image_buffer_query[miss_indices] = True if self.spectrum_mask_indices is None: # check if a centroids list exists # no one was set so we put the raw spectra into the buffer image_buffer_data[miss_indices] = spectra else: # a spectrum mask exists if self.ranges: # use range queries along the spectra for k, [l,u] in enumerate(self.ranges): image_buffer_data[miss_indices,k] = self.reduce_function(spectra[:,l:u],axis=1) # print(self.reduce_function(spectra[:,l:u],axis=1), l,u, image_buffer_data[miss_indices][:,k]) else: image_buffer_data[miss_indices] = spectra[..., self.spectrum_mask_indices] # load from buffer batch_data = image_buffer_data[indices] # print("mean", np.mean(batch_data)) else: # no buffering is used spectra = self.images[imageID].GetSpectra(indices) if self.spectrum_mask_indices is not None: if self.ranges: # use range queries along the spectra for k, [l,u] in enumerate(self.ranges): batch_data = self.reduce_function(spectra[:,l:u],axis=1) else: batch_data = spectra[..., self.spectrum_mask_indices] else: batch_data = self.images[imageID].GetSpectra(indices) if self.shape: # reshape to [B,C,H,W] batch_data = np.reshape(batch_data, (-1,) + self.shape + (self.spectrum_depth,)) batch_data = np.swapaxes(batch_data, 1,3) # combine queries from different images if result_data is None: result_data = batch_data else: result_data = np.concatenate([result_data, batch_data]) if result_labels is None: result_labels = np.array(labels, np.int32) else: result_labels = np.concatenate([result_labels, np.array(labels, np.int32)]) # at this stage we receive a list of spectra for n queried central spectra we receive a list of # n*(2*s+1)^2 where s is the half size of the neighborhood # we can add a transformer that now reshapes those spectra into the correct shape. # e.g. from [n*(2*s+1)^2,D] we go to => [n, D, 2*s+1, 2*s+1] # with D is the number of channels # trafo = transforms.Compose([ # transforms.Lambda(lambda x: np.transpose(x)), #=> (D,9) # transforms.Lambda(lambda x: np.reshape(x, (1, x.shape[0], int(np.sqrt(x.shape[1])), int(np.sqrt(x.shape[1]))))), #=> (1,D,3,3) # transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.float32))]) if self.transform_data is not None: result_data = self.transform_data(result_data) if self.transform_labels is not None: result_labels = self.transform_labels(result_labels) return result_data, result_labels
[docs] class IonImageDataset(BaseDataSet): def __init__(self, images : List[ImageIO.ImzMLReader], centroids:List[float], tolerance:float, tolerance_type:str='ppm', buffer_type='memory', transforms=None)-> None: super().__init__(images, buffer_type) self.elements = centroids self.tolerance = tolerance self.tolerance_type = tolerance_type self.transforms = transforms self.buffer = [{} for _ in images]
[docs] def get_tolerance(self, c): if self.tolerance_type == "ppm": return c * self.tolerance * 10e-6 else: return self.tolerance
[docs] def make_buffered_image(self, index): c = self.elements[index%len(self.elements)] image_id = index//len(self.elements) if c not in self.buffer[image_id] or self.buffer_type == 'none': # create ion image ii = self.images[image_id].GetArray(c, self.get_tolerance(c), squeeze=False) if self.buffer_type == 'memory': # buffer image in memory self.buffer[image_id][c] = ii else: # buffer image on disk hash = hashlib.sha224("{:.8f}".format(c).encode()) hash.update("{:.8f}".format(self.get_tolerance(c)).encode()) path = self.buffer_path.joinpath(self.images[image_id].GetImageName()) path.mkdir(exist_ok=True) path = path.joinpath(hash.hexdigest()) ,ii) self.buffer[image_id][c] = path else: # load buffered entry if self.buffer_type == 'memory': ii = self.buffer[image_id][c] else: ii = np.load(str(self.buffer[image_id][c])+".npy") if self.transforms is not None: ii = self.transforms(ii) return ii
def __len__(self) -> int: return len(self.elements) * len(self.images) def __getitem__(self, index): return self.make_buffered_image(index)
[docs] def getitems(self, indexes: List[int]): return np.stack([self.__getitem__(index) for index in indexes])