Source code for gridlib.resample

import multiprocessing as mp
import concurrent.futures
from itertools import chain
import signal
import math
import time
from typing import Dict, List, Tuple, Union, Callable

import numpy as np
import psutil
from tqdm import tqdm

from . import compute, data_utils, fit


def _resample_data(
    data: Dict[str, Dict[str, np.ndarray]], perc: float = 0.8
) -> Dict[str, Dict[str, np.ndarray]]:
    """
    Function resamples from the provided dataset with a given percentage of data points.

    Parameters
    ----------
    data : Dict[str, Dict[str, np.ndarray]]
        A dictionary mapping keys (time-lapse conditions) to the corresponding time and
        value arrays of the survival functions. For example::

            {
                "0.05s": {
                    "time": array([0.05, 0.1, 0.15, ...]),
                    "value": array([1.000e+04, 8.464e+03, 7.396e+03, ...]),
                },
                "1s": {
                    "time": array([1., 2., 3., 4., ...]),
                    "value": array([1.000e+04, 6.925e+03, 5.541e+03, 4.756e+03, ...]),
                },
            }

    perc : float, optional
        The percentage of data points to include in the resampled dataset. Defaults to
        0.8.

    Returns
    -------
    Dict[str, Dict[str, np.ndarray]]
        A dictionary mapping keys (time-lapse conditions) to the corresponding time and
        value arrays of the resampled survival functions. The structure is the same
        as the structure of the input data, just with resampled data values.
    """
    rng = np.random.default_rng()

    data_sampled = dict()
    for t_tl in data.keys():

        # Retrieve the time in seconds
        t_tl_s = data_utils.get_time_sec(t_tl)

        # The total number of data points for a time-lapse condition in the complete
        # dataset
        data_point_all = data[t_tl]["value"][0]

        # Determine how many data points should be in the resampled version
        # math.ceil returns an integer
        num_sample = math.ceil(data_point_all * perc)

        # Round the possible track lifes to the closest integer
        track_lifes = np.rint(data[t_tl]["time"] / t_tl_s)
        # Cast it to the integer dtype, since the compute survival functions needs
        # an array with dtype integer.
        track_lifes = track_lifes.astype(np.int64)

        # Calculate the probability for every track life
        # ----------------------------------------------
        # First retrieve the true values and invert these, so the counts go from
        # low to high
        temp = data[t_tl]["value"][::-1]

        # Calculate the difference between every track life step
        diff = temp[1:] - temp[:-1]  # check out np.diff?
        # invert the values back so they align again with the track lifes
        diff = diff[::-1]

        # Concatenate the last value, because it equals "temp[0] - 0 = temp[0]"
        values = np.concatenate(
            (diff, np.array([temp[0]]))
        )  # add the last element again
        p = values / np.sum(values)
        # THE NEXT LINE IS WRONG!! THIS IS THE CUMULATIVE SUM, NOT THE INDIVIDUAL COUNTS
        # p = data[t_tl]["value"] / np.sum(data[t_tl]["value"])

        track_lifes_sampled = rng.choice(
            track_lifes,
            size=num_sample,
            replace=True,
            p=p,
        )

        time, value = compute.compute_survival_function(track_lifes_sampled, t_tl_s)

        # Save the results to the appropriate keys
        data_sampled[t_tl] = dict()
        data_sampled[t_tl]["time"] = time
        data_sampled[t_tl]["value"] = value

    # Returns non-normalized survival function
    return data_sampled


