stMind

about Tech, Computer vision and Machine learning

ScenicのSegment AnythingをGoogle Colab実行

Segment Anythingは、facebook researchのpytorch実装がありますが、 ScenicにもJax/FlaxのSegment Anythingの実装が公開されています。

READMEにある通りですが、Google Colabでシンプルに実行してみます。

image_pathを変えれば、他の画像でも試せる。

READMEでは、モデル実行時のキーワード引数(input_imageではなくimage)が間違っていて、一部実行エラーになる箇所が あったので、細かい点ではあるがPRを送ってみた。

Scenic: A JAX Library for Computer Vision Research and Beyond

github.com

Scenicは、TransformerベースのモデルにフォーカスしたオープンソースのJAXライブラリ。 最近、Transformerを適用した動画認識モデルの論文(ViViT, MTV, ObjectViViT)を読んでいる中で見かけていました。

研究のコードであっても、構造化され、実験しやすいことが、色々なアイデアを素早く検証できるベースになることを実感していて、 Scenicが気になっていました。 そこで、arxivに公開されているScenicの論文を読んでみたので、ここで内容をメモしておきます。

arxiv.org

Abstract

Scenicの目的は、新しいビジョンアーキテクチャやモデルの素早い実験、プロトタイピング、リサーチを促進すること。 Scenicは、マルチホスト、マルチデバイスの大規模学習のためのGPU/TPUサポートとともに、多様なビジョンタスク(分類、セグメンテーション、検出など)に対応し、マルチモーダルな問題に対する作業を容易にする。 また、幅広いモダリティのSOTAなモデルの最適化実装も提供する。

Introduction

Scenicとは何かを一言で表すと以下の通りとなる。

  • ビジョン分野を中心として、大規模なモデルを学習する際に遭遇するタスクを解決するための軽量な共有ライブラリ
  • これらのライブラリを利用した固有の問題に対応する多数のプロジェクト

Scenicは、様々な抽象化レベルを提供するように設計されている。例えば、ハイパラの変更のみのプロジェクト、入力パイプラインから、モデルのアーキテクチャ、ロスやメトリクス、学習ループまでカスタマイズが必要なプロジェクトなど。

これを実現するために、Scenicは大きく二つのレベルに整理されている。

  • project-level code
    • 特定のプロジェクトやベースライン用にカスタマイズされたコード
  • library-level code
    • 多くのプロジェクトに共通する機能や一般的なパターンのコード

philosophy

Scenicは大規模モデルの素早いプロトタイピングを促進することを目的としている。コードを理解、拡張しやすくするために、複雑さを加えたり、抽象度を上げるよりも、フォークやコピーを好む。 複数のモデルやタスクに広く有用である場合のみ、library-levelに機能を加える。library-levelで固有のユースケースのサポートを最小限にすることで、複雑で理解しづらくなる一般化を避ける。一方で、project-levelでは、複雑さや抽象化を加えることができる。

Design

Library-level code

目標は、ライブラリレベルのコードを最小限かつ十分にテストされたものに保ち、マイナーなユースケースをサポートするために余分な抽象化を導入しないようにすること。 共有ライブラリはFigure 1にあるように4つに分割されている。

  • dataset_lib
    • 一般的なタスクやベンチマークのデータをロードし、前処理するためのIOパイプラインを実装。
  • model_lib
    • タスクに特化したロスとメトリクスを持つ、いくつかの抽象モデル・インターフェース(例:model_lib/base_modelsのClassificationModelやSegmentationModel)
    • attentionとtransformerの効率的な実装に焦点を当てたNN Layer(model_lib/layers)
    • アクセラレーターフレンドリーなbipartite matching algorithmの実装(model_lib/matchers)
  • train_lib
    • 学習ループを構築するためのツールを提供
  • common_lib
    • ロギングやデバッグモジュール、Raw データを処理するための機能などの共通ユーティリティ

Project-level code

Project-level codeは、"プロジェクト "という概念によって、個別タスクやデータのためにカスタマイズされたソリューションの開発をサポート。プロジェクトは、設定ファイルのみで共通のモデルや学習器などのlibrary-level codeを使うこともできるし、フォークして再定義することもできる。 ResNetやViT、DETRはprojects/baselinesに実装されている。

Scenic BaseModel

ソリューションは通常、データやタスクのパイプライン、モデルアーキテクチャ、ロスとメトリクス、学習と評価などのパーツに分かれている。Scenicで行われる研究の多くが異なるアーキテクチャを試していることから、プラグイン/プラグアウトでの実験を容易にするための「model」という概念を導入。「model」は、ネットワークのアーキテクチャ、ロス、評価メトリクスとして定義され、BaseModelとして実装されている。

