# coding: utf8 import unittest from unittest.mock import patch import datetime import os import io import json import base64 from collections import defaultdict from time import sleep from werkzeug.security import generate_password_hash, check_password_hash from flask import session from flask_testing import TestCase from ihatemoney.run import create_app, db, load_configuration from ihatemoney.manage import GenerateConfig, GeneratePasswordHash, DeleteProject from ihatemoney import models from ihatemoney import utils from sqlalchemy import orm # Unset configuration file env var if previously set os.environ.pop("IHATEMONEY_SETTINGS_FILE_PATH", None) __HERE__ = os.path.dirname(os.path.abspath(__file__)) class BaseTestCase(TestCase): SECRET_KEY = "TEST SESSION" def create_app(self): # Pass the test object as a configuration. return create_app(self) def setUp(self): db.create_all() def tearDown(self): # clean after testing db.session.remove() db.drop_all() def login(self, project, password=None, test_client=None): password = password or project return self.client.post( "/authenticate", data=dict(id=project, password=password), follow_redirects=True, ) def post_project(self, name): """Create a fake project""" # create the project self.client.post( "/create", data={ "name": name, "id": name, "password": name, "contact_email": "%s@notmyidea.org" % name, }, ) def create_project(self, name): project = models.Project( id=name, name=str(name), password=generate_password_hash(name), contact_email="%s@notmyidea.org" % name, ) models.db.session.add(project) models.db.session.commit() class IhatemoneyTestCase(BaseTestCase): SQLALCHEMY_DATABASE_URI = "sqlite://" TESTING = True WTF_CSRF_ENABLED = False # Simplifies the tests. def assertStatus(self, expected, resp, url=""): return self.assertEqual( expected, resp.status_code, "%s expected %s, got %s" % (url, expected, resp.status_code), ) class ConfigurationTestCase(BaseTestCase): def test_default_configuration(self): """Test that default settings are loaded when no other configuration file is specified""" self.assertFalse(self.app.config["DEBUG"]) self.assertEqual( self.app.config["SQLALCHEMY_DATABASE_URI"], "sqlite:////tmp/ihatemoney.db" ) self.assertFalse(self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]) self.assertEqual( self.app.config["MAIL_DEFAULT_SENDER"], ("Budget manager", "budget@notmyidea.org"), ) def test_env_var_configuration_file(self): """Test that settings are loaded from the specified configuration file""" os.environ["IHATEMONEY_SETTINGS_FILE_PATH"] = os.path.join( __HERE__, "ihatemoney_envvar.cfg" ) load_configuration(self.app) self.assertEqual(self.app.config["SECRET_KEY"], "lalatra") # Test that the specified configuration file is loaded # even if the default configuration file ihatemoney.cfg exists os.environ["IHATEMONEY_SETTINGS_FILE_PATH"] = os.path.join( __HERE__, "ihatemoney_envvar.cfg" ) self.app.config.root_path = __HERE__ load_configuration(self.app) self.assertEqual(self.app.config["SECRET_KEY"], "lalatra") os.environ.pop("IHATEMONEY_SETTINGS_FILE_PATH", None) def test_default_configuration_file(self): """Test that settings are loaded from the default configuration file""" self.app.config.root_path = __HERE__ load_configuration(self.app) self.assertEqual(self.app.config["SECRET_KEY"], "supersecret") class BudgetTestCase(IhatemoneyTestCase): def test_notifications(self): """Test that the notifications are sent, and that email addresses are checked properly. """ # sending a message to one person with self.app.mail.record_messages() as outbox: # create a project self.login("raclette") self.post_project("raclette") self.client.post( "/raclette/invite", data={"emails": "alexis@notmyidea.org"} ) self.assertEqual(len(outbox), 2) self.assertEqual(outbox[0].recipients, ["raclette@notmyidea.org"]) self.assertEqual(outbox[1].recipients, ["alexis@notmyidea.org"]) # sending a message to multiple persons with self.app.mail.record_messages() as outbox: self.client.post( "/raclette/invite", data={"emails": "alexis@notmyidea.org, toto@notmyidea.org"}, ) # only one message is sent to multiple persons self.assertEqual(len(outbox), 1) self.assertEqual( outbox[0].recipients, ["alexis@notmyidea.org", "toto@notmyidea.org"] ) # mail address checking with self.app.mail.record_messages() as outbox: response = self.client.post("/raclette/invite", data={"emails": "toto"}) self.assertEqual(len(outbox), 0) # no message sent self.assertIn("The email toto is not valid", response.data.decode("utf-8")) # mixing good and wrong addresses shouldn't send any messages with self.app.mail.record_messages() as outbox: self.client.post( "/raclette/invite", data={"emails": "alexis@notmyidea.org, alexis"} ) # not valid # only one message is sent to multiple persons self.assertEqual(len(outbox), 0) def test_invite(self): """Test that invitation e-mails are sent properly """ self.login("raclette") self.post_project("raclette") with self.app.mail.record_messages() as outbox: self.client.post("/raclette/invite", data={"emails": "toto@notmyidea.org"}) self.assertEqual(len(outbox), 1) url_start = outbox[0].body.find("You can log in using this link: ") + 32 url_end = outbox[0].body.find(".\n", url_start) url = outbox[0].body[url_start:url_end] self.client.get("/exit") # Test that we got a valid token resp = self.client.get(url, follow_redirects=True) self.assertIn( 'You probably want to ", resp.data.decode("utf-8")) # Test that password can be changed self.client.post( url, data={"password": "pass", "password_confirmation": "pass"} ) resp = self.login("raclette", password="pass") self.assertIn( "Account manager - raclette", resp.data.decode("utf-8") ) # Test empty and null tokens resp = self.client.get("/reset-password") self.assertIn("No token provided", resp.data.decode("utf-8")) resp = self.client.get("/reset-password?token=token") self.assertIn("Invalid token", resp.data.decode("utf-8")) def test_project_creation(self): with self.app.test_client() as c: # add a valid project c.post( "/create", data={ "name": "The fabulous raclette party", "id": "raclette", "password": "party", "contact_email": "raclette@notmyidea.org", }, ) # session is updated self.assertTrue(session["raclette"]) # project is created self.assertEqual(len(models.Project.query.all()), 1) # Add a second project with the same id models.Project.query.get("raclette") c.post( "/create", data={ "name": "Another raclette party", "id": "raclette", # already used ! "password": "party", "contact_email": "raclette@notmyidea.org", }, ) # no new project added self.assertEqual(len(models.Project.query.all()), 1) def test_project_creation_without_public_permissions(self): self.app.config["ALLOW_PUBLIC_PROJECT_CREATION"] = False with self.app.test_client() as c: # add a valid project c.post( "/create", data={ "name": "The fabulous raclette party", "id": "raclette", "password": "party", "contact_email": "raclette@notmyidea.org", }, ) # session is not updated self.assertNotIn("raclette", session) # project is not created self.assertEqual(len(models.Project.query.all()), 0) def test_project_creation_with_public_permissions(self): self.app.config["ALLOW_PUBLIC_PROJECT_CREATION"] = True with self.app.test_client() as c: # add a valid project c.post( "/create", data={ "name": "The fabulous raclette party", "id": "raclette", "password": "party", "contact_email": "raclette@notmyidea.org", }, ) # session is updated self.assertTrue(session["raclette"]) # project is created self.assertEqual(len(models.Project.query.all()), 1) def test_project_deletion(self): with self.app.test_client() as c: c.post( "/create", data={ "name": "raclette party", "id": "raclette", "password": "party", "contact_email": "raclette@notmyidea.org", }, ) # project added self.assertEqual(len(models.Project.query.all()), 1) c.get("/raclette/delete") # project removed self.assertEqual(len(models.Project.query.all()), 0) def test_bill_placeholder(self): self.post_project("raclette") self.login("raclette") result = self.client.get("/raclette/") # Empty bill list and no members, should now propose to add members first self.assertIn( 'You probably want to ', resp.data.decode("utf-8")) def test_authentication(self): # try to authenticate without credentials should redirect # to the authentication page resp = self.client.post("/authenticate") self.assertIn("Authentication", resp.data.decode("utf-8")) # raclette that the login / logout process works self.create_project("raclette") # try to see the project while not being authenticated should redirect # to the authentication page resp = self.client.get("/raclette", follow_redirects=True) self.assertIn("Authentication", resp.data.decode("utf-8")) # try to connect with wrong credentials should not work with self.app.test_client() as c: resp = c.post("/authenticate", data={"id": "raclette", "password": "nope"}) self.assertIn("Authentication", resp.data.decode("utf-8")) self.assertNotIn("raclette", session) # try to connect with the right credentials should work with self.app.test_client() as c: resp = c.post( "/authenticate", data={"id": "raclette", "password": "raclette"} ) self.assertNotIn("Authentication", resp.data.decode("utf-8")) self.assertIn("raclette", session) self.assertTrue(session["raclette"]) # logout should wipe the session out c.get("/exit") self.assertNotIn("raclette", session) # test that with admin credentials, one can access every project self.app.config["ADMIN_PASSWORD"] = generate_password_hash("pass") with self.app.test_client() as c: resp = c.post("/admin?goto=%2Fraclette", data={"admin_password": "pass"}) self.assertNotIn("Authentication", resp.data.decode("utf-8")) self.assertTrue(session["is_admin"]) def test_admin_authentication(self): self.app.config["ADMIN_PASSWORD"] = generate_password_hash("pass") # Disable public project creation so we have an admin endpoint to test self.app.config["ALLOW_PUBLIC_PROJECT_CREATION"] = False # test the redirection to the authentication page when trying to access admin endpoints resp = self.client.get("/create") self.assertIn('', resp.data.decode("utf-8")) # test right password resp = self.client.post( "/admin?goto=%2Fcreate", data={"admin_password": "pass"} ) self.assertIn('/create', resp.data.decode("utf-8")) # test wrong password resp = self.client.post( "/admin?goto=%2Fcreate", data={"admin_password": "wrong"} ) self.assertNotIn('/create', resp.data.decode("utf-8")) # test empty password resp = self.client.post("/admin?goto=%2Fcreate", data={"admin_password": ""}) self.assertNotIn('/create', resp.data.decode("utf-8")) def test_login_throttler(self): self.app.config["ADMIN_PASSWORD"] = generate_password_hash("pass") # Activate admin login throttling by authenticating 4 times with a wrong passsword self.client.post("/admin?goto=%2Fcreate", data={"admin_password": "wrong"}) self.client.post("/admin?goto=%2Fcreate", data={"admin_password": "wrong"}) self.client.post("/admin?goto=%2Fcreate", data={"admin_password": "wrong"}) resp = self.client.post( "/admin?goto=%2Fcreate", data={"admin_password": "wrong"} ) self.assertIn( "Too many failed login attempts, please retry later.", resp.data.decode("utf-8"), ) # Change throttling delay import gc for obj in gc.get_objects(): if isinstance(obj, utils.LoginThrottler): obj._delay = 0.005 break # Wait for delay to expire and retry logging in sleep(1) resp = self.client.post( "/admin?goto=%2Fcreate", data={"admin_password": "wrong"} ) self.assertNotIn( "Too many failed login attempts, please retry later.", resp.data.decode("utf-8"), ) def test_manage_bills(self): self.post_project("raclette") # add two persons self.client.post("/raclette/members/add", data={"name": "alexis"}) self.client.post("/raclette/members/add", data={"name": "fred"}) members_ids = [m.id for m in models.Project.query.get("raclette").members] # create a bill self.client.post( "/raclette/add", data={ "date": "2011-08-10", "what": "fromage à raclette", "payer": members_ids[0], "payed_for": members_ids, "amount": "25", }, ) models.Project.query.get("raclette") bill = models.Bill.query.one() self.assertEqual(bill.amount, 25) # edit the bill self.client.post( "/raclette/edit/%s" % bill.id, data={ "date": "2011-08-10", "what": "fromage à raclette", "payer": members_ids[0], "payed_for": members_ids, "amount": "10", }, ) bill = models.Bill.query.one() self.assertEqual(bill.amount, 10, "bill edition") # delete the bill self.client.get("/raclette/delete/%s" % bill.id) self.assertEqual(0, len(models.Bill.query.all()), "bill deletion") # test balance self.client.post( "/raclette/add", data={ "date": "2011-08-10", "what": "fromage à raclette", "payer": members_ids[0], "payed_for": members_ids, "amount": "19", }, ) self.client.post( "/raclette/add", data={ "date": "2011-08-10", "what": "fromage à raclette", "payer": members_ids[1], "payed_for": members_ids[0], "amount": "20", }, ) self.client.post( "/raclette/add", data={ "date": "2011-08-10", "what": "fromage à raclette", "payer": members_ids[1], "payed_for": members_ids, "amount": "17", }, ) balance = models.Project.query.get("raclette").balance self.assertEqual(set(balance.values()), set([19.0, -19.0])) # Bill with negative amount self.client.post( "/raclette/add", data={ "date": "2011-08-12", "what": "fromage à raclette", "payer": members_ids[0], "payed_for": members_ids, "amount": "-25", }, ) bill = models.Bill.query.filter(models.Bill.date == "2011-08-12")[0] self.assertEqual(bill.amount, -25) # add a bill with a comma self.client.post( "/raclette/add", data={ "date": "2011-08-01", "what": "fromage à raclette", "payer": members_ids[0], "payed_for": members_ids, "amount": "25,02", }, ) bill = models.Bill.query.filter(models.Bill.date == "2011-08-01")[0] self.assertEqual(bill.amount, 25.02) def test_weighted_balance(self): self.post_project("raclette") # add two persons self.client.post("/raclette/members/add", data={"name": "alexis"}) self.client.post( "/raclette/members/add", data={"name": "freddy familly", "weight": 4} ) members_ids = [m.id for m in models.Project.query.get("raclette").members] # test balance self.client.post( "/raclette/add", data={ "date": "2011-08-10", "what": "fromage à raclette", "payer": members_ids[0], "payed_for": members_ids, "amount": "10", }, ) self.client.post( "/raclette/add", data={ "date": "2011-08-10", "what": "pommes de terre", "payer": members_ids[1], "payed_for": members_ids, "amount": "10", }, ) balance = models.Project.query.get("raclette").balance self.assertEqual(set(balance.values()), set([6, -6])) def test_trimmed_members(self): self.post_project("raclette") # Add two times the same person (with a space at the end). self.client.post("/raclette/members/add", data={"name": "alexis"}) self.client.post("/raclette/members/add", data={"name": "alexis "}) members = models.Project.query.get("raclette").members self.assertEqual(len(members), 1) def test_weighted_members_list(self): self.post_project("raclette") # add two persons self.client.post("/raclette/members/add", data={"name": "alexis"}) self.client.post("/raclette/members/add", data={"name": "tata", "weight": 1}) resp = self.client.get("/raclette/") self.assertIn("extra-info", resp.data.decode("utf-8")) self.client.post( "/raclette/members/add", data={"name": "freddy familly", "weight": 4} ) resp = self.client.get("/raclette/") self.assertNotIn("extra-info", resp.data.decode("utf-8")) def test_negative_weight(self): self.post_project("raclette") # Add one user and edit it to have a negative share self.client.post("/raclette/members/add", data={"name": "alexis"}) resp = self.client.post( "/raclette/members/1/edit", data={"name": "alexis", "weight": -1} ) # An error should be generated, and its weight should still be 1. self.assertIn('

