Source code for madgui.model.errors

"""
This module defines functions and classes to load and represent MAD-X model
errors, such as alignment errors or field errors.

Errors are represented as strings such as::

    'Δax_b3mu1'         # absolute error in parameter
    'δax_b3mu1'         # relative error in parameter
    'Δb3mu1v->kick'     # absolute error in element attribute
    'g3mu1<dx>'         # alignment error
"""

__all__ = [
    'import_errors',
    'apply_errors',
    'parse_error',
    'Param',
    'Ealign',
    'Efcomp',
    'ElemAttr',
    'InitTwiss',
    'ScaleAttr',
    'ScaleParam',
    'BaseError',
    'RelativeError',
]

import re
from contextlib import ExitStack

from cpymad.util import is_identifier


[docs]def import_errors(model, spec: dict): """ Apply errors to a model defined by a dictionary ``{error: value}`` with types ``{str: float}``. The error keys are parsed by ``parse_error``. """ return apply_errors( model, map(parse_error, spec.keys()), spec.values())
[docs]def apply_errors(model, errors, values): """Apply list of errors and list of corresponding values to a given model.""" with ExitStack() as stack: for error, value in zip(errors, values): stack.enter_context(error.vary(model, value)) return stack.pop_all()
[docs]def parse_error(name): """ Instanciate a subtype of :class:`BaseError`, depending on the format of ``name``. We currently understand the following formats:: x -> InitTwiss Δax_b3mu1 -> ElemAttr δax_b3mu1 -> ScaleAttr Δg3mu1->angle -> Param δg3mu1->angle -> ScaleParam g3mu1<dx> -> Ealign """ mult = name.startswith('δ') name = name.lstrip('δΔ \t') if name in ('x', 'y', 'px', 'py'): return InitTwiss(name) if '->' in name: elem, attr = name.split('->') if mult: return ScaleAttr(elem, attr) return ElemAttr(elem, attr) if '<' in name: elem, attr = re.match(r'(.*)\<(.*)\>', name).groups() return Ealign({'range': elem}, attr) if is_identifier(name): if mult: return ScaleParam(name) return Param(name) # TODO: efcomp field errors! raise ValueError("{!r} is not a valid error specification!".format(name))
[docs]class BaseError: """ Base class for model errors. Subclasses must implement ``get``, ``set``, and ``tinker``. In the simplest case, ``get`` returns the current value of the error, ``tinker`` returns the given step, and ``set`` sets a variable. However, this logic is not always available. In general, the following protocol must be implemented: - :meth:`get`: return a backup value that will be later used to restore the current error value - :meth:`tinker` returns a value that should be used to update the current value - :meth:`set` is called with the return value of :meth:`tinker` to change the value of the error, and later with the return value of :meth:`get` to restore to the original state. """ leader = 'Δ' def __init__(self, name): self.name = name
[docs] def vary(self, model, step): """Applies the error and returns a context manager that restores the error to its original value on exit.""" old = self.get(model, step) new = self.tinker(old, step) with ExitStack() as stack: if new != old: self.set(model, new) stack.callback(self.set, model, old) return stack.pop_all()
[docs] def get(self, model, step): """Get a "backup" value that represents with what :meth:`set` should be called to restore the current value.""" return 0.0
[docs] def set(self, model, value): """Update the error value.""" raise NotImplementedError
def __repr__(self): return "{}{}".format(self.leader, self.name)
[docs] def tinker(self, value, step): """Return the value that should be passed to :meth:`set` in order to increment the error by ``step``. ``value`` is provided as the return value of :meth:`get`.""" if isinstance(value, str): return "({}) + ({})".format(value, step) elif value is None: return step else: return value + step
[docs] def is_defined_for(self, model): """Check whether this error is relevant for the given model.""" return True
[docs]class Param(BaseError): """Error on a global variable (knob)."""
[docs] def get(self, model, step): return model.globals.cmdpar[self.name].definition
[docs] def set(self, model, value): model.globals[self.name] = value
[docs] def is_defined_for(self, model): return self.name in model.globals
[docs]class Ealign(BaseError): """Alignment error.""" def __init__(self, select, attr): self.select = select self.attr = attr self.name = '{}<{}>'.format(select.get('range'), attr)
[docs] def set(self, model, value): cmd = model.madx.command cmd.select(flag='error', clear=True) cmd.select(flag='error', **self.select) cmd.ealign(**{self.attr: value})
[docs] def get(self, model, step): return -step
[docs] def tinker(self, value, step): return -value
[docs] def is_defined_for(self, model): elem = self.select.get('range') return elem and elem in model.elements
[docs]class Efcomp(BaseError): """Field error.""" def __init__(self, select, attr, value, order=0, radius=1): self.select = select self.attr = attr self.value = value self.order = order self.radius = radius self.name = '{}+{}'.format(select['range'], attr)
[docs] def set(self, model, value): cmd = model.madx.command cmd.select(flag='error', clear=True) cmd.select(flag='error', **self.select) cmd.efcomp(**{ 'order': self.order, 'radius': self.radius, self.attr: [v * value for v in self.value], })
[docs] def get(self, model, step): return -step
[docs] def tinker(self, value, step): return -value
[docs] def is_defined_for(self, model): elem = self.select.get('range') return elem and elem in model.elements
[docs]class ElemAttr(BaseError): """Element attribute error.""" def __init__(self, elem, attr): self.elem = elem self.attr = attr self.name = '{}->{}'.format(elem, attr)
[docs] def get(self, model, step): return model.elements[self.elem].cmdpar[self.attr].definition
[docs] def set(self, model, value): model.elements[self.elem][self.attr] = value
[docs] def is_defined_for(self, model): return self.elem in model.elements
[docs]class InitTwiss(BaseError): """Error in twiss initial condition (x, px, y, py)."""
[docs] def get(self, model, step): return model.twiss_args.get(self.name)
[docs] def set(self, model, value): model.update_twiss_args({self.name: value})
[docs]class RelativeError(BaseError): """Base class for relative errors.""" leader = 'δ'
[docs] def tinker(self, value, step): if isinstance(value, str): return "({}) * ({})".format(value, 1 + step) elif value is None: return None else: return value * (1 + step)
[docs]class ScaleAttr(RelativeError, ElemAttr): """Relative element attribute error."""
[docs]class ScaleParam(RelativeError, Param): """Relative global variable error."""