BaseModelは抽象クラスで、3つのメンバを持つ。

  • build_flax_model
  • loss_fn
  • get_metrics_fn

Scenicのモデルを定義する抽象クラスは、model_lib/base_modelsにあり、BaseModelの他に、BaseModelを継承したClassificationModel、MulitLabelClassificationModel、EncoderDecoderModel、SegmentationModelも含まれる。

これらのデザインパターンは推奨であり、様々なプロジェクトでうまく機能するが、強制ではなく、プロジェクト内でこの構造から逸脱しても問題はない。


一言

Library-levelのコードと、Project-levelのコードの二つに分けて、複数のモデルやタスクに広く有用である場合のみ、Library-levelに機能を追加するという考えは分かりやすいと思いました。bipartite matchingはattentionと同じ階層でLibrary-levelなんですね。OpenPoseで使われていたと記憶してますが、他にも幅広く使われているということなのか。

GPTを自作して学習済みパラメータでテキスト生成

2024年の最初のエントリーはGPTです。 GPTモデルを自作して、OpenAIが公開している学習済みのパラメータをロード、テキスト生成までの一連の処理を実行します。

モデル

正確にはGPT2のTransformerブロックを自作します。 アーキテクチャの大部分はGPTと同じですが、以下の変更(pre-norm)が行われています。

  • LayerNormはAttentionとMLPの前で適用
  • 追加のLayerNormをTransformerブロックの後で適用

Transformerブロックを除くText & Position埋め込みとNext Token生成は、 picoGPTのコードを利用します(解説ブログは GPT in 60 Lines of NumPy | Jay Mody)。

また、以下で紹介するコードはTensorflowを用いて実装しています(picoGPTの諸々のコードがTensorflowを利用していて、そのまま使いたかったため)。

Transformerブロック

GPT2の論文ではアーキテクチャの図がないので、下記はGPTのアーキテクチャ図ですが、上で書いたように、LayerNormはAttentionとMLPの前で適用します。

これをTransformerDecoderBlockクラスとして用意します(推論だけ行うのでDropoutは不要ですが)。

class TransformerDecoderBlock(tf.keras.Model):
    def __init__(self, h_dim, n_heads, drop_p):
        super().__init__()

        self.attn = MaskedMultiSelfAttention(h_dim, n_heads, drop_p)
        self.mlp = tf.keras.Sequential(
            [
                tf.keras.layers.Dense(units=4 * h_dim, activation="gelu"),
                tf.keras.layers.Dense(units=h_dim),
                tf.keras.layers.Dropout(rate=drop_p),
            ]
        )
        self.ln1 = tf.keras.layers.LayerNormalization()
        self.ln2 = tf.keras.layers.LayerNormalization()

    def call(self, x):
        x = self.attn(self.ln1(x)) + x
        x = self.mlp(self.ln2(x)) + x
        return x

Masked Multi Self Attention

Attentionの計算は複雑なところはなく、以前に作った ゼロから作るVision Transformer (JAX/Flax) - stMindとも同じです。

  • 入力トークン列から、クエリq、キーk、バリューvを作成
  • 複数ヘッド毎のアテンション行列とアテンションの計算
  • ヘッド毎のアテンションを集約

ただし、このままだと現在のトークンが未来のトークンも参照してしまうことになるので、アテンション行列において現在のトークンと未来のトークンの関係性はなし(0)にする必要があります。未来のトークンは、アテンション行列の各行で列方向に並ぶので、下三角行列を作成してアテンション行列をマスク。この時、マスクした後でソフトマックスを計算すると正規化されなくなるので、未来のトークンは非常に小さい値にしておいて、そのあとでソフトマックスを適用します。

