aboutsummaryrefslogtreecommitdiff
path: root/ihatemoney
diff options
context:
space:
mode:
author0livd <github@destras.fr>2018-01-25 17:41:28 +0100
committerAlexis Metaireau <alexis@notmyidea.org>2018-01-25 17:41:28 +0100
commitb93ea4830d5290def99d597f17292a8aa5d4c090 (patch)
treebeb34a47d9eb85643672cc3b743c9b6d37932092 /ihatemoney
parent830718e1fe5f18959f455a696ebc2172a2d5f253 (diff)
downloadihatemoney-mirror-b93ea4830d5290def99d597f17292a8aa5d4c090.zip
ihatemoney-mirror-b93ea4830d5290def99d597f17292a8aa5d4c090.tar.gz
ihatemoney-mirror-b93ea4830d5290def99d597f17292a8aa5d4c090.tar.bz2
API: Migrate from flask-rest to flask-restful (#315)
The flask-rest custom json encoder is still needed and thus was added to ihatemoney's utils. Closes #298
Diffstat (limited to 'ihatemoney')
-rw-r--r--ihatemoney/api.py168
-rw-r--r--ihatemoney/run.py4
-rw-r--r--ihatemoney/tests/tests.py14
-rw-r--r--ihatemoney/utils.py27
4 files changed, 120 insertions, 93 deletions
diff --git a/ihatemoney/api.py b/ihatemoney/api.py
index 827202c..31ed06c 100644
--- a/ihatemoney/api.py
+++ b/ihatemoney/api.py
@@ -1,62 +1,68 @@
# -*- coding: utf-8 -*-
from flask import Blueprint, request
-from flask_rest import RESTResource, need_auth
+from flask_restful import Resource, Api, abort
from wtforms.fields.core import BooleanField
from ihatemoney.models import db, Project, Person, Bill
from ihatemoney.forms import (ProjectForm, EditProjectForm, MemberForm,
get_billform_for)
from werkzeug.security import check_password_hash
+from functools import wraps
api = Blueprint("api", __name__, url_prefix="/api")
+restful_api = Api(api)
-def check_project(*args, **kwargs):
+def need_auth(f):
"""Check the request for basic authentication for a given project.
- Return the project if the authorization is good, False otherwise
+ Return the project if the authorization is good, abort the request with a 401 otherwise
"""
- auth = request.authorization
-
- # project_id should be contained in kwargs and equal to the username
- if auth and "project_id" in kwargs and \
- auth.username == kwargs["project_id"]:
- project = Project.query.get(auth.username)
- if project and check_password_hash(project.password, auth.password):
- return project
- return False
-
-
-class ProjectHandler(object):
-
- def add(self):
+ @wraps(f)
+ def wrapper(*args, **kwargs):
+ auth = request.authorization
+ project_id = kwargs.get("project_id")
+
+ if auth and project_id and auth.username == project_id:
+ project = Project.query.get(auth.username)
+ if project and check_password_hash(project.password, auth.password):
+ # The whole project object will be passed instead of project_id
+ kwargs.pop("project_id")
+ return f(*args, project=project, **kwargs)
+ abort(401)
+ return wrapper
+
+
+class ProjectsHandler(Resource):
+ def post(self):
form = ProjectForm(meta={'csrf': False})
if form.validate():
project = form.save()
db.session.add(project)
db.session.commit()
- return 201, project.id
- return 400, form.errors
+ return project.id, 201
+ return form.errors, 400
+
+
+class ProjectHandler(Resource):
+ method_decorators = [need_auth]
- @need_auth(check_project, "project")
def get(self, project):
- return 200, project
+ return project
- @need_auth(check_project, "project")
def delete(self, project):
db.session.delete(project)
db.session.commit()
- return 200, "DELETED"
+ return "DELETED"
- @need_auth(check_project, "project")
- def update(self, project):
+ def put(self, project):
form = EditProjectForm(meta={'csrf': False})
if form.validate():
form.update(project)
db.session.commit()
- return 200, "UPDATED"
- return 400, form.errors
+ return "UPDATED"
+ return form.errors, 400
class APIMemberForm(MemberForm):
@@ -71,98 +77,92 @@ class APIMemberForm(MemberForm):
return super(APIMemberForm, self).save(project, person)
-class MemberHandler(object):
+class MembersHandler(Resource):
+ method_decorators = [need_auth]
- def get(self, project, member_id):
- member = Person.query.get(member_id, project)
- if not member or member.project != project:
- return 404, "Not Found"
- return 200, member
-
- def list(self, project):
- return 200, project.members
+ def get(self, project):
+ return project.members
- def add(self, project):
+ def post(self, project):
form = MemberForm(project, meta={'csrf': False})
if form.validate():
member = Person()
form.save(project, member)
db.session.commit()
- return 201, member.id
- return 400, form.errors
+ return member.id, 201
+ return form.errors, 400
+
+
+class MemberHandler(Resource):
+ method_decorators = [need_auth]
+
+ def get(self, project, member_id):
+ member = Person.query.get(member_id, project)
+ if not member or member.project != project:
+ return "Not Found", 404
+ return member
- def update(self, project, member_id):
+ def put(self, project, member_id):
form = APIMemberForm(project, meta={'csrf': False}, edit=True)
if form.validate():
member = Person.query.get(member_id, project)
form.save(project, member)
db.session.commit()
- return 200, member
- return 400, form.errors
+ return member
+ return form.errors, 400
def delete(self, project, member_id):
if project.remove_member(member_id):
- return 200, "OK"
- return 404, "Not Found"
+ return "OK"
+ return "Not Found", 404
-class BillHandler(object):
-
- def get(self, project, bill_id):
- bill = Bill.query.get(project, bill_id)
- if not bill:
- return 404, "Not Found"
- return 200, bill
+class BillsHandler(Resource):
+ method_decorators = [need_auth]
- def list(self, project):
+ def get(self, project):
return project.get_bills().all()
- def add(self, project):
+ def post(self, project):
form = get_billform_for(project, True, meta={'csrf': False})
if form.validate():
bill = Bill()
form.save(bill, project)
db.session.add(bill)
db.session.commit()
- return 201, bill.id
- return 400, form.errors
+ return bill.id, 201
+ return form.errors, 400
+
- def update(self, project, bill_id):
+class BillHandler(Resource):
+ method_decorators = [need_auth]
+
+ def get(self, project, bill_id):
+ bill = Bill.query.get(project, bill_id)
+ if not bill:
+ return "Not Found", 404
+ return bill, 200
+
+ def put(self, project, bill_id):
form = get_billform_for(project, True, meta={'csrf': False})
if form.validate():
bill = Bill.query.get(project, bill_id)
form.save(bill, project)
db.session.commit()
- return 200, bill.id
- return 400, form.errors
+ return bill.id, 200
+ return form.errors, 400
def delete(self, project, bill_id):
bill = Bill.query.delete(project, bill_id)
db.session.commit()
if not bill:
- return 404, "Not Found"
- return 200, "OK"
-
-
-project_resource = RESTResource(
- name="project",
- route="/projects",
- app=api,
- actions=["add", "update", "delete", "get"],
- handler=ProjectHandler())
-
-member_resource = RESTResource(
- name="member",
- inject_name="project",
- route="/projects/<project_id>/members",
- app=api,
- handler=MemberHandler(),
- authentifier=check_project)
-
-bill_resource = RESTResource(
- name="bill",
- inject_name="project",
- route="/projects/<project_id>/bills",
- app=api,
- handler=BillHandler(),
- authentifier=check_project)
+ return "Not Found", 404
+ return "OK", 200
+
+
+restful_api.add_resource(ProjectsHandler, '/projects')
+restful_api.add_resource(ProjectHandler, '/projects/<string:project_id>')
+restful_api.add_resource(MembersHandler, "/projects/<string:project_id>/members")
+restful_api.add_resource(MemberHandler, "/projects/<string:project_id>/members/<int:member_id>")
+restful_api.add_resource(BillsHandler, "/projects/<string:project_id>/bills")
+restful_api.add_resource(BillHandler, "/projects/<string:project_id>/bills/<int:bill_id>")
diff --git a/ihatemoney/run.py b/ihatemoney/run.py
index e3a7c1e..a8de26f 100644
--- a/ihatemoney/run.py
+++ b/ihatemoney/run.py
@@ -11,7 +11,7 @@ from werkzeug.contrib.fixers import ProxyFix
from ihatemoney.api import api
from ihatemoney.models import db
-from ihatemoney.utils import PrefixedWSGI, minimal_round
+from ihatemoney.utils import PrefixedWSGI, minimal_round, IhmJSONEncoder
from ihatemoney.web import main as web_interface
from ihatemoney import default_settings
@@ -68,6 +68,8 @@ def load_configuration(app, configuration=None):
app.config.from_pyfile(env_var_config)
else:
app.config.from_pyfile('ihatemoney.cfg', silent=True)
+ # Configure custom JSONEncoder used by the API
+ app.config['RESTFUL_JSON'] = {'cls': IhmJSONEncoder}
def validate_configuration(app):
diff --git a/ihatemoney/tests/tests.py b/ihatemoney/tests/tests.py
index d4b6d7a..c13131c 100644
--- a/ihatemoney/tests/tests.py
+++ b/ihatemoney/tests/tests.py
@@ -1053,7 +1053,7 @@ class APITestCase(IhatemoneyTestCase):
})
self.assertTrue(400, resp.status_code)
- self.assertEqual('{"contact_email": ["Invalid email address."]}',
+ self.assertEqual('{"contact_email": ["Invalid email address."]}\n',
resp.data.decode('utf-8'))
# create it
@@ -1139,7 +1139,7 @@ class APITestCase(IhatemoneyTestCase):
headers=self.get_auth("raclette"))
self.assertStatus(200, req)
- self.assertEqual('[]', req.data.decode('utf-8'))
+ self.assertEqual('[]\n', req.data.decode('utf-8'))
# add a member
req = self.client.post("/api/projects/raclette/members", data={
@@ -1148,7 +1148,7 @@ class APITestCase(IhatemoneyTestCase):
# the id of the new member should be returned
self.assertStatus(201, req)
- self.assertEqual("1", req.data.decode('utf-8'))
+ self.assertEqual("1\n", req.data.decode('utf-8'))
# the list of members should contain one member
req = self.client.get("/api/projects/raclette/members",
@@ -1223,7 +1223,7 @@ class APITestCase(IhatemoneyTestCase):
headers=self.get_auth("raclette"))
self.assertStatus(200, req)
- self.assertEqual('[]', req.data.decode('utf-8'))
+ self.assertEqual('[]\n', req.data.decode('utf-8'))
def test_bills(self):
# create a project
@@ -1239,7 +1239,7 @@ class APITestCase(IhatemoneyTestCase):
headers=self.get_auth("raclette"))
self.assertStatus(200, req)
- self.assertEqual("[]", req.data.decode('utf-8'))
+ self.assertEqual("[]\n", req.data.decode('utf-8'))
# add a bill
req = self.client.post("/api/projects/raclette/bills", data={
@@ -1252,7 +1252,7 @@ class APITestCase(IhatemoneyTestCase):
# should return the id
self.assertStatus(201, req)
- self.assertEqual(req.data.decode('utf-8'), "1")
+ self.assertEqual(req.data.decode('utf-8'), "1\n")
# get this bill details
req = self.client.get("/api/projects/raclette/bills/1",
@@ -1288,7 +1288,7 @@ class APITestCase(IhatemoneyTestCase):
}, headers=self.get_auth("raclette"))
self.assertStatus(400, req)
- self.assertEqual('{"date": ["This field is required."]}', req.data.decode('utf-8'))
+ self.assertEqual('{"date": ["This field is required."]}\n', req.data.decode('utf-8'))
# edit a bill
req = self.client.put("/api/projects/raclette/bills/1", data={
diff --git a/ihatemoney/utils.py b/ihatemoney/utils.py
index 6af0112..aaae2a0 100644
--- a/ihatemoney/utils.py
+++ b/ihatemoney/utils.py
@@ -3,7 +3,7 @@ import re
from io import BytesIO, StringIO
from jinja2 import filters
-from json import dumps
+from json import dumps, JSONEncoder
from flask import redirect
from werkzeug.routing import HTTPException, RoutingException
import six
@@ -170,3 +170,28 @@ class LoginThrottler():
def reset(self, ip):
self._attempts.pop(ip, None)
+
+
+class IhmJSONEncoder(JSONEncoder):
+ """Subclass of the default encoder to support custom objects.
+ Taken from the deprecated flask-rest package."""
+ def default(self, o):
+ if hasattr(o, "_to_serialize"):
+ # build up the object
+ data = {}
+ for attr in o._to_serialize:
+ data[attr] = getattr(o, attr)
+ return data
+ elif hasattr(o, "isoformat"):
+ return o.isoformat()
+ else:
+ try:
+ from flask_babel import speaklater
+ if isinstance(o, speaklater.LazyString):
+ try:
+ return unicode(o) # For python 2.
+ except NameError:
+ return str(o) # For python 3.
+ except ImportError:
+ pass
+ return JSONEncoder.default(self, o)