aboutsummaryrefslogtreecommitdiff
path: root/ihatemoney/tests/main_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'ihatemoney/tests/main_test.py')
-rw-r--r--ihatemoney/tests/main_test.py265
1 files changed, 265 insertions, 0 deletions
diff --git a/ihatemoney/tests/main_test.py b/ihatemoney/tests/main_test.py
new file mode 100644
index 0000000..06dbfac
--- /dev/null
+++ b/ihatemoney/tests/main_test.py
@@ -0,0 +1,265 @@
+import io
+import os
+import smtplib
+import socket
+import unittest
+from unittest.mock import MagicMock, patch
+
+from sqlalchemy import orm
+
+from ihatemoney import models
+from ihatemoney.currency_convertor import CurrencyConverter
+from ihatemoney.manage import DeleteProject, GenerateConfig, GeneratePasswordHash
+from ihatemoney.run import load_configuration
+from ihatemoney.tests.common.ihatemoney_testcase import BaseTestCase, IhatemoneyTestCase
+
+# Unset configuration file env var if previously set
+os.environ.pop("IHATEMONEY_SETTINGS_FILE_PATH", None)
+
+__HERE__ = os.path.dirname(os.path.abspath(__file__))
+
+
+class ConfigurationTestCase(BaseTestCase):
+ def test_default_configuration(self):
+ """Test that default settings are loaded when no other configuration file is specified"""
+ self.assertFalse(self.app.config["DEBUG"])
+ self.assertEqual(
+ self.app.config["SQLALCHEMY_DATABASE_URI"], "sqlite:////tmp/ihatemoney.db"
+ )
+ self.assertFalse(self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"])
+ self.assertEqual(
+ self.app.config["MAIL_DEFAULT_SENDER"],
+ ("Budget manager", "budget@notmyidea.org"),
+ )
+
+ def test_env_var_configuration_file(self):
+ """Test that settings are loaded from the specified configuration file"""
+ os.environ["IHATEMONEY_SETTINGS_FILE_PATH"] = os.path.join(
+ __HERE__, "ihatemoney_envvar.cfg"
+ )
+ load_configuration(self.app)
+ self.assertEqual(self.app.config["SECRET_KEY"], "lalatra")
+
+ # Test that the specified configuration file is loaded
+ # even if the default configuration file ihatemoney.cfg exists
+ os.environ["IHATEMONEY_SETTINGS_FILE_PATH"] = os.path.join(
+ __HERE__, "ihatemoney_envvar.cfg"
+ )
+ self.app.config.root_path = __HERE__
+ load_configuration(self.app)
+ self.assertEqual(self.app.config["SECRET_KEY"], "lalatra")
+
+ os.environ.pop("IHATEMONEY_SETTINGS_FILE_PATH", None)
+
+ def test_default_configuration_file(self):
+ """Test that settings are loaded from the default configuration file"""
+ self.app.config.root_path = __HERE__
+ load_configuration(self.app)
+ self.assertEqual(self.app.config["SECRET_KEY"], "supersecret")
+
+
+class ServerTestCase(IhatemoneyTestCase):
+ def test_homepage(self):
+ # See https://github.com/spiral-project/ihatemoney/pull/358
+ self.app.config["APPLICATION_ROOT"] = "/"
+ req = self.client.get("/")
+ self.assertStatus(200, req)
+
+ def test_unprefixed(self):
+ self.app.config["APPLICATION_ROOT"] = "/"
+ req = self.client.get("/foo/")
+ self.assertStatus(303, req)
+
+ def test_prefixed(self):
+ self.app.config["APPLICATION_ROOT"] = "/foo"
+ req = self.client.get("/foo/")
+ self.assertStatus(200, req)
+
+
+class CommandTestCase(BaseTestCase):
+ def test_generate_config(self):
+ """Simply checks that all config file generation
+ - raise no exception
+ - produce something non-empty
+ """
+ cmd = GenerateConfig()
+ for config_file in cmd.get_options()[0].kwargs["choices"]:
+ with patch("sys.stdout", new=io.StringIO()) as stdout:
+ cmd.run(config_file)
+ print(stdout.getvalue())
+ self.assertNotEqual(len(stdout.getvalue().strip()), 0)
+
+ def test_generate_password_hash(self):
+ cmd = GeneratePasswordHash()
+ with patch("sys.stdout", new=io.StringIO()) as stdout, patch(
+ "getpass.getpass", new=lambda prompt: "secret"
+ ): # NOQA
+ cmd.run()
+ print(stdout.getvalue())
+ self.assertEqual(len(stdout.getvalue().strip()), 189)
+
+ def test_demo_project_deletion(self):
+ self.create_project("demo")
+ self.assertEquals(models.Project.query.get("demo").name, "demo")
+
+ cmd = DeleteProject()
+ cmd.run("demo")
+
+ self.assertEqual(len(models.Project.query.all()), 0)
+
+
+class ModelsTestCase(IhatemoneyTestCase):
+ def test_bill_pay_each(self):
+
+ self.post_project("raclette")
+
+ # add members
+ self.client.post("/raclette/members/add", data={"name": "zorglub", "weight": 2})
+ self.client.post("/raclette/members/add", data={"name": "fred"})
+ self.client.post("/raclette/members/add", data={"name": "tata"})
+ # Add a member with a balance=0 :
+ self.client.post("/raclette/members/add", data={"name": "pépé"})
+
+ # create bills
+ self.client.post(
+ "/raclette/add",
+ data={
+ "date": "2011-08-10",
+ "what": "fromage à raclette",
+ "payer": 1,
+ "payed_for": [1, 2, 3],
+ "amount": "10.0",
+ },
+ )
+
+ self.client.post(
+ "/raclette/add",
+ data={
+ "date": "2011-08-10",
+ "what": "red wine",
+ "payer": 2,
+ "payed_for": [1],
+ "amount": "20",
+ },
+ )
+
+ self.client.post(
+ "/raclette/add",
+ data={
+ "date": "2011-08-10",
+ "what": "delicatessen",
+ "payer": 1,
+ "payed_for": [1, 2],
+ "amount": "10",
+ },
+ )
+
+ project = models.Project.query.get_by_name(name="raclette")
+ zorglub = models.Person.query.get_by_name(name="zorglub", project=project)
+ zorglub_bills = models.Bill.query.options(
+ orm.subqueryload(models.Bill.owers)
+ ).filter(models.Bill.owers.contains(zorglub))
+ for bill in zorglub_bills.all():
+ if bill.what == "red wine":
+ pay_each_expected = 20 / 2
+ self.assertEqual(bill.pay_each(), pay_each_expected)
+ if bill.what == "fromage à raclette":
+ pay_each_expected = 10 / 4
+ self.assertEqual(bill.pay_each(), pay_each_expected)
+ if bill.what == "delicatessen":
+ pay_each_expected = 10 / 3
+ self.assertEqual(bill.pay_each(), pay_each_expected)
+
+
+class EmailFailureTestCase(IhatemoneyTestCase):
+ def test_creation_email_failure_smtp(self):
+ self.login("raclette")
+ with patch.object(
+ self.app.mail, "send", MagicMock(side_effect=smtplib.SMTPException)
+ ):
+ resp = self.post_project("raclette")
+ # Check that an error message is displayed
+ self.assertIn(
+ "We tried to send you an reminder email, but there was an error",
+ resp.data.decode("utf-8"),
+ )
+ # Check that we were redirected to the home page anyway
+ self.assertIn(
+ 'You probably want to <a href="/raclette/members/add"',
+ resp.data.decode("utf-8"),
+ )
+
+ def test_creation_email_failure_socket(self):
+ self.login("raclette")
+ with patch.object(self.app.mail, "send", MagicMock(side_effect=socket.error)):
+ resp = self.post_project("raclette")
+ # Check that an error message is displayed
+ self.assertIn(
+ "We tried to send you an reminder email, but there was an error",
+ resp.data.decode("utf-8"),
+ )
+ # Check that we were redirected to the home page anyway
+ self.assertIn(
+ 'You probably want to <a href="/raclette/members/add"',
+ resp.data.decode("utf-8"),
+ )
+
+ def test_password_reset_email_failure(self):
+ self.create_project("raclette")
+ for exception in (smtplib.SMTPException, socket.error):
+ with patch.object(self.app.mail, "send", MagicMock(side_effect=exception)):
+ resp = self.client.post(
+ "/password-reminder", data={"id": "raclette"}, follow_redirects=True
+ )
+ # Check that an error message is displayed
+ self.assertIn(
+ "there was an error while sending you an email",
+ resp.data.decode("utf-8"),
+ )
+ # Check that we were not redirected to the success page
+ self.assertNotIn(
+ "A link to reset your password has been sent to you",
+ resp.data.decode("utf-8"),
+ )
+
+ def test_invitation_email_failure(self):
+ self.login("raclette")
+ self.post_project("raclette")
+ for exception in (smtplib.SMTPException, socket.error):
+ with patch.object(self.app.mail, "send", MagicMock(side_effect=exception)):
+ resp = self.client.post(
+ "/raclette/invite",
+ data={"emails": "toto@notmyidea.org"},
+ follow_redirects=True,
+ )
+ # Check that an error message is displayed
+ self.assertIn(
+ "there was an error while trying to send the invitation emails",
+ resp.data.decode("utf-8"),
+ )
+ # Check that we are still on the same page (no redirection)
+ self.assertIn(
+ "Invite people to join this project", resp.data.decode("utf-8")
+ )
+
+
+class TestCurrencyConverter(unittest.TestCase):
+ converter = CurrencyConverter()
+ mock_data = {"USD": 1, "EUR": 0.8115}
+ converter.get_rates = MagicMock(return_value=mock_data)
+
+ def test_only_one_instance(self):
+ one = id(CurrencyConverter())
+ two = id(CurrencyConverter())
+ self.assertEqual(one, two)
+
+ def test_get_currencies(self):
+ self.assertCountEqual(self.converter.get_currencies(), ["USD", "EUR"])
+
+ def test_exchange_currency(self):
+ result = self.converter.exchange_currency(100, "USD", "EUR")
+ self.assertEqual(result, 81.15)
+
+
+if __name__ == "__main__":
+ unittest.main()