class MaskedMultiSelfAttention(tf.keras.layers.Layer):
    def __init__(self, h_dim, n_heads, drop_p):
        super(MaskedMultiSelfAttention, self).__init__()
        self.n_heads = n_heads

        self.c_attn = tf.keras.layers.Dense(3 * h_dim)

        self.c_proj = tf.keras.layers.Dense(h_dim)

        self.attn_drop = tf.keras.layers.Dropout(drop_p)
        self.proj_drop = tf.keras.layers.Dropout(drop_p)

    def call(self, x):
        B, T, C = x.shape
        N, D = self.n_heads, C // self.n_heads

        # Create lower triangle mask
        mask = tf.linalg.band_part(tf.ones((T, T)), -1, 0)
        mask = tf.reshape(mask, (1, 1, T, T))

        qkv = self.c_attn(x)
        q, k, v = tf.split(qkv, 3, axis=-1)
        q = tf.reshape(q, (B, T, N, D))
        k = tf.reshape(k, (B, T, N, D))
        v = tf.reshape(v, (B, T, N, D))

        q = tf.transpose(q, perm=[0, 2, 1, 3])
        k = tf.transpose(k, perm=[0, 2, 1, 3])
        v = tf.transpose(v, perm=[0, 2, 1, 3])

        weights = tf.matmul(q, k, transpose_b=True) / tf.math.sqrt(
            tf.cast(D, dtype=tf.float32)
        )

        # Apply mask
        weights += (1 - mask) * -1e9

        normalized_weights = tf.nn.softmax(weights, axis=-1)
        attention = self.attn_drop(tf.matmul(normalized_weights, v))
        attention = tf.transpose(attention, perm=[0, 2, 1, 3])
        attention = tf.reshape(attention, (B, T, C))

        out = self.proj_drop(self.c_proj(attention))
        return out

GPT2モデル全体

Transformerブロックができたので、Text & Position埋め込みと追加LayerNormを含めたGPT2全体を作ります。 GPT2全体コードは長くなるので、callメソッドだけ抜き出すと下のようになります。 埋め込みベクトルは、次で説明するパラメータを使って、input_idsに対する埋め込みを取り出して生成します。

def call(self, input_ids):
    # Text and Position Embedding
    input_ids = tf.cast(input_ids, tf.int32)
    x = tf.gather(self.wte, input_ids) + tf.gather(
        self.wpe, range(input_ids.shape[1])
    )
    # Transformer Block (Decoder only)
    for block in self.blocks:
        x = block(x)
    # Additional LayerNorm
    x = self.layer_norm(x)
    # Linear
    return tf.matmul(x, self.params["wte"].T)

トークン生成テスト

OpenAIの学習済みパラメータを使用して、トークンを生成してみます。

学習済みパラメータ

パラメータには、入力と位置埋め込み、Transformerの各ブロックのパラメータがあり、picoGPTで辞書形式に変換されているものを使用します。

  • blocks : Transformerブロックのパラメータ
  • ln_f : 追加のLayerNormのパラメータ
  • wpe : 位置埋め込みベクトル
  • wte : トークンの埋め込みベクトル

また、blocksは124Mのモデルの場合は12個の要素があり、それぞれが以下の項目を含んでいます。(768は次元数)

  • attn : アテンションブロック(以下のbはバイアス項、wは重み)
    • c_attn : {"b": [2304], "w": [768, 2304]}
    • c_proj : {"b": [768], "w": [768, 768]}
  • ln1 : {"b": [768], "g": [768]}, Attentionの前に適用するLayerNorm。bはbeta、gはgamma
  • ln2 : {"b": [768], "g": [768]}, MLPの前に適用するLayerNorm
  • mlp : MLPブロック
    • c_fc : {"b": [3072], "w": [768, 3072]}
    • c_proj : {"b": [768], "w": [3072, 768]}

Tensorflowにおけるパラメータの割り当て

tf.keras.layers.Layerのset_weigthsを使います。この関数は、numpy の配列からパラメータ値を設定します。 例えば、c_attnの場合だと、これはDense層なのでwとbの順番でset_weightsに指定します。

block.layers[0].c_attn.set_weights(
    [
        self.params["blocks"][layer_idx]["attn"]["c_attn"]["w"],
        self.params["blocks"][layer_idx]["attn"]["c_attn"]["b"],
    ]
)

ln_fはLayerNormなので、gammaとbetaでset_weightsに指定します。

self.layer_norm.set_weights(
    [self.params["ln_f"]["g"], self.params["ln_f"]["b"]]
)

生成結果

GPT2モデルの作成、重みパラメータの設定が出来たので、 picoGPTと同じプロンプトで実験してみます。

$ python tf/gpt_tf.py --prompt "Alan Turing theorized that computers would one day become" --n_tokens_to_generate 8
...
Input text:
 Alan Turing theorized that computers would one day become
Generated:
  the most powerful machines on the planet.

同じ結果が生成されました。生成には、M1 Macで2秒弱くらいかかりました。

別のプロンプトも試してみます。

python tf/gpt_tf.py --prompt "Imagination is more important" --n_tokens_to_generate 6
...
Input text:
 Imagination is more important
Generated:
  than any other skill.

