きり丸の技術日記

技術検証したり、資格等をここに残していきます。

Pythonでflatmapをする

始めに

小ネタ。PythonでNestした配列に対して、flatなデータにしたいと思った時の処理を残します。

環境

  • Python 3.13

実装

構造が2階層だけならfrom itertools import chainlist(chain.from_iterable(input_))を使用する。

from itertools import chain

async def test_01(self):
    input_ = ["1", ["2", "3"]]

    actual = list(chain.from_iterable(input_))

    assert ["1", "2", "3"] == actual

それ以上に複雑な構造をflatにしたい場合は自作関数を使用して再起処理をするのがよい。

def flatten(lst):
    for el in lst:
        if isinstance(el, list):
            yield from flatten(el)
        else:
            yield el

async def test_04(self):
    input_ = [1, [2, [3, [4, 5]]]]

    actual = list(flatten(input_))

    assert [1, 2, 3, 4, 5] == actual

ソースコード

終わりに

Pythonは配列操作が微妙に面倒ですね。pandasとかpolarsとかを使ったほうが楽かもしれませんが、アプリケーションPythonだとわざわざ使うこともないので、さくっとID項目を抽出したいときはchainを使うのがオススメです。