', resp.data.decode("utf-8")) self.assertEqual(len(models.Project.query.get("raclette").members), 1) self.assertEqual(models.Project.query.get("raclette").members[0].weight, 1) def test_rounding(self): self.post_project("raclette") # add members self.client.post("/raclette/members/add", data={"name": "alexis"}) self.client.post("/raclette/members/add", data={"name": "fred"}) self.client.post("/raclette/members/add", data={"name": "tata"}) # create bills self.client.post( "/raclette/add", data={ "date": "2011-08-10", "what": "fromage à raclette", "payer": 1, "payed_for": [1, 2, 3], "amount": "24.36", }, ) self.client.post( "/raclette/add", data={ "date": "2011-08-10", "what": "red wine", "payer": 2, "payed_for": [1], "amount": "19.12", }, ) self.client.post( "/raclette/add", data={ "date": "2011-08-10", "what": "delicatessen", "payer": 1, "payed_for": [1, 2], "amount": "22", }, ) balance = models.Project.query.get("raclette").balance result = {} result[models.Project.query.get("raclette").members[0].id] = 8.12 result[models.Project.query.get("raclette").members[1].id] = 0.0 result[models.Project.query.get("raclette").members[2].id] = -8.12 # Since we're using floating point to store currency, we can have some # rounding issues that prevent test from working. # However, we should obtain the same values as the theoretical ones if we # round to 2 decimals, like in the UI. for key, value in balance.items(): self.assertEqual(round(value, 2), result[key]) def test_edit_project(self): # A project should be editable self.post_project("raclette") new_data = { "name": "Super raclette party!", "contact_email": "alexis@notmyidea.org", "password": "didoudida", } resp = self.client.post("/raclette/edit", data=new_data, follow_redirects=True) self.assertEqual(resp.status_code, 200) project = models.Project.query.get("raclette") self.assertEqual(project.name, new_data["name"]) self.assertEqual(project.contact_email, new_data["contact_email"]) self.assertTrue(check_password_hash(project.password, new_data["password"])) # Editing a project with a wrong email address should fail new_data["contact_email"] = "wrong_email" resp = self.client.post("/raclette/edit", data=new_data, follow_redirects=True) self.assertIn("Invalid email address", resp.data.decode("utf-8")) def test_dashboard(self): # test that the dashboard is deactivated by default resp = self.client.post( "/admin?goto=%2Fdashboard", data={"admin_password": "adminpass"}, follow_redirects=True, ) self.assertIn('

