"""
Contains classes to work with an automated measurement procedure that changes
element parameters and records values whenever a new set of monitor readouts
is received.
"""
__all__ = [
'OrbitRecord',
'Target',
'Corrector',
'ProcBot',
]
from itertools import accumulate, product
import logging
import textwrap
from datetime import datetime, timezone
import numpy as np
import madgui.util.yaml as yaml
from madgui.util.collections import List, Boxed
from madgui.util.history import History
from madgui.util.misc import invalidate
from madgui.util.signal import Signal
from madgui.model.match import Matcher
from .orbit import fit_particle_orbit, add_offsets
[docs]class OrbitRecord:
def __init__(self, monitor, readout, optics, tm):
self.monitor = monitor
self.readout = readout
self.optics = optics
self.tm = tm
[docs]class Target:
def __init__(self, elem, x, y, px=0., py=0.):
self.elem = elem
self.x = x
self.y = y
self.px = px
self.py = py
[docs]class Corrector(Matcher):
"""
Class for orbit correction procedure.
"""
mode = 'xy'
setup_changed = Signal()
def __init__(self, session, direct=True):
super().__init__(session.model())
self.session = session
self.control = control = session.control
self.direct = direct
self._knobs = control.get_knobs()
self.file = None
self.use_backtracking = Boxed(True)
# save elements
self.monitors = List()
self.targets = List()
self.readouts = List()
control.sampler.updated.connect(self._update_readouts)
self.records = List()
self.fit_range = None
self.objective_values = {}
self._offsets = session.config['online_control']['offsets']
self.optics = List()
self.strategy = Boxed('orm')
self.saved_optics = History()
self.online_optic = {}
# for ORM
kick_elements = ('hkicker', 'vkicker', 'kicker', 'sbend')
self.all_kickers = [
elem for elem in self.model.elements
if elem.base_name.lower() in kick_elements]
self.all_monitors = [
elem.name for elem in self.model.elements
if elem.base_name.lower().endswith('monitor')]
def _update_readouts(self, *_):
self.readouts[:] = self.control.sampler.fetch(self.monitors)
[docs] def setup(self, config, dirs=None):
dirs = dirs or self.mode
self.saved_optics.clear()
elements = self.model.elements
self.selected = config
monitors = sorted(config['monitors'], key=elements.index)
last_mon = max(map(elements.index, monitors), default=0)
knob_elems = {}
for elem in elements:
for knob in self.model.get_elem_knobs(elem):
knob_elems.setdefault(knob.lower(), []).append(elem)
# steerer optics -> good default for ORM analysis
optic_knobs = config.setdefault('optics', [
knob
for name in self.all_kickers
for elem in [elements[name]]
if elem.index < last_mon
for knob in self.model.get_elem_knobs(elem)
])
optic_knobs = [k.lower() for k in optic_knobs]
self.optic_elems = [
elem.name.lower()
for knob in optic_knobs
for elem in knob_elems[knob]
]
self.optic_params = [self._knobs[k] for k in optic_knobs
if k in self._knobs]
# again, steerer optics only useful for ORM
config.setdefault('steerers', {
'x': [knob for knob in optic_knobs
if any(elem.base_name != 'vkicker'
for elem in knob_elems[knob])],
'y': [knob for knob in optic_knobs
if any(elem.base_name == 'vkicker'
for elem in knob_elems[knob])],
})
targets = config.setdefault('targets', {})
steerers = sum([config['steerers'][d] for d in dirs], [])
self.method = config.get('method', ('jacobian', {}))
self.mode = dirs
self.match_names = [s for s in steerers if isinstance(s, str)]
self.assign = {k: v for s in steerers if isinstance(s, dict)
for k, v in s.items()}
targets = sorted(targets, key=elements.index)
self.objective_values.update({
t.elem: (t.x, t.y)
for t in self.targets
})
self.targets[:] = [
Target(elem, x, y)
for elem in targets
for x, y in [self.objective_values.get(elem, (0, 0))]
]
self.monitors[:] = sorted(monitors, key=elements.index)
fit_elements = targets + list(self.monitors) + list(self.optic_elems)
self.fit_range = (min(fit_elements, key=elements.index, default=0),
max(fit_elements, key=elements.index, default=0))
self.update_vars()
self.variables[:] = [
knob
for knob in self.match_names + list(self.assign)
if knob.lower() in self._knobs
]
self._update_readouts()
self.setup_changed.emit()
[docs] def set_optics_delta(self, deltas, default):
self.update_vars()
self.optics = [{}] + [
{knob: self.base_optics[knob] + delta}
for knob in self.match_names
if knob.lower() in self._knobs
for delta in [deltas.get(knob.lower(), default)]
if delta
]
def _read_vars(self):
model = self.model
return {
knob.lower(): model.read_param(knob)
for knob in self.match_names + list(self.assign)
if knob.lower() in self._knobs
}
[docs] def update_vars(self):
self.control.read_all()
self.base_optics = {
knob: self.model.read_param(knob)
for knob in self.control.get_knobs()
}
self.online_optic = self.saved_optics.push(self._read_vars())
[docs] def update_records(self):
if self.direct:
self.records[:] = self.current_orbit_records()
[docs] def can_fit(self):
return (len(self.records) >= 2 and
len(self.variables) >= 1 and
all(r.readout.valid for r in self.records))
[docs] def update_fit(self):
if not self.can_fit():
return
if self.use_backtracking():
init_orbit, chi_squared, singular = \
self.fit_particle_orbit(self.records)
if singular or not init_orbit:
return
self.model.update_twiss_args(init_orbit)
self.compute_steerer_corrections()
[docs] def apply(self):
optic = self.saved_optics()
self.model.write_params(optic.items())
self.control.write_params(optic.items())
super().apply()
active_optic = None
[docs] def set_optic(self, i):
optic = {}
if self.active_optic is not None:
optic.update({
k: self.base_optics[k] for k in self.optics[self.active_optic]
})
if i is not None:
optic.update(self.optics[i])
# only for optic variation method
self.model.write_params(optic.items())
self.control.write_params(optic.items())
self.active_optic = i
if i is not None:
self.write_data([{
'optics': self.optics[i],
'time': format_datetime(),
}])
# computations
[docs] def fit_particle_orbit(self, records):
readouts = [r.readout for r in records]
secmaps = [r.tm for r in records]
return fit_particle_orbit(
self.model, add_offsets(readouts, self._offsets),
secmaps, self.fit_range[0])[0]
[docs] def current_orbit_records(self):
model = self.model
start = self.fit_range[0]
secmaps = model.get_transfer_maps([start] + list(self.monitors))
secmaps = list(accumulate(secmaps, lambda a, b: np.dot(b, a)))
optics = {k: model.globals[k] for k in self._knobs}
readouts = {r.name.lower(): r for r in self.readouts}
return [
OrbitRecord(monitor, readouts[monitor.lower()], optics, secmap)
for monitor, secmap in zip(self.monitors, secmaps)
]
[docs] def compute_steerer_corrections(self):
strats = {
'match': self._compute_steerer_corrections_match,
'orm': self._compute_steerer_corrections_orm_ndiff1,
'tm': self._compute_steerer_corrections_orm_sectormap,
}
return self.saved_optics.push(strats[self.strategy()]())
def _compute_steerer_corrections_match(self):
"""
Compute corrections for the x_steerers, y_steerers.
"""
model = self.model
constraints = self._get_constraints()
with model.undo_stack.rollback("Orbit correction", transient=True):
model.update_globals(self.assign)
model.match(
vary=self.match_names,
limits=self.selected.get('limits'),
method=self.method,
weight={'x': 1e3, 'y': 1e3, 'px': 1e2, 'py': 1e2},
constraints=constraints)
return self._read_vars()
def _compute_steerer_corrections_orm_sectormap(self):
return self._compute_steerer_corrections_orm(
self.compute_sectormap())
def _compute_steerer_corrections_orm_ndiff1(self):
knowsReadouts = self.knows_targets_readouts()
return self._compute_steerer_corrections_orm(
self.compute_orbit_response_matrix(knowsReadouts))
def _get_objective_deltas(self):
if self.knows_targets_readouts():
measured = {
(r.name.lower(), ax): val
for r in self.readouts
for ax, val in zip("xy", (r.posx, r.posy))
}
else:
logging.warning(
"Matching absolute orbit (more sensitive to inaccurate "
"backtracking)!")
logging.warning(
"Make sure to use as many shots as possible")
offsets = self._offsets
elem_twiss = self.model.get_elem_twiss
measured = {
(el, ax): elem_twiss(t.elem)[ax] - offset
for t in self.targets
for el in [t.elem.lower()]
for ax, offset in zip("xy", offsets.get(el, (0, 0)))
}
return [
(el, ax, objective_value - measured_value)
for el, ax, objective_value in self._get_objectives()
for measured_value in [measured.get(((el, ax)))]
]
def _compute_steerer_corrections_orm(self, orm):
mons, axs, deltas = zip(*self._get_objective_deltas())
targets = set(zip(mons, axs))
S = [
i for i, (elem, axis) in enumerate(product(self.monitors, 'xy'))
if (elem.lower(), axis) in targets
]
globals_ = self.model.globals
try:
if not self.knows_targets_readouts():
# Sidenote: I am not sure if the ordering is kept
# in a few tests it appears to be the case
dvar = np.linalg.lstsq(orm, deltas, rcond=1e-10)[0]
else:
dvar = np.linalg.lstsq(
orm[S, :], deltas, rcond=1e-10)[0]
return {
var.lower(): globals_[var] + delta
for var, delta in zip(self.variables, dvar)
}
except np.linalg.LinAlgError:
logging.error('Unable to correct the orbit with this method')
logging.warning('Please try another configuration or another method')
return {}
def _get_constraints(self):
model = self.model
elements = model.elements
elem_twiss = model.get_elem_twiss
return [
(elements[mon], None, ax, elem_twiss(mon)[ax] + delta)
for mon, ax, delta in self._get_objective_deltas()
]
[docs] def knows_targets_readouts(self):
targets = {t.elem.lower() for t in self.targets}
monitors = {m.lower() for m in self.monitors}
return targets.issubset(monitors)
def _get_objectives(self):
return [
(t.elem.lower(), ax, val)
for t in self.targets
for ax, val in zip("xy", (t.x, t.y))
if ax in self.mode
]
[docs] def compute_sectormap(self):
model = self.model
elems = model.elements
with model.undo_stack.rollback("Orbit correction", transient=True):
invalidate(model, 'sector')
elem_by_knob = {}
for elem in elems:
for knob in model.get_elem_knobs(elem):
elem_by_knob.setdefault(knob.lower(), elem.index)
return np.vstack([
np.hstack([
model.sectormap(c, m)[[0, 2], 1+2*is_vkicker].flatten()
for m in self.monitors
])
for v in self.variables
for c in [elem_by_knob[v.lower()]]
for is_vkicker in [elems[c].base_name == 'vkicker']
]).T
# TODO: share implementation with `madgui.model.orm.NumericalORM`!!
[docs] def compute_orbit_response_matrix(self, knowsReadouts=True):
if not knowsReadouts:
# In case we have targets that are not monitors or input data
# cannot be reached
targets = [t.elem for t in self.targets]
return self.model.get_orbit_response_matrix(
targets, self.variables).reshape((-1, len(self.variables)))
return self.model.get_orbit_response_matrix(
self.monitors, self.variables).reshape((-1, len(self.variables)))
[docs] def add_record(self, step, shot, time=None):
# update_vars breaks ORM procedures because it re-reads base_optics!
# self.update_vars()
self.control.read_all()
records = self.current_orbit_records()
self.records.extend(records)
self.write_shot(step, shot, {
r.monitor: [r.readout.posx, r.readout.posy,
r.readout.envx, r.readout.envy]
for r in records
}, time=time)
[docs] def write_shot(self, step, shot, records, time=None):
if self.file:
if shot == 0:
self.file.write(' shots:\n')
records = {'time': format_datetime(time), **records}
self.write_data([records], " ")
[docs] def open_export(self, fname):
self.file = open(fname, 'wt', encoding='utf-8')
self.write_data({
'sequence': self.model.seq_name,
'monitors': list(self.selected['monitors']),
'steerers': self.optic_elems,
'knobs': list(self.selected['optics']),
'twiss_args': self.model._get_twiss_args(),
})
self.write_data({
'model': self.base_optics,
'extra': self.control.backend.read_params(),
}, default_flow_style=False)
self.file.write(
'# posx[m] posy[m] envx[m] envy[m]\n'
'records:\n')
[docs] def close_export(self):
if self.file:
self.file.close()
self.file = None
[docs] def write_data(self, data, indent="", **kwd):
if self.file:
self.file.write(textwrap.indent(yaml.safe_dump(data, **kwd), indent))
self.file.flush()
[docs]class ProcBot:
def __init__(self, widget, corrector):
self.widget = widget
self.corrector = corrector
self.running = False
self.model = corrector.model
self.control = corrector.control
self.totalops = 100
self.progress = 0
[docs] def start(self, num_ignore, num_average, gui=True):
if self.running:
return
self.corrector.records.clear()
self.numsteps = len(self.corrector.optics)
self.numshots = num_average + num_ignore
self.num_ignore = num_ignore
self.totalops = self.numsteps * self.numshots
self.progress = -1
self.running = True
self.widget.update_ui()
self.widget.log("Started")
self.corrector.control.sampler.updated.connect(self._feed)
self._advance()
[docs] def finish(self):
self.stop()
self.widget.update_fit()
self.widget.log(
"Set initial optic: {}", self.corrector.optics[0])
self.corrector.set_optic(0)
self.widget.log("Finished\n")
[docs] def cancel(self):
if self.running:
self.stop()
self.widget.update_ui()
self.widget.log("Cancelled by user.\n")
[docs] def stop(self):
if self.running:
self.corrector.close_export()
self.corrector.set_optic(None)
self.running = False
self.corrector.control.sampler.updated.disconnect(self._feed)
self.widget.update_ui()
def _feed(self, time, activity):
step = self.progress // self.numshots
shot = self.progress % self.numshots
if shot < self.num_ignore:
self.widget.log(' -> shot {} (ignored)', shot)
else:
self.widget.log(' -> shot {}', shot)
self.corrector.add_record(step, shot-self.num_ignore, time)
self._advance()
def _advance(self):
self.progress += 1
step = self.progress // self.numshots
shot = self.progress % self.numshots
self.widget.set_progress(self.progress)
if self.progress == self.totalops:
self.finish()
elif shot == 0:
self.widget.log(
"optic {} of {}: {}", step, self.numsteps,
self.corrector.optics[step])
self.corrector.set_optic(step)
def format_datetime(datime=None):
if datime is None:
datime = datetime.now(timezone.utc)
elif isinstance(datime, (int, float)):
datime = datetime.fromtimestamp(datime)
return datime.astimezone().strftime('%Y-%m-%d %H:%M:%S.%f %z')