Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 5 additions & 32 deletions lumispy/utils/signals.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import numpy as np
from hyperspy.axes import FunctionalDataAxis
from scipy.ndimage import center_of_mass
from scipy.interpolate import interp1d


def com(spectrum_intensities, signal_axis, **kwargs):
Expand All @@ -15,8 +13,6 @@ def com(spectrum_intensities, signal_axis, **kwargs):
signal_axis: hyperspy.axes.BaseDataAxis subclass
A HyperSpy signal axis class containing an array with the wavelength/
energy for each intensity/signal value.
kwargs : dictionary
For the scipy.interpolate.interp1d function.

Returns
-------
Expand All @@ -35,37 +31,13 @@ def com(spectrum_intensities, signal_axis, **kwargs):
>>> print(center_of_mass) # Outputs: [400.0]
"""

def _interpolate_signal(axis_array, index, **kwargs):
"""
Wrapper for `hs.axes.index2value` that linearly interpolates between
values should the index passed not be a integer. Using the kwargs, the
interpolation method can be changed.
"""
rem = index % 1
index = int(index // 1)
if rem == 0:
return axis_array[int(index)]
else:
y = [axis_array[index], axis_array[index + 1]]
x = [0, 1]
fx = interp1d(x, y, **kwargs)
return float(fx(rem))

# Find center of mass wrt array index
index_com = float(center_of_mass(spectrum_intensities)[0])

# Check for the type of hyperspy.axis
if type(signal_axis) == FunctionalDataAxis:
# Calculate value y from x[index_com]
x = _interpolate_signal(signal_axis.x.axis, index_com)
kwargs = {}
for kwarg in signal_axis.parameters_list:
kwargs[kwarg] = getattr(signal_axis, kwarg)
com_val = signal_axis._function(x, **kwargs)
xs = signal_axis.axis

elif hasattr(signal_axis, "axis"):
# Calculate value interpolating between index_com 0 and 1
com_val = _interpolate_signal(signal_axis.axis, index_com)
xs = signal_axis.axis

elif type(signal_axis) in (list, np.ndarray, tuple):
# Check for dimensionality
if len(spectrum_intensities) != len(signal_axis):
Expand All @@ -74,8 +46,9 @@ def _interpolate_signal(axis_array, index, **kwargs):
"the length of the wavelength array {len(signal_axis)}."
)
# Calculate value interpolating between index_com 0 and 1
com_val = _interpolate_signal(np.array(signal_axis), index_com)
xs = np.array(signal_axis)
else:
raise ValueError("The parmeter `signal_axis` must be a HyperSpy Axis object.")

com_val = np.sum(xs * spectrum_intensities) / np.sum(spectrum_intensities)
return com_val