', resp.data.decode("utf-8")) # test access to the dashboard when it is activated self.app.config["ACTIVATE_ADMIN_DASHBOARD"] = True self.app.config["ADMIN_PASSWORD"] = generate_password_hash("adminpass") resp = self.client.post( "/admin?goto=%2Fdashboard", data={"admin_password": "adminpass"}, follow_redirects=True, ) self.assertIn( "ProjectNumber of members", resp.data.decode("utf-8"), ) def test_statistics_page(self): self.post_project("raclette") response = self.client.get("/raclette/statistics") self.assertEqual(response.status_code, 200) def test_statistics(self): self.post_project("raclette") # add members self.client.post("/raclette/members/add", data={"name": "alexis", "weight": 2}) self.client.post("/raclette/members/add", data={"name": "fred"}) self.client.post("/raclette/members/add", data={"name": "tata"}) # Add a member with a balance=0 : self.client.post("/raclette/members/add", data={"name": "toto"}) # create bills self.client.post( "/raclette/add", data={ "date": "2011-08-10", "what": "fromage à raclette", "payer": 1, "payed_for": [1, 2, 3], "amount": "10.0", }, ) self.client.post( "/raclette/add", data={ "date": "2011-08-10", "what": "red wine", "payer": 2, "payed_for": [1], "amount": "20", }, ) self.client.post( "/raclette/add", data={ "date": "2011-08-10", "what": "delicatessen", "payer": 1, "payed_for": [1, 2], "amount": "10", }, ) response = self.client.get("/raclette/statistics") first_cell = '' indent = "\n " self.assertIn( first_cell + "alexis" + indent + "20.00" + indent + "31.67\n", response.data.decode("utf-8"), ) self.assertIn( first_cell + "fred" + indent + "20.00" + indent + "5.83\n", response.data.decode("utf-8"), ) self.assertIn( first_cell + "tata" + indent + "0.00" + indent + "2.50\n", response.data.decode("utf-8"), ) self.assertIn( first_cell + "toto" + indent + "0.00" + indent + "0.00\n", response.data.decode("utf-8"), ) def test_settle_page(self): self.post_project("raclette") response = self.client.get("/raclette/settle_bills") self.assertEqual(response.status_code, 200) def test_settle(self): self.post_project("raclette") # add members self.client.post("/raclette/members/add", data={"name": "alexis"}) self.client.post("/raclette/members/add", data={"name": "fred"}) self.client.post("/raclette/members/add", data={"name": "tata"}) # Add a member with a balance=0 : self.client.post("/raclette/members/add", data={"name": "toto"}) # create bills self.client.post( "/raclette/add", data={ "date": "2011-08-10", "what": "fromage à raclette", "payer": 1, "payed_for": [1, 2, 3], "amount": "10.0", }, ) self.client.post( "/raclette/add", data={ "date": "2011-08-10", "what": "red wine", "payer": 2, "payed_for": [1], "amount": "20", }, ) self.client.post( "/raclette/add", data={ "date": "2011-08-10", "what": "delicatessen", "payer": 1, "payed_for": [1, 2], "amount": "10", }, ) project = models.Project.query.get("raclette") transactions = project.get_transactions_to_settle_bill() members = defaultdict(int) # We should have the same values between transactions and project balances for t in transactions: members[t["ower"]] -= t["amount"] members[t["receiver"]] += t["amount"] balance = models.Project.query.get("raclette").balance for m, a in members.items(): assert abs(a - balance[m.id]) < 0.01 return def test_settle_zero(self): self.post_project("raclette") # add members self.client.post("/raclette/members/add", data={"name": "alexis"}) self.client.post("/raclette/members/add", data={"name": "fred"}) self.client.post("/raclette/members/add", data={"name": "tata"}) # create bills self.client.post( "/raclette/add", data={ "date": "2016-12-31", "what": "fromage à raclette", "payer": 1, "payed_for": [1, 2, 3], "amount": "10.0", }, ) self.client.post( "/raclette/add", data={ "date": "2016-12-31", "what": "red wine", "payer": 2, "payed_for": [1, 3], "amount": "20", }, ) self.client.post( "/raclette/add", data={ "date": "2017-01-01", "what": "refund", "payer": 3, "payed_for": [2], "amount": "13.33", }, ) project = models.Project.query.get("raclette") transactions = project.get_transactions_to_settle_bill() # There should not be any zero-amount transfer after rounding for t in transactions: rounded_amount = round(t["amount"], 2) self.assertNotEqual( 0.0, rounded_amount, msg="%f is equal to zero after rounding" % t["amount"], ) def test_export(self): self.post_project("raclette") # add members self.client.post("/raclette/members/add", data={"name": "alexis", "weight": 2}) self.client.post("/raclette/members/add", data={"name": "fred"}) self.client.post("/raclette/members/add", data={"name": "tata"}) self.client.post("/raclette/members/add", data={"name": "pépé"}) # create bills self.client.post( "/raclette/add", data={ "date": "2016-12-31", "what": "fromage à raclette", "payer": 1, "payed_for": [1, 2, 3, 4], "amount": "10.0", }, ) self.client.post( "/raclette/add", data={ "date": "2016-12-31", "what": "red wine", "payer": 2, "payed_for": [1, 3], "amount": "200", }, ) self.client.post( "/raclette/add", data={ "date": "2017-01-01", "what": "refund", "payer": 3, "payed_for": [2], "amount": "13.33", }, ) # generate json export of bills resp = self.client.get("/raclette/export/bills.json") expected = [ { "date": "2017-01-01", "what": "refund", "amount": 13.33, "payer_name": "tata", "payer_weight": 1.0, "owers": ["fred"], }, { "date": "2016-12-31", "what": "red wine", "amount": 200.0, "payer_name": "fred", "payer_weight": 1.0, "owers": ["alexis", "tata"], }, { "date": "2016-12-31", "what": "fromage \xe0 raclette", "amount": 10.0, "payer_name": "alexis", "payer_weight": 2.0, "owers": ["alexis", "fred", "tata", "p\xe9p\xe9"], }, ] self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) # generate csv export of bills resp = self.client.get("/raclette/export/bills.csv") expected = [ "date,what,amount,payer_name,payer_weight,owers", "2017-01-01,refund,13.33,tata,1.0,fred", '2016-12-31,red wine,200.0,fred,1.0,"alexis, tata"', '2016-12-31,fromage à raclette,10.0,alexis,2.0,"alexis, fred, tata, pépé"', ] received_lines = resp.data.decode("utf-8").split("\n") for i, line in enumerate(expected): self.assertEqual( set(line.split(",")), set(received_lines[i].strip("\r").split(",")) ) # generate json export of transactions resp = self.client.get("/raclette/export/transactions.json") expected = [ {"amount": 2.00, "receiver": "fred", "ower": "p\xe9p\xe9"}, {"amount": 55.34, "receiver": "fred", "ower": "tata"}, {"amount": 127.33, "receiver": "fred", "ower": "alexis"}, ] self.assertEqual(json.loads(resp.data.decode("utf-8")), expected) # generate csv export of transactions resp = self.client.get("/raclette/export/transactions.csv") expected = [ "amount,receiver,ower", "2.0,fred,pépé", "55.34,fred,tata", "127.33,fred,alexis", ] received_lines = resp.data.decode("utf-8").split("\n") for i, line in enumerate(expected): self.assertEqual( set(line.split(",")), set(received_lines[i].strip("\r").split(",")) ) # wrong export_format should return a 404 resp = self.client.get("/raclette/export/transactions.wrong") self.assertEqual(resp.status_code, 404) def test_import_new_project(self): # Import JSON in an empty project self.post_project("raclette") self.login("raclette") project = models.Project.query.get("raclette") json_to_import = [ { "date": "2017-01-01", "what": "refund", "amount": 13.33, "payer_name": "tata", "payer_weight": 1.0, "owers": ["fred"], }, { "date": "2016-12-31", "what": "red wine", "amount": 200.0, "payer_name": "fred", "payer_weight": 1.0, "owers": ["alexis", "tata"], }, { "date": "2016-12-31", "what": "fromage a raclette", "amount": 10.0, "payer_name": "alexis", "payer_weight": 2.0, "owers": ["alexis", "fred", "tata", "pepe"], }, ] from ihatemoney.web import import_project file = io.StringIO() json.dump(json_to_import, file) file.seek(0) import_project(file, project) bills = project.get_pretty_bills() # Check if all bills has been add self.assertEqual(len(bills), len(json_to_import)) # Check if name of bills are ok b = [e["what"] for e in bills] b.sort() ref = [e["what"] for e in json_to_import] ref.sort() self.assertEqual(b, ref) # Check if other informations in bill are ok for i in json_to_import: for j in bills: if j["what"] == i["what"]: self.assertEqual(j["payer_name"], i["payer_name"]) self.assertEqual(j["amount"], i["amount"]) self.assertEqual(j["payer_weight"], i["payer_weight"]) self.assertEqual(j["date"], i["date"]) list_project = [ower for ower in j["owers"]] list_project.sort() list_json = [ower for ower in i["owers"]] list_json.sort() self.assertEqual(list_project, list_json) def test_import_partial_project(self): # Import a JSON in a project with already existing data self.post_project("raclette") self.login("raclette") project = models.Project.query.get("raclette") self.client.post("/raclette/members/add", data={"name": "alexis", "weight": 2}) self.client.post("/raclette/members/add", data={"name": "fred"}) self.client.post("/raclette/members/add", data={"name": "tata"}) self.client.post( "/raclette/add", data={ "date": "2016-12-31", "what": "red wine", "payer": 2, "payed_for": [1, 3], "amount": "200", }, ) json_to_import = [ { "date": "2017-01-01", "what": "refund", "amount": 13.33, "payer_name": "tata", "payer_weight": 1.0, "owers": ["fred"], }, { # This expense does not have to be present twice. "date": "2016-12-31", "what": "red wine", "amount": 200.0, "payer_name": "fred", "payer_weight": 1.0, "owers": ["alexis", "tata"], }, { "date": "2016-12-31", "what": "fromage a raclette", "amount": 10.0, "payer_name": "alexis", "payer_weight": 2.0, "owers": ["alexis", "fred", "tata", "pepe"], }, ] from ihatemoney.web import import_project file = io.StringIO() json.dump(json_to_import, file) file.seek(0) import_project(file, project) bills = project.get_pretty_bills() # Check if all bills has been add self.assertEqual(len(bills), len(json_to_import)) # Check if name of bills are ok b = [e["what"] for e in bills] b.sort() ref = [e["what"] for e in json_to_import] ref.sort() self.assertEqual(b, ref) # Check if other informations in bill are ok for i in json_to_import: for j in bills: if j["what"] == i["what"]: self.assertEqual(j["payer_name"], i["payer_name"]) self.assertEqual(j["amount"], i["amount"]) self.assertEqual(j["payer_weight"], i["payer_weight"]) self.assertEqual(j["date"], i["date"]) list_project = [ower for ower in j["owers"]] list_project.sort() list_json = [ower for ower in i["owers"]] list_json.sort() self.assertEqual(list_project, list_json) def test_import_wrong_json(self): self.post_project("raclette") self.login("raclette") project = models.Project.query.get("raclette") json_1 = [ { # wrong keys "checked": False, "dimensions": {"width": 5, "height": 10}, "id": 1, "name": "A green door", "price": 12.5, "tags": ["home", "green"], } ] json_2 = [ { # amount missing "date": "2017-01-01", "what": "refund", "payer_name": "tata", "payer_weight": 1.0, "owers": ["fred"], } ] from ihatemoney.web import import_project try: file = io.StringIO() json.dump(json_1, file) file.seek(0) import_project(file, project) except ValueError: self.assertTrue(True) except Exception: self.fail("unexpected exception raised") else: self.fail("ExpectedException not raised") try: file = io.StringIO() json.dump(json_2, file) file.seek(0) import_project(file, project) except ValueError: self.assertTrue(True) except Exception: self.fail("unexpected exception raised") else: self.fail("ExpectedException not raised") class APITestCase(IhatemoneyTestCase): """Tests the API""" def api_create(self, name, id=None, password=None, contact=None): id = id or name password = password or name contact = contact or "%s@notmyidea.org" % name return self.client.post( "/api/projects", data={ "name": name, "id": id, "password": password, "contact_email": contact, }, ) def api_add_member(self, project, name, weight=1): self.client.post( "/api/projects/%s/members" % project, data={"name": name, "weight": weight}, headers=self.get_auth(project), ) def get_auth(self, username, password=None): password = password or username base64string = ( base64.encodebytes(("%s:%s" % (username, password)).encode("utf-8")) .decode("utf-8") .replace("\n", "") ) return {"Authorization": "Basic %s" % base64string} def test_cors_requests(self): # Create a project and test that CORS headers are present if requested. resp = self.api_create("raclette") self.assertStatus(201, resp) # Try to do an OPTIONS requests and see if the headers are correct. resp = self.client.options( "/api/projects/raclette", headers=self.get_auth("raclette") ) self.assertEqual(resp.headers["Access-Control-Allow-Origin"], "*") def test_basic_auth(self): # create a project resp = self.api_create("raclette") self.assertStatus(201, resp) # try to do something on it being unauth should return a 401 resp = self.client.get("/api/projects/raclette") self.assertStatus(401, resp) # PUT / POST / DELETE / GET on the different resources # should also return a 401 for verb in ("post",): for resource in ("/raclette/members", "/raclette/bills"): url = "/api/projects" + resource self.assertStatus(401, getattr(self.client, verb)(url), verb + resource) for verb in ("get", "delete", "put"): for resource in ("/raclette", "/raclette/members/1", "/raclette/bills/1"): url = "/api/projects" + resource self.assertStatus(401, getattr(self.client, verb)(url), verb + resource) def test_project(self): # wrong email should return an error resp = self.client.post( "/api/projects", data={ "name": "raclette", "id": "raclette", "password": "raclette", "contact_email": "not-an-email", }, ) self.assertTrue(400, resp.status_code) self.assertEqual( '{"contact_email": ["Invalid email address."]}\n', resp.data.decode("utf-8") ) # create it resp = self.api_create("raclette") self.assertTrue(201, resp.status_code) # create it twice should return a 400 resp = self.api_create("raclette") self.assertTrue(400, resp.status_code) self.assertIn("id", json.loads(resp.data.decode("utf-8"))) # get information about it resp = self.client.get( "/api/projects/raclette", headers=self.get_auth("raclette") ) self.assertTrue(200, resp.status_code) expected = { "members": [], "name": "raclette", "contact_email": "raclette@notmyidea.org", "id": "raclette", } decoded_resp = json.loads(resp.data.decode("utf-8")) self.assertDictEqual(decoded_resp, expected) # edit should work resp = self.client.put( "/api/projects/raclette", data={ "contact_email": "yeah@notmyidea.org", "password": "raclette", "name": "The raclette party", }, headers=self.get_auth("raclette"), ) self.assertEqual(200, resp.status_code) resp = self.client.get( "/api/projects/raclette", headers=self.get_auth("raclette") ) self.assertEqual(200, resp.status_code) expected = { "name": "The raclette party", "contact_email": "yeah@notmyidea.org", "members": [], "id": "raclette", } decoded_resp = json.loads(resp.data.decode("utf-8")) self.assertDictEqual(decoded_resp, expected) # password change is possible via API resp = self.client.put( "/api/projects/raclette", data={ "contact_email": "yeah@notmyidea.org", "password": "tartiflette", "name": "The raclette party", }, headers=self.get_auth("raclette"), ) self.assertEqual(200, resp.status_code) resp = self.client.get( "/api/projects/raclette", headers=self.get_auth("raclette", "tartiflette") ) self.assertEqual(200, resp.status_code) # delete should work resp = self.client.delete( "/api/projects/raclette", headers=self.get_auth("raclette", "tartiflette") ) # get should return a 401 on an unknown resource resp = self.client.get( "/api/projects/raclette", headers=self.get_auth("raclette") ) self.assertEqual(401, resp.status_code) def test_token_creation(self): """Test that token of project is generated """ # Create project resp = self.api_create("raclette") self.assertTrue(201, resp.status_code) # Get token resp = self.client.get( "/api/projects/raclette/token", headers=self.get_auth("raclette") ) self.assertEqual(200, resp.status_code) decoded_resp = json.loads(resp.data.decode("utf-8")) # Access with token resp = self.client.get( "/api/projects/raclette/token", headers={"Authorization": "Basic %s" % decoded_resp["token"]}, ) self.assertEqual(200, resp.status_code) def test_token_login(self): resp = self.api_create("raclette") # Get token resp = self.client.get( "/api/projects/raclette/token", headers=self.get_auth("raclette") ) decoded_resp = json.loads(resp.data.decode("utf-8")) resp = self.client.get("/authenticate?token={}".format(decoded_resp["token"])) # Test that we are redirected. self.assertEqual(302, resp.status_code) def test_member(self): # create a project self.api_create("raclette") # get the list of members (should be empty) req = self.client.get( "/api/projects/raclette/members", headers=self.get_auth("raclette") ) self.assertStatus(200, req) self.assertEqual("[]\n", req.data.decode("utf-8")) # add a member req = self.client.post( "/api/projects/raclette/members", data={"name": "Alexis"}, headers=self.get_auth("raclette"), ) # the id of the new member should be returned self.assertStatus(201, req) self.assertEqual("1\n", req.data.decode("utf-8")) # the list of members should contain one member req = self.client.get( "/api/projects/raclette/members", headers=self.get_auth("raclette") ) self.assertStatus(200, req) self.assertEqual(len(json.loads(req.data.decode("utf-8"))), 1) # Try to add another member with the same name. req = self.client.post( "/api/projects/raclette/members", data={"name": "Alexis"}, headers=self.get_auth("raclette"), ) self.assertStatus(400, req) # edit the member req = self.client.put( "/api/projects/raclette/members/1", data={"name": "Fred", "weight": 2}, headers=self.get_auth("raclette"), ) self.assertStatus(200, req) # get should return the new name req = self.client.get( "/api/projects/raclette/members/1", headers=self.get_auth("raclette") ) self.assertStatus(200, req) self.assertEqual("Fred", json.loads(req.data.decode("utf-8"))["name"]) self.assertEqual(2, json.loads(req.data.decode("utf-8"))["weight"]) # edit this member with same information # (test PUT idemopotence) req = self.client.put( "/api/projects/raclette/members/1", data={"name": "Fred"}, headers=self.get_auth("raclette"), ) self.assertStatus(200, req) # de-activate the user req = self.client.put( "/api/projects/raclette/members/1", data={"name": "Fred", "activated": False}, headers=self.get_auth("raclette"), ) self.assertStatus(200, req) req = self.client.get( "/api/projects/raclette/members/1", headers=self.get_auth("raclette") ) self.assertStatus(200, req) self.assertEqual(False, json.loads(req.data.decode("utf-8"))["activated"]) # re-activate the user req = self.client.put( "/api/projects/raclette/members/1", data={"name": "Fred", "activated": True}, headers=self.get_auth("raclette"), ) req = self.client.get( "/api/projects/raclette/members/1", headers=self.get_auth("raclette") ) self.assertStatus(200, req) self.assertEqual(True, json.loads(req.data.decode("utf-8"))["activated"]) # delete a member req = self.client.delete( "/api/projects/raclette/members/1", headers=self.get_auth("raclette") ) self.assertStatus(200, req) # the list of members should be empty req = self.client.get( "/api/projects/raclette/members", headers=self.get_auth("raclette") ) self.assertStatus(200, req) self.assertEqual("[]\n", req.data.decode("utf-8")) def test_bills(self): # create a project self.api_create("raclette") # add members self.api_add_member("raclette", "alexis") self.api_add_member("raclette", "fred") self.api_add_member("raclette", "arnaud") # get the list of bills (should be empty) req = self.client.get( "/api/projects/raclette/bills", headers=self.get_auth("raclette") ) self.assertStatus(200, req) self.assertEqual("[]\n", req.data.decode("utf-8")) # add a bill req = self.client.post( "/api/projects/raclette/bills", data={ "date": "2011-08-10", "what": "fromage", "payer": "1", "payed_for": ["1", "2"], "amount": "25", "external_link": "https://raclette.fr", }, headers=self.get_auth("raclette"), ) # should return the id self.assertStatus(201, req) self.assertEqual(req.data.decode("utf-8"), "1\n") # get this bill details req = self.client.get( "/api/projects/raclette/bills/1", headers=self.get_auth("raclette") ) # compare with the added info self.assertStatus(200, req) expected = { "what": "fromage", "payer_id": 1, "owers": [ {"activated": True, "id": 1, "name": "alexis", "weight": 1}, {"activated": True, "id": 2, "name": "fred", "weight": 1}, ], "amount": 25.0, "date": "2011-08-10", "id": 1, "external_link": "https://raclette.fr", } got = json.loads(req.data.decode("utf-8")) self.assertEqual( datetime.date.today(), datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date(), ) del got["creation_date"] self.assertDictEqual(expected, got) # the list of bills should length 1 req = self.client.get( "/api/projects/raclette/bills", headers=self.get_auth("raclette") ) self.assertStatus(200, req) self.assertEqual(1, len(json.loads(req.data.decode("utf-8")))) # edit with errors should return an error req = self.client.put( "/api/projects/raclette/bills/1", data={ "date": "201111111-08-10", # not a date "what": "fromage", "payer": "1", "payed_for": ["1", "2"], "amount": "25", "external_link": "https://raclette.fr", }, headers=self.get_auth("raclette"), ) self.assertStatus(400, req) self.assertEqual( '{"date": ["This field is required."]}\n', req.data.decode("utf-8") ) # edit a bill req = self.client.put( "/api/projects/raclette/bills/1", data={ "date": "2011-09-10", "what": "beer", "payer": "2", "payed_for": ["1", "2"], "amount": "25", "external_link": "https://raclette.fr", }, headers=self.get_auth("raclette"), ) # check its fields req = self.client.get( "/api/projects/raclette/bills/1", headers=self.get_auth("raclette") ) creation_date = datetime.datetime.strptime( json.loads(req.data.decode("utf-8"))["creation_date"], "%Y-%m-%d" ).date() expected = { "what": "beer", "payer_id": 2, "owers": [ {"activated": True, "id": 1, "name": "alexis", "weight": 1}, {"activated": True, "id": 2, "name": "fred", "weight": 1}, ], "amount": 25.0, "date": "2011-09-10", "external_link": "https://raclette.fr", "id": 1, } got = json.loads(req.data.decode("utf-8")) self.assertEqual( creation_date, datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date(), ) del got["creation_date"] self.assertDictEqual(expected, got) # delete a bill req = self.client.delete( "/api/projects/raclette/bills/1", headers=self.get_auth("raclette") ) self.assertStatus(200, req) # getting it should return a 404 req = self.client.get( "/api/projects/raclette/bills/1", headers=self.get_auth("raclette") ) self.assertStatus(404, req) def test_bills_with_calculation(self): # create a project self.api_create("raclette") # add members self.api_add_member("raclette", "alexis") self.api_add_member("raclette", "fred") # valid amounts input_expected = [ ("((100 + 200.25) * 2 - 100) / 2", 250.25), ("3/2", 1.5), ("2 + 1 * 5 - 2 / 1", 5), ] for i, pair in enumerate(input_expected): input_amount, expected_amount = pair id = i + 1 req = self.client.post( "/api/projects/raclette/bills", data={ "date": "2011-08-10", "what": "fromage", "payer": "1", "payed_for": ["1", "2"], "amount": input_amount, }, headers=self.get_auth("raclette"), ) # should return the id self.assertStatus(201, req) self.assertEqual(req.data.decode("utf-8"), "{}\n".format(id)) # get this bill's details req = self.client.get( "/api/projects/raclette/bills/{}".format(id), headers=self.get_auth("raclette"), ) # compare with the added info self.assertStatus(200, req) expected = { "what": "fromage", "payer_id": 1, "owers": [ {"activated": True, "id": 1, "name": "alexis", "weight": 1}, {"activated": True, "id": 2, "name": "fred", "weight": 1}, ], "amount": expected_amount, "date": "2011-08-10", "id": id, "external_link": "", } got = json.loads(req.data.decode("utf-8")) self.assertEqual( datetime.date.today(), datetime.datetime.strptime(got["creation_date"], "%Y-%m-%d").date(), ) del got["creation_date"] self.assertDictEqual(expected, got) # should raise errors erroneous_amounts = [ "lambda ", # letters "(20 + 2", # invalid expression "20/0", # invalid calc "9999**99999999999999999", # exponents "2" * 201, # greater than 200 chars, ] for amount in erroneous_amounts: req = self.client.post( "/api/projects/raclette/bills", data={ "date": "2011-08-10", "what": "fromage", "payer": "1", "payed_for": ["1", "2"], "amount": amount, }, headers=self.get_auth("raclette"), ) self.assertStatus(400, req) def test_statistics(self): # create a project self.api_create("raclette") # add members self.api_add_member("raclette", "alexis") self.api_add_member("raclette", "fred") # add a bill req = self.client.post( "/api/projects/raclette/bills", data={ "date": "2011-08-10", "what": "fromage", "payer": "1", "payed_for": ["1", "2"], "amount": "25", }, headers=self.get_auth("raclette"), ) # get the list of bills (should be empty) req = self.client.get( "/api/projects/raclette/statistics", headers=self.get_auth("raclette") ) self.assertStatus(200, req) self.assertEqual( [ { "balance": 12.5, "member": { "activated": True, "id": 1, "name": "alexis", "weight": 1.0, }, "paid": 25.0, "spent": 12.5, }, { "balance": -12.5, "member": { "activated": True, "id": 2, "name": "fred", "weight": 1.0, }, "paid": 0, "spent": 12.5, }, ], json.loads(req.data.decode("utf-8")), ) def test_username_xss(self): # create a project # self.api_create("raclette") self.post_project("raclette") self.login("raclette") # add members self.api_add_member("raclette", "