"""
This module defines the toplevel context that is used in madgui to keep track
of current model, config, online control, and mainwindow.
"""
__all__ = [
    'Session',
]
import glob
import os
import sys
from types import SimpleNamespace
import numpy as np
from madgui.util.collections import Boxed, Selection
from madgui.util.misc import relpath, userpath
from madgui.online.control import Control
from madgui.core.config import load as load_config
from madgui.model.match import Matcher
import madgui.util.yaml as yaml
[docs]class Session:
    """
    Context variables and top-level application logic for a madgui session,
    i.e. the interaction between user and different parts of the computer
    program. This object keeps track and coordinates the use of the currently
    opened model, GUI window, user variables, control system connection, and
    configuration data.
    """
    def __init__(self, config=None):
        if config is None:
            config = load_config()
        self.config = config
        self.window = Boxed(None)
        self.model = Boxed(None)
        self.control = Control(self)
        self.matcher = None
        self.user_ns = user_ns = SimpleNamespace()
        self.session_file = userpath(config.session_file)
        self.folder = userpath(config.model_path)
        self.selected_elements = Selection()
        self.model.changed2.connect(self.on_model_changed)
        # Maintain these members into the namespace
        subscribe(user_ns, 'model', self.model)
        subscribe(user_ns, 'window', self.window)
        user_ns.config = config
        user_ns.context = self
        user_ns.control = self.control
[docs]    def on_model_changed(self, old, new):
        self.selected_elements.clear()
        self.selected_elements = Selection()
        if old:
            self.matcher = None
            old.destroy()
        if new:
            self.matcher = Matcher(new, self.config.get('matching')) 
[docs]    def set_interpolate(self, points_per_meter):
        self.config.interpolate = points_per_meter
        model = self.model()
        if model:
            model.interpolate = points_per_meter
            model.invalidate() 
[docs]    def terminate(self):
        if self.session_file:
            self.save(self.session_file)
            self.session_file = None
        if self.control.is_connected():
            self.control.disconnect()
        self.model.set(None)
        self.window.set(None) 
[docs]    def load_default(self, model=None):
        self.configure()
        config = self.config
        filename = model or config.load_default
        if filename:
            self.load_model(filename)
        if self.control.can_connect() and config.online_control.connect:
            self.control.connect() 
[docs]    def load_model(self, name, **madx_args):
        filename = self.find_model(name)
        exts = ('.cpymad.yml', '.madx', '.str', '.seq')
        if not filename.endswith(exts):
            raise NotImplementedError("Unsupported file format: {}"
                                      .format(filename))
        from madgui.model.madx import Model
        self.model.set(Model.load_file(
            filename, **dict(self.model_args(filename), **madx_args))) 
    known_extensions = ['.cpymad.yml', '.madx']
[docs]    def find_model(self, name):
        for path in [name, os.path.join(self.folder or '.', name)]:
            path = expand_ext(path, '', *self.known_extensions)
            if os.path.isfile(path):
                return path
            if os.path.isdir(path):
                models = (glob.glob(os.path.join(path, '*.cpymad.yml')) +
                          glob.glob(os.path.join(path, '*.madx')))
                if models:
                    return models[0]
        raise OSError("File not found: {!r}".format(name)) 
[docs]    def model_args(self, filename):
        """Please OVERRIDE to provide custom model arguments."""
        return {'interpolate': self.config.interpolate} 
[docs]    def save(self, filename):
        """Save session state to file."""
        yaml.save_file(filename, self.session_data()) 
[docs]    def session_data(self):
        folder = self.config.model_path or self.folder
        default = self.model() and relpath(self.model().filename, folder)
        data = {
            'online_control': {
                'backend': self.control.backend_spec,
                'connect': self.control.is_connected(),
                'monitors': self.config.online_control['monitors'],
                'offsets': self.config.online_control['offsets'],
                'settings': self.control.export_settings() or {},
            },
            'model_path': folder,
            'load_default': default,
            'number': self.config['number'],
        }
        if self.window():
            data.update(self.window().session_data())
        return data  
def subscribe(ns, key, boxed):
    """Update ``ns[key]`` with the current value of a ``Boxed``."""
    setter = lambda val: setattr(ns, key, val)
    setter(boxed())
    boxed.changed.connect(setter)
def expand_ext(path, *exts):
    """Add the first of the given file extensions ``exts`` to ``path`` that
    refers to an existing file."""
    for ext in exts:
        if os.path.isfile(path+ext):
            return path+ext
    return path