Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions spikeinterface_gui/backend_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ def on_active_view_updated(self, param):
view._panel_view_is_active = False

def on_unit_color_changed(self, param):
if not self._active:
return
# In this case we send it also if the view is not active, because we want to update colors anyways
for view in self.controller.views:
if param.obj.view == view:
continue
Expand Down
31 changes: 24 additions & 7 deletions spikeinterface_gui/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ def __init__(
curation_data = json.load(f)

elif self.analyzer.format == "zarr":
import zarr
zarr_root = zarr.open(self.analyzer.folder, mode='r')
from spikeinterface.core.zarrextractors import super_zarr_open
zarr_root = super_zarr_open(self.analyzer.folder, mode='r')
if "spikeinterface_gui" in zarr_root.keys() and "curation_data" in zarr_root["spikeinterface_gui"].attrs.keys():
curation_data = zarr_root["spikeinterface_gui"].attrs["curation_data"]

Expand Down Expand Up @@ -548,22 +548,39 @@ def get_information_txt(self):

return txt

def get_divergent_unit_colors(self, colormap="tab10"):
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

unit_locations = self.analyzer.get_extension("unit_locations").get_data()
cmap = plt.get_cmap(colormap)
if not isinstance(cmap, ListedColormap):
raise ValueError(f"Colormap {colormap} is not a qualitative colormap")
num_entries = len(cmap.colors)
# lexsort by x and y
sorted_inds = np.lexsort((unit_locations[:, 0], unit_locations[:, 1]))
# now assign colors with sequentially to sorted units
colors = {}
for i, unit_ind in enumerate(sorted_inds):
unit_id = self.unit_ids[unit_ind]
colors[unit_id] = cmap.colors[i % num_entries]
Comment thread
alejoe91 marked this conversation as resolved.
Outdated
return colors


def refresh_colors(self):
if self.backend == "qt":
self._cached_qcolors = {}
elif self.backend == "panel":
pass

if self.main_settings['color_mode'] == 'color_by_unit':
self.colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar',
shuffle=True, seed=42)
self.colors = self.get_divergent_unit_colors(colormap="tab10")
elif self.main_settings['color_mode'] == 'color_only_visible':
unit_colors = get_unit_colors(self.analyzer.sorting, color_engine='matplotlib', map_name='gist_ncar',
shuffle=True, seed=42)
unit_colors = self.get_divergent_unit_colors(colormap="tab10")
self.colors = {unit_id: (0.3, 0.3, 0.3, 1.) for unit_id in self.unit_ids}
for unit_id in self.get_visible_unit_ids():
self.colors[unit_id] = unit_colors[unit_id]
elif self.main_settings['color_mode'] == 'color_by_visibility':
elif self.main_settings['color_mode'] == 'color_by_visibility':
self.colors = {unit_id: (0.3, 0.3, 0.3, 1.) for unit_id in self.unit_ids}
import matplotlib.pyplot as plt
cmap = plt.colormaps['tab10']
Expand Down
14 changes: 14 additions & 0 deletions spikeinterface_gui/correlogramview.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def _compute(self):
# clear cache
self.figure_cache = {}

def on_unit_color_changed(self):
# clear cache
self.figure_cache = {}

## Qt ##

def _qt_make_layout(self):
Expand Down Expand Up @@ -145,6 +149,16 @@ def _panel_refresh(self):

if (unit1, unit2) in self.figure_cache:
fig = self.figure_cache[(unit1, unit2)]
# for the color_by_visibility
if self.controller.main_settings["color_mode"] == 'color_by_visibility':
# Update color in cached figure
if r == c:
unit_id = visible_unit_ids[r]
color = colors[unit_id]
for renderer in fig.renderers:
if hasattr(renderer, 'glyph') and hasattr(renderer.glyph, 'fill_color'):
renderer.glyph.fill_color = color
renderer.glyph.line_color = color
Comment thread
chrishalcrow marked this conversation as resolved.
else:
# create new figure
i = unit_ids.index(unit1)
Expand Down
1 change: 0 additions & 1 deletion spikeinterface_gui/mainsettingsview.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def on_max_visible_units_changed(self):
self.notify_unit_visibility_changed()

def on_change_color_mode(self):

self.controller.main_settings['color_mode'] = self.main_settings['color_mode']
self.controller.refresh_colors()
self.notify_unit_color_changed()
Expand Down
1 change: 0 additions & 1 deletion spikeinterface_gui/mergeview.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ def accept_group_merge(self, group_ids):
)
return
self.notify_manual_curation_updated()
self.refresh()
Comment thread
chrishalcrow marked this conversation as resolved.

### QT
def _qt_get_selected_group_ids(self):
Expand Down
15 changes: 12 additions & 3 deletions spikeinterface_gui/utils_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
except ImportError:
from typing_extensions import NotRequired

import re
import numpy as np
import time
import panel as pn
Expand Down Expand Up @@ -478,10 +479,18 @@ def _on_sort_change(self, event):
ascending=(self.direction_dropdown.value == "↑")
)
else:
df = self.tabulator.value.sort_values(
by=self.sort_dropdown.value,
ascending=(self.direction_dropdown.value == "↑")
import pandas.api.types as ptypes

col = self.sort_dropdown.value
sort_kwargs = dict(
by=col,
ascending=(self.direction_dropdown.value == "↑"),
)
if ptypes.is_string_dtype(self.tabulator.value[col]):
sort_kwargs["key"] = lambda x: x.map(
lambda v: [int(c) if c.isdigit() else c.lower() for c in re.split(r'(\d+)', str(v))]
)
df = self.tabulator.value.sort_values(**sort_kwargs)
Comment thread
chrishalcrow marked this conversation as resolved.
self.tabulator.value = df

def _on_selection_change(self, event):
Expand Down
Loading