import numpy as np
from biomechzoo.utils.engine import engine
from biomechzoo.utils.zload import zload
from biomechzoo.ensembler.helpers import match_condition, extract_subject_id, extract_events, ZooEvent, ConditionSource, ConditionSpec
[docs]
class DataStore:
"""
Loads, indexes and extracts relevant data and information from the zoo files.
"""
def __init__(self, fld, condition_spec: ConditionSpec | None=None, events=None, subj_list=None, str_match=None):
self.fld = fld
self.condition_spec = condition_spec or ConditionSpec(
source=ConditionSource.BETWEEN, conditions=[]
)
self.conditions = self.condition_spec.conditions
self.subj_list = subj_list
self.event_list = events
self.str_match = str_match
self._fl = engine(self.fld)
self.subjects = self._resolve_subjects()
# lazy caches — populated on first access
self._extracted: set[tuple[str, str]] = set()
self._lines: dict[tuple[str, str], list[np.ndarray]] = {}
self._events: dict[tuple[str, str, str], list[ZooEvent]] = {}
self._subj_index: dict[tuple[str, str], list[str]] = {}
self._event_subj_index: dict[tuple[str, str, str], list[str]] = {}
def _ensure_extracted(self, channel: str, condition: str) -> None:
key = (channel, condition)
if key not in self._extracted:
self._extract(channel, condition)
self._extracted.add(key)
[docs]
def get_lines(self, channel, condition):
self._ensure_extracted(channel, condition)
return self._lines.get((channel, condition), [])
[docs]
def get_events(self, channel, condition, event_name):
self._ensure_extracted(channel, condition)
event_key = (channel, condition, event_name)
if event_key not in self._events:
self._extract_events(channel, condition, event_name)
return self._events.get(event_key, [])
[docs]
def get_subject_ids(self, channel, condition):
self._ensure_extracted(channel, condition)
return self._subj_index.get((channel, condition), [])
def _extract(self, channel, condition):
"""Parse all zoo files for on (channel, condition) pair."""
key = (channel,condition)
self._lines[key] = []
self._subj_index[key] = []
zoo_channel = self._resolve_zoo_channel(channel, condition)
for f in self._fl:
data = zload(f)
if self.condition_spec.source == ConditionSource.BETWEEN:
matched = match_condition(f, self.conditions)
# fall save: condition needs to be all or match the condition currently in favour
if matched != "__all__":
if matched != condition:
continue
# fail save: key must be in data.
if zoo_channel not in data.keys():
continue
subj = extract_subject_id(f, subj_list=self.subj_list, str_pattern=self.str_match)
if subj is None:
continue
ch_data = data[zoo_channel]
raw = ch_data.get("line")
if raw is not None:
arr = np.asarray(raw, dtype=float).squeeze()
self._lines[key].append(arr)
self._subj_index[key].append(subj)
[docs]
def get_event_values(self, channel: str, condition: str, event_name: str) -> list[float]:
"""Convenience — y-only, for violin/stats renderers."""
return [ev.y for ev in self.get_events(channel, condition, event_name)]
def _extract_events(self, channel, condition, event_name):
"""Separate pass for events. Only runs when events are needed"""
event_key = (channel, condition, event_name)
self._events[event_key] = []
self._event_subj_index[event_key] = []
zoo_channel = self._resolve_zoo_channel(channel, condition)
for f in self._fl:
data = zload(f)
# Condition matching - branch on source type
if self.condition_spec.source == ConditionSource.BETWEEN:
matched = match_condition(f, self.conditions)
# fall save: condition needs to be all or match the condition currently in favour
if matched != "__all__":
if matched != condition:
continue
if zoo_channel not in data.keys():
continue
subj = extract_subject_id(f, subj_list=self.subj_list, str_pattern=self.str_match)
if subj is None:
continue
val = extract_events(data[zoo_channel], event_name)
if val is not None:
self._events[event_key].append(val)
self._event_subj_index[event_key].append(subj)
[docs]
def get_event_subject_ids(self, channel, condition, event_name):
event_key = (channel, condition, event_name)
if event_key not in self._events:
self._extract_events(channel, condition, event_name)
return self._event_subj_index.get(event_key, [])
def _resolve_zoo_channel(self, channel, condition):
"""
Returns the actual key to look up in the zoo dict.
- BETWEEN source → channel name is used as-is
- WITHIN source → look up from channel_map
"""
if self.condition_spec.source == ConditionSource.WITHIN:
cond_map = self.condition_spec.channel_map.get(condition, {})
resolved = cond_map.get(channel)
if resolved is None:
raise KeyError(f"No channel_map entry for base channel {channel!r} "
f"under condition {condition!r}. "
f"Available: {list(cond_map.keys())}")
return resolved
return channel
def _resolve_subjects(self):
seen, result = set(), []
for f in self._fl:
if self.condition_spec.source == ConditionSource.BETWEEN:
matched = match_condition(f, self.conditions)
if matched != "__all__":
if matched not in self.conditions:
continue
subj = extract_subject_id(f, subj_list=self.subj_list, str_pattern=self.str_match)
if subj is None:
continue
if subj not in seen:
seen.add(subj)
result.append(subj)
return result
[docs]
def to_events_dataframe(self, channels : list[str], event_names : list[str]):
"""
Returns a long-formant DataFrame of all scalar events specified
"""
row = []
for channel in channels:
for condition in self.conditions:
for event_name in event_names:
values = self.get_event_values(channel, condition, event_name)
subjects = self.get_event_subject_ids(channel, condition, event_name)
for subj, val in zip(subjects, values):
row.append({"subject": subj,
"condition": condition,
"channel": channel,
"event": event_name,
"value" : val,})
return pd.DataFrame(row)
[docs]
def to_lines_dataframe(self, channels : list[str]):
"""
Returns a long-format DataFrame of all line data.
All lines need to be time-normalized
"""
rows = []
n_frames = 100
for channel in channels:
for condition in self.conditions:
arrays = self.get_lines(channel, condition)
subjects = self.get_subject_ids(channel, condition)
for arr, subj in zip(arrays, subjects):
x_new = np.linspace(0, 100, n_frames)
for frame, val in zip(x_new, arr):
rows.append({"subject": subj,
"condition": condition,
"channel": channel,
"frame": frame,
"value": val})
return pd.DataFrame(rows)