# TODO: update docstring
def _resample_data_and_fit_worker(
    fn: Callable[
        [Dict, Dict[str, Dict[str, np.ndarray]], bool],
        Dict[str, Dict[str, Union[np.ndarray, float]]],
    ],
    parameters,
    data,
    perc: float = 0.8,
):
    """
    Function resamples from the provided dataset with a given percentage of data points
    and then perform GRID and/or multi-exponential fitting on the resampled data and
    returns the results.

    Parameters
    ----------
    fn: function

    parameters: Dict

    data : Dict[str, Dict[str, np.ndarray]]
        A dictionary mapping keys (time-lapse conditions) to the corresponding time and
        value arrays of the survival functions. For example::

            {
                "0.05s": {
                    "time": array([0.05, 0.1, 0.15, ...]),
                    "value": array([1.000e+04, 8.464e+03, 7.396e+03, ...]),
                },
                "1s": {
                    "time": array([1., 2., 3., 4., ...]),
                    "value": array([1.000e+04, 6.925e+03, 5.541e+03, 4.756e+03, ...]),
                },
            }

    perc : float, optional
        The percentage of data points to include in the resampled dataset. Defaults to
        0.8.

    Returns
    -------
    fit_results_temp: Dict[str, Dict[str, Union[np.ndarray, float]]]
        A dictionary mapping keys (fitting procedure) to the corresponding fit results.
        For example::

            {
                "grid": {
                    "k": array([1.00000000e-03, 1.04737090e-03, ...]),
                    "s": array([3.85818587e-17, 6.42847878e-18, ...]),
                    "a": 0.010564217803906671,
                    "loss": 0.004705659331508584,
                },
                "1-exp": {
                    "k": array([0.02563639]),
                    "s": array([1.]),
                    "a": 0.08514936433699753,
                    "loss": 1.2825570522448484
                },
                "2-exp": {
                    "k": array([0.03715506, 1.7248619]),
                    "s": array([0.17296989, 0.82703011]),
                    "a": 0.011938572088673213,
                    "loss": 0.2868809590425386
                },
                "3-exp": {
                    "k": array([0.0137423 , 0.27889073, 3.6560956]),
                    "s": array([0.06850312, 0.23560175, 0.69589513]),
                    "a": 0.011125323730424764,
                    "loss": 0.0379697542735324
                },
            }

    """
    data_resampled = _resample_data(data, perc=perc)

    fit_results_temp = fn(parameters, data_resampled, disp=False)

    return fit_results_temp


# TODO: add docstring
def _resample_and_fit_sequential(fn, parameters, data, n, perc):

    t0 = time.time()

    # Then perform fitting on the full dataset
    # ----------------------------------------
    print("----------------------------------------")
    print("Fitting of all data starts now...")
    # Determine the full fit results based on all the data.
    t1 = time.time()
    fit_results_full = fn(parameters, data, disp=False)
    t2 = time.time()
    print("Fitting of all data finished now.")
    print("----------------------------------------")

    print(f"Estimated time for resampling: ~{((t2-t1)*n)/60.0:.0f} min")

    print("Start resampling...")
    print("----------------------------------------")

    # Determine the GRID fit results for the resampled data
    fit_results_resampled = []
    with tqdm(total=n) as pbar:
        # Just start at 1 for the print statement later
        for i in range(1, n + 1):
            fit_results_temp = _resample_data_and_fit_worker(
                parameters, data, perc=perc
            )
            fit_results_resampled.append(fit_results_temp)
            pbar.update(1)
            print()
            print(f"{i}-th resampling is finished.")

    t3 = time.time()
    print("Resampling is finished.")
    print(f"Total time: {(t3-t0)/60.0:.0f} min")
    print("----------------------------------------")

    return fit_results_full, fit_results_resampled


def init_pool():
    """Ignore CTRL+C in the worker process."""
    signal.signal(signal.SIGINT, signal.SIG_IGN)


# TODO: add docstring
def _resample_and_fit_multiprocess(
    fn: Callable[
        [Dict, Dict[str, Dict[str, np.ndarray]], bool],
        Dict[str, Dict[str, Union[np.ndarray, float]]],
    ],
    parameters: Dict,
    data: Dict[str, Dict[str, np.ndarray]],
    n: int = 10,
    perc: float = 0.8,
    max_workers: int = None,
):

    # First determine the max number of workers
    # -----------------------------------------
    # The limit for the max number of workers is number of logical cores minus 1 (there
    # needs to be one logical core for the background tasks)
    max_workers_limit = psutil.cpu_count(logical=False) - 1

    if max_workers is None:
        max_workers = max_workers_limit
    elif max_workers < 1:
        raise ValueError("max_workers should be >=1")
    elif max_workers > max_workers_limit:
        max_workers = max_workers_limit
        print(
            f"WARNING: max_workers is set to {max_workers}, since initial value",
            "was set higher than the allowed number of workers ({max_workers_limit}).",
        )

    t0 = time.time()

    # Then perform fitting on the full dataset
    # ----------------------------------------
    print("----------------------------------------")
    print("Fitting of all data starts now...")
    # Determine the full fit results based on all the data.
    t1 = time.time()
    fit_results_full = fn(parameters, data, disp=False)
    t2 = time.time()
    print("Fitting of all data finished now.")
    print("----------------------------------------")

    print(f"Estimated time for resampling: ~{((t2-t1)*n)/60.0/max_workers:.0f} min")

    print("Start resampling...")
    print("----------------------------------------")

    # Determine the GRID fit results for the resampled data
    fit_results_resampled = []
    i = 1
    with tqdm(total=n) as pbar:
        # Parallelization
        with concurrent.futures.ProcessPoolExecutor(
            max_workers=max_workers, initializer=init_pool
        ) as executor:
            try:

                # Submit the tasks
                futures = [
                    executor.submit(
                        _resample_data_and_fit_worker,
                        fn,
                        parameters,
                        data,
                        perc,
                    )
                    for _ in range(n)
                ]
                # Check if a task is finished
                for future in concurrent.futures.as_completed(futures):
                    fit_results_temp = future.result()
                    fit_results_resampled.append(fit_results_temp)
                    pbar.update(1)
                    print()
                    print(f"{i}-th resampling is finished.")
                    i += 1
            except KeyboardInterrupt:
                print("Caught keyboardinterrupt")

    t3 = time.time()
    print("Resampling is finished.")
    print(f"Total time: {(t3-t0)/60.0:.0f} min")
    print("----------------------------------------")

    return fit_results_full, fit_results_resampled


