from operator import index
from os import times
from time import time
from typing import *
from . import ImageIO
import hashlib
import pathlib
import numpy as np
[docs]
class BaseDataSet():
def __init__(self, images : List[ImageIO.ImzMLReader], buffer_type : str) -> None:
self.images = images
self.elements = []
self.buffer_type = buffer_type
if 'disk' in self.buffer_type: # buffer images on disk; check path exists
self.buffer_path = pathlib.Path(self.buffer_type.split(':')[1])
if not self.buffer_path.exists():
raise Exception(f"The path {str(self.buffer_path)} does not exist!")
def __len__(self) -> int:
pass
def __getitem__(self, index):
pass
[docs]
def getitems(self, indexes: List[int]):
pass
[docs]
class SpectrumDataset(BaseDataSet):
"""Dataset for accession of individual spectra.
__len__ so that len(dataset) returns the size of the dataset, that is equal to the sum of the number of spectra (N) for each image.
__getitem__ to support the indexing such that dataset[i] can be used to get i'th sample. i is pointing to indices 0,...,p-1,p,...,q-1, ... N, with p=#SpectraImage1, q=#SpectraImage2 etc...
neighborhood_size so that 2*neighborhood_size+1 is the window size
Parameters
----------
images : List[ImageIO.ImzMLReader]
"""
def __init__(self,images : List[ImageIO.ImzMLReader], neighborhood_size: int = 0, transforms = None, buffer_type='memory')-> None:
super().__init__(images, buffer_type)
self.spectrum_depth = self.images[0].GetXAxisDepth()
if "memory" == self.buffer_type:
self.buffer = [(np.array([None] * self.images[k].GetNumberOfSpectra()), np.zeros((self.images[k].GetNumberOfSpectra(), self.spectrum_depth), dtype = np.float32)) for k in range(len(self.images))]
self.neighborhood_size = neighborhood_size
self.index_images = []
self.footprint = np.ones([2*neighborhood_size+1, 2*neighborhood_size+1, 1])
self.transforms = transforms
self.hit_counter = [0]*len(self.images)
for imageID, handle in enumerate(self.images):
assert(self.spectrum_depth == handle.GetXAxisDepth())
imageElements = []
# get array returns a 2D image
index_image = handle.GetArray(handle.GetXAxis()[self.spectrum_depth//2], 1, squeeze=False).astype(np.int32)
index_image.fill(-1)
for spectrumID in range(handle.GetNumberOfSpectra()):
(x,y,z) = handle.GetSpectrumPosition(spectrumID)
imageElements.append((imageID, spectrumID, (x,y,z)))
index_image[z,y,x] = spectrumID
self.elements.extend(imageElements)
self.index_images.append(index_image)
def __len__(self) -> int:
return len(self.elements)
def __getitem__(self, index):
return self.getitems([index])
[docs]
def getitems(self, indexes: List[int]):
ids_split_to_images = {}
#sort by images and create list of image related ids
for index in indexes:
imageID, spectrumID, (x,y,z) = self.elements[index]
shape = self.images[imageID].GetShape()
if imageID not in ids_split_to_images:
ids_split_to_images[imageID] = []
if self.neighborhood_size > 0:
# neighborhood do not violate border regions
left = np.clip(x-self.neighborhood_size, 0, shape[0])
right = np.clip(x+self.neighborhood_size+1, 0, shape[0])
bottom = np.clip(y-self.neighborhood_size,0, shape[1])
top = np.clip(y+self.neighborhood_size+1,0, shape[1])
# get all spectra indices for the requested spectrum position and neighbors
indices = self.index_images[imageID][0,bottom:top,left:right]
indices = indices.flatten()
# handle invalid spectra positions (indicated by -1 in index images)
if np.any(indices < 0):
choices = np.random.choice(indices[indices>=0], len(indices[indices < 0]))
indices[indices < 0] = choices
# handling missing values if at border
expected_indices = (self.neighborhood_size*2+1)**2
if len(indices) < expected_indices:
missing_indices = expected_indices - len(indices)
indices = np.concatenate([indices,np.random.choice(indices, missing_indices)])
indices = indices.tolist()
else:
indices = [spectrumID]
ids_split_to_images[imageID].extend(indices)
result = None
interim = None
BUFFER_QUERY = 0
BUFFER_DATA = 1
for imageID, indices in ids_split_to_images.items():
# use buffering
if self.buffer_type is not None:
indices = np.array(indices)
interim = np.zeros((len(indices), self.spectrum_depth))
# mask buffer entries which are not already filled with data
query_mask = self.buffer[imageID][0][indices] == None
# load data from imzML
if query_mask.any():
self.hit_counter[imageID] = self.hit_counter[imageID] + 1
print(self.hit_counter, end='\r')
# get data for those indices from the image
interim[query_mask] = self.images[imageID].GetSpectra(indices[query_mask])
# mark buffer query structure to prevent double queries
self.buffer[imageID][BUFFER_QUERY][indices[query_mask]] = True
# copy data to buffer
self.buffer[imageID][BUFFER_DATA][indices[query_mask]] = interim[query_mask]
# for all entries in the BUFFER_QUERY structure which already have data
# copy buffered data to interim data structure
if ~query_mask.any():
interim[~query_mask] = self.buffer[imageID][BUFFER_DATA][indices[~query_mask]]
else:
interim = self.images[imageID].GetSpectra(indices)
if result is None:
result = interim
else:
result = np.concatenate([result, interim])
if self.transforms is not None:
result = self.transforms(result)
return result
[docs]
class IonImageDataset(BaseDataSet):
def __init__(self, images : List[ImageIO.ImzMLReader],
centroids:List[float],
tolerance:float,
tolerance_type:str='ppm',
buffer_type='memory',
transforms=None)-> None:
super().__init__(images, buffer_type)
self.elements = centroids
self.tolerance = tolerance
self.tolerance_type = tolerance_type
self.transforms = transforms
self.buffer = [{} for _ in images]
[docs]
def get_tolerance(self, c):
if self.tolerance_type == "ppm":
return c * self.tolerance * 10e-6
else:
return self.tolerance
[docs]
def make_buffered_image(self, index):
c = self.elements[index%len(self.elements)]
image_id = index//len(self.elements)
if c not in self.buffer[image_id] or self.buffer_type == 'none':
# create ion image
ii = self.images[image_id].GetArray(c, self.get_tolerance(c), squeeze=False)
if self.buffer_type == 'memory': # buffer image in memory
self.buffer[image_id][c] = ii
else: # buffer image on disk
hash = hashlib.sha224("{:.8f}".format(c).encode())
hash.update("{:.8f}".format(self.get_tolerance(c)).encode())
path = self.buffer_path.joinpath(self.images[image_id].GetImageName())
path.mkdir(exist_ok=True)
path = path.joinpath(hash.hexdigest())
np.save(path ,ii)
self.buffer[image_id][c] = path
else: # load buffered entry
if self.buffer_type == 'memory':
ii = self.buffer[image_id][c]
else:
ii = np.load(str(self.buffer[image_id][c])+".npy")
if self.transforms is not None:
ii = self.transforms(ii)
return ii
def __len__(self) -> int:
return len(self.elements) * len(self.images)
def __getitem__(self, index):
return self.make_buffered_image(index)
[docs]
def getitems(self, indexes: List[int]):
return np.stack([self.__getitem__(index) for index in indexes])