Source code for biomechzoo.processing.rep_trial_data

from typing import Dict, List, Union, Tuple
import numpy as np

from biomechzoo.processing.explodechannel_data import explodechannel_data
from biomechzoo.biomech_ops.normalize_line import normalize_line
from biomechzoo.statistics.rmse import rmse
from biomechzoo.processing.removechannel_data import removechannel_data

[docs] def reptrial_data(gdata: Dict, channels: Union[List[str], str], method: str = 'mean') -> Tuple[Dict, Union[int, str]]: """ Compute a representative trial from a set of trials for a subject/condition. This function can operate in two modes: 1. 'mean': Computes the pointwise mean of each specified channel across all trials, producing a synthetic representative trial. 2. 'rmse': Computes the trial whose waveform is closest to the mean in the root-mean-squared error (RMSE) sense, per channel, and returns that trial as the representative. Parameters ---------- gdata : dict Dictionary of zoo data. Each key corresponds to a trial (e.g., 'data1', 'data2', ...). channels : list of str or 'all' List of channel names to include in the representative trial computation. If 'all', all channels in the first trial are used. method : {'mean', 'rmse'}, optional Method to compute the representative trial. Default is 'mean'. - 'mean' : Synthetic trial from pointwise mean. - 'rmse' : Select existing trial closest to mean waveform. Returns ------- rep : dict Representative trial, in the same format as a single trial in gdata. file_index : int or str Index of the selected trial in gdata for 'rmse' method, or 'mean' string if method='mean'. Raises ------ ValueError If NaN values are found in channels or if unknown method is specified. Notes ----- Events are not handled here. Rather, the user could run event detection for the representative trial. """ nlength = 101 trials = list(gdata.keys()) # in case upper case RMSE method = method.lower() # determine channels if channels == 'all': channels = [ch for ch in gdata[trials[0]].keys() if ch != 'zoosystem'] # explode any n x 3 channels # todo: test this functionality exploded = [] for ch in list(channels): data = gdata[trials[0]][ch]['line'] if data.ndim == 2 and data.shape[1] == 3: for t in trials: gdata[t] = explodechannel_data(gdata[t], ch) channels.remove(ch) channels.extend([ch+'_x', ch+'_y', ch+'_z']) exploded.append(ch) if method == 'mean': rep = gdata[trials[0]] for ch in channels: stk = [] for t in trials: stk.append(normalize_line(gdata[t][ch]['line'], nlength)) rep[ch]['line'] = np.mean(np.vstack(stk), axis=0) file_index = 'mean' elif method == 'rmse': rms_stack = np.zeros((len(trials), len(channels))) for i, ch in enumerate(channels): stk = [] for t in trials: stk.append(normalize_line(gdata[t][ch]['line'], nlength)) stk = np.vstack(stk) mean_val = np.mean(stk, axis=0) if np.isnan(mean_val).any(): raise ValueError('NaNs found in channel {}'.format(ch)) for j in range(len(trials)): rms_stack[j, i] = rmse(mean_val, stk[j, :]) RMSvals = np.mean(rms_stack, axis=1) file_index = int(np.argmin(RMSvals)) rep = gdata[trials[file_index]] else: raise ValueError('Unknown method {}, choose mean or rmse'.format(method)) # collapse exploded channels for ch in exploded: rep = removechannel_data(rep, [ch+'_x', ch+'_y', ch+'_z']) # metadata rep['zoosystem']['CompInfo']['Reptrials'] = { 'Trials': len(trials), 'Method': method } return rep, file_index
if __name__ == "__main__": from biomechzoo.utils.set_zoosystem import set_zoosystem # synthetic zoo-like data np.random.seed(0) gdata = {} n_samples = 50 n_trials = 5 channels = ['HipFlexion', 'KneeAngle'] for t in range(n_trials): trial_name = f"data{t+1}" gdata[trial_name] = {} for ch in channels: # synthetic waveform with slight random variation base = np.linspace(0, 30, n_samples) gdata[trial_name][ch] = {'line': base + np.random.randn(n_samples), 'event': {}} gdata[trial_name]['zoosystem'] = set_zoosystem() # Test mean method # rep_mean, idx_mean = reptrial_data(gdata, channels=channels, method='mean') # Test rmse method rep_rmse, idx_rmse = reptrial_data(gdata, channels=channels, method='rmse')