混合精度トレーニング手法の進展
私たちの新しい方法は、精度を落とさずに機械学習のメモリ使用量を減らすよ。
― 1 分で読む
目次
機械学習の世界では、モデルがどんどん大きくて複雑になってる。こういう成長は、画像認識や言語処理などのタスクの正確性を向上させることが多いけど、大きな問題も出てくる。大きなモデルはトレーニングにたくさんのメモリと計算力を必要とするんだ。従来のトレーニング方法は、単精度浮動小数点演算という数学の一種に頼ってるから、高いメモリ使用量になってトレーニングが遅くなることがあるんだ。
混合精度トレーニング
この問題を解決するために、混合精度トレーニングという方法が導入された。このテクニックは、半精度と単精度の2種類の数学フォーマットを使う。ほとんどの計算に半精度を使って、必要な時だけ単精度に切り替えることで、正確性を犠牲にせずにメモリの必要量を減らせるんだ。
現在のアプローチ
混合精度トレーニングでは、各モデルパラメータの2つのコピーを保持するのが一般的。1つは半精度で、もう1つは単精度。ただ、この方法はメモリが倍になっちゃうのが残念だ。そこで、私たちはこの問題に対処する新しいアプローチを開発した。
両方のコピーを保持する代わりに、私たちの方法は半精度バージョンだけを維持して、必要な時に半精度と単精度の値の違いを計算する。これによって、メモリ使用量が大幅に削減できるんだ。また、トレーニング中に直接アップデートを適用することで、勾配値を保持する必要もなくなる。
メモリの課題
機械学習モデルが大きくなるにつれて、メモリの要求も増えていく。最近、LLaMAという大規模な言語モデルは500GB以上のメモリを必要とするけど、これは一般的なGPUが提供できるものよりかなり多い。これを解決するために、研究者たちはオプティマイザーをもっとメモリ効率的にすることに注力してる。
オプティマイザーのメモリ使用量
現代のトレーニング方法は、オプティマイザーが使うメモリに依存することが多い。従来のオプティマイザーは、各パラメータに対して最大8バイトのメモリを消費することがあるけど、Adafactorや8bit-Adamのような新しいオプティマイザーは、この必要量を4バイトや2バイトに減らすことに成功している。
私たちの解決策
私たちの研究は、モデルパラメータが消費するメモリを最小限に抑えることに焦点を当てている。完全な単精度値と半精度バージョンの違いだけを保持することで、メモリ使用量を削減できる。これにより、各パラメータは少なくとも6バイト少ないメモリを占有できるようになる。
さらに、私たちの方法はトレーニングプロセスを変更する必要がない。自然言語処理や画像分類など、たくさんのメモリを必要とする複雑なモデルでうまく機能するんだ。
浮動小数点の基本
機械学習におけるメモリ使用量を理解するには、浮動小数点数がメモリにどのように表現されるかを見てみる必要がある。浮動小数点数は、符号ビット(数が正か負かを示す)、指数(サイズを示す)、そして小数値を提供する分数の3つの主な部分から成る。この数の表現方法は広範囲の値を可能にするけど、メモリに関しては困難もある。
関連研究
混合精度トレーニングの標準フレームワークでは、各パラメータの完全精度コピーを保持することが一般的だった。これにより、特に重量の更新が小さいときにトレーニング中の安定性が確保される。中には、16ビットなどの小さいフォーマットを使って正確性を維持しながらメモリニーズを減らそうとする試みもあった。
量子化は、データを小さなチャンクに分けてメモリ使用量を減らすための一般的なテクニックだった。しかし、私たちの焦点はトレーニングフェーズ中の改善にあり、最適化プロセスが重要なんだ。
2つの新しい方法
私たちは、混合精度トレーニング中のメモリニーズを最小限に抑えるために、2つの主要な方法を開発した。1つ目は、パラメータの完全精度コピーを排除して、かなり少ないメモリでモデルをトレーニングできるようにする方法。2つ目は、勾配値を計算したと同時に削除することで、必要なメモリを削減する方法。
演算の融合
従来のトレーニングでは、勾配を計算してからモデルパラメータを別々に更新する必要があった。これにはすべての勾配をメモリに保存する必要があり、すぐに足りなくなる。私たちの新しいアプローチでは、これら2つのステップを1つにまとめて、勾配を計算した直後にパラメータを更新できるようにしてる。
このプロセスは、後で使うために勾配を保持する必要がなくなるので、メモリの要求が減る。私たちの方法は従来のトレーニングフレームワークともうまくフィットして、特に大きな変更なしに様々なオプティマイザーを効率的に扱える。
テスト結果
私たちのアプローチの効果を確認するために、いくつかのモデルで実験を行った。画像分類、テキスト変換、画像生成タスクが含まれている。その結果は期待を超えるものだった。
一般的な言語ベンチマークでモデルをファインチューニングした際、私たちのアプローチはメモリ使用量を11%削減し、計算にかかる時間はわずかに増加するだけだった。標準的な混合精度トレーニングと比較して、他のテストではピークメモリが54%も削減できた。
私たちのアプローチは、伝統的な方法と同様の精度を維持していた。多くのケースで、完全精度トレーニングと同じ結果をもたらしていて、大きくて複雑なモデルにとって実用的な選択肢になり得るんだ。
追加フォーマット
私たちの方法は16ビットフォーマットとうまく機能する。ただ、さらに小さいフォーマット、例えば8ビットにも可能性がある。今はあまり一般的ではないけど、トレーニングプロセスに統合できれば、追加の利点が得られるかもしれない。
効率的な向上
私たちの実験を通じて、方法の効率性は使う追加ビット数によって異なることがわかった。2ビット以上の追加ビットを使うことで、速度とメモリ使用量が顕著に改善できた。
融合オプティマイザーとの組み合わせでも、標準的なオプティマイザーと比較して全体的なトレーニング時間が少しだけ改善された。
制限
私たちの方法にはいくつかの制限がある。特に勾配蓄積に関して。勾配蓄積は通常、大きなバッチを小さなものに分けることでメモリニーズを下げる。けど、私たちの方法では勾配を計算後すぐに取り除くため、従来の勾配蓄積を直接行うことはできない。
もう一つの制約は、勾配クリッピングに関連するもので、勾配を調整して高い値を避けるために使われるテクニックなんだけど、融合演算ではこれが問題になる。だけど、パラメータフックのような代替手段で管理することができる。
最後に、いくつかのトレーニングセットアップでは「クロージャ」関数を利用していて、これが私たちの方法と互換性がないため、最適化プロセス中に予期しない問題が発生する可能性がある。
結論
混合精度トレーニングは、正確性を維持しつつ機械学習プロセスをスピードアップする素晴らしい方法を提供する。私たちの新しい方法は、パフォーマンスを損なうことなくメモリ使用量をさらに削減できることを示している。さらに小さいフォーマットを探求する可能性があることで、私たちのアプローチは効率的なモデルトレーニングの未来の進展につながるかもしれない。
徹底的なテストを通じて、私たちの方法はさまざまなモデルやタスクで効果的に機能することが示され、機械学習における増大するメモリニーズへの解決策を提供できた。今後もこれらの技術をさらに洗練させたり拡張する機会がたくさんある。
タイトル: Memory Efficient Mixed-Precision Optimizers
概要: Traditional optimization methods rely on the use of single-precision floating point arithmetic, which can be costly in terms of memory size and computing power. However, mixed precision optimization techniques leverage the use of both single and half-precision floating point arithmetic to reduce memory requirements while maintaining model accuracy. We provide here an algorithm to further reduce memory usage during the training of a model by getting rid of the floating point copy of the parameters, virtually keeping only half-precision numbers. We also explore the benefits of getting rid of the gradient's value by executing the optimizer step during the back-propagation. In practice, we achieve up to 25% lower peak memory use and 15% faster training while maintaining the same level of accuracy.
著者: Basile Lewandowski, Atli Kosson
最終更新: 2023-09-21 00:00:00
言語: English
ソースURL: https://arxiv.org/abs/2309.12381
ソースPDF: https://arxiv.org/pdf/2309.12381
ライセンス: https://creativecommons.org/licenses/by/4.0/
変更点: この要約はAIの助けを借りて作成されており、不正確な場合があります。正確な情報については、ここにリンクされている元のソース文書を参照してください。
オープンアクセスの相互運用性を利用させていただいた arxiv に感謝します。
参照リンク
- https://github.com/huggingface/pytorch-image-models
- https://github.com/huggingface/transformers/tree/main/examples/pytorch/text-classification
- https://github.com/facebookresearch/dlrm
- https://ailab.criteo.com/ressources/
- https://github.com/NVIDIA/apex/tree/master/examples/dcgan
- https://github.com/w86763777/pytorch-gan-metrics