aboutsummaryrefslogtreecommitdiff
path: root/ihatemoney/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'ihatemoney/utils.py')
-rw-r--r--ihatemoney/utils.py46
1 files changed, 43 insertions, 3 deletions
diff --git a/ihatemoney/utils.py b/ihatemoney/utils.py
index 6af0112..5dd1e7b 100644
--- a/ihatemoney/utils.py
+++ b/ihatemoney/utils.py
@@ -2,8 +2,8 @@ import base64
import re
from io import BytesIO, StringIO
-from jinja2 import filters
-from json import dumps
+import jinja2
+from json import dumps, JSONEncoder
from flask import redirect
from werkzeug.routing import HTTPException, RoutingException
import six
@@ -83,7 +83,7 @@ def minimal_round(*args, **kw):
from http://stackoverflow.com/questions/28458524/
"""
# Use the original round filter, to deal with the extra arguments
- res = filters.do_round(*args, **kw)
+ res = jinja2.filters.do_round(*args, **kw)
# Test if the result is equivalent to an integer and
# return depending on it
ires = int(res)
@@ -170,3 +170,43 @@ class LoginThrottler():
def reset(self, ip):
self._attempts.pop(ip, None)
+
+
+def create_jinja_env(folder, strict_rendering=False):
+ """Creates and return a Jinja2 Environment object, used, to load the
+ templates.
+
+ :param strict_rendering:
+ if set to `True`, all templates which use an undefined variable will
+ throw an exception (default to `False`).
+ """
+ loader = jinja2.PackageLoader('ihatemoney', folder)
+ kwargs = {'loader': loader}
+ if strict_rendering:
+ kwargs['undefined'] = jinja2.StrictUndefined
+ return jinja2.Environment(**kwargs)
+
+
+class IhmJSONEncoder(JSONEncoder):
+ """Subclass of the default encoder to support custom objects.
+ Taken from the deprecated flask-rest package."""
+ def default(self, o):
+ if hasattr(o, "_to_serialize"):
+ # build up the object
+ data = {}
+ for attr in o._to_serialize:
+ data[attr] = getattr(o, attr)
+ return data
+ elif hasattr(o, "isoformat"):
+ return o.isoformat()
+ else:
+ try:
+ from flask_babel import speaklater
+ if isinstance(o, speaklater.LazyString):
+ try:
+ return unicode(o) # For python 2.
+ except NameError:
+ return str(o) # For python 3.
+ except ImportError:
+ pass
+ return JSONEncoder.default(self, o)