文章としては問題ないものが生成されたように思います。

まとめ

以前のViTと今回のGPTでAttentionの自作をしましたが、処理自体はそれほど複雑ではないので、 Transformerブロックを実装するのは、慣れれば難しくはないように感じました。また、モデルを実装して学習するのはHW制約などもあって大変なことが多いですが、推論であれば公開されているパラメータを使うことで、比較的試してみやすいのではと思います。コアとなるTransformer、Attentionを自作することで、Transformer系の論文の数式やコードの読解力が上がったように感じるので、興味のある方は自作にトライしてみることをオススメします。

参考文献とコード

GPT in 60 Lines of NumPy | Jay Mody

GitHub - satojkovic/gpt-tf-pytorch-jax: GPT from scratch (tensorflow / pytorch / jax)

Pose Estimationとローパスフィルタ(One Euro Filter)の実験

フレーム単位の姿勢推定を動画に適用した場合に発生するジッターを解決するローパスフィルタを実験。

使用したのは、ここで紹介されていたコード。

towardsdatascience.com

アクション認識のデータセットUCF-101の動画を利用。実行したのはM1 Mac

フィルタなし

フィルタあり

処理時間としてはほとんど変化ないが、滑らかな追跡となっている。(代わりに追跡の遅れが気になる)

Lambda Labs GPU CloudでJAX/Flax

MacのMetalを使って、手持ちのM1 MacにもJAX/Flaxの実行環境を作ることは出来るのですが、 実際に学習をしようとしてもエラーで詰まってしまうことが多く、JAX/Flaxを実行できる環境を探していました。

Colabを使っても良いのですが、学習を実行するだけでなくて、JAXのビルド自体も試してみたいと思ったので、Lambda Labs GPU Cloudで実行してみることにしました。

実行したのは、FlaxのチュートリアルにあるMNISTの画像分類モデル(CNN)の学習です。

CPU実行するだけであれば何もする必要はなかったのですが、GPUを使う場合には少しだけ苦労しました。

GPU実行時のエラーとTF_FORCE_GPU_ALLOW_GROWTH

最初にGPUで実行したとき、次のようなエラーが出ました。

...
2023-09-18 05:06:37.640924: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
...
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

今回のチュートリアルでは、tensorflow_datasetsを使ってMNISTのデータセットをロードしていたけれど、tensorflowは何も設定しないと、初期化時にGPUメモリの大部分を割り当ててしまうので、TF_FORCE_GPU_ALLOW_GROWTH=trueとしておかないとエラーになる可能性がある様子。

stackoverflow.com

github.com

github.com

最後のIssueでは、PyTorchと一緒に使った場合にCUDNNのエラーになったようで、こちらはXLA_PYTHON_CLIENT_MEM_FRACTION=.88としてJAXのGPU割り当てを制限する方法で解決していた。

MNISTの学習実行

export TF_FORCE_GPU_ALLOW_GROWTH=trueとして実行すると、Lambda Labs GPU CloudでJAX/Flaxで書いたCNNモデルが学習できました。

...
2023-09-18 05:25:48.968484: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:42] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
...
train epoch: 10, loss: 0.007537598721683025, accuracy: 99.84666442871094
test epoch: 10, loss: 0.032926980406045914, accuracy: 99.0184326171875

論文紹介:Video-LLaMA

MetaがLLaMALLaMA2と公開したことで、どんどん加速しているように見えるLLMに関する研究ですが、LLMに関するOpen Challengesの一つとも考えられている様々なモダリティの取り込みに関する研究で、Video-LLaMAを読んでみました。

arxiv.org

github.com

Video-LLaMAで出来ること

GitHubのREADMEにある図を見ると分かりやすいです。

例えば、下の図では、Video-LLaMAに「何が聞こえるかを述べて」と尋ねると、「足音と家の中で犬が吠えている」と回答してます。また、「サングラスはかけているか?」に対してもYesと回答していて、動画の視覚的な内容と聴覚的な内容に対して、正しく回答を生成できています。

https://user-images.githubusercontent.com/18526640/244575825-7f7bddb2-5cf1-4cf4-bce3-3fa67974cbb3.gif

他にも、動画の時系列な認識が出来ている例として「船はどちらの方向に動いているのか」に対して、正しい方向を回答できています。

https://user-images.githubusercontent.com/18526640/244579143-7304ad6f-1009-46f1-aca4-7f861b636363.gif

Video-LLaMAのアーキテクチャ

