fix shopping list sharing

This commit is contained in:
Chris Scoggins
2022-01-29 11:59:06 -06:00
parent e2f8f29ec8
commit a7796cbf5c
3 changed files with 40 additions and 7 deletions

View File

@ -205,9 +205,9 @@ class CustomIsShared(permissions.BasePermission):
return request.user.is_authenticated return request.user.is_authenticated
def has_object_permission(self, request, view, obj): def has_object_permission(self, request, view, obj):
# temporary hack to make old shopping list work with new shopping list # # temporary hack to make old shopping list work with new shopping list
if obj.__class__.__name__ in ['ShoppingList', 'ShoppingListEntry']: # if obj.__class__.__name__ in ['ShoppingList', 'ShoppingListEntry']:
return is_object_shared(request.user, obj) or obj.created_by in list(request.user.get_shopping_share()) # return is_object_shared(request.user, obj) or obj.created_by in list(request.user.get_shopping_share())
return is_object_shared(request.user, obj) return is_object_shared(request.user, obj)

View File

@ -609,7 +609,7 @@ class NutritionInformation(models.Model, PermissionModelMixin):
) )
proteins = models.DecimalField(default=0, decimal_places=16, max_digits=32) proteins = models.DecimalField(default=0, decimal_places=16, max_digits=32)
calories = models.DecimalField(default=0, decimal_places=16, max_digits=32) calories = models.DecimalField(default=0, decimal_places=16, max_digits=32)
source = models.CharField( max_length=512, default="", null=True, blank=True) source = models.CharField(max_length=512, default="", null=True, blank=True)
space = models.ForeignKey(Space, on_delete=models.CASCADE) space = models.ForeignKey(Space, on_delete=models.CASCADE)
objects = ScopedManager(space='space') objects = ScopedManager(space='space')
@ -852,11 +852,12 @@ class ShoppingListEntry(ExportModelOperationsMixin('shopping_list_entry'), model
def __str__(self): def __str__(self):
return f'Shopping list entry {self.id}' return f'Shopping list entry {self.id}'
# TODO deprecate
def get_shared(self): def get_shared(self):
return self.shoppinglist_set.first().shared.all() try:
return self.shoppinglist_set.first().shared.all()
except AttributeError:
return self.created_by.userpreference.shopping_share.all()
# TODO deprecate
def get_owner(self): def get_owner(self):
try: try:
return self.created_by or self.shoppinglist_set.first().created_by return self.created_by or self.shoppinglist_set.first().created_by
@ -881,6 +882,12 @@ class ShoppingList(ExportModelOperationsMixin('shopping_list'), models.Model, Pe
def __str__(self): def __str__(self):
return f'Shopping list {self.id}' return f'Shopping list {self.id}'
def get_shared(self):
try:
return self.shared.all() or self.created_by.userpreference.shopping_share.all()
except AttributeError:
return []
class ShareLink(ExportModelOperationsMixin('share_link'), models.Model, PermissionModelMixin): class ShareLink(ExportModelOperationsMixin('share_link'), models.Model, PermissionModelMixin):
recipe = models.ForeignKey(Recipe, on_delete=models.CASCADE) recipe = models.ForeignKey(Recipe, on_delete=models.CASCADE)

View File

@ -156,6 +156,32 @@ def test_sharing(request, shared, count, sle_2, sle, u1_s1):
# confirm shared user sees their list and the list that's shared with them # confirm shared user sees their list and the list that's shared with them
assert len(json.loads(r.content)) == count assert len(json.loads(r.content)) == count
# test shared user can mark complete
x = shared_client.patch(
reverse(DETAIL_URL, args={sle[0].id}),
{'checked': True},
content_type='application/json'
)
r = json.loads(shared_client.get(reverse(LIST_URL)).content)
assert len(r) == count
# count unchecked entries
if not x.status_code == 404:
count = count-1
assert [x['checked'] for x in r].count(False) == count
# test shared user can delete
x = shared_client.delete(
reverse(
DETAIL_URL,
args={sle[1].id}
)
)
r = json.loads(shared_client.get(reverse(LIST_URL)).content)
assert len(r) == count
# count unchecked entries
if not x.status_code == 404:
count = count-1
assert [x['checked'] for x in r].count(False) == count
def test_completed(sle, u1_s1): def test_completed(sle, u1_s1):
# check 1 entry # check 1 entry