aboutsummaryrefslogtreecommitdiff
path: root/ihatemoney/models.py
diff options
context:
space:
mode:
Diffstat (limited to 'ihatemoney/models.py')
-rw-r--r--ihatemoney/models.py74
1 files changed, 72 insertions, 2 deletions
diff --git a/ihatemoney/models.py b/ihatemoney/models.py
index 4d32fd9..d765c93 100644
--- a/ihatemoney/models.py
+++ b/ihatemoney/models.py
@@ -1,6 +1,8 @@
from collections import defaultdict
from datetime import datetime
+
+import sqlalchemy
from flask_sqlalchemy import SQLAlchemy, BaseQuery
from flask import g, current_app
@@ -13,6 +15,40 @@ from itsdangerous import (
BadSignature,
SignatureExpired,
)
+from sqlalchemy_continuum import make_versioned
+from sqlalchemy_continuum.plugins import FlaskPlugin
+from sqlalchemy_continuum import version_class
+
+from ihatemoney.patch_sqlalchemy_continuum import PatchedBuilder
+from ihatemoney.versioning import (
+ LoggingMode,
+ ConditionalVersioningManager,
+ version_privacy_predicate,
+ get_ip_if_allowed,
+)
+
+
+make_versioned(
+ user_cls=None,
+ manager=ConditionalVersioningManager(
+ # Conditionally Disable the versioning based on each
+ # project's privacy preferences
+ tracking_predicate=version_privacy_predicate,
+ # Patch in a fix to a SQLAchemy-Continuum Bug.
+ # See patch_sqlalchemy_continuum.py
+ builder=PatchedBuilder(),
+ ),
+ plugins=[
+ FlaskPlugin(
+ # Redirect to our own function, which respects user preferences
+ # on IP address collection
+ remote_addr_factory=get_ip_if_allowed,
+ # Suppress the plugin's attempt to grab a user id,
+ # which imports the flask_login module (causing an error)
+ current_user_id_factory=lambda: None,
+ )
+ ],
+)
db = SQLAlchemy()
@@ -22,11 +58,20 @@ class Project(db.Model):
def get_by_name(self, name):
return Project.query.filter(Project.name == name).one()
+ # Direct SQLAlchemy-Continuum to track changes to this model
+ __versioned__ = {}
+
id = db.Column(db.String(64), primary_key=True)
name = db.Column(db.UnicodeText)
password = db.Column(db.String(128))
contact_email = db.Column(db.String(128))
+ logging_preference = db.Column(
+ db.Enum(LoggingMode),
+ default=LoggingMode.default(),
+ nullable=False,
+ server_default=LoggingMode.default().name,
+ )
members = db.relationship("Person", backref="project")
query_class = ProjectQuery
@@ -37,6 +82,7 @@ class Project(db.Model):
"id": self.id,
"name": self.name,
"contact_email": self.contact_email,
+ "logging_preference": self.logging_preference.value,
"members": [],
}
@@ -277,6 +323,9 @@ class Project(db.Model):
return None
return data["project_id"]
+ def __str__(self):
+ return self.name
+
def __repr__(self):
return "<Project %s>" % self.name
@@ -301,6 +350,11 @@ class Person(db.Model):
query_class = PersonQuery
+ # Direct SQLAlchemy-Continuum to track changes to this model
+ __versioned__ = {}
+
+ __table_args__ = {"sqlite_autoincrement": True}
+
id = db.Column(db.Integer, primary_key=True)
project_id = db.Column(db.String(64), db.ForeignKey("project.id"))
bills = db.relationship("Bill", backref="payer")
@@ -337,8 +391,9 @@ class Person(db.Model):
# We need to manually define a join table for m2m relations
billowers = db.Table(
"billowers",
- db.Column("bill_id", db.Integer, db.ForeignKey("bill.id")),
- db.Column("person_id", db.Integer, db.ForeignKey("person.id")),
+ db.Column("bill_id", db.Integer, db.ForeignKey("bill.id"), primary_key=True),
+ db.Column("person_id", db.Integer, db.ForeignKey("person.id"), primary_key=True),
+ sqlite_autoincrement=True,
)
@@ -365,6 +420,11 @@ class Bill(db.Model):
query_class = BillQuery
+ # Direct SQLAlchemy-Continuum to track changes to this model
+ __versioned__ = {}
+
+ __table_args__ = {"sqlite_autoincrement": True}
+
id = db.Column(db.Integer, primary_key=True)
payer_id = db.Column(db.Integer, db.ForeignKey("person.id"))
@@ -403,6 +463,9 @@ class Bill(db.Model):
else:
return 0
+ def __str__(self):
+ return "%s for %s" % (self.amount, self.what)
+
def __repr__(self):
return "<Bill of %s from %s for %s>" % (
self.amount,
@@ -426,3 +489,10 @@ class Archive(db.Model):
def __repr__(self):
return "<Archive>"
+
+
+sqlalchemy.orm.configure_mappers()
+
+PersonVersion = version_class(Person)
+ProjectVersion = version_class(Project)
+BillVersion = version_class(Bill)