Django MPTT serializes relational data efficiently with DRF
I have a category model which is an MPTT model. This is m2m for Group and I need to serialize the tree with the appropriate counts, imagine my category tree is:
Root (related to 1 group)
- Branch (related to 2 groups)
- Leaf (related to 3 groups)
...
Thus, the serialized output will look like this:
{
id: 1,
name: 'root1',
full_name: 'root1',
group_count: 6,
children: [
{
id: 2,
name: 'branch1',
full_name: 'root1 - branch1',
group_count: 5,
children: [
{
id: 3,
name: 'leaf1',
full_name: 'root1 - branch1 - leaf1',
group_count: 3,
children: []
}]
}]
}
This is my current super efficient implementation:
Model
class Category(MPTTModel):
name = ...
parent = ... (related_name='children')
def get_full_name(self):
names = self.get_ancestors(include_self=True).values('name')
full_name = ' - '.join(map(lambda x: x['name'], names))
return full_name
def get_group_count(self):
cats = self.get_descendants(include_self=True)
return Group.objects.filter(categories__in=cats).count()
View
class CategoryViewSet(ModelViewSet):
def list(self, request):
tree = cache_tree_children(Category.objects.filter(level=0))
serializer = CategorySerializer(tree, many=True)
return Response(serializer.data)
Serializer
class RecursiveField(serializers.Serializer):
def to_native(self, value):
return self.parent.to_native(value)
class CategorySerializer(serializers.ModelSerializer):
children = RecursiveField(many=True, required=False)
full_name = serializers.Field(source='get_full_name')
group_count = serializers.Field(source='get_group_count')
class Meta:
model = Category
fields = ('id', 'name', 'children', 'full_name', 'group_count')
This works, but also hits the DB with an insane amount of queries, and additional relationships, not just Group. Is there a way to make this efficient? How can I write my own serializer?
source to share
You are definitely running into an N + 1 query problem, which I covered in detail in another answer on Stack Overflow . I would recommend reading query optimization in Django as it is a very common problem.
Now Django MPTT also has a few issues that you will need to handle N + 1 queries. Both methods self.get_ancestors
and self.get_descendants
create a new set of queries, which in your case there is for each object that you serialize. You might want to look into the best way to avoid this, I have described possible improvements below.
In your method, get_full_name
you are calling self.get_ancestors
to create the chain that is being used. Given that you always have a parent when you create the output, you may need to move it to SerializerMethodField
, which reuses the parent to generate the name. The following might work:
class RecursiveField(serializers.Serializer):
def to_native(self, value):
return CategorySerializer(value, context={"parent": self.parent.object, "parent_serializer": self.parent})
class CategorySerializer(serializers.ModelSerializer):
children = RecursiveField(many=True, required=False)
full_name = SerializerMethodField("get_full_name")
group_count = serializers.Field(source='get_group_count')
class Meta:
model = Category
fields = ('id', 'name', 'children', 'full_name', 'group_count')
def get_full_name(self, obj):
name = obj.name
if "parent" in self.context:
parent = self.context["parent"]
parent_name = self.context["parent_serializer"].get_full_name(parent)
name = "%s - %s" % (parent_name, name, )
return name
You may need to modify this code a little, but the general idea is that you don't always need to get ancestors, because you will already have a chain of ancestors.
This doesn't resolve queries Group
that you might not be able to optimize, but at least reduce your queries. Recursive queries are incredibly difficult to optimize, and they usually plan to figure out how best to get the data they want without going back to N + 1 situations.
source to share
I found a solution for counting. Thanks to the function, django-mptt
get_cached_trees
you can do the following:
from django.db.models import Count
class CategorySerializer(serializers.ModelSerializer):
def get_group_count(self, obj, field=field):
return obj.group_count
class Meta:
model = Category
fields = [
'name',
'slug',
'children',
'group_count',
]
CategorySerializer._declared_fields['children'] = CategorySerializer(
many=True,
source='get_children',
)
class CategoryViewSet(ModelViewSet):
serializer_class = CategorySerializer
def get_queryset(self, queryset=None):
queryset = Category.tree.annotate('group_count': Count('group')})
queryset = queryset.get_cached_trees()
return queryset
Where tree is mptt
TreeManager
, as used in django-categories
, for which I wrote a little more complex code for this PR: https://github.com/callowayproject/django-categories/pull/145/files
source to share