# -*- coding: utf-8 -*-
#
# Copyright © 2009-2010 CEA
# Pierre Raybaut
# Licensed under the terms of the CECILL License
# (see guiqwt/__init__.py for details)
# pylint: disable=C0103
"""
guiqwt.widgets.fit
------------------
The `fit` module provides an interactive curve fitting widget/dialog allowing:
* to fit data manually (by moving sliders)
* or automatically (with standard optimization algorithms
provided by :py:mod:`scipy`).
Example
~~~~~~~
.. literalinclude:: /../guiqwt/tests/fit.py
:start-after: SHOW
:end-before: Workaround for Sphinx v0.6 bug: empty 'end-before' directive
.. image:: /images/screenshots/fit.png
Reference
~~~~~~~~~
.. autofunction:: guifit
.. autoclass:: FitDialog
:members:
:inherited-members:
.. autoclass:: FitParam
:members:
:inherited-members:
.. autoclass:: AutoFitParam
:members:
:inherited-members:
"""
from qtpy.QtWidgets import (
QGridLayout,
QLabel,
QSlider,
QPushButton,
QLineEdit,
QDialog,
QVBoxLayout,
QHBoxLayout,
QWidget,
QDialogButtonBox,
)
from qtpy.QtCore import Qt
from qtpy import PYQT5
import numpy as np
from numpy import inf # Do not remove this import (used by optimization funcs)
import guidata
from guidata.utils import update_dataset, restore_dataset
from guidata.qthelpers import create_groupbox, win32_fix_title_bar_background
from guidata.configtools import get_icon
from guidata.dataset.datatypes import DataSet
from guidata.dataset.dataitems import (
StringItem,
FloatItem,
IntItem,
ChoiceItem,
BoolItem,
)
# Local imports
from guiqwt.config import _
from guiqwt.builder import make
from guiqwt.plot import CurveWidgetMixin
[docs]class AutoFitParam(DataSet):
xmin = FloatItem("xmin")
xmax = FloatItem("xmax")
method = ChoiceItem(
_("Method"),
[
("simplex", "Simplex"),
("powel", "Powel"),
("bfgs", "BFGS"),
("l_bfgs_b", "L-BFGS-B"),
("cg", _("Conjugate Gradient")),
("lq", _("Least squares")),
],
default="lq",
)
err_norm = StringItem(
"enorm",
default=2.0,
help=_("for simplex, powel, cg and bfgs norm used " "by the error function"),
)
xtol = FloatItem(
"xtol", default=0.0001, help=_("for simplex, powel, least squares")
)
ftol = FloatItem(
"ftol", default=0.0001, help=_("for simplex, powel, least squares")
)
gtol = FloatItem("gtol", default=0.0001, help=_("for cg, bfgs"))
norm = StringItem(
"norm", default="inf", help=_("for cg, bfgs. inf is max, -inf is min")
)
class FitParamDataSet(DataSet):
name = StringItem(_("Name"))
value = FloatItem(_("Value"), default=0.0)
min = FloatItem(_("Min"), default=-1.0)
max = FloatItem(_("Max"), default=1.0).set_pos(col=1)
steps = IntItem(_("Steps"), default=5000)
format = StringItem(_("Format"), default="%.3f").set_pos(col=1)
logscale = BoolItem(_("Logarithmic"), _("Scale"))
unit = StringItem(_("Unit"), default="").set_pos(col=1)
[docs]class FitParam(object):
def __init__(
self,
name,
value,
min,
max,
logscale=False,
steps=5000,
format="%.3f",
size_offset=0,
unit="",
):
self.name = name
self.value = value
self.min = min
self.max = max
self.logscale = logscale
self.steps = steps
self.format = format
self.unit = unit
self.prefix_label = None
self.lineedit = None
self.unit_label = None
self.slider = None
self.button = None
self._widgets = []
self._size_offset = size_offset
self._refresh_callback = None
self.dataset = FitParamDataSet(title=_("Curve fitting parameter"))
[docs] def copy(self):
"""Return a copy of this fitparam"""
return self.__class__(
self.name,
self.value,
self.min,
self.max,
self.logscale,
self.steps,
self.format,
self._size_offset,
self.unit,
)
def create_widgets(self, parent, refresh_callback):
self._refresh_callback = refresh_callback
self.prefix_label = QLabel()
font = self.prefix_label.font()
font.setPointSize(font.pointSize() + self._size_offset)
self.prefix_label.setFont(font)
self.button = QPushButton()
self.button.setIcon(get_icon("settings.png"))
self.button.setToolTip(_("Edit '%s' fit parameter properties") % self.name)
self.button.clicked.connect(lambda: self.edit_param(parent))
self.lineedit = QLineEdit()
self.lineedit.editingFinished.connect(self.line_editing_finished)
self.unit_label = QLabel(self.unit)
self.slider = QSlider()
self.slider.setOrientation(Qt.Horizontal)
self.slider.setRange(0, self.steps - 1)
self.slider.valueChanged.connect(self.slider_value_changed)
self.update(refresh=False)
self.add_widgets(
[
self.prefix_label,
self.lineedit,
self.unit_label,
self.slider,
self.button,
]
)
def add_widgets(self, widgets):
self._widgets += widgets
def get_widgets(self):
return self._widgets
def set_scale(self, state):
self.logscale = state > 0
self.update_slider_value()
def set_text(self, fmt=None):
style = "<span style='color: #444444'><b>%s</b></span>"
self.prefix_label.setText(style % self.name)
if self.value is None:
value_str = ""
else:
if fmt is None:
fmt = self.format
value_str = fmt % self.value
self.lineedit.setText(value_str)
self.lineedit.setDisabled(self.value == self.min and self.max == self.min)
def line_editing_finished(self):
try:
self.value = float(self.lineedit.text())
except ValueError:
self.set_text()
self.update_slider_value()
self._refresh_callback()
def slider_value_changed(self, int_value):
if self.logscale:
total_delta = np.log10(1 + self.max - self.min)
self.value = (
self.min + 10 ** (total_delta * int_value / (self.steps - 1)) - 1
)
else:
total_delta = self.max - self.min
self.value = self.min + total_delta * int_value / (self.steps - 1)
self.set_text()
self._refresh_callback()
def update_slider_value(self):
if self.value is None or self.min is None or self.max is None:
self.slider.setEnabled(False)
if self.slider.parent() and self.slider.parent().isVisible():
self.slider.show()
elif self.value == self.min and self.max == self.min:
self.slider.hide()
else:
self.slider.setEnabled(True)
if self.slider.parent() and self.slider.parent().isVisible():
self.slider.show()
if self.logscale:
value_delta = max([np.log10(1 + self.value - self.min), 0.0])
total_delta = np.log10(1 + self.max - self.min)
else:
value_delta = self.value - self.min
total_delta = self.max - self.min
intval = int(self.steps * value_delta / total_delta)
self.slider.blockSignals(True)
self.slider.setValue(intval)
self.slider.blockSignals(False)
def edit_param(self, parent):
update_dataset(self.dataset, self)
if self.dataset.edit(parent=parent):
restore_dataset(self.dataset, self)
if self.value > self.max:
self.max = self.value
if self.value < self.min:
self.min = self.value
self.update()
def update(self, refresh=True):
self.unit_label.setText(self.unit)
self.slider.setRange(0, self.steps - 1)
self.update_slider_value()
self.set_text()
if refresh:
self._refresh_callback()
def add_fitparam_widgets_to(layout, fitparams, refresh_callback, param_cols=1):
row_contents = []
row_nb = 0
col_nb = 0
for i, param in enumerate(fitparams):
if not param.get_widgets():
param.create_widgets(layout.parent(), refresh_callback)
widgets = param.get_widgets()
w_colums = len(widgets) + 1
row_contents += [
(widget, row_nb, j + col_nb * w_colums) for j, widget in enumerate(widgets)
]
col_nb += 1
if col_nb == param_cols:
row_nb += 1
col_nb = 0
for widget, row, col in row_contents:
layout.addWidget(widget, row, col)
if fitparams:
for col_nb in range(param_cols):
layout.setColumnStretch(1 + col_nb * w_colums, 5)
if col_nb > 0:
layout.setColumnStretch(col_nb * w_colums - 1, 1)
class FitWidgetMixin(CurveWidgetMixin):
def __init__(
self,
wintitle="guiqwt plot",
icon="guiqwt.svg",
toolbar=False,
options=None,
panels=None,
param_cols=1,
legend_anchor="TR",
auto_fit=True,
):
if wintitle is None:
wintitle = _("Curve fitting")
self.x = None
self.y = None
self.fitfunc = None
self.fitargs = None
self.fitkwargs = None
self.fitparams = None
self.autofit_prm = None
self.data_curve = None
self.fit_curve = None
self.legend = None
self.legend_anchor = legend_anchor
self.xrange = None
self.show_xrange = False
self.param_cols = param_cols
self.auto_fit_enabled = auto_fit
self.button_list = [] # list of buttons to be disabled at startup
self.fit_layout = None
self.params_layout = None
CurveWidgetMixin.__init__(
self,
wintitle=wintitle,
icon=icon,
toolbar=toolbar,
options=options,
panels=panels,
)
self.refresh()
# QWidget API --------------------------------------------------------------
def resizeEvent(self, event):
QWidget.resizeEvent(self, event)
self.get_plot().replot()
# CurveWidgetMixin API -----------------------------------------------------
def setup_widget_layout(self):
self.fit_layout = QHBoxLayout()
self.params_layout = QGridLayout()
params_group = create_groupbox(
self, _("Fit parameters"), layout=self.params_layout
)
if self.auto_fit_enabled:
auto_group = self.create_autofit_group()
self.fit_layout.addWidget(auto_group)
self.fit_layout.addWidget(params_group)
self.plot_layout.addLayout(self.fit_layout, 1, 0)
vlayout = QVBoxLayout(self)
vlayout.addWidget(self.toolbar)
vlayout.addLayout(self.plot_layout)
self.setLayout(vlayout)
def create_plot(self, options):
CurveWidgetMixin.create_plot(self, options)
for plot in self.get_plots():
plot.SIG_RANGE_CHANGED.connect(self.range_changed)
# Public API ---------------------------------------------------------------
def set_data(
self, x, y, fitfunc=None, fitparams=None, fitargs=None, fitkwargs=None
):
if self.fitparams is not None and fitparams is not None:
self.clear_params_layout()
self.x = x
self.y = y
if fitfunc is not None:
self.fitfunc = fitfunc
if fitparams is not None:
self.fitparams = fitparams
if fitargs is not None:
self.fitargs = fitargs
if fitkwargs is not None:
self.fitkwargs = fitkwargs
self.autofit_prm = AutoFitParam(title=_("Automatic fitting options"))
self.autofit_prm.xmin = x.min()
self.autofit_prm.xmax = x.max()
self.compute_imin_imax()
if self.fitparams is not None and fitparams is not None:
self.populate_params_layout()
self.refresh()
def set_fit_data(self, fitfunc, fitparams, fitargs=None, fitkwargs=None):
if self.fitparams is not None:
self.clear_params_layout()
self.fitfunc = fitfunc
self.fitparams = fitparams
self.fitargs = fitargs
self.fitkwargs = fitkwargs
self.populate_params_layout()
self.refresh()
def clear_params_layout(self):
for i, param in enumerate(self.fitparams):
for widget in param.get_widgets():
if widget is not None:
self.params_layout.removeWidget(widget)
widget.hide()
def populate_params_layout(self):
add_fitparam_widgets_to(
self.params_layout, self.fitparams, self.refresh, param_cols=self.param_cols
)
def create_autofit_group(self):
auto_button = QPushButton(get_icon("apply.png"), _("Run"), self)
auto_button.clicked.connect(self.autofit)
autoprm_button = QPushButton(get_icon("settings.png"), _("Settings"), self)
autoprm_button.clicked.connect(self.edit_parameters)
xrange_button = QPushButton(get_icon("xrange.png"), _("Bounds"), self)
xrange_button.setCheckable(True)
xrange_button.toggled.connect(self.toggle_xrange)
auto_layout = QVBoxLayout()
auto_layout.addWidget(auto_button)
auto_layout.addWidget(autoprm_button)
auto_layout.addWidget(xrange_button)
self.button_list += [auto_button, autoprm_button, xrange_button]
return create_groupbox(self, _("Automatic fit"), layout=auto_layout)
def get_fitfunc_arguments(self):
"""Return fitargs and fitkwargs"""
fitargs = self.fitargs
if self.fitargs is None:
fitargs = []
fitkwargs = self.fitkwargs
if self.fitkwargs is None:
fitkwargs = {}
return fitargs, fitkwargs
def refresh(self, slider_value=None):
"""Refresh Fit Tool dialog box"""
# Update button states
enable = (
self.x is not None
and self.y is not None
and self.x.size > 0
and self.y.size > 0
and self.fitfunc is not None
and self.fitparams is not None
and len(self.fitparams) > 0
)
for btn in self.button_list:
btn.setEnabled(enable)
if not enable:
# Fit widget is not yet configured
return
fitargs, fitkwargs = self.get_fitfunc_arguments()
yfit = self.fitfunc(
self.x, [p.value for p in self.fitparams], *fitargs, **fitkwargs
)
plot = self.get_plot()
if self.legend is None:
self.legend = make.legend(anchor=self.legend_anchor)
plot.add_item(self.legend)
if self.xrange is None:
self.xrange = make.range(0.0, 1.0)
plot.add_item(self.xrange)
self.xrange.set_range(self.autofit_prm.xmin, self.autofit_prm.xmax)
self.xrange.setVisible(self.show_xrange)
if self.data_curve is None:
self.data_curve = make.curve([], [], _("Data"), color="b", linewidth=2)
plot.add_item(self.data_curve)
self.data_curve.set_data(self.x, self.y)
if self.fit_curve is None:
self.fit_curve = make.curve([], [], _("Fit"), color="r", linewidth=2)
plot.add_item(self.fit_curve)
self.fit_curve.set_data(self.x, yfit)
plot.replot()
plot.disable_autoscale()
def range_changed(self, xrange_obj, xmin, xmax):
self.autofit_prm.xmin, self.autofit_prm.xmax = xmin, xmax
self.compute_imin_imax()
def toggle_xrange(self, state):
self.xrange.setVisible(state)
plot = self.get_plot()
plot.replot()
if state:
plot.set_active_item(self.xrange)
else:
# If the button is unckeked then set to the complete range
self.autofit_prm.xmin = self.x.min()
self.autofit_prm.xmax = self.x.max()
self.xrange.set_range(self.autofit_prm.xmin, self.autofit_prm.xmax)
self.show_xrange = state
def edit_parameters(self):
if self.autofit_prm.edit(parent=self):
self.xrange.set_range(self.autofit_prm.xmin, self.autofit_prm.xmax)
plot = self.get_plot()
plot.replot()
self.compute_imin_imax()
def compute_imin_imax(self):
self.i_min = self.x.searchsorted(self.autofit_prm.xmin)
self.i_max = self.x.searchsorted(self.autofit_prm.xmax, side="right")
def errorfunc(self, params):
x = self.x[self.i_min : self.i_max]
y = self.y[self.i_min : self.i_max]
fitargs, fitkwargs = self.get_fitfunc_arguments()
return y - self.fitfunc(x, params, *fitargs, **fitkwargs)
def autofit(self):
meth = self.autofit_prm.method
x0 = np.array([p.value for p in self.fitparams])
if meth == "lq":
x = self.autofit_lq(x0)
elif meth == "simplex":
x = self.autofit_simplex(x0)
elif meth == "powel":
x = self.autofit_powel(x0)
elif meth == "bfgs":
x = self.autofit_bfgs(x0)
elif meth == "l_bfgs_b":
x = self.autofit_l_bfgs(x0)
elif meth == "cg":
x = self.autofit_cg(x0)
else:
return
for v, p in zip(x, self.fitparams):
p.value = v
self.refresh()
for prm in self.fitparams:
prm.update()
def get_norm_func(self):
prm = self.autofit_prm
err_norm = eval(prm.err_norm)
def func(params):
err = np.linalg.norm(self.errorfunc(params), err_norm)
return err
return func
def autofit_simplex(self, x0):
prm = self.autofit_prm
from scipy.optimize import fmin
x = fmin(self.get_norm_func(), x0, xtol=prm.xtol, ftol=prm.ftol)
return x
def autofit_powel(self, x0):
prm = self.autofit_prm
from scipy.optimize import fmin_powell
x = fmin_powell(self.get_norm_func(), x0, xtol=prm.xtol, ftol=prm.ftol)
return x
def autofit_bfgs(self, x0):
prm = self.autofit_prm
from scipy.optimize import fmin_bfgs
x = fmin_bfgs(self.get_norm_func(), x0, gtol=prm.gtol, norm=eval(prm.norm))
return x
def autofit_l_bfgs(self, x0):
prm = self.autofit_prm
bounds = [(p.min, p.max) for p in self.fitparams]
from scipy.optimize import fmin_l_bfgs_b
x, _f, _d = fmin_l_bfgs_b(
self.get_norm_func(), x0, pgtol=prm.gtol, approx_grad=1, bounds=bounds
)
return x
def autofit_cg(self, x0):
prm = self.autofit_prm
from scipy.optimize import fmin_cg
x = fmin_cg(self.get_norm_func(), x0, gtol=prm.gtol, norm=eval(prm.norm))
return x
def autofit_lq(self, x0):
prm = self.autofit_prm
def func(params):
err = self.errorfunc(params)
return err
from scipy.optimize import leastsq
x, _ier = leastsq(func, x0, xtol=prm.xtol, ftol=prm.ftol)
return x
def get_values(self):
"""Convenience method to get fit parameter values"""
return [param.value for param in self.fitparams]
class FitWidget(QWidget, FitWidgetMixin):
def __init__(
self,
wintitle=None,
icon="guiqwt.svg",
toolbar=False,
options=None,
parent=None,
panels=None,
param_cols=1,
legend_anchor="TR",
auto_fit=False,
):
QWidget.__init__(self, parent)
FitWidgetMixin.__init__(
self,
wintitle,
icon,
toolbar,
options,
panels,
param_cols,
legend_anchor,
auto_fit,
)
[docs]class FitDialog(QDialog, FitWidgetMixin):
def __init__(
self,
wintitle=None,
icon="guiqwt.svg",
edit=True,
toolbar=False,
options=None,
parent=None,
panels=None,
param_cols=1,
legend_anchor="TR",
auto_fit=False,
):
if not PYQT5:
QDialog.__init__(self, parent)
self.edit = edit
self.button_layout = None
if PYQT5:
super(FitDialog, self).__init__(
parent,
wintitle=wintitle,
icon=icon,
toolbar=toolbar,
options=options,
panels=panels,
param_cols=param_cols,
legend_anchor=legend_anchor,
auto_fit=auto_fit,
)
else:
FitWidgetMixin.__init__(
self,
wintitle,
icon,
toolbar,
options,
panels,
param_cols,
legend_anchor,
auto_fit,
)
self.setWindowFlags(Qt.Window)
win32_fix_title_bar_background(self)
def setup_widget_layout(self):
FitWidgetMixin.setup_widget_layout(self)
if self.edit:
self.install_button_layout()
def install_button_layout(self):
bbox = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel)
bbox.accepted.connect(self.accept)
bbox.rejected.connect(self.reject)
self.button_list += [bbox.button(QDialogButtonBox.Ok)]
self.button_layout = QHBoxLayout()
self.button_layout.addStretch()
self.button_layout.addWidget(bbox)
vlayout = self.layout()
vlayout.addSpacing(10)
vlayout.addLayout(self.button_layout)
[docs]def guifit(
x,
y,
fitfunc,
fitparams,
fitargs=None,
fitkwargs=None,
wintitle=None,
title=None,
xlabel=None,
ylabel=None,
param_cols=1,
auto_fit=True,
winsize=None,
winpos=None,
):
"""GUI-based curve fitting tool"""
_app = guidata.qapplication()
# win = FitWidget(wintitle=wintitle, toolbar=True,
# param_cols=param_cols, auto_fit=auto_fit,
# options=dict(title=title, xlabel=xlabel, ylabel=ylabel))
win = FitDialog(
edit=True,
wintitle=wintitle,
toolbar=True,
param_cols=param_cols,
auto_fit=auto_fit,
options=dict(title=title, xlabel=xlabel, ylabel=ylabel),
)
win.set_data(x, y, fitfunc, fitparams, fitargs, fitkwargs)
if winsize is not None:
win.resize(*winsize)
if winpos is not None:
win.move(*winpos)
if win.exec_():
return win.get_values()
# win.show()
# _app.exec_()
# return win.get_values()
if __name__ == "__main__":
x = np.linspace(-10, 10, 1000)
y = np.cos(1.5 * x) + np.random.rand(x.shape[0]) * 0.2
def fit(x, params):
a, b = params
return np.cos(b * x) + a
a = FitParam("Offset", 1.0, 0.0, 2.0)
b = FitParam("Frequency", 1.05, 0.0, 10.0, logscale=True)
params = [a, b]
values = guifit(x, y, fit, params, auto_fit=True)
print(values)
print([param.value for param in params])