from abc import ABC, abstractmethod
import plotly.graph_objs as go
import plotly.express as px
import numpy as np
# to test my bland-altman plot
import pyCompare
from biomechzoo.ensembler.data_store import DataStore
from biomechzoo.ensembler.style_content import StyleContext
from biomechzoo.ensembler.helpers import compute_ensemble, _compute_bandwidth, align_by_subject, resolve_shade
#### Plot options to add
# Scatter plot/correlation plots --> regression line option?
# Violinplots
# MeanSD
# RainCloud?
class Renderer(ABC):
"""
Knows how to add traces for ONE (channel, condition) into a subplot
Receives the DataStore and shared style context - Nothing else.
"""
@abstractmethod
def render(self,
fig: go.Figure,
store: DataStore,
style: StyleContext,
spec: "PlotSpec",
row: int,
col: int,
) -> None: ...
[docs]
class IndividualLinesRenderer(Renderer):
[docs]
def render(self, fig, store, style, spec, row, col):
for condition in spec.all_conditions:
arrays = store.get_lines(spec.channel, condition)
subjects = store.get_subject_ids(spec.channel, condition)
x_norm = np.linspace(0, 100, max((len(a) for a in arrays), default=1))
for arr, subj in zip(arrays, subjects):
color = style.subject_color(subj)
dash = style.condition_dash(condition)
show_leg = style.should_show_legend("subj", subj)
fig.add_trace(go.Scatter(
x=x_norm, y=arr,
mode="lines",
name=subj,
legendgroup=subj,
line=dict(color=color, dash=dash, width=1.2),
opacity=0.45,
showlegend=show_leg,
hovertemplate=f"<b>{subj}</b><br>%{{x:.1f}}% | %{{y:.2f}}<extra></extra>",
), row=row, col=col)
[docs]
class MeanSDRenderer(Renderer):
[docs]
def render(self, fig, store, style, spec, row, col):
for condition in spec.all_conditions:
arrays = store.get_lines(spec.channel, condition)
if not arrays:
return
n = len(arrays[0])
x = np.linspace(0, 100, n)
mean, upper, lower = compute_ensemble(arrays)
color = style.condition_color(condition)
dash = style.condition_dash(condition)
shade_color = resolve_shade(color)
# Standard deviation ribbon lower limit
fig.add_trace(go.Scatter(
x=x,
y=lower,
fillcolor=shade_color,
line=dict(color="rgba(0,0,0,0)"),
showlegend=False,
), row=row, col=col)
# Standard deviation ribbon upper limit
fig.add_trace(go.Scatter(
x=x,
y=upper,
fill="tonexty",
fillcolor=shade_color,
line=dict(color="rgba(0,0,0,0)"),
showlegend=False,
), row=row, col=col)
# mean line
show_leg = style.should_show_legend("mean", condition)
fig.add_trace(go.Scatter(
x=x, y=mean,
name=f"Mean_{condition}",
legendgroup=f"Mean_{condition}",
line=dict(color=color, width=3, dash=dash),
hovertemplate=f"<b>Mean – {condition}</b><br>%{{x:.1f}}% | %{{y:.2f}}<extra></extra>",
showlegend=show_leg,
), row=row, col=col)
[docs]
class EventOverlayRenderer(Renderer):
[docs]
def render(self, fig, store, style, spec, row, col):
if not spec.events:
return
for event_name in spec.events:
for condition in spec.all_conditions:
evs = store.get_events(spec.channel, condition, event_name) # list[ZooEvent]
subjects = store.get_event_subject_ids(spec.channel, condition, event_name)
for ev, subj in zip(evs, subjects):
color = style.subject_color(subj)
show_leg = style.should_show_legend("event", f"{subj}_{event_name}")
fig.add_trace(go.Scatter(
x=[ev.x], y=[ev.y],
mode="markers",
name=f"{subj} – {event_name}",
legendgroup=subj,
marker=dict(color=color, size=8),
showlegend=show_leg,
hovertemplate=(
f"<b>{subj} – {event_name}</b><br>"
f"x: %{{x:.1f}} | y: %{{y:.2f}}<extra></extra>"
),
), row=row, col=col)
class ViolinRenderer(Renderer):
def __init__(self, show_points: bool = True, bandwidth: float | None = None):
self.show_points = show_points
self.bandwidth = bandwidth
def render(self, fig, store, style, spec, row, col):
if not spec.events:
return
for event_name in spec.events:
for condition in spec.all_conditions:
values = store.get_event_values(spec.channel, condition, event_name) # y-only
subjects = store.get_event_subject_ids(spec.channel, condition, event_name)
if not values:
continue
if spec.group_by and spec.group_map:
groups = [spec.group_map.get(s, "Unknown") for s in subjects]
else:
groups = [condition] * len(values)
unique_groups = dict.fromkeys(groups)
for grp in unique_groups:
grp_values = [v for v, g in zip(values, groups) if g == grp]
bw = self.bandwidth if self.bandwidth is not None else _compute_bandwidth(grp_values)
color = style.condition_color(condition)
label = f"{condition} – {event_name} – {grp}" if spec.group_by else f"{condition} – {event_name}"
show_leg = style.should_show_legend("violin", grp)
fig.add_trace(go.Violin(
x = [f"{grp}"] * len(grp_values),
y = grp_values,
name=grp,
legendgroup=label,
line_color=color,
fillcolor=color,
opacity=0.6,
box_visible=True,
points="all" if self.show_points else False,
bandwidth=bw,
showlegend=show_leg,
), row=row, col=col)
class BlandAltmanRenderer(Renderer):
"""
Plots a BlandAltman agreement plot between two conditions
Requires exactly two conditions via spec.all_conditions.
Computes:
mean = (method_A - method_B) / 2
diff = methodA - method_B
bias = mean(diff)
LoA = bias +/- 1.96 * std(diff)
Works with either:
- line data (Uses a scaler per trial, e.g. mean of the line)
- event data (Uses event scaler directly, e.g. "max")
"""
def __init__(self, use_lines: bool = False, show_subjects: bool = False, loa_multiplier: float = 1.96, line_scaler : str = "mean"):
if line_scaler not in ("mean", "max", "min", "median"):
raise ValueError("line_scaler must be one of 'mean', 'max', 'min', or 'median'")
self.use_lines = use_lines
self.show_subjects = show_subjects
self.loa_multiplier = loa_multiplier
self.line_scaler = line_scaler
def render(self, fig, store, style, spec, row, col):
if len(spec.all_conditions) != 2:
raise ValueError(f"BlandAltmanRenderer requires exactly two conditions, "
f"got {spec.all_conditions}. Use companions= to specify the second")
cond_a, cond_b = spec.all_conditions
if not self.use_lines:
if not spec.events:
raise ValueError(f"BlandAltmanRenderer with use_events=True requires events to be specified ")
event_name = spec.events[0]
vals_a = store.get_event_values(spec.channel, cond_a, event_name)
vals_b = store.get_event_values(spec.channel, cond_b, event_name)
# pyCompare.blandAltman(vals_a, vals_b)
subjects_a = store.get_event_subject_ids(spec.channel, cond_a, event_name)
subjects_b = store.get_event_subject_ids(spec.channel, cond_b, event_name)
vals_a, vals_b, subjects = align_by_subject(vals_a, subjects_a, vals_b, subjects_b)
if not vals_a:
return
arr_a = np.asarray(vals_a)
arr_b = np.asarray(vals_b)
means = np.mean([arr_a, arr_b], axis=0)
diffs = arr_a - arr_b
bias = float(np.mean(diffs))
sd = float(np.std(diffs))
loa_upper = bias + self.loa_multiplier * sd
loa_lower = bias - self.loa_multiplier * sd
x_min, x_max = np.min(arr_a), np.max(arr_a)
x_pad = (x_max - x_min) * 0.1
x_range = [x_min - x_pad, x_max + x_pad]
for mean_val, diff_val, subj in zip(means, diffs, subjects):
color = style.subject_color(subj) if self.show_subjects else "#1f77b4"
show_leg = style.should_show_legend("ba_subj", subj) if self.show_subjects else False
# subject scatter
fig.add_trace(go.Scatter(
x = [mean_val], y=[diff_val],
mode = "markers", name=subj,
marker=dict(color=color, size=8,),
legendgroup=subj,
showlegend=show_leg,
), row=row, col=col)
# bias, loa, and reference lines
fig.add_hline(y = bias, line_color="black", line_dash="dash",
annotation_text=f"Bias: {bias:.2f}",
annotation_position="bottom right", row=row, col=col)
fig.add_hline(y = loa_upper, line_color="red", line_dash="dash",
annotation_text = f"LoA: {loa_upper:.2f}",
annotation_position = "top right", row=row, col=col)
fig.add_hline(y = loa_lower, line_color="red", line_dash="dash",
annotation_text = f"LoA: {loa_lower:.2f}",
annotation_position = "bottom right", row=row, col=col)
fig.add_hline(y=0, line_color="grey", line_dash="dash", row=row, col=col)
class ScatterRenderer(Renderer):
"""
"""
def __init__(self, regression_line: bool = False, show_subjects: bool = False, identity_line: bool = True,):
self.regression_line = regression_line
self.show_subjects = show_subjects
self.identity_line = identity_line
def render(self, fig, store, style, spec, row, col):
if len(spec.all_conditions) != 2:
raise ValueError(f"ScatterRenderer requires exactly two conditions, "
f"got {spec.all_conditions}. Use companions= to specify the second")
cond_a, cond_b = spec.all_conditions
if not spec.events:
raise ValueError(f"ScatterRenderer requires events to be specified ")
event_name = spec.events[0]
vals_a = store.get_event_values(spec.channel, cond_a, event_name)
vals_b = store.get_event_values(spec.channel, cond_b, event_name)
# pyCompare.blandAltman(vals_a, vals_b)
subjects_a = store.get_event_subject_ids(spec.channel, cond_a, event_name)
subjects_b = store.get_event_subject_ids(spec.channel, cond_b, event_name)
vals_a, vals_b, subjects = align_by_subject(vals_a, subjects_a, vals_b, subjects_b)
if not vals_a:
return
arr_a = np.asarray(vals_a)
arr_b = np.asarray(vals_b)
for a, b, subj in zip(arr_a, arr_b, subjects):
color = style.subject_color(subj) if self.show_subjects else style.condition_color(cond_a)
show_leg = style.should_show_legend("scatter_subj", subj) if self.show_subjects else False
# subject scatter
fig.add_trace(go.Scatter(
x=[a], y=[b],
mode="markers", name=subj,
marker=dict(color=color, size=8, ),
legendgroup=subj,
showlegend=show_leg,
), row=row, col=col)
# Plot the identity line
if self.identity_line:
all_vals = np.concatenate([arr_a, arr_b])
lim = [float(all_vals.min()), float(all_vals.max())]
fig.add_trace(go.Scatter(
x=lim, y=lim,
mode="lines", name="Identity (y=x)",
line=dict(color="grey", width=1.5, dash="dot"),
showlegend=True,
), row=row, col=col)
# Get the OLS regression line
if self.regression_line:
coeffs = np.polyfit(arr_a, arr_b, 1)
x_line = np.linspace(arr_a.min(), arr_a.max(), 100)
y_line = np.polyval(coeffs, x_line)
r_sq = np.corrcoef(arr_a, arr_b)[0, 1] ** 2
fig.add_trace(go.Scatter(
x=x_line, y=y_line,
mode="lines", name=f"OLS (R²={r_sq:.2f})",
line=dict(color="#333", width=2.5),
showlegend=True,
), row=row, col=col)
#==================================================
#All future renders be placed right above this line
#==================================================
# Compose renderers freely
[docs]
class CompositeRenderer(Renderer):
"""Run multiple renderers on the same subplot"""
def __init__(self, *renderers: Renderer):
self._renderers = renderers
[docs]
def render(self, fig, store, style, spec, row, col):
for r in self._renderers:
r.render(fig, store, style, spec, row, col)