きり丸の技術日記

技術検証したり、資格等をここに残していきます。

Pythonでgroup_byしたいならdefaultdictを使う

始めに

Pythonでデータをグループ化する際、defaultdictを使用すると簡単かつ効率的に実装できます。この記事では、defaultdictを使ったgroup_byの実装方法と、itertools.groupbyとの違いについて解説します。

環境

  • Python 3.12.6

実装

defaultdictを使用すればシンプルに実装できます。

from collections import defaultdict

class TestGroupBy:
    class _Test:
        def __init__(self, user_id, group_id):
            self.user_id = user_id
            self.group_id = group_id

    @pytest.fixture
    def parameters(self):
        return [
            self._Test(1, 'A'),
            self._Test(2, 'A'),
            self._Test(3, 'B'),
            self._Test(4, 'A'),
            self._Test(5, 'B'),
        ]

    class TestDefaultDict:
        def test_group_by(self, parameters):
            # NOTE: defaultdictは dictと違い、Keyが存在しない場合にもKeyErrorを発生させません
            grouped_data = defaultdict(list)
            for user in parameters:
                grouped_data[user.group_id].append(user.user_id)
            expected = {'A': [1, 2, 4], 'B': [3, 5]}
            assert dict(grouped_data) == expected

itertools.groupbyは次のコードで実装できます。

    class TestItertools:
        def test_group_by(self, parameters):
            # NOTE: ソートがかかっていないと正しくgroup_byされない
            non_continuous_data = {k: [user.user_id for user in v] for k, v in groupby(parameters, key=attrgetter('group_id'))}
            expected = {'A': [4], 'B': [5]}
            assert non_continuous_data == expected

            sorted_users = sorted(parameters, key=attrgetter('group_id'))

            grouped_data = {k: [user.user_id for user in v] for k, v in groupby(sorted_users, key=attrgetter('group_id'))}
            expected = {'A': [1, 2, 4], 'B': [3, 5]}
            assert grouped_data == expected

差分

基本的にはdefaultdictで問題ありません。

itertools.groupbyの場合はコード内にも記載していますが、非連続なデータの場合は期待どおりにgroup_byされないパターンがあるので特別に採用したいユースケースはないです。大規模データを変換したいことはあるでしょうが、そのときはpandasとかpolars使っているでしょうし…。一応、メモリ効率に軍配があがるので、OOMが発生したらitertools.groupbyを使用することを考えてもよいと思います。

ソースコード

終わりに

groupbyという名前がついているのでitertools.groupbyを使用していたのですが、非連続なデータでは使用できないという点でハマってしまいました。

Pythonで自分のブログに来る人はいないかもしれませんが、ぜひハマらないように注意してください。