# -*- coding: utf-8 -*-
#
##########################################################################
# eoleauth - Redis session manager
# Copyright © 2013 Pôle de compétences EOLE <eole@ac-dijon.fr>
#
# License CeCILL:
#  * in french: http://www.cecill.info/licences/Licence_CeCILL_V2-fr.html
#  * in english http://www.cecill.info/licences/Licence_CeCILL_V2-en.html
#
# based on http://flask.pocoo.org/snippets/75/
#
# server side session with data stored in Redis
#
##########################################################################
import pickle
from datetime import timedelta
from uuid import uuid4
from redis import Redis
from werkzeug.datastructures import CallbackDict
from flask import request, session, current_app
from flask.sessions import SecureCookieSessionInterface, SecureCookie, SessionMixin
from eoleauthlib.i18n import _

class RedisSession(SecureCookie, SessionMixin, CallbackDict):

    def __init__(self, initial=None, sid=None, new=False):
        def on_update(self):
            self.modified = True
        CallbackDict.__init__(self, initial, on_update)
        self.sid = sid
        self.new = new
        self.modified = False

class RedisSessionInterface(SecureCookieSessionInterface):
    """Session interface for storing sessions in Redis
    """
    serializer = pickle
    session_class = RedisSession

    def __init__(self, redis=None, prefix='session:'):
        if redis is None:
            redis = Redis(unix_socket_path='/var/run/redis/redis.sock')
        self.redis = redis
        self.prefix = prefix

    def generate_sid(self):
        return str(uuid4())

    def get_redis_expiration_time(self, app, session):
        if session.permanent:
            return app.permanent_session_lifetime
        return timedelta(days=1)

    def open_session(self, app, request):
        sid = request.cookies.get(app.session_cookie_name)
        if not sid:
            sid = self.generate_sid()
            return self.session_class(sid=sid, new=True)
        val = self.redis.get(self.prefix + sid)
        if val is not None:
            data = self.serializer.loads(val)
            return self.session_class(data, sid=sid)
        return self.session_class(sid=sid, new=True)

    def save_session(self, app, session, response):
        domain = self.get_cookie_domain(app)
        if not session:
            if hasattr(session, 'sid') and self.redis.exists(self.prefix + session.sid):
                # should not happen, session is deleted
                # with its mapping in invalidate_session
                self.redis.delete(self.prefix + session.sid)
            if session.modified:
                response.delete_cookie(app.session_cookie_name,
                                       domain=domain)
            return
        redis_exp = self.get_redis_expiration_time(app, session)
        cookie_exp = self.get_expiration_time(app, session)
        val = self.serializer.dumps(dict(session))
        # store attribute/session mapping if defined in auth plugin
        if getattr(app, 'eoleauth_map_attr', None):
            map_key = session.get(app.eoleauth_map_attr, None)
            if map_key and hasattr(session, 'sid'):
                app.logger.debug(_('storing session_id mapping ({0})').format(app.eoleauth_map_attr))
                self.redis.setex('map_' + self.prefix + map_key, session.sid,
                                 int(redis_exp.total_seconds()))
        self.redis.setex(self.prefix + session.sid, val,
                         int(redis_exp.total_seconds()))
        response.set_cookie(app.session_cookie_name, session.sid,
                            expires=cookie_exp, httponly=True,
                            domain=domain)

    def remove_session(self, map_key):
        """eoleauth specific: use mapping to a specific session attribute to
        invalidate server side session (attribute defined in auth plugin and
        stored in app.eoleauth_map_attr)
        """
        current_app.logger.debug(_('session removal requested (mapping:{0})').format(map_key))
        map_key = 'map_' + self.prefix + map_key
        session_id = self.redis.get(map_key)
        if session_id:
            current_app.logger.debug(_('deleteting session from redis : {0}').format(self.prefix + session_id))
            self.redis.delete(self.prefix + session_id)
            self.redis.delete(map_key)
            return True
        current_app.logger.debug(_('no session found for this mapping'))
        return False