Video-LLaMAでは、動画における視覚と聴覚の内容を認識できるようにするために、動画における時間変化を捉え、視覚と聴覚のデータを扱う二つのブランチを持つアーキテクチャを提案しています。

Vision Language Branchは、4つの要素で構成。動画フレームを入力として、フレーム毎の特徴を抽出するためのVisual Encoder、位置情報を追加するPosition Embedding Layer、フレーム毎の特徴を統合するVideo Q-former、LLMのテキスト埋め込みと同じ次元のベクトルを出力するためのLinear Layerです。 Visual Encoderは事前学習済みのViT-G/14とQ-formerを利用、Video Q-formerはBLIP2のQ-formerを実装している。

Audio Language Branchも同じように、4つの要素で構成。動画を短いクリップ(2秒間)にして、クリップ毎の音声データから特徴を抽出するAudio Encoder、位置情報を追加するPosition Embedding Layer、クリップ毎の特徴を統合するAudio Q-former、出力を生成するLinear Layerです。 Audio Encoderは事前学習済みのImageBind、Audio Q-formerもVideo Q-formerと同様にBLIP2のQ-formerを実装している。

Multi Branch Cross Modal Training

学習は2ステージで行い、最初のステージでは大規模なキャプションデータセットを使って学習、次のステージではinstruction following データセットでFine-tuningを行う。 二つのブランチは別々に学習を行うが、Audioとテキストが含まれるデータは希少なので、Audio BranchでもVision Branchと同じ動画キャプション(WebVid2M)、画像キャプション(CC595K)データセットを用いて学習する。Audio Branchは、Audio EncoderとしてImageBindをつかったことで、音声データで学習していないにも関わらず、推論時には音声を理解するように学習されたと述べている。 二つ目のステージでは、MiniGPT-4のimage-detail-description dataset、LLaVAのimage-instruction dataset、Video-Chatのvideo-instruction datasetを統合して使用する。

まとめ

以上が、Video-LLaMAのざっくりの内容です。Vision Language BranchとAudio Language Branchを用意して、動画フレーム/音声データに基づいた回答を行えるようにしているのはシンプルな仕組みだと思います。Audio Language Branchと言いながら、学習には音声データが使われていないのは不思議な感じですが。

JAXとcomposable program transformations

https://github.com/google/jaxのAboutは、次のように記述されています。

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Composable transformationsはどういうことなのか? NeurIPS2020: JAX Ecosystem Meetupの動画で、DeepMindのエンジニアの方が解説されていました。

NeurIPS 2020: JAX Ecosystem Meetup - YouTube

例として、次の関数を考えます。

def fn(x, y):
  return x**2 + y

fn(1., 2.) # (1**2 + 2) = 3

これに対して、gradientはどう書けるか?

df_dx = grad(fn)
df_dx(1., 2.) # df_dx = 2*x = 2*1 = 2

ここで、gradは関数を返す関数で、df_dxも関数になる。そして、通常の関数呼び出しで使用することができる。

さらに、second order gradientはどう書けるか?

df2_dx = grad(grad(fn))
df2_dx(1., 2.) # df2_dx = d(2*x)_dx = 2

gradはcomposableなため、gradをもう一つ追加するだけで良い。

composableなのはgradだけでなく、他の変換も使用することができる。compiled second-order gradientは以下のように実行できる。

df2_dx = jit(grad(grad(fn)))
df2_dx(1., 2.) # 2, ここでcompileされる
df2_dx(1., 2.) # 2, XLA pre compileのコードを実行、一回目よりも早い実行ができる

さらに、バッチ計算もcomposableに付け加えることができる。(batched compiled second-order gradient)

df2_dx = vmap(jit(grad(grad(fn))))
xs = jnp.ones((batch_size,))
df2_dx(xs, 2 * xs) # [2, 2] if batch_size=2

複数のアクセラレータ(GPUなど)で実行する場合も、composableに付け加えることができる。(multi-gpu batched compiled second-order gradient)

df2_dx = pmap(vmap(jit(grad(grad(fn)))))
xs = jnp.ones((num_gpus, batch_size,))
df2_dx(xs, 2 * xs) # [[2, 2], [2, 2]] if batch_size=2 and num_gpus=2

まとめ

以上が約5分のプレゼンで解説されていた内容ですが、分かりやすくて、 変換の組み合わせってそういうことか!と感動しました。 また動画には、HaikuやOptaxといったEcosystemの話や、他にもGANsなど様々なJAX実装の例があり、勉強になりました。 前回、JAX/FlaxでViTを実装してみましたが、今年はJAXをもっと使っていこうと思います。