diff options
Diffstat (limited to 'ihatemoney/models.py')
| -rw-r--r-- | ihatemoney/models.py | 74 |
1 files changed, 72 insertions, 2 deletions
diff --git a/ihatemoney/models.py b/ihatemoney/models.py index 4d32fd9..d765c93 100644 --- a/ihatemoney/models.py +++ b/ihatemoney/models.py @@ -1,6 +1,8 @@ from collections import defaultdict from datetime import datetime + +import sqlalchemy from flask_sqlalchemy import SQLAlchemy, BaseQuery from flask import g, current_app @@ -13,6 +15,40 @@ from itsdangerous import ( BadSignature, SignatureExpired, ) +from sqlalchemy_continuum import make_versioned +from sqlalchemy_continuum.plugins import FlaskPlugin +from sqlalchemy_continuum import version_class + +from ihatemoney.patch_sqlalchemy_continuum import PatchedBuilder +from ihatemoney.versioning import ( + LoggingMode, + ConditionalVersioningManager, + version_privacy_predicate, + get_ip_if_allowed, +) + + +make_versioned( + user_cls=None, + manager=ConditionalVersioningManager( + # Conditionally Disable the versioning based on each + # project's privacy preferences + tracking_predicate=version_privacy_predicate, + # Patch in a fix to a SQLAchemy-Continuum Bug. + # See patch_sqlalchemy_continuum.py + builder=PatchedBuilder(), + ), + plugins=[ + FlaskPlugin( + # Redirect to our own function, which respects user preferences + # on IP address collection + remote_addr_factory=get_ip_if_allowed, + # Suppress the plugin's attempt to grab a user id, + # which imports the flask_login module (causing an error) + current_user_id_factory=lambda: None, + ) + ], +) db = SQLAlchemy() @@ -22,11 +58,20 @@ class Project(db.Model): def get_by_name(self, name): return Project.query.filter(Project.name == name).one() + # Direct SQLAlchemy-Continuum to track changes to this model + __versioned__ = {} + id = db.Column(db.String(64), primary_key=True) name = db.Column(db.UnicodeText) password = db.Column(db.String(128)) contact_email = db.Column(db.String(128)) + logging_preference = db.Column( + db.Enum(LoggingMode), + default=LoggingMode.default(), + nullable=False, + server_default=LoggingMode.default().name, + ) members = db.relationship("Person", backref="project") query_class = ProjectQuery @@ -37,6 +82,7 @@ class Project(db.Model): "id": self.id, "name": self.name, "contact_email": self.contact_email, + "logging_preference": self.logging_preference.value, "members": [], } @@ -277,6 +323,9 @@ class Project(db.Model): return None return data["project_id"] + def __str__(self): + return self.name + def __repr__(self): return "<Project %s>" % self.name @@ -301,6 +350,11 @@ class Person(db.Model): query_class = PersonQuery + # Direct SQLAlchemy-Continuum to track changes to this model + __versioned__ = {} + + __table_args__ = {"sqlite_autoincrement": True} + id = db.Column(db.Integer, primary_key=True) project_id = db.Column(db.String(64), db.ForeignKey("project.id")) bills = db.relationship("Bill", backref="payer") @@ -337,8 +391,9 @@ class Person(db.Model): # We need to manually define a join table for m2m relations billowers = db.Table( "billowers", - db.Column("bill_id", db.Integer, db.ForeignKey("bill.id")), - db.Column("person_id", db.Integer, db.ForeignKey("person.id")), + db.Column("bill_id", db.Integer, db.ForeignKey("bill.id"), primary_key=True), + db.Column("person_id", db.Integer, db.ForeignKey("person.id"), primary_key=True), + sqlite_autoincrement=True, ) @@ -365,6 +420,11 @@ class Bill(db.Model): query_class = BillQuery + # Direct SQLAlchemy-Continuum to track changes to this model + __versioned__ = {} + + __table_args__ = {"sqlite_autoincrement": True} + id = db.Column(db.Integer, primary_key=True) payer_id = db.Column(db.Integer, db.ForeignKey("person.id")) @@ -403,6 +463,9 @@ class Bill(db.Model): else: return 0 + def __str__(self): + return "%s for %s" % (self.amount, self.what) + def __repr__(self): return "<Bill of %s from %s for %s>" % ( self.amount, @@ -426,3 +489,10 @@ class Archive(db.Model): def __repr__(self): return "<Archive>" + + +sqlalchemy.orm.configure_mappers() + +PersonVersion = version_class(Person) +ProjectVersion = version_class(Project) +BillVersion = version_class(Bill) |