[docs]def resample_and_fit( parameters: Dict, data: Dict[str, Dict[str, np.ndarray]], n: int = 200, perc: float = 0.8, fit_mode: str = "grid", multiprocess_flag: bool = True, max_workers: int = None, ) -> Tuple[ Dict[str, Dict[str, Union[np.ndarray, float]]], List[Dict[str, Dict[str, Union[np.ndarray, float]]]], ]: """ This function performs a resampling procedure to fit survival time and value data with a given set of parameters. It first fits the full dataset with the parameters, then resamples the data and fits the resampled datasets with the parameters. The function returns the fit results for the full dataset and the resampled datasets. Parameters ---------- parameters : Dict Dictionary containing all the parameters needed to perform the GRID and/or multi-exponential fitting. data : Dict[str, Dict[str, np.ndarray]] A dictionary mapping keys (time-lapse conditions) to the corresponding time and value arrays of the survival functions. For example:: { "0.05s": { "time": array([0.05, 0.1, 0.15, ...]), "value": array([1.000e+04, 8.464e+03, 7.396e+03, ...]), }, "1s": { "time": array([1., 2., 3., ...]), "value": array([1.000e+04, 6.925e+03, 5.541e+03, ...]), }, } n : int, optional The number of times you want to resample from the full dataset. Defaults to 200. perc : float, optional The percentage of data points to include in each resampled dataset. Defaults to 0.8. fit_mode : {"grid", "multi-exp", "all"}, optional The fitting procedure to perform: * "grid": perform GRID fitting procedure on the data. * "multi-exp": perform multi-exponential fitting procedure on the data. * "all": perform both GRID fitting and multi-exponential fitting procedure on the data. multiprocess_flag : bool, optional A flag indicating whether to use parallel processing when fitting the resampled datasets. Defaults to True. max_workers: int, optional The maximum number of logical cores to use for the multiprocessing. This number is only used if multiprocess_flag is set to True and the value has to be between 1 and the maximum # logical cores - 1. If value is set to None, than it will be set to # logical cores - 1. Default is None. Returns ------- fit_results_full: Dict[str, Dict[str, Union[np.ndarray, float]]] A dictionary mapping keys (fitting procedure) to the corresponding fit results. For example:: { "grid": { "k": array([1.00000000e-03, 1.04737090e-03, ...]), "s": array([3.85818587e-17, 6.42847878e-18, ...]), "a": 0.010564217803906671, "loss": 0.004705659331508584, }, } fit_results_resampled: List[Dict[str, Dict[str, Union[np.ndarray, float]]]] A list of dictionaries, each a dictionary mapping keys (fitting procedure) to the corresponding fit results. For example:: [ { "grid": { "k": array([1.00000000e-03, 1.04737090e-03, ...]), "s": array([3.85818587e-17, 6.42847878e-18, ...]), "a": 0.010564217803906671, "loss": 0.004705659331508584, }, }, { "grid": { "k": array([1.00000000e-03, 1.04737090e-03, ...]), "s": array([3.85818587e-17, 6.42847878e-18, ...]), "a": 0.010564217803906671, "loss": 0.004705659331508584, }, }, ... ] """ if fit_mode not in {"grid", "multi-exp", "all"}: raise ValueError("fit_mode is not a valid, see docstring for valid modes.") # Set the correct function for the fitting if fit_mode == "grid": fn = fit.fit_grid elif fit_mode == "multi-exp": fn = fit.fit_multi_exp elif fit_mode == "all": fn = fit.fit_all if multiprocess_flag: fit_results_full, fit_results_resampled = _resample_and_fit_multiprocess( fn, parameters, data, n=n, perc=perc, max_workers=max_workers ) else: fit_results_full, fit_results_resampled = _resample_and_fit_sequential( fn, parameters, data, n=n, perc=perc ) return fit_results_full, fit_results_resampled