aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlexis Metaireau <alexis@notmyidea.org>2011-10-08 13:22:18 +0200
committerAlexis Metaireau <alexis@notmyidea.org>2011-10-08 13:27:30 +0200
commit48bc551853b8b0067cdaeef98b3e454c3249f98f (patch)
tree75332a6548aba9d95e0771de2af0e9e19d53efbc
parent402dbce153639668d47db00fdc7a0479d9ebc3f6 (diff)
downloadihatemoney-mirror-48bc551853b8b0067cdaeef98b3e454c3249f98f.zip
ihatemoney-mirror-48bc551853b8b0067cdaeef98b3e454c3249f98f.tar.gz
ihatemoney-mirror-48bc551853b8b0067cdaeef98b3e454c3249f98f.tar.bz2
Complete the REST API + Tests. Fix #27
-rw-r--r--budget/api.py46
-rw-r--r--budget/forms.py5
-rw-r--r--budget/models.py15
-rw-r--r--budget/rest.py6
-rw-r--r--budget/tests.py312
-rw-r--r--budget/utils.py17
-rw-r--r--budget/web.py4
7 files changed, 353 insertions, 52 deletions
diff --git a/budget/api.py b/budget/api.py
index 3df8ab2..c5ae76b 100644
--- a/budget/api.py
+++ b/budget/api.py
@@ -2,8 +2,8 @@
from flask import *
from models import db, Project, Person, Bill
-from forms import ProjectForm
-from utils import for_all_methods
+from forms import ProjectForm, EditProjectForm, MemberForm, BillForm
+from utils import for_all_methods, get_billform_for
from rest import RESTResource, need_auth# FIXME make it an ext
from werkzeug import Response
@@ -32,7 +32,7 @@ class ProjectHandler(object):
def add(self):
form = ProjectForm(csrf_enabled=False)
if form.validate():
- project = form.save(Project())
+ project = form.save()
db.session.add(project)
db.session.commit()
return 201, project.id
@@ -40,7 +40,7 @@ class ProjectHandler(object):
@need_auth(check_project, "project")
def get(self, project):
- return project
+ return 200, project
@need_auth(check_project, "project")
def delete(self, project):
@@ -50,9 +50,9 @@ class ProjectHandler(object):
@need_auth(check_project, "project")
def update(self, project):
- form = ProjectForm(csrf_enabled=False)
+ form = EditProjectForm(csrf_enabled=False)
if form.validate():
- form.save(project)
+ form.update(project)
db.session.commit()
return 200, "UPDATED"
return 400, form.errors
@@ -61,25 +61,25 @@ class ProjectHandler(object):
class MemberHandler(object):
def get(self, project, member_id):
- member = Person.query.get(member_id)
+ member = Person.query.get(member_id, project)
if not member or member.project != project:
return 404, "Not Found"
- return member
+ return 200, member
def list(self, project):
- return project.members
+ return 200, project.members
def add(self, project):
- form = MemberForm(csrf_enabled=False)
+ form = MemberForm(project, csrf_enabled=False)
if form.validate():
member = Person()
form.save(project, member)
db.session.commit()
- return 200, member.id
+ return 201, member.id
return 400, form.errors
def update(self, project, member_id):
- form = MemberForm(csrf_enabled=False)
+ form = MemberForm(project, csrf_enabled=False)
if form.validate():
member = Person.query.get(member_id, project)
form.save(project, member)
@@ -99,39 +99,41 @@ class BillHandler(object):
bill = Bill.query.get(project, bill_id)
if not bill:
return 404, "Not Found"
- return bill
+ return 200, bill
def list(self, project):
return project.get_bills().all()
def add(self, project):
- form = BillForm(csrf_enabled=False)
+ form = get_billform_for(project, True, csrf_enabled=False)
if form.validate():
bill = Bill()
- form.save(bill)
+ form.save(bill, project)
db.session.add(bill)
db.session.commit()
- return 200, bill.id
+ return 201, bill.id
return 400, form.errors
def update(self, project, bill_id):
- form = BillForm(csrf_enabled=False)
+ form = get_billform_for(project, True, csrf_enabled=False)
if form.validate():
- form.save(bill)
+ bill = Bill.query.get(project, bill_id)
+ form.save(bill, project)
db.session.commit()
return 200, bill.id
return 400, form.errors
def delete(self, project, bill_id):
bill = Bill.query.delete(project, bill_id)
+ db.session.commit()
if not bill:
return 404, "Not Found"
- return bill
+ return 200, "OK"
project_resource = RESTResource(
name="project",
- route="/project",
+ route="/projects",
app=api,
actions=["add", "update", "delete", "get"],
handler=ProjectHandler())
@@ -139,7 +141,7 @@ project_resource = RESTResource(
member_resource = RESTResource(
name="member",
inject_name="project",
- route="/project/<project_id>/members",
+ route="/projects/<project_id>/members",
app=api,
handler=MemberHandler(),
authentifier=check_project)
@@ -147,7 +149,7 @@ member_resource = RESTResource(
bill_resource = RESTResource(
name="bill",
inject_name="project",
- route="/project/<project_id>/bills",
+ route="/projects/<project_id>/bills",
app=api,
handler=BillHandler(),
authentifier=check_project)
diff --git a/budget/forms.py b/budget/forms.py
index 16fa0d6..25731bc 100644
--- a/budget/forms.py
+++ b/budget/forms.py
@@ -95,12 +95,13 @@ class BillForm(Form):
validators=[Required()], widget=select_multi_checkbox)
submit = SubmitField("Send the bill")
- def save(self, bill):
+ def save(self, bill, project):
bill.payer_id=self.payer.data
bill.amount=self.amount.data
bill.what=self.what.data
bill.date=self.date.data
- bill.owers = [Person.query.get(ower) for ower in self.payed_for.data]
+ bill.owers = [Person.query.get(ower, project)
+ for ower in self.payed_for.data]
return bill
diff --git a/budget/models.py b/budget/models.py
index 5ee7b07..c938e97 100644
--- a/budget/models.py
+++ b/budget/models.py
@@ -60,14 +60,13 @@ class Project(db.Model):
This method returns the status DELETED or DEACTIVATED regarding the
changes made.
"""
- person = Person.query.get_or_404(member_id)
- if person.project == self:
- if not person.has_bills():
- db.session.delete(person)
- db.session.commit()
- else:
- person.activated = False
- db.session.commit()
+ person = Person.query.get(member_id, self)
+ if not person.has_bills():
+ db.session.delete(person)
+ db.session.commit()
+ else:
+ person.activated = False
+ db.session.commit()
return person
def __repr__(self):
diff --git a/budget/rest.py b/budget/rest.py
index f237217..992a61e 100644
--- a/budget/rest.py
+++ b/budget/rest.py
@@ -90,7 +90,7 @@ def need_auth(authentifier, name=None, remove_attr=True):
If the request is authorized, the object returned by the authentifier
is added to the kwargs of the method.
- If not, issue a 403 Forbidden error
+ If not, issue a 401 Unauthorized error
:authentifier:
The callable to check the context onto.
@@ -114,7 +114,7 @@ def need_auth(authentifier, name=None, remove_attr=True):
del kwargs["%s_id" % name]
return func(*args, **kwargs)
else:
- return 403, "Forbidden"
+ return 401, "Unauthorized"
return wrapped
return wrapper
@@ -126,7 +126,7 @@ def serialize(func):
"""
def wrapped(*args, **kwargs):
# get the mimetype
- mime = request.accept_mimetypes.best_match(SERIALIZERS.keys())
+ mime = request.accept_mimetypes.best_match(SERIALIZERS.keys()) or "text/json"
data = func(*args, **kwargs)
serializer = SERIALIZERS[mime]
diff --git a/budget/tests.py b/budget/tests.py
index 4bb8e60..1541cbb 100644
--- a/budget/tests.py
+++ b/budget/tests.py
@@ -2,6 +2,8 @@
import os
import tempfile
import unittest
+import base64
+import json
from flask import session
@@ -13,7 +15,7 @@ class TestCase(unittest.TestCase):
def setUp(self):
run.app.config['TESTING'] = True
- run.app.config['SQLALCHEMY_DATABASE_URI'] = "sqlite:///memory"
+ run.app.config['SQLALCHEMY_DATABASE_URI'] = "sqlite:///memory"
run.app.config['CSRF_ENABLED'] = False # simplify the tests
self.app = run.app.test_client()
@@ -45,7 +47,7 @@ class TestCase(unittest.TestCase):
})
def create_project(self, name):
- models.db.session.add(models.Project(id=name, name=unicode(name),
+ models.db.session.add(models.Project(id=name, name=unicode(name),
password=name, contact_email="%s@notmyidea.org" % name))
models.db.session.commit()
@@ -76,7 +78,7 @@ class BudgetTestCase(TestCase):
# only one message is sent to multiple persons
self.assertEqual(len(outbox), 1)
- self.assertEqual(outbox[0].recipients,
+ self.assertEqual(outbox[0].recipients,
["alexis@notmyidea.org", "toto@notmyidea.org"])
# mail address checking
@@ -107,7 +109,7 @@ class BudgetTestCase(TestCase):
# session is updated
self.assertEqual(session['raclette'], 'party')
-
+
# project is created
self.assertEqual(len(models.Project.query.all()), 1)
@@ -144,9 +146,9 @@ class BudgetTestCase(TestCase):
# check fred is present in the bills page
result = self.app.get("/raclette/")
self.assertIn("fred", result.data)
-
+
# remove fred
- self.app.post("/raclette/members/%s/delete" %
+ self.app.post("/raclette/members/%s/delete" %
models.Project.query.get("raclette").members[-1].id)
# as fred is not bound to any bill, he is removed
@@ -186,7 +188,7 @@ class BudgetTestCase(TestCase):
self.assertEqual(
len(models.Project.query.get("raclette").active_members), 2)
- # adding an user with the same name as another user from a different
+ # adding an user with the same name as another user from a different
# project should not cause any troubles
self.post_project("randomid")
self.login("randomid")
@@ -198,7 +200,7 @@ class BudgetTestCase(TestCase):
def test_demo(self):
# Test that it is possible to connect automatically by going onto /demo
with run.app.test_client() as c:
- models.db.session.add(models.Project(id="demo", name=u"demonstration",
+ models.db.session.add(models.Project(id="demo", name=u"demonstration",
password="demo", contact_email="demo@notmyidea.org"))
models.db.session.commit()
c.get("/demo")
@@ -216,14 +218,14 @@ class BudgetTestCase(TestCase):
# raclette that the login / logout process works
self.create_project("raclette")
- # try to see the project while not being authenticated should redirect
+ # try to see the project while not being authenticated should redirect
# to the authentication page
resp = self.app.post("/raclette", follow_redirects=True)
self.assertIn("Authentication", resp.data)
-
+
# try to connect with wrong credentials should not work
with run.app.test_client() as c:
- resp = c.post("/authenticate",
+ resp = c.post("/authenticate",
data={'id': 'raclette', 'password': 'nope'})
self.assertIn("Authentication", resp.data)
@@ -231,7 +233,7 @@ class BudgetTestCase(TestCase):
# try to connect with the right credentials should work
with run.app.test_client() as c:
- resp = c.post("/authenticate",
+ resp = c.post("/authenticate",
data={'id': 'raclette', 'password': 'raclette'})
self.assertNotIn("Authentication", resp.data)
@@ -250,7 +252,7 @@ class BudgetTestCase(TestCase):
self.app.post("/raclette/members/add", data={'name': 'fred' })
members_ids = [m.id for m in models.Project.query.get("raclette").members]
-
+
# create a bill
self.app.post("/raclette/add", data={
'date': '2011-08-10',
@@ -317,7 +319,7 @@ class BudgetTestCase(TestCase):
'password': 'didoudida'
}
- resp = self.app.post("/raclette/edit", data=new_data,
+ resp = self.app.post("/raclette/edit", data=new_data,
follow_redirects=True)
self.assertEqual(resp.status_code, 200)
project = models.Project.query.get("raclette")
@@ -333,5 +335,287 @@ class BudgetTestCase(TestCase):
self.assertIn("Invalid email address", resp.data)
+class APITestCase(TestCase):
+ """Tests the API"""
+
+ def api_create(self, name, id=None, password=None, contact=None):
+ id = id or name
+ password = password or name
+ contact = contact or "%s@notmyidea.org" % name
+
+ return self.app.post("/api/projects", data={
+ 'name': name,
+ 'id': id,
+ 'password': password,
+ 'contact_email': contact
+ })
+
+ def api_add_member(self, project, name):
+ self.app.post("/api/projects/%s/members" % project,
+ data={"name": name}, headers=self.get_auth(project))
+
+ def get_auth(self, username, password=None):
+ password = password or username
+ base64string = base64.encodestring(
+ '%s:%s' % (username, password)).replace('\n', '')
+ return {"Authorization": "Basic %s" % base64string}
+
+ def assertStatus(self, expected, resp, url=""):
+
+ return self.assertEqual(expected, resp.status_code,
+ "%s expected %s, got %s" % (url, expected, resp.status_code))
+
+ def test_basic_auth(self):
+ # create a project
+ resp = self.api_create("raclette")
+ self.assertStatus(201, resp)
+
+ # try to do something on it being unauth should return a 401
+ resp = self.app.get("/api/projects/raclette")
+ self.assertStatus(401, resp)
+
+ # PUT / POST / DELETE / GET on the different resources
+ # should also return a 401
+ for verb in ('post',):
+ for resource in ("/raclette/members", "/raclette/bills"):
+ url = "/api/projects" + resource
+ self.assertStatus(401, getattr(self.app, verb)(url),
+ verb + resource)
+
+ for verb in ('get', 'delete', 'put'):
+ for resource in ("/raclette", "/raclette/members/1",
+ "/raclette/bills/1"):
+ url = "/api/projects" + resource
+
+ self.assertStatus(401, getattr(self.app, verb)(url),
+ verb + resource)
+
+ def test_project(self):
+ # wrong email should return an error
+ resp = self.app.post("/api/projects", data={
+ 'name': "raclette",
+ 'id': "raclette",
+ 'password': "raclette",
+ 'contact_email': "not-an-email"
+ })
+
+ self.assertTrue(400, resp.status_code)
+ self.assertEqual('{"contact_email": ["Invalid email address."]}', resp.data)
+
+ # create it
+ resp = self.api_create("raclette")
+ self.assertTrue(201, resp.status_code)
+
+ # create it twice should return a 400
+ resp = self.api_create("raclette")
+
+ self.assertTrue(400, resp.status_code)
+ self.assertEqual('{"id": ["This project id is already used"]}', resp.data)
+ # get information about it
+ resp = self.app.get("/api/projects/raclette",
+ headers=self.get_auth("raclette"))
+
+ self.assertTrue(200, resp.status_code)
+ expected = {
+ "active_members": [],
+ "name": "raclette",
+ "contact_email": "raclette@notmyidea.org",
+ "members": [],
+ "password": "raclette",
+ "id": "raclette"
+ }
+ self.assertDictEqual(json.loads(resp.data), expected)
+
+ # edit should work
+ resp = self.app.put("/api/projects/raclette", data = {
+ "contact_email": "yeah@notmyidea.org",
+ "password": "raclette",
+ "name": "The raclette party",
+ }, headers=self.get_auth("raclette"))
+
+ self.assertEqual(200, resp.status_code)
+
+ resp = self.app.get("/api/projects/raclette",
+ headers=self.get_auth("raclette"))
+
+ self.assertEqual(200, resp.status_code)
+ expected = {
+ "active_members": [],
+ "name": "The raclette party",
+ "contact_email": "yeah@notmyidea.org",
+ "members": [],
+ "password": "raclette",
+ "id": "raclette"
+ }
+ self.assertDictEqual(json.loads(resp.data), expected)
+
+ # delete should work
+ resp = self.app.delete("/api/projects/raclette",
+ headers=self.get_auth("raclette"))
+
+ self.assertEqual(200, resp.status_code)
+
+ # get should return a 401 on an unknown resource
+ resp = self.app.get("/api/projects/raclette",
+ headers=self.get_auth("raclette"))
+ self.assertEqual(401, resp.status_code)
+
+ def test_member(self):
+ # create a project
+ self.api_create("raclette")
+
+ # get the list of members (should be empty)
+ req = self.app.get("/api/projects/raclette/members",
+ headers=self.get_auth("raclette"))
+
+ self.assertStatus(200, req)
+ self.assertEqual('[]', req.data)
+
+ # add a member
+ req = self.app.post("/api/projects/raclette/members", data={
+ "name": "Alexis"
+ }, headers=self.get_auth("raclette"))
+
+ # the id of the new member should be returned
+ self.assertStatus(201, req)
+ self.assertEqual("1", req.data)
+
+ # the list of members should contain one member
+ req = self.app.get("/api/projects/raclette/members",
+ headers=self.get_auth("raclette"))
+
+ self.assertStatus(200, req)
+ self.assertEqual(len(json.loads(req.data)), 1)
+
+ # edit this member
+ req = self.app.put("/api/projects/raclette/members/1", data={
+ "name": "Fred"
+ }, headers=self.get_auth("raclette"))
+
+ self.assertStatus(200, req)
+
+ # get should return the new name
+ req = self.app.get("/api/projects/raclette/members/1",
+ headers=self.get_auth("raclette"))
+
+ self.assertStatus(200, req)
+ self.assertEqual("Fred", json.loads(req.data)["name"])
+
+ # delete a member
+
+ req = self.app.delete("/api/projects/raclette/members/1",
+ headers=self.get_auth("raclette"))
+
+ self.assertStatus(200, req)
+
+ # the list of members should be empty
+ # get the list of members (should be empty)
+ req = self.app.get("/api/projects/raclette/members",
+ headers=self.get_auth("raclette"))
+
+ self.assertStatus(200, req)
+ self.assertEqual('[]', req.data)
+
+ def test_bills(self):
+ # create a project
+ self.api_create("raclette")
+
+ # add members
+ self.api_add_member("raclette", "alexis")
+ self.api_add_member("raclette", "fred")
+ self.api_add_member("raclette", "arnaud")
+
+ # get the list of bills (should be empty)
+ req = self.app.get("/api/projects/raclette/bills",
+ headers=self.get_auth("raclette"))
+ self.assertStatus(200, req)
+
+ self.assertEqual("[]", req.data)
+
+ # add a bill
+ req = self.app.post("/api/projects/raclette/bills", data={
+ 'date': '2011-08-10',
+ 'what': u'fromage',
+ 'payer': "1",
+ 'payed_for': ["1", "2"],
+ 'amount': '25',
+ }, headers=self.get_auth("raclette"))
+
+ # should return the id
+ self.assertStatus(201, req)
+ self.assertEqual(req.data, "1")
+
+ # get this bill details
+ req = self.app.get("/api/projects/raclette/bills/1",
+ headers=self.get_auth("raclette"))
+
+ # compare with the added info
+ self.assertStatus(200, req)
+ expected = {
+ "what": "fromage",
+ "payer_id": 1,
+ "owers": [
+ {"activated": True, "id": 1, "name": "alexis"},
+ {"activated": True, "id": 2, "name": "fred"}],
+ "amount": 25.0,
+ "date": "2011-08-10",
+ "id": 1}
+
+ self.assertDictEqual(expected, json.loads(req.data))
+
+ # the list of bills should lenght 1
+ req = self.app.get("/api/projects/raclette/bills",
+ headers=self.get_auth("raclette"))
+ self.assertStatus(200, req)
+ self.assertEqual(1, len(json.loads(req.data)))
+
+ # edit with errors should return an error
+ req = self.app.put("/api/projects/raclette/bills/1", data={
+ 'date': '201111111-08-10', # not a date
+ 'what': u'fromage',
+ 'payer': "1",
+ 'payed_for': ["1", "2"],
+ 'amount': '25',
+ }, headers=self.get_auth("raclette"))
+
+ self.assertStatus(400, req)
+ self.assertEqual('{"date": ["This field is required."]}', req.data)
+
+ # edit a bill
+ req = self.app.put("/api/projects/raclette/bills/1", data={
+ 'date': '2011-09-10',
+ 'what': u'beer',
+ 'payer': "2",
+ 'payed_for': ["1", "2"],
+ 'amount': '25',
+ }, headers=self.get_auth("raclette"))
+
+ # check its fields
+ req = self.app.get("/api/projects/raclette/bills/1",
+ headers=self.get_auth("raclette"))
+
+ expected = {
+ "what": "beer",
+ "payer_id": 2,
+ "owers": [
+ {"activated": True, "id": 1, "name": "alexis"},
+ {"activated": True, "id": 2, "name": "fred"}],
+ "amount": 25.0,
+ "date": "2011-09-10",
+ "id": 1}
+
+ self.assertDictEqual(expected, json.loads(req.data))
+
+ # delete a bill
+ req = self.app.delete("/api/projects/raclette/bills/1",
+ headers=self.get_auth("raclette"))
+ self.assertStatus(200, req)
+
+ # getting it should return a 404
+ req = self.app.get("/api/projects/raclette/bills/1",
+ headers=self.get_auth("raclette"))
+ self.assertStatus(404, req)
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/budget/utils.py b/budget/utils.py
index 88b8580..df165b5 100644
--- a/budget/utils.py
+++ b/budget/utils.py
@@ -17,6 +17,22 @@ def slugify(value):
value = unicode(re.sub('[^\w\s-]', '', value).strip().lower())
return re.sub('[-\s]+', '-', value)
+
+def get_billform_for(project, set_default=True, **kwargs):
+ """Return an instance of BillForm configured for a particular project.
+
+ :set_default: if set to True, on GET methods (usually when we want to
+ display the default form, it will call set_default on it.
+
+ """
+ form = BillForm(**kwargs)
+ form.payed_for.choices = form.payer.choices = [(str(m.id), m.name) for m in project.active_members]
+ form.payed_for.default = [str(m.id) for m in project.active_members]
+
+ if set_default and request.method == "GET":
+ form.set_default()
+ return form
+
class Redirect303(HTTPException, RoutingException):
"""Raise if the map requests a redirect. This is for example the case if
`strict_slashes` are activated and an url that requires a trailing slash.
@@ -39,4 +55,3 @@ def for_all_methods(decorator):
setattr(cls, name, decorator(method))
return cls
return decorate
-
diff --git a/budget/web.py b/budget/web.py
index 37c6415..94c42d3 100644
--- a/budget/web.py
+++ b/budget/web.py
@@ -262,7 +262,7 @@ def add_bill():
if request.method == 'POST':
if form.validate():
bill = Bill()
- db.session.add(form.save(bill))
+ db.session.add(form.save(bill, g.project))
db.session.commit()
flash("The bill has been added")
@@ -295,7 +295,7 @@ def edit_bill(bill_id):
form = get_billform_for(request, g.project, set_default=False)
if request.method == 'POST' and form.validate():
- form.save(bill)
+ form.save(bill, g.project)
db.session.commit()
flash("The bill has been modified")