aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xihatemoney/manage.py10
-rw-r--r--ihatemoney/tests/tests.py12
2 files changed, 20 insertions, 2 deletions
diff --git a/ihatemoney/manage.py b/ihatemoney/manage.py
index 3207b55..a9eca0f 100755
--- a/ihatemoney/manage.py
+++ b/ihatemoney/manage.py
@@ -10,7 +10,7 @@ from flask_migrate import Migrate, MigrateCommand
from werkzeug.security import generate_password_hash
from ihatemoney.run import create_app
-from ihatemoney.models import db
+from ihatemoney.models import db, Project
from ihatemoney.utils import create_jinja_env
@@ -57,6 +57,13 @@ class GenerateConfig(Command):
))
+class DeleteProject(Command):
+ def run(self, project_name):
+ demo_project = Project.query.get(project_name)
+ db.session.delete(demo_project)
+ db.session.commit()
+
+
def main():
QUIET_COMMANDS = ('generate_password_hash', 'generate-config')
@@ -76,6 +83,7 @@ def main():
manager.add_command('db', MigrateCommand)
manager.add_command('generate_password_hash', GeneratePasswordHash)
manager.add_command('generate-config', GenerateConfig)
+ manager.add_command('delete-project', DeleteProject)
manager.run()
diff --git a/ihatemoney/tests/tests.py b/ihatemoney/tests/tests.py
index fd72a8d..63a7394 100644
--- a/ihatemoney/tests/tests.py
+++ b/ihatemoney/tests/tests.py
@@ -20,7 +20,8 @@ from flask import session
from flask_testing import TestCase
from ihatemoney.run import create_app, db, load_configuration
-from ihatemoney.manage import GenerateConfig, GeneratePasswordHash
+from ihatemoney.manage import (
+ GenerateConfig, GeneratePasswordHash, DeleteProject)
from ihatemoney import models
from ihatemoney import utils
@@ -1472,6 +1473,15 @@ class CommandTestCase(BaseTestCase):
print(stdout.getvalue())
self.assertEqual(len(stdout.getvalue().strip()), 187)
+ def test_demo_project_deletion(self):
+ self.create_project('demo')
+ self.assertEquals(models.Project.query.get('demo').name, 'demo')
+
+ cmd = DeleteProject()
+ cmd.run('demo')
+
+ self.assertEqual(len(models.Project.query.all()), 0)
+
if __name__ == "__main__":
unittest.main()