# coding: utf8
import base64
from collections import defaultdict
import datetime
import io
import json
import os
from time import sleep
import unittest
from unittest.mock import patch
from flask import session
from flask_testing import TestCase
from sqlalchemy import orm
from werkzeug.security import check_password_hash, generate_password_hash
from ihatemoney import history, models, utils
from ihatemoney.manage import DeleteProject, GenerateConfig, GeneratePasswordHash
from ihatemoney.run import create_app, db, load_configuration
from ihatemoney.versioning import LoggingMode
# Unset configuration file env var if previously set
os.environ.pop("IHATEMONEY_SETTINGS_FILE_PATH", None)
__HERE__ = os.path.dirname(os.path.abspath(__file__))
class BaseTestCase(TestCase):
SECRET_KEY = "TEST SESSION"
def create_app(self):
# Pass the test object as a configuration.
return create_app(self)
def setUp(self):
db.create_all()
def tearDown(self):
# clean after testing
db.session.remove()
db.drop_all()
def login(self, project, password=None, test_client=None):
password = password or project
return self.client.post(
"/authenticate",
data=dict(id=project, password=password),
follow_redirects=True,
)
def post_project(self, name):
"""Create a fake project"""
# create the project
self.client.post(
"/create",
data={
"name": name,
"id": name,
"password": name,
"contact_email": "%s@notmyidea.org" % name,
},
)
def create_project(self, name):
project = models.Project(
id=name,
name=str(name),
password=generate_password_hash(name),
contact_email="%s@notmyidea.org" % name,
)
models.db.session.add(project)
models.db.session.commit()
class IhatemoneyTestCase(BaseTestCase):
SQLALCHEMY_DATABASE_URI = "sqlite://"
TESTING = True
WTF_CSRF_ENABLED = False # Simplifies the tests.
def assertStatus(self, expected, resp, url=""):
return self.assertEqual(
expected,
resp.status_code,
"%s expected %s, got %s" % (url, expected, resp.status_code),
)
class ConfigurationTestCase(BaseTestCase):
def test_default_configuration(self):
"""Test that default settings are loaded when no other configuration file is specified"""
self.assertFalse(self.app.config["DEBUG"])
self.assertEqual(
self.app.config["SQLALCHEMY_DATABASE_URI"], "sqlite:////tmp/ihatemoney.db"
)
self.assertFalse(self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"])
self.assertEqual(
self.app.config["MAIL_DEFAULT_SENDER"],
("Budget manager", "budget@notmyidea.org"),
)
def test_env_var_configuration_file(self):
"""Test that settings are loaded from the specified configuration file"""
os.environ["IHATEMONEY_SETTINGS_FILE_PATH"] = os.path.join(
__HERE__, "ihatemoney_envvar.cfg"
)
load_configuration(self.app)
self.assertEqual(self.app.config["SECRET_KEY"], "lalatra")
# Test that the specified configuration file is loaded
# even if the default configuration file ihatemoney.cfg exists
os.environ["IHATEMONEY_SETTINGS_FILE_PATH"] = os.path.join(
__HERE__, "ihatemoney_envvar.cfg"
)
self.app.config.root_path = __HERE__
load_configuration(self.app)
self.assertEqual(self.app.config["SECRET_KEY"], "lalatra")
os.environ.pop("IHATEMONEY_SETTINGS_FILE_PATH", None)
def test_default_configuration_file(self):
"""Test that settings are loaded from the default configuration file"""
self.app.config.root_path = __HERE__
load_configuration(self.app)
self.assertEqual(self.app.config["SECRET_KEY"], "supersecret")
class BudgetTestCase(IhatemoneyTestCase):
def test_notifications(self):
"""Test that the notifications are sent, and that email addresses
are checked properly.
"""
# sending a message to one person
with self.app.mail.record_messages() as outbox:
# create a project
self.login("raclette")
self.post_project("raclette")
self.client.post(
"/raclette/invite", data={"emails": "alexis@notmyidea.org"}
)
self.assertEqual(len(outbox), 2)
self.assertEqual(outbox[0].recipients, ["raclette@notmyidea.org"])
self.assertEqual(outbox[1].recipients, ["alexis@notmyidea.org"])
# sending a message to multiple persons
with self.app.mail.record_messages() as outbox:
self.client.post(
"/raclette/invite",
data={"emails": "alexis@notmyidea.org, toto@notmyidea.org"},
)
# only one message is sent to multiple persons
self.assertEqual(len(outbox), 1)
self.assertEqual(
outbox[0].recipients, ["alexis@notmyidea.org", "toto@notmyidea.org"]
)
# mail address checking
with self.app.mail.record_messages() as outbox:
response = self.client.post("/raclette/invite", data={"emails": "toto"})
self.assertEqual(len(outbox), 0) # no message sent
self.assertIn("The email toto is not valid", response.data.decode("utf-8"))
# mixing good and wrong addresses shouldn't send any messages
with self.app.mail.record_messages() as outbox:
self.client.post(
"/raclette/invite", data={"emails": "alexis@notmyidea.org, alexis"}
) # not valid
# only one message is sent to multiple persons
self.assertEqual(len(outbox), 0)
def test_invite(self):
"""Test that invitation e-mails are sent properly
"""
self.login("raclette")
self.post_project("raclette")
with self.app.mail.record_messages() as outbox:
self.client.post("/raclette/invite", data={"emails": "toto@notmyidea.org"})
self.assertEqual(len(outbox), 1)
url_start = outbox[0].body.find("You can log in using this link: ") + 32
url_end = outbox[0].body.find(".\n", url_start)
url = outbox[0].body[url_start:url_end]
self.client.get("/exit")
# Test that we got a valid token
resp = self.client.get(url, follow_redirects=True)
self.assertIn(
'You probably want to ", resp.data.decode("utf-8"))
# Test that password can be changed
self.client.post(
url, data={"password": "pass", "password_confirmation": "pass"}
)
resp = self.login("raclette", password="pass")
self.assertIn(
"Account manager - raclette", resp.data.decode("utf-8")
)
# Test empty and null tokens
resp = self.client.get("/reset-password")
self.assertIn("No token provided", resp.data.decode("utf-8"))
resp = self.client.get("/reset-password?token=token")
self.assertIn("Invalid token", resp.data.decode("utf-8"))
def test_project_creation(self):
with self.app.test_client() as c:
# add a valid project
c.post(
"/create",
data={
"name": "The fabulous raclette party",
"id": "raclette",
"password": "party",
"contact_email": "raclette@notmyidea.org",
},
)
# session is updated
self.assertTrue(session["raclette"])
# project is created
self.assertEqual(len(models.Project.query.all()), 1)
# Add a second project with the same id
models.Project.query.get("raclette")
c.post(
"/create",
data={
"name": "Another raclette party",
"id": "raclette", # already used !
"password": "party",
"contact_email": "raclette@notmyidea.org",
},
)
# no new project added
self.assertEqual(len(models.Project.query.all()), 1)
def test_project_creation_without_public_permissions(self):
self.app.config["ALLOW_PUBLIC_PROJECT_CREATION"] = False
with self.app.test_client() as c:
# add a valid project
c.post(
"/create",
data={
"name": "The fabulous raclette party",
"id": "raclette",
"password": "party",
"contact_email": "raclette@notmyidea.org",
},
)
# session is not updated
self.assertNotIn("raclette", session)
# project is not created
self.assertEqual(len(models.Project.query.all()), 0)
def test_project_creation_with_public_permissions(self):
self.app.config["ALLOW_PUBLIC_PROJECT_CREATION"] = True
with self.app.test_client() as c:
# add a valid project
c.post(
"/create",
data={
"name": "The fabulous raclette party",
"id": "raclette",
"password": "party",
"contact_email": "raclette@notmyidea.org",
},
)
# session is updated
self.assertTrue(session["raclette"])
# project is created
self.assertEqual(len(models.Project.query.all()), 1)
def test_project_deletion(self):
with self.app.test_client() as c:
c.post(
"/create",
data={
"name": "raclette party",
"id": "raclette",
"password": "party",
"contact_email": "raclette@notmyidea.org",
},
)
# project added
self.assertEqual(len(models.Project.query.all()), 1)
c.get("/raclette/delete")
# project removed
self.assertEqual(len(models.Project.query.all()), 0)
def test_bill_placeholder(self):
self.post_project("raclette")
self.login("raclette")
result = self.client.get("/raclette/")
# Empty bill list and no members, should now propose to add members first
self.assertIn(
'You probably want to ', resp.data.decode("utf-8"))
def test_authentication(self):
# try to authenticate without credentials should redirect
# to the authentication page
resp = self.client.post("/authenticate")
self.assertIn("Authentication", resp.data.decode("utf-8"))
# raclette that the login / logout process works
self.create_project("raclette")
# try to see the project while not being authenticated should redirect
# to the authentication page
resp = self.client.get("/raclette", follow_redirects=True)
self.assertIn("Authentication", resp.data.decode("utf-8"))
# try to connect with wrong credentials should not work
with self.app.test_client() as c:
resp = c.post("/authenticate", data={"id": "raclette", "password": "nope"})
self.assertIn("Authentication", resp.data.decode("utf-8"))
self.assertNotIn("raclette", session)
# try to connect with the right credentials should work
with self.app.test_client() as c:
resp = c.post(
"/authenticate", data={"id": "raclette", "password": "raclette"}
)
self.assertNotIn("Authentication", resp.data.decode("utf-8"))
self.assertIn("raclette", session)
self.assertTrue(session["raclette"])
# logout should wipe the session out
c.get("/exit")
self.assertNotIn("raclette", session)
# test that with admin credentials, one can access every project
self.app.config["ADMIN_PASSWORD"] = generate_password_hash("pass")
with self.app.test_client() as c:
resp = c.post("/admin?goto=%2Fraclette", data={"admin_password": "pass"})
self.assertNotIn("Authentication", resp.data.decode("utf-8"))
self.assertTrue(session["is_admin"])
def test_admin_authentication(self):
self.app.config["ADMIN_PASSWORD"] = generate_password_hash("pass")
# Disable public project creation so we have an admin endpoint to test
self.app.config["ALLOW_PUBLIC_PROJECT_CREATION"] = False
# test the redirection to the authentication page when trying to access admin endpoints
resp = self.client.get("/create")
self.assertIn('', resp.data.decode("utf-8"))
# test right password
resp = self.client.post(
"/admin?goto=%2Fcreate", data={"admin_password": "pass"}
)
self.assertIn('/create', resp.data.decode("utf-8"))
# test wrong password
resp = self.client.post(
"/admin?goto=%2Fcreate", data={"admin_password": "wrong"}
)
self.assertNotIn('/create', resp.data.decode("utf-8"))
# test empty password
resp = self.client.post("/admin?goto=%2Fcreate", data={"admin_password": ""})
self.assertNotIn('/create', resp.data.decode("utf-8"))
def test_login_throttler(self):
self.app.config["ADMIN_PASSWORD"] = generate_password_hash("pass")
# Activate admin login throttling by authenticating 4 times with a wrong passsword
self.client.post("/admin?goto=%2Fcreate", data={"admin_password": "wrong"})
self.client.post("/admin?goto=%2Fcreate", data={"admin_password": "wrong"})
self.client.post("/admin?goto=%2Fcreate", data={"admin_password": "wrong"})
resp = self.client.post(
"/admin?goto=%2Fcreate", data={"admin_password": "wrong"}
)
self.assertIn(
"Too many failed login attempts, please retry later.",
resp.data.decode("utf-8"),
)
# Change throttling delay
import gc
for obj in gc.get_objects():
if isinstance(obj, utils.LoginThrottler):
obj._delay = 0.005
break
# Wait for delay to expire and retry logging in
sleep(1)
resp = self.client.post(
"/admin?goto=%2Fcreate", data={"admin_password": "wrong"}
)
self.assertNotIn(
"Too many failed login attempts, please retry later.",
resp.data.decode("utf-8"),
)
def test_manage_bills(self):
self.post_project("raclette")
# add two persons
self.client.post("/raclette/members/add", data={"name": "alexis"})
self.client.post("/raclette/members/add", data={"name": "fred"})
members_ids = [m.id for m in models.Project.query.get("raclette").members]
# create a bill
self.client.post(
"/raclette/add",
data={
"date": "2011-08-10",
"what": "fromage à raclette",
"payer": members_ids[0],
"payed_for": members_ids,
"amount": "25",
},
)
models.Project.query.get("raclette")
bill = models.Bill.query.one()
self.assertEqual(bill.amount, 25)
# edit the bill
self.client.post(
"/raclette/edit/%s" % bill.id,
data={
"date": "2011-08-10",
"what": "fromage à raclette",
"payer": members_ids[0],
"payed_for": members_ids,
"amount": "10",
},
)
bill = models.Bill.query.one()
self.assertEqual(bill.amount, 10, "bill edition")
# delete the bill
self.client.get("/raclette/delete/%s" % bill.id)
self.assertEqual(0, len(models.Bill.query.all()), "bill deletion")
# test balance
self.client.post(
"/raclette/add",
data={
"date": "2011-08-10",
"what": "fromage à raclette",
"payer": members_ids[0],
"payed_for": members_ids,
"amount": "19",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2011-08-10",
"what": "fromage à raclette",
"payer": members_ids[1],
"payed_for": members_ids[0],
"amount": "20",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2011-08-10",
"what": "fromage à raclette",
"payer": members_ids[1],
"payed_for": members_ids,
"amount": "17",
},
)
balance = models.Project.query.get("raclette").balance
self.assertEqual(set(balance.values()), set([19.0, -19.0]))
# Bill with negative amount
self.client.post(
"/raclette/add",
data={
"date": "2011-08-12",
"what": "fromage à raclette",
"payer": members_ids[0],
"payed_for": members_ids,
"amount": "-25",
},
)
bill = models.Bill.query.filter(models.Bill.date == "2011-08-12")[0]
self.assertEqual(bill.amount, -25)
# add a bill with a comma
self.client.post(
"/raclette/add",
data={
"date": "2011-08-01",
"what": "fromage à raclette",
"payer": members_ids[0],
"payed_for": members_ids,
"amount": "25,02",
},
)
bill = models.Bill.query.filter(models.Bill.date == "2011-08-01")[0]
self.assertEqual(bill.amount, 25.02)
def test_weighted_balance(self):
self.post_project("raclette")
# add two persons
self.client.post("/raclette/members/add", data={"name": "alexis"})
self.client.post(
"/raclette/members/add", data={"name": "freddy familly", "weight": 4}
)
members_ids = [m.id for m in models.Project.query.get("raclette").members]
# test balance
self.client.post(
"/raclette/add",
data={
"date": "2011-08-10",
"what": "fromage à raclette",
"payer": members_ids[0],
"payed_for": members_ids,
"amount": "10",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2011-08-10",
"what": "pommes de terre",
"payer": members_ids[1],
"payed_for": members_ids,
"amount": "10",
},
)
balance = models.Project.query.get("raclette").balance
self.assertEqual(set(balance.values()), set([6, -6]))
def test_trimmed_members(self):
self.post_project("raclette")
# Add two times the same person (with a space at the end).
self.client.post("/raclette/members/add", data={"name": "alexis"})
self.client.post("/raclette/members/add", data={"name": "alexis "})
members = models.Project.query.get("raclette").members
self.assertEqual(len(members), 1)
def test_weighted_members_list(self):
self.post_project("raclette")
# add two persons
self.client.post("/raclette/members/add", data={"name": "alexis"})
self.client.post("/raclette/members/add", data={"name": "tata", "weight": 1})
resp = self.client.get("/raclette/")
self.assertIn("extra-info", resp.data.decode("utf-8"))
self.client.post(
"/raclette/members/add", data={"name": "freddy familly", "weight": 4}
)
resp = self.client.get("/raclette/")
self.assertNotIn("extra-info", resp.data.decode("utf-8"))
def test_negative_weight(self):
self.post_project("raclette")
# Add one user and edit it to have a negative share
self.client.post("/raclette/members/add", data={"name": "alexis"})
resp = self.client.post(
"/raclette/members/1/edit", data={"name": "alexis", "weight": -1}
)
# An error should be generated, and its weight should still be 1.
self.assertIn('
', resp.data.decode("utf-8"))
self.assertEqual(len(models.Project.query.get("raclette").members), 1)
self.assertEqual(models.Project.query.get("raclette").members[0].weight, 1)
def test_rounding(self):
self.post_project("raclette")
# add members
self.client.post("/raclette/members/add", data={"name": "alexis"})
self.client.post("/raclette/members/add", data={"name": "fred"})
self.client.post("/raclette/members/add", data={"name": "tata"})
# create bills
self.client.post(
"/raclette/add",
data={
"date": "2011-08-10",
"what": "fromage à raclette",
"payer": 1,
"payed_for": [1, 2, 3],
"amount": "24.36",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2011-08-10",
"what": "red wine",
"payer": 2,
"payed_for": [1],
"amount": "19.12",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2011-08-10",
"what": "delicatessen",
"payer": 1,
"payed_for": [1, 2],
"amount": "22",
},
)
balance = models.Project.query.get("raclette").balance
result = {}
result[models.Project.query.get("raclette").members[0].id] = 8.12
result[models.Project.query.get("raclette").members[1].id] = 0.0
result[models.Project.query.get("raclette").members[2].id] = -8.12
# Since we're using floating point to store currency, we can have some
# rounding issues that prevent test from working.
# However, we should obtain the same values as the theoretical ones if we
# round to 2 decimals, like in the UI.
for key, value in balance.items():
self.assertEqual(round(value, 2), result[key])
def test_edit_project(self):
# A project should be editable
self.post_project("raclette")
new_data = {
"name": "Super raclette party!",
"contact_email": "alexis@notmyidea.org",
"password": "didoudida",
"logging_preference": LoggingMode.ENABLED.value,
}
resp = self.client.post("/raclette/edit", data=new_data, follow_redirects=True)
self.assertEqual(resp.status_code, 200)
project = models.Project.query.get("raclette")
self.assertEqual(project.name, new_data["name"])
self.assertEqual(project.contact_email, new_data["contact_email"])
self.assertTrue(check_password_hash(project.password, new_data["password"]))
# Editing a project with a wrong email address should fail
new_data["contact_email"] = "wrong_email"
resp = self.client.post("/raclette/edit", data=new_data, follow_redirects=True)
self.assertIn("Invalid email address", resp.data.decode("utf-8"))
def test_dashboard(self):
# test that the dashboard is deactivated by default
resp = self.client.post(
"/admin?goto=%2Fdashboard",
data={"admin_password": "adminpass"},
follow_redirects=True,
)
self.assertIn('
', resp.data.decode("utf-8"))
# test access to the dashboard when it is activated
self.app.config["ACTIVATE_ADMIN_DASHBOARD"] = True
self.app.config["ADMIN_PASSWORD"] = generate_password_hash("adminpass")
resp = self.client.post(
"/admin?goto=%2Fdashboard",
data={"admin_password": "adminpass"},
follow_redirects=True,
)
self.assertIn(
"| Project | Number of members",
resp.data.decode("utf-8"),
)
def test_statistics_page(self):
self.post_project("raclette")
response = self.client.get("/raclette/statistics")
self.assertEqual(response.status_code, 200)
def test_statistics(self):
self.post_project("raclette")
# add members
self.client.post("/raclette/members/add", data={"name": "alexis", "weight": 2})
self.client.post("/raclette/members/add", data={"name": "fred"})
self.client.post("/raclette/members/add", data={"name": "tata"})
# Add a member with a balance=0 :
self.client.post("/raclette/members/add", data={"name": "toto"})
# create bills
self.client.post(
"/raclette/add",
data={
"date": "2011-08-10",
"what": "fromage à raclette",
"payer": 1,
"payed_for": [1, 2, 3],
"amount": "10.0",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2011-08-10",
"what": "red wine",
"payer": 2,
"payed_for": [1],
"amount": "20",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2011-08-10",
"what": "delicatessen",
"payer": 1,
"payed_for": [1, 2],
"amount": "10",
},
)
response = self.client.get("/raclette/statistics")
first_cell = ' | '
indent = "\n "
self.assertIn(
first_cell
+ "alexis | "
+ indent
+ "20.00 | "
+ indent
+ "31.67 | \n",
response.data.decode("utf-8"),
)
self.assertIn(
first_cell
+ "fred"
+ indent
+ "20.00 | "
+ indent
+ "5.83 | \n",
response.data.decode("utf-8"),
)
self.assertIn(
first_cell
+ "tata"
+ indent
+ "0.00 | "
+ indent
+ "2.50 | \n",
response.data.decode("utf-8"),
)
self.assertIn(
first_cell
+ "toto"
+ indent
+ "0.00 | "
+ indent
+ "0.00 | \n",
response.data.decode("utf-8"),
)
def test_settle_page(self):
self.post_project("raclette")
response = self.client.get("/raclette/settle_bills")
self.assertEqual(response.status_code, 200)
def test_settle(self):
self.post_project("raclette")
# add members
self.client.post("/raclette/members/add", data={"name": "alexis"})
self.client.post("/raclette/members/add", data={"name": "fred"})
self.client.post("/raclette/members/add", data={"name": "tata"})
# Add a member with a balance=0 :
self.client.post("/raclette/members/add", data={"name": "toto"})
# create bills
self.client.post(
"/raclette/add",
data={
"date": "2011-08-10",
"what": "fromage à raclette",
"payer": 1,
"payed_for": [1, 2, 3],
"amount": "10.0",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2011-08-10",
"what": "red wine",
"payer": 2,
"payed_for": [1],
"amount": "20",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2011-08-10",
"what": "delicatessen",
"payer": 1,
"payed_for": [1, 2],
"amount": "10",
},
)
project = models.Project.query.get("raclette")
transactions = project.get_transactions_to_settle_bill()
members = defaultdict(int)
# We should have the same values between transactions and project balances
for t in transactions:
members[t["ower"]] -= t["amount"]
members[t["receiver"]] += t["amount"]
balance = models.Project.query.get("raclette").balance
for m, a in members.items():
assert abs(a - balance[m.id]) < 0.01
return
def test_settle_zero(self):
self.post_project("raclette")
# add members
self.client.post("/raclette/members/add", data={"name": "alexis"})
self.client.post("/raclette/members/add", data={"name": "fred"})
self.client.post("/raclette/members/add", data={"name": "tata"})
# create bills
self.client.post(
"/raclette/add",
data={
"date": "2016-12-31",
"what": "fromage à raclette",
"payer": 1,
"payed_for": [1, 2, 3],
"amount": "10.0",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2016-12-31",
"what": "red wine",
"payer": 2,
"payed_for": [1, 3],
"amount": "20",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2017-01-01",
"what": "refund",
"payer": 3,
"payed_for": [2],
"amount": "13.33",
},
)
project = models.Project.query.get("raclette")
transactions = project.get_transactions_to_settle_bill()
# There should not be any zero-amount transfer after rounding
for t in transactions:
rounded_amount = round(t["amount"], 2)
self.assertNotEqual(
0.0,
rounded_amount,
msg="%f is equal to zero after rounding" % t["amount"],
)
def test_export(self):
self.post_project("raclette")
# add members
self.client.post("/raclette/members/add", data={"name": "alexis", "weight": 2})
self.client.post("/raclette/members/add", data={"name": "fred"})
self.client.post("/raclette/members/add", data={"name": "tata"})
self.client.post("/raclette/members/add", data={"name": "pépé"})
# create bills
self.client.post(
"/raclette/add",
data={
"date": "2016-12-31",
"what": "fromage à raclette",
"payer": 1,
"payed_for": [1, 2, 3, 4],
"amount": "10.0",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2016-12-31",
"what": "red wine",
"payer": 2,
"payed_for": [1, 3],
"amount": "200",
},
)
self.client.post(
"/raclette/add",
data={
"date": "2017-01-01",
"what": "refund",
"payer": 3,
"payed_for": [2],
"amount": "13.33",
},
)
# generate json export of bills
resp = self.client.get("/raclette/export/bills.json")
expected = [
{
"date": "2017-01-01",
"what": "refund",
"amount": 13.33,
"payer_name": "tata",
"payer_weight": 1.0,
"owers": ["fred"],
},
{
"date": "2016-12-31",
"what": "red wine",
"amount": 200.0,
"payer_name": "fred",
"payer_weight": 1.0,
"owers": ["alexis", "tata"],
},
{
"date": "2016-12-31",
"what": "fromage \xe0 raclette",
"amount": 10.0,
"payer_name": "alexis",
"payer_weight": 2.0,
"owers": ["alexis", "fred", "tata", "p\xe9p\xe9"],
},
]
self.assertEqual(json.loads(resp.data.decode("utf-8")), expected)
# generate csv export of bills
resp = self.client.get("/raclette/export/bills.csv")
expected = [
"date,what,amount,payer_name,payer_weight,owers",
"2017-01-01,refund,13.33,tata,1.0,fred",
'2016-12-31,red wine,200.0,fred,1.0,"alexis, tata"',
'2016-12-31,fromage à raclette,10.0,alexis,2.0,"alexis, fred, tata, pépé"',
]
received_lines = resp.data.decode("utf-8").split("\n")
for i, line in enumerate(expected):
self.assertEqual(
set(line.split(",")), set(received_lines[i].strip("\r").split(","))
)
# generate json export of transactions
resp = self.client.get("/raclette/export/transactions.json")
expected = [
{"amount": 2.00, "receiver": "fred", "ower": "p\xe9p\xe9"},
{"amount": 55.34, "receiver": "fred", "ower": "tata"},
{"amount": 127.33, "receiver": "fred", "ower": "alexis"},
]
self.assertEqual(json.loads(resp.data.decode("utf-8")), expected)
# generate csv export of transactions
resp = self.client.get("/raclette/export/transactions.csv")
expected = [
"amount,receiver,ower",
"2.0,fred,pépé",
"55.34,fred,tata",
"127.33,fred,alexis",
]
received_lines = resp.data.decode("utf-8").split("\n")
for i, line in enumerate(expected):
self.assertEqual(
set(line.split(",")), set(received_lines[i].strip("\r").split(","))
)
# wrong export_format should return a 404
resp = self.client.get("/raclette/export/transactions.wrong")
self.assertEqual(resp.status_code, 404)
def test_import_new_project(self):
# Import JSON in an empty project
self.post_project("raclette")
self.login("raclette")
project = models.Project.query.get("raclette")
json_to_import = [
{
"date": "2017-01-01",
"what": "refund",
"amount": 13.33,
"payer_name": "tata",
"payer_weight": 1.0,
"owers": ["fred"],
},
{
"date": "2016-12-31",
"what": "red wine",
"amount": 200.0,
"payer_name": "fred",
"payer_weight": 1.0,
"owers": ["alexis", "tata"],
},
{
"date": "2016-12-31",
"what": "fromage a raclette",
"amount": 10.0,
"payer_name": "alexis",
"payer_weight": 2.0,
"owers": ["alexis", "fred", "tata", "pepe"],
},
]
from ihatemoney.web import import_project
file = io.StringIO()
json.dump(json_to_import, file)
file.seek(0)
import_project(file, project)
bills = project.get_pretty_bills()
# Check if all bills has been add
self.assertEqual(len(bills), len(json_to_import))
# Check if name of bills are ok
b = [e["what"] for e in bills]
b.sort()
ref = [e["what"] for e in json_to_import]
ref.sort()
self.assertEqual(b, ref)
# Check if other informations in bill are ok
for i in json_to_import:
for j in bills:
if j["what"] == i["what"]:
self.assertEqual(j["payer_name"], i["payer_name"])
self.assertEqual(j["amount"], i["amount"])
self.assertEqual(j["payer_weight"], i["payer_weight"])
self.assertEqual(j["date"], i["date"])
list_project = [ower for ower in j["owers"]]
list_project.sort()
list_json = [ower for ower in i["owers"]]
list_json.sort()
self.assertEqual(list_project, list_json)
def test_import_partial_project(self):
# Import a JSON in a project with already existing data
self.post_project("raclette")
self.login("raclette")
project = models.Project.query.get("raclette")
self.client.post("/raclette/members/add", data={"name": "alexis", "weight": 2})
self.client.post("/raclette/members/add", data={"name": "fred"})
self.client.post("/raclette/members/add", data={"name": "tata"})
self.client.post(
"/raclette/add",
data={
"date": "2016-12-31",
"what": "red wine",
"payer": 2,
"payed_for": [1, 3],
"amount": "200",
},
)
json_to_import = [
{
"date": "2017-01-01",
"what": "refund",
"amount": 13.33,
"payer_name": "tata",
"payer_weight": 1.0,
"owers": ["fred"],
},
{ # This expense does not have to be present twice.
"date": "2016-12-31",
"what": "red wine",
"amount": 200.0,
"payer_name": "fred",
"payer_weight": 1.0,
"owers": ["alexis", "tata"],
},
{
"date": "2016-12-31",
"what": "fromage a raclette",
"amount": 10.0,
"payer_name": "alexis",
"payer_weight": 2.0,
"owers": ["alexis", "fred", "tata", "pepe"],
},
]
from ihatemoney.web import import_project
file = io.StringIO()
json.dump(json_to_import, file)
file.seek(0)
import_project(file, project)
bills = project.get_pretty_bills()
# Check if all bills has been add
self.assertEqual(len(bills), len(json_to_import))
# Check if name of bills are ok
b = [e["what"] for e in bills]
b.sort()
ref = [e["what"] for e in json_to_import]
ref.sort()
self.assertEqual(b, ref)
# Check if other informations in bill are ok
for i in json_to_import:
for j in bills:
if j["what"] == i["what"]:
self.assertEqual(j["payer_name"], i["payer_name"])
self.assertEqual(j["amount"], i["amount"])
self.assertEqual(j["payer_weight"], i["payer_weight"])
self.assertEqual(j["date"], i["date"])
list_project = [ower for ower in j["owers"]]
list_project.sort()
list_json = [ower for ower in i["owers"]]
list_json.sort()
self.assertEqual(list_project, list_json)
def test_import_wrong_json(self):
self.post_project("raclette")
self.login("raclette")
project = models.Project.query.get("raclette")
json_1 = [
{ # wrong keys
"checked": False,
"dimensions": {"width": 5, "height": 10},
"id": 1,
"name": "A green door",
"price": 12.5,
"tags": ["home", "green"],
}
]
json_2 = [
{ # amount missing
"date": "2017-01-01",
"what": "refund",
"payer_name": "tata",
"payer_weight": 1.0,
"owers": ["fred"],
}
]
from ihatemoney.web import import_project
try:
file = io.StringIO()
json.dump(json_1, file)
file.seek(0)
import_project(file, project)
except ValueError:
self.assertTrue(True)
except Exception:
self.fail("unexpected exception raised")
else:
self.fail("ExpectedException not raised")
try:
file = io.StringIO()
json.dump(json_2, file)
file.seek(0)
import_project(file, project)
except ValueError:
self.assertTrue(True)
except Exception:
self.fail("unexpected exception raised")
else:
self.fail("ExpectedException not raised")
class APITestCase(IhatemoneyTestCase):
"""Tests the API"""
def api_create(self, name, id=None, password=None, contact=None):
id = id or name
password = password or name
contact = contact or "%s@notmyidea.org" % name
return self.client.post(
"/api/projects",
data={
"name": name,
"id": id,
"password": password,
"contact_email": contact,
},
)
def api_add_member(self, project, name, weight=1):
self.client.post(
"/api/projects/%s/members" % project,
data={"name": name, "weight": weight},
headers=self.get_auth(project),
)
def get_auth(self, username, password=None):
password = password or username
base64string = (
base64.encodebytes(("%s:%s" % (username, password)).encode("utf-8"))
.decode("utf-8")
.replace("\n", "")
)
return {"Authorization": "Basic %s" % base64string}
def test_cors_requests(self):
# Create a project and test that CORS headers are present if requested.
resp = self.api_create("raclette")
self.assertStatus(201, resp)
# Try to do an OPTIONS requests and see if the headers are correct.
resp = self.client.options(
"/api/projects/raclette", headers=self.get_auth("raclette")
)
self.assertEqual(resp.headers["Access-Control-Allow-Origin"], "*")
def test_basic_auth(self):
# create a project
resp = self.api_create("raclette")
self.assertStatus(201, resp)
# try to do something on it being unauth should return a 401
resp = self.client.get("/api/projects/raclette")
self.assertStatus(401, resp)
# PUT / POST / DELETE / GET on the different resources
# should also return a 401
for verb in ("post",):
for resource in ("/raclette/members", "/raclette/bills"):
url = "/api/projects" + resource
self.assertStatus(401, getattr(self.client, verb)(url), verb + resource)
for verb in ("get", "delete", "put"):
for resource in ("/raclette", "/raclette/members/1", "/raclette/bills/1"):
url = "/api/projects" + resource
self.assertStatus(401, getattr(self.client, verb)(url), verb + resource)
def test_project(self):
# wrong email should return an error
resp = self.client.post(
"/api/projects",
data={
"name": "raclette",
"id": "raclette",
"password": "raclette",
"contact_email": "not-an-email",
},
)
self.assertTrue(400, resp.status_code)
self.assertEqual(
'{"contact_email": ["Invalid email address."]}\n', resp.data.decode("utf-8")
)
# create it
resp = self.api_create("raclette")
self.assertTrue(201, resp.status_code)
# create it twice should return a 400
resp = self.api_create("raclette")
self.assertTrue(400, resp.status_code)
self.assertIn("id", json.loads(resp.data.decode("utf-8")))
# get information about it
resp = self.client.get(
"/api/projects/raclette", headers=self.get_auth("raclette")
)
self.assertTrue(200, resp.status_code)
expected = {
"members": [],
"name": "raclette",
"contact_email": "raclette@notmyidea.org",
"id": "raclette",
"logging_preference": 1,
}
decoded_resp = json.loads(resp.data.decode("utf-8"))
self.assertDictEqual(decoded_resp, expected)
# edit should work
resp = self.client.put(
"/api/projects/raclette",
data={
"contact_email": "yeah@notmyidea.org",
"password": "raclette",
"name": "The raclette party",
"project_history": "y",
},
headers=self.get_auth("raclette"),
)
self.assertEqual(200, resp.status_code)
resp = self.client.get(
"/api/projects/raclette", headers=self.get_auth("raclette")
)
self.assertEqual(200, resp.status_code)
expected = {
"name": "The raclette party",
"contact_email": "yeah@notmyidea.org",
"members": [],
"id": "raclette",
"logging_preference": 1,
}
decoded_resp = json.loads(resp.data.decode("utf-8"))
self.assertDictEqual(decoded_resp, expected)
# password change is possible via API
resp = self.client.put(
"/api/projects/raclette",
data={
"contact_email": "yeah@notmyidea.org",
"password": "tartiflette",
"name": "The raclette party",
},
headers=self.get_auth("raclette"),
)
self.assertEqual(200, resp.status_code)
resp = self.client.get(
"/api/projects/raclette", headers=self.get_auth("raclette", "tartiflette")
)
self.assertEqual(200, resp.status_code)
# delete should work
resp = self.client.delete(
"/api/projects/raclette", headers=self.get_auth("raclette", "tartiflette")
)
# get should return a 401 on an unknown resource
resp = self.client.get(
"/api/projects/raclette", headers=self.get_auth("raclette")
)
self.assertEqual(401, resp.status_code)
def test_token_creation(self):
"""Test that token of project is generated
"""
# Create project
resp = self.api_create("raclette")
self.assertTrue(201, resp.status_code)
# Get token
resp = self.client.get(
"/api/projects/raclette/token", headers=self.get_auth("raclette")
)
self.assertEqual(200, resp.status_code)
decoded_resp = json.loads(resp.data.decode("utf-8"))
# Access with token
resp = self.client.get(
"/api/projects/raclette/token",
headers={"Authorization": "Basic %s" % decoded_resp["token"]},
)
self.assertEqual(200, resp.status_code)
def test_token_login(self):
resp = self.api_create("raclette")
# Get token
resp = self.client.get(
"/api/projects/raclette/token", headers=self.get_auth("raclette")
)
decoded_resp = json.loads(resp.data.decode("utf-8"))
resp = self.client.get("/authenticate?token={}".format(decoded_resp["token"]))
# Test that we are redirected.
self.assertEqual(302, resp.status_code)
def test_member(self):
# create a project
self.api_create("raclette")
# get the list of members (should be empty)
req = self.client.get(
"/api/projects/raclette/members", headers=self.get_auth("raclette")
)
self.assertStatus(200, req)
self.assertEqual("[]\n", req.data.decode("utf-8"))
# add a member
req = self.client.post(
"/api/projects/raclette/members",
data={"name": "Alexis"},
headers=self.get_auth("raclette"),
)
# the id of the new member should be returned
self.assertStatus(201, req)
self.assertEqual("1\n", req.data.decode("utf-8"))
# the list of members should contain one member
req = self.client.get(
"/api/projects/raclette/members", headers=self.get_auth("raclette")
)
self.assertStatus(200, req)
self.assertEqual(len(json.loads(req.data.decode("utf-8"))), 1)
# Try to add another member with the same name.
req = self.client.post(
"/api/projects/raclette/members",
data={"name": "Alexis"},
headers=self.get_auth("raclette"),
)
self.assertStatus(400, req)
# edit the member
req = self.client.put(
"/api/projects/raclette/members/1",
data={"name": "Fred", "weight": 2},
headers=self.get_auth("raclette"),
)
self.assertStatus(200, req)
# get should return the new name
req = self.client.get(
"/api/projects/raclette/members/1", headers=self.get_auth("raclette")
)
self.assertStatus(200, req)
self.assertEqual("Fred", json.loads(req.data.decode("utf-8"))["name"])
self.assertEqual(2, json.loads(req.data.decode("utf-8"))["weight"])
# edit this member with same information
# (test PUT idemopotence)
req = self.client.put(
"/api/projects/raclette/members/1",
data={"name": "Fred"},
headers=self.get_auth("raclette"),
)
self.assertStatus(200, req)
# de-activate the user
req = self.client.put(
"/api/projects/raclette/members/1",
data={"name": "Fred", "activated": False},
headers=self.get_auth("raclette"),
)
self.assertStatus(200, req)
req = self.client.get(
"/api/projects/raclette/members/1", headers=self.get_auth("raclette")
)
self.assertStatus(200, req)
self.assertEqual(False, json.loads(req.data.decode("utf-8"))["activated"])
# re-activate the user
req = self.client.put(
"/api/projects/raclette/members/1",
data={"name": "Fred", "activated": True},
headers=self.get_auth("raclette"),
)
req = self.client.get(
"/api/projects/raclette/members/1", headers=self.get_auth("raclette")
)
self.assertStatus(200, req)
self.assertEqual(True, json.loads(req.data.decode("utf-8"))["activated"])
# delete a member
req = self.client.delete(
"/api/projects/raclette/members/1", headers=self.get_auth("raclette")
)
self.assertStatus(200, req)
# the list of members should be empty
req = self.client.get(
"/api/projects/raclette/members", headers=self.get_auth("raclette")
)
self.assertStatus(200, req)
self.assertEqual("[]\n", req.data.decode("utf-8"))
def test_bills(self):
# create a project
self.api_create("raclette")
# add members
self.api_add_member("raclette", "alexis")
self.api_add_member("raclette", "fred")
self.api_add_member("raclette", "arnaud")
# get the list of bills (should be empty)
req = self.client.get(
"/api/projects/raclette/bills", headers=self.get_auth("raclette")
)
self.assertStatus(200, req)
self.assertEqual("[]\n", req.data.decode("utf-8"))
# add a bill
req = self.client.post(
"/api/projects/raclette/bills",
data={
"date": "2011-08-10",
"what": "fromage",
"payer": "1",
"payed_for": ["1", "2"],
"amount": "25",
"external_link": "https://raclette.fr",
},
headers=self.get_auth("raclette"),
)
# should return the id
self.assertStatus(201, req)
self.assertEqual(req.data.decode("utf-8"), "1\n")
# get this bill details
req = self.client.get(
"/api/projects/raclette/bills/1", headers=self.get_auth("raclette")
)
# compare with the added info
self.assertStatus(200, req)
expected = {
"what": "fromage",
"payer_id": 1,
"owers": [
{"activated": True, "id": 1, "name": "alexis", "weight": 1},
{"activated": True, "id": 2, "name": "fred", "weight": 1},
],
"amount": 25.0,
"date": "2011-08-10",
"id": 1,
"external_link": "https://raclette.fr",
}
got = json.loads(req.data.decode("utf-8"))
self.assertEqual(
datetime.date.today(),
datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date(),
)
del got["creation_date"]
self.assertDictEqual(expected, got)
# the list of bills should length 1
req = self.client.get(
"/api/projects/raclette/bills", headers=self.get_auth("raclette")
)
self.assertStatus(200, req)
self.assertEqual(1, len(json.loads(req.data.decode("utf-8"))))
# edit with errors should return an error
req = self.client.put(
"/api/projects/raclette/bills/1",
data={
"date": "201111111-08-10", # not a date
"what": "fromage",
"payer": "1",
"payed_for": ["1", "2"],
"amount": "25",
"external_link": "https://raclette.fr",
},
headers=self.get_auth("raclette"),
)
self.assertStatus(400, req)
self.assertEqual(
'{"date": ["This field is required."]}\n', req.data.decode("utf-8")
)
# edit a bill
req = self.client.put(
"/api/projects/raclette/bills/1",
data={
"date": "2011-09-10",
"what": "beer",
"payer": "2",
"payed_for": ["1", "2"],
"amount": "25",
"external_link": "https://raclette.fr",
},
headers=self.get_auth("raclette"),
)
# check its fields
req = self.client.get(
"/api/projects/raclette/bills/1", headers=self.get_auth("raclette")
)
creation_date = datetime.datetime.strptime(
json.loads(req.data.decode("utf-8"))["creation_date"], "%Y-%m-%d"
).date()
expected = {
"what": "beer",
"payer_id": 2,
"owers": [
{"activated": True, "id": 1, "name": "alexis", "weight": 1},
{"activated": True, "id": 2, "name": "fred", "weight": 1},
],
"amount": 25.0,
"date": "2011-09-10",
"external_link": "https://raclette.fr",
"id": 1,
}
got = json.loads(req.data.decode("utf-8"))
self.assertEqual(
creation_date,
datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date(),
)
del got["creation_date"]
self.assertDictEqual(expected, got)
# delete a bill
req = self.client.delete(
"/api/projects/raclette/bills/1", headers=self.get_auth("raclette")
)
self.assertStatus(200, req)
# getting it should return a 404
req = self.client.get(
"/api/projects/raclette/bills/1", headers=self.get_auth("raclette")
)
self.assertStatus(404, req)
def test_bills_with_calculation(self):
# create a project
self.api_create("raclette")
# add members
self.api_add_member("raclette", "alexis")
self.api_add_member("raclette", "fred")
# valid amounts
input_expected = [
("((100 + 200.25) * 2 - 100) / 2", 250.25),
("3/2", 1.5),
("2 + 1 * 5 - 2 / 1", 5),
]
for i, pair in enumerate(input_expected):
input_amount, expected_amount = pair
id = i + 1
req = self.client.post(
"/api/projects/raclette/bills",
data={
"date": "2011-08-10",
"what": "fromage",
"payer": "1",
"payed_for": ["1", "2"],
"amount": input_amount,
},
headers=self.get_auth("raclette"),
)
# should return the id
self.assertStatus(201, req)
self.assertEqual(req.data.decode("utf-8"), "{}\n".format(id))
# get this bill's details
req = self.client.get(
"/api/projects/raclette/bills/{}".format(id),
headers=self.get_auth("raclette"),
)
# compare with the added info
self.assertStatus(200, req)
expected = {
"what": "fromage",
"payer_id": 1,
"owers": [
{"activated": True, "id": 1, "name": "alexis", "weight": 1},
{"activated": True, "id": 2, "name": "fred", "weight": 1},
],
"amount": expected_amount,
"date": "2011-08-10",
"id": id,
"external_link": "",
}
got = json.loads(req.data.decode("utf-8"))
self.assertEqual(
datetime.date.today(),
datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date(),
)
del got["creation_date"]
self.assertDictEqual(expected, got)
# should raise errors
erroneous_amounts = [
"lambda ", # letters
"(20 + 2", # invalid expression
"20/0", # invalid calc
"9999**99999999999999999", # exponents
"2" * 201, # greater than 200 chars,
]
for amount in erroneous_amounts:
req = self.client.post(
"/api/projects/raclette/bills",
data={
"date": "2011-08-10",
"what": "fromage",
"payer": "1",
"payed_for": ["1", "2"],
"amount": amount,
},
headers=self.get_auth("raclette"),
)
self.assertStatus(400, req)
def test_statistics(self):
# create a project
self.api_create("raclette")
# add members
self.api_add_member("raclette", "alexis")
self.api_add_member("raclette", "fred")
# add a bill
req = self.client.post(
"/api/projects/raclette/bills",
data={
"date": "2011-08-10",
"what": "fromage",
"payer": "1",
"payed_for": ["1", "2"],
"amount": "25",
},
headers=self.get_auth("raclette"),
)
# get the list of bills (should be empty)
req = self.client.get(
"/api/projects/raclette/statistics", headers=self.get_auth("raclette")
)
self.assertStatus(200, req)
self.assertEqual(
[
{
"balance": 12.5,
"member": {
"activated": True,
"id": 1,
"name": "alexis",
"weight": 1.0,
},
"paid": 25.0,
"spent": 12.5,
},
{
"balance": -12.5,
"member": {
"activated": True,
"id": 2,
"name": "fred",
"weight": 1.0,
},
"paid": 0,
"spent": 12.5,
},
],
json.loads(req.data.decode("utf-8")),
)
def test_username_xss(self):
# create a project
# self.api_create("raclette")
self.post_project("raclette")
self.login("raclette")
# add members
self.api_add_member("raclette", "
|---|