aboutsummaryrefslogtreecommitdiff
path: root/ihatemoney/api/common.py
diff options
context:
space:
mode:
Diffstat (limited to 'ihatemoney/api/common.py')
-rw-r--r--ihatemoney/api/common.py191
1 files changed, 191 insertions, 0 deletions
diff --git a/ihatemoney/api/common.py b/ihatemoney/api/common.py
new file mode 100644
index 0000000..728d2a8
--- /dev/null
+++ b/ihatemoney/api/common.py
@@ -0,0 +1,191 @@
+# coding: utf8
+from flask import request, current_app
+from flask_restful import Resource, abort
+from wtforms.fields.core import BooleanField
+
+from ihatemoney.models import db, Project, Person, Bill
+from ihatemoney.forms import ProjectForm, EditProjectForm, MemberForm, get_billform_for
+from werkzeug.security import check_password_hash
+from functools import wraps
+
+
+def need_auth(f):
+ """Check the request for basic authentication for a given project.
+
+ Return the project if the authorization is good, abort the request with a 401 otherwise
+ """
+
+ @wraps(f)
+ def wrapper(*args, **kwargs):
+ auth = request.authorization
+ project_id = kwargs.get("project_id")
+
+ # Use Basic Auth
+ if auth and project_id and auth.username == project_id:
+ project = Project.query.get(auth.username)
+ if project and check_password_hash(project.password, auth.password):
+ # The whole project object will be passed instead of project_id
+ kwargs.pop("project_id")
+ return f(*args, project=project, **kwargs)
+ else:
+ # Use Bearer token Auth
+ auth_header = request.headers.get("Authorization", "")
+ auth_token = ""
+ try:
+ auth_token = auth_header.split(" ")[1]
+ except IndexError:
+ abort(401)
+ project_id = Project.verify_token(auth_token, token_type="non_timed_token")
+ if auth_token and project_id:
+ project = Project.query.get(project_id)
+ if project:
+ kwargs.pop("project_id")
+ return f(*args, project=project, **kwargs)
+ abort(401)
+
+ return wrapper
+
+
+class ProjectsHandler(Resource):
+ def post(self):
+ form = ProjectForm(meta={"csrf": False})
+ if form.validate() and current_app.config.get("ALLOW_PUBLIC_PROJECT_CREATION"):
+ project = form.save()
+ db.session.add(project)
+ db.session.commit()
+ return project.id, 201
+ return form.errors, 400
+
+
+class ProjectHandler(Resource):
+ method_decorators = [need_auth]
+
+ def get(self, project):
+ return project
+
+ def delete(self, project):
+ db.session.delete(project)
+ db.session.commit()
+ return "DELETED"
+
+ def put(self, project):
+ form = EditProjectForm(meta={"csrf": False})
+ if form.validate() and current_app.config.get("ALLOW_PUBLIC_PROJECT_CREATION"):
+ form.update(project)
+ db.session.commit()
+ return "UPDATED"
+ return form.errors, 400
+
+
+class ProjectStatsHandler(Resource):
+ method_decorators = [need_auth]
+
+ def get(self, project):
+ return project.members_stats
+
+
+class APIMemberForm(MemberForm):
+ """ Member is not disablable via a Form.
+
+ But we want Member.enabled to be togglable via the API.
+ """
+
+ activated = BooleanField(false_values=("false", "", "False"))
+
+ def save(self, project, person):
+ person.activated = self.activated.data
+ return super(APIMemberForm, self).save(project, person)
+
+
+class MembersHandler(Resource):
+ method_decorators = [need_auth]
+
+ def get(self, project):
+ return project.members
+
+ def post(self, project):
+ form = MemberForm(project, meta={"csrf": False})
+ if form.validate():
+ member = Person()
+ form.save(project, member)
+ db.session.commit()
+ return member.id, 201
+ return form.errors, 400
+
+
+class MemberHandler(Resource):
+ method_decorators = [need_auth]
+
+ def get(self, project, member_id):
+ member = Person.query.get(member_id, project)
+ if not member or member.project != project:
+ return "Not Found", 404
+ return member
+
+ def put(self, project, member_id):
+ form = APIMemberForm(project, meta={"csrf": False}, edit=True)
+ if form.validate():
+ member = Person.query.get(member_id, project)
+ form.save(project, member)
+ db.session.commit()
+ return member
+ return form.errors, 400
+
+ def delete(self, project, member_id):
+ if project.remove_member(member_id):
+ return "OK"
+ return "Not Found", 404
+
+
+class BillsHandler(Resource):
+ method_decorators = [need_auth]
+
+ def get(self, project):
+ return project.get_bills().all()
+
+ def post(self, project):
+ form = get_billform_for(project, True, meta={"csrf": False})
+ if form.validate():
+ bill = Bill()
+ form.save(bill, project)
+ db.session.add(bill)
+ db.session.commit()
+ return bill.id, 201
+ return form.errors, 400
+
+
+class BillHandler(Resource):
+ method_decorators = [need_auth]
+
+ def get(self, project, bill_id):
+ bill = Bill.query.get(project, bill_id)
+ if not bill:
+ return "Not Found", 404
+ return bill, 200
+
+ def put(self, project, bill_id):
+ form = get_billform_for(project, True, meta={"csrf": False})
+ if form.validate():
+ bill = Bill.query.get(project, bill_id)
+ form.save(bill, project)
+ db.session.commit()
+ return bill.id, 200
+ return form.errors, 400
+
+ def delete(self, project, bill_id):
+ bill = Bill.query.delete(project, bill_id)
+ db.session.commit()
+ if not bill:
+ return "Not Found", 404
+ return "OK", 200
+
+
+class TokenHandler(Resource):
+ method_decorators = [need_auth]
+
+ def get(self, project):
+ if not project:
+ return "Not Found", 404
+
+ token = project.generate_token()
+ return {"token": token}, 200