Source code for GPmix.projector

import skfda
from skfda.representation.grid import FDataGrid
from skfda.representation.basis import FDataBasis, FourierBasis, BSplineBasis
from skfda.misc import inner_product, inner_product_matrix
from skfda.preprocessing.dim_reduction import FPCA
from skfda.exploratory.visualization import FPCAPlot
from skfda.misc.covariances import Exponential
from skfda.datasets import make_gaussian_process

import numpy as np
import pywt

import seaborn as sns
import matplotlib.pyplot as plt

import warnings

"""
projector.py
============

This module provides the `Projector` class for projecting functional data onto a set of univariate data
using various types of projection functions (basis). It supports Fourier, B-spline, eigenfunction (FPC),
wavelet, Ornstein-Uhlenbeck process, and random linear combinations of eigenfunctions as projection bases.

Classes
-------
Projector
    Transforms functional data to univariate data by projecting onto specified basis functions.

Dependencies
------------
- scikit-fda (skfda)
- numpy
- pywavelets (pywt)
- seaborn
- matplotlib

Example
-------
>>> from skfda.datasets import make_gaussian_process
>>> from projector import Projector
>>> fdata = make_gaussian_process(n_samples=10, n_features=100)
>>> proj = Projector(basis_type='fourier', n_proj=3)
>>> coeffs = proj.fit(fdata)
>>> proj.plot_basis()
>>> proj.plot_projection_coeffs()

"""


[docs] class Projector(): """ Transform functional data to a set of univariate data by projection onto specified projection functions. Parameters ---------- basis_type : str Specifies the type of projection function. Supported types include: - 'fourier': Fourier basis functions. - 'fpc': Eigenfunctions from functional principal component analysis. - 'wavelet': Discrete wavelet basis. - 'bspline': B-spline basis functions. - 'ou': Ornstein-Uhlenbeck process realizations. - 'rl-fpc': Random linear combinations of eigenfunctions. n_proj : int, default=3 Number of projection functions to use (i.e., number of univariate projections). basis_params : dict, optional Dictionary of hyperparameters for basis generation. Supported keys: - 'period': Period for Fourier basis. - 'order': Order for B-spline basis. - 'wv_name': Name of the wavelet (for wavelet basis). - 'resolution': Base resolution for wavelet basis. Attributes ---------- n_features : int Number of grid points for the projection functions and sample curves. basis : skfda.FDataGrid The projection functions used. coefficients : ndarray of shape (n_proj, N) Projection coefficients for N samples. Notes ----- The class supports orthogonalization of basis functions and can handle several types of bases. """ def __init__(self, basis_type: str, n_proj: int = 3, basis_params: dict = {} ) -> None: """ Initialize the Projector. Parameters ---------- basis_type : str Type of projection basis ('fourier', 'fpc', 'wavelet', 'bspline', 'ou', 'rl-fpc'). n_proj : int, default=3 Number of projection functions. basis_params : dict, optional Parameters for basis generation. """ self.basis_type = basis_type self.n_proj = n_proj self.basis_params = basis_params # Check for unwanted keys in basis_params if not all(key in ['period', 'order', 'wv_name', 'resolution'] for key in self.basis_params.keys()): raise ValueError('basis_params contains some unknown keys. ' 'Ensure that the dict keys are limited to the following: ' "'period', 'order', 'wv_name', 'resolution'." )
[docs] def get_wavelet_signal(self, wavelet_name): """ Retrieve the scaling and wavelet functions for a given discrete wavelet. Parameters ---------- wavelet_name : str Name of the discrete wavelet. Returns ------- scaling_function : ndarray The scaling (father) wavelet function. wavelet_function : ndarray The wavelet (mother) function. Raises ------ ValueError If the wavelet is not discrete or is unknown. """ try: wavelet = pywt.Wavelet(wavelet_name) except ValueError as e: if 'Use pywt.ContinuousWavelet instead' in e.args[0]: raise ValueError(f"The `Projector` class only works with discrete wavelets, {wavelet_name} is a continuous wavelet.") elif 'Unknown wavelet name' in e.args[0]: raise ValueError(f"Unknown wavelet name {wavelet_name}, check pywt.wavelist(kind = 'discrete') for the list of available builtin wavelets.") wavefuns = wavelet.wavefun() scaling_function, wavelet_function, x = wavefuns[0], wavefuns[1], wavefuns[-1] # Truncate tails of scaling and wavelet functions tails = 1e-1 nonzero_idx = np.argwhere(np.abs(wavelet_function) > tails) wavelet_function = wavelet_function[nonzero_idx[0,0]: nonzero_idx[-1,0] + 1] nonzero_idx = np.argwhere(np.abs(scaling_function) > tails) scaling_function = scaling_function[nonzero_idx[0,0]: nonzero_idx[-1,0] + 1] return scaling_function, wavelet_function
[docs] def dilate_translate_signal(self, signal, n_trans): """ Generate dilated and translated versions of a signal over the domain. Parameters ---------- signal : ndarray The signal to dilate and translate. n_trans : int Number of translations (intervals). Returns ------- list of skfda.FDataGrid List of normalized, dilated, and translated signals as FDataGrid objects. """ knots = np.linspace(self.domain_range[0], self.domain_range[1], n_trans + 1) signals_ = [skfda.FDataGrid(signal, grid_points= np.linspace(knots[i], knots[i+1], len(signal)), extrapolation= 'zeros') for i in range(n_trans)] # Return normalized signals return [signal / np.sqrt(skfda.misc.inner_product(signal, signal)) for signal in signals_]
[docs] def get_wavelet_basis(self, wavelet_name, n): """ Construct a wavelet basis using scaling and wavelet functions. Parameters ---------- wavelet_name : str Name of the discrete wavelet. n : int Number of intervals at the lowest resolution. Returns ------- skfda.FDataGrid The constructed wavelet basis as FDataGrid. """ scaling_signal, wavelet_signal = self.get_wavelet_signal(wavelet_name) # Get lowest resolution father wavelet basis = self.dilate_translate_signal(scaling_signal, n) # Get lowest resolution mother wavelet basis = basis + self.dilate_translate_signal(wavelet_signal, n) # Get higher resolution wavelets r_basis = self.n_proj - 2 * n while r_basis > 0: n *= 2 basis = basis + self.dilate_translate_signal(wavelet_signal, n) r_basis -= n # Evaluate the basis at grid points basis_grid = [skfda.FDataGrid(basis_(self.grid_points).squeeze(), grid_points= self.grid_points) for basis_ in basis[ :self.n_proj] ] # Return basis as a FDataGrid object return skfda.concatenate(basis_grid)
def _generate_basis(self) -> FDataGrid: """ Generate projection functions from the specified basis type. Returns ------- skfda.FDataGrid The generated basis functions. Raises ------ ValueError If the basis type is unknown. """ if self.basis_type == 'fourier': self.period = self.basis_params.get('period', self.domain_range[1] - self.domain_range[0]) nb = self.n_proj if (nb % 2) == 0: nb += 1 coeffs = np.eye(nb) basis = FourierBasis(domain_range= self.domain_range, n_basis= nb, period = self.period) return FDataBasis(basis, coeffs).to_grid(self.grid_points)[ : self.n_proj] elif self.basis_type == 'bspline': self.order = self.basis_params.get('order', 3) coeffs = np.eye(self.n_proj) basis = BSplineBasis(domain_range= self.domain_range, n_basis=self.n_proj, order = self.order) return FDataBasis(basis, coeffs).to_grid(self.grid_points) elif self.basis_type == 'ou': # Ornstein-Uhlenbeck process: mean = 0, k(x,y) = exp(-|x - y|) basis = make_gaussian_process(start = self.domain_range[0], stop = self.domain_range[1], n_samples = self.n_proj, n_features = 2 * len(self.grid_points), mean = 0, cov = Exponential(variance = 1, length_scale=1) ).to_grid(self.grid_points) return basis elif self.basis_type == 'wavelet': wavelet_name = self.basis_params.get('wv_name', 'db5') n = self.basis_params.get('resolution', 1) return self.get_wavelet_basis(wavelet_name, n) def _compute_fpc_combination(self, fdata): """ Construct projection functions as random linear combinations of eigenfunctions explaining at least 95% of the variation in the sample data. Parameters ---------- fdata : skfda.FDataGrid The functional data. Returns ------- skfda.FDataGrid The constructed basis as random linear combinations of eigenfunctions. """ fpca_ = FPCA(n_components= min(fdata.data_matrix.squeeze().shape)) fpca_.fit(fdata) lambdas_sq = np.square(fpca_.singular_values_) jn = np.argmax(np.cumsum(lambdas_sq / lambdas_sq.sum()) >= 0.95) + 1 s2 = [skfda.misc.inner_product(fpca_.components_[i], fdata).var() for i in range(jn)] ej = fpca_.components_[:jn] gammas = np.array([np.random.normal(0, np.sqrt(s2_), self.n_proj) for s2_ in s2]) basis_ = (gammas[:,0] * ej).sum() for i in range(1,self.n_proj): basis_ = basis_.concatenate((gammas[:,i] * ej).sum()) return basis_ def _is_orthogonal(self, basis: FDataGrid, tol: float | None = None) -> bool: """ Check the orthogonality of a given set of projection functions. Parameters ---------- basis : skfda.FDataGrid The basis functions to check. tol : float, optional Tolerance for orthogonality. If None, checks at 1e-15 and 1e-10. Returns ------- bool True if orthogonal within tolerance, False otherwise. """ basis_gram = inner_product_matrix(basis) basis_gram_off_diagonal = basis_gram - np.diag(np.diagonal(basis_gram)) if not tol is None: nonzeros = np.count_nonzero(np.abs(basis_gram_off_diagonal) > tol) if nonzeros == 0: return True else: for tol in [1e-15, 1e-10]: nonzeros = np.count_nonzero(np.absolute(basis_gram_off_diagonal) > tol) if nonzeros == 0: return True return False def _gram_schmidt(self, funs: FDataGrid) -> FDataGrid: """ Perform Gram-Schmidt orthogonalization on a set of functions. Parameters ---------- funs : skfda.FDataGrid Functions to orthogonalize. Returns ------- skfda.FDataGrid Orthogonalized functions. """ funs_ = funs.copy() num_funs = len(funs_) for i in range(num_funs): fun_ = funs_[i] for j in range(i): projection = inner_product(funs_[i], funs_[j]) / np.sqrt(inner_product(funs_[j], funs_[j])) fun_ -= projection * funs_[j] if i == 0: orthogonalized_funs = fun_.copy() else: orthogonalized_funs = orthogonalized_funs.concatenate(fun_.copy()) return orthogonalized_funs def _compute_coefficients(self, fdata: FDataGrid): """ Orthogonalize the basis functions if necessary and compute projection coefficients. Parameters ---------- fdata : skfda.FDataGrid Functional data to project. Returns ------- tuple (coefficients, basis) where coefficients is an array of projection coefficients, and basis is the (possibly orthogonalized) basis functions. """ basis = self._generate_basis() assert all((basis.grid_points[0].shape == fdata.grid_points[i].shape for i in range(len(fdata.grid_points)))), 'Set the appropriate sample_points for basis functions; number of sample points for both objects, the basis and the functional sample data, must be equal.' assert all(((basis.grid_points[0] == fdata.grid_points[i]).all() for i in range(len(fdata.grid_points)))), 'Set the appropriate sample_points for basis functions; sample points for both objects, the basis and the functional sample data, must be equal.' # Enforce orthogonality where necessary if self.basis_type not in ['ou', 'wavelet']: while not self._is_orthogonal(basis): basis = self._gram_schmidt(basis) return inner_product_matrix(basis, fdata), basis def _compute_fpc(self, fdata): """ Construct the eigenfunctions (principal components) from the data. Parameters ---------- fdata : skfda.FDataGrid Functional data. Returns ------- skfda.FDataGrid The principal component functions. """ fpca_ = FPCA(n_components = self.n_proj) basis = fpca_.fit(fdata).components_ return basis
[docs] def fit(self, fdata: FDataGrid): """ Compute the projection coefficients of sample functions. Parameters ---------- fdata : skfda.FDataGrid Functional data to project. Returns ------- ndarray Projection coefficients. """ self.domain_range = fdata.domain_range[0] self.grid_points = fdata.grid_points[0] # Center data fdata = fdata - fdata.mean() if self.basis_type in ['fourier', 'ou', 'wavelet', 'bspline']: self.coefficients, self.basis = self._compute_coefficients(fdata) elif self.basis_type == 'fpc': self.basis = self._compute_fpc(fdata) self.coefficients = inner_product_matrix(self.basis, fdata) elif self.basis_type == 'rl-fpc': self.basis = self._compute_fpc_combination(fdata) self.coefficients = inner_product_matrix(self.basis, fdata) else: raise ValueError(f"Unknown basis_type: {self.basis_type}. Choose from the supported options: 'fourier', 'bspline', 'ou', 'rl-fpc', 'wavelet', 'fpc'.") return self.coefficients
[docs] def plot_basis(self, **kwargs): """ Plot the projection basis functions. Parameters ---------- **kwargs Additional keyword arguments passed to the plot function. """ self.basis.plot(group = range(1, len(self.basis)+1), **kwargs) plt.xlabel('t') plt.ylabel('$\\beta_v(t)$')
[docs] def plot_projection_coeffs(self, ncols=4, **kwargs): """ Plot the distribution of univariate projection coefficients. Parameters ---------- ncols : int, optional Number of columns in the subplot grid. Default is 4. **kwargs Additional keyword arguments passed to seaborn.histplot. """ if self.n_proj >= ncols: fig, axes = plt.subplots(int(np.ceil(self.n_proj / ncols)), ncols, figsize=(15, 15)) else: fig, axes = plt.subplots(1, self.n_proj, figsize=(10, 5)) axes = axes.ravel() for i, coeffs, ax in zip(range(len(self.coefficients)), self.coefficients, axes): sns.histplot(data=coeffs, stat='density', ax=ax, **kwargs) label_ = 'alpha_{i' + str(i+1) + '}' ax.set_xlabel(fr'$\{label_}$') fig.tight_layout()