Keras 3.0 설명

Why?

과거 keras는 Tensorflow, Theano, MXNet등 여러 deep learning backend framework가 있을 때 multi-backend 지원의 강점을 가지며 출시되었다. 하지만 Teano와 MXNet등 여러 framework들은 쇄퇴의 길을 걸었고, tensorflow만 살아남게 되었다. 하지만 당시 tensorflow도 문제점은 가지고 있었다. 그것은 model 선언이 비직관적이라는 것이다. 따라서 tensorflow는 keras를 공식 레포에 집어넣어 keras.layer로 모델을 만들고 tensorflow backend로 학습하는 구조로 발전했다.

하지만 현재는 여러 연구에서 pytorch를 사용하고, pytorch 기반의 huggingface가 등장하면서 keras입장에서는 pytorch가 매력적인 시장으로 보였다. 그리고 tensorflow의 사용자가 줄어가고, 윈도우 네이티브 업데이트 지원을 종료하면서 다시 multi-backend의 강점을 다시 살리기로 했다.

그래서 keras는 다음과 같은 기능을 강조하면서 keras3.0을 출시했다.

주요 기능

Keras 3.0을 출시하면서 다음과 같은 중요한 기능을 제시했다.

  • The full Keras API, available for TensorFlow, JAX, and PyTorch
  • A cross-framework low-level language for deep learning
  • Seamless integration with native workflows in JAX, PyTorch, and TensorFlow
  • Support for cross-framework data pipelines with all backends
  • A new distribution API for large-scale data parallelism and model parallelism
  • Pretrained models
  • Progressive disclosure of complexity
  • A new stateless API for layers, models, metrics, and optimizers

하나씩 살펴보자

The full Keras API, available for TensorFlow, JAX, and PyTorch

keras로 선언된 모델과 함수들은 tensorflow, jax, pytorch에서 모두 사용가능하다. 즉, 3개의 프레임워크에서 모두 keras 함수를 사용할 수 있다는 것이다. 여기서 재밌는 점은 기존에 tf.keras로 선언된 모델도 jax, pytorch에서 실행 가능하다.

A cross-framework low-level language for deep learning

딥러닝 모델을 구성하다보면 matmul, stack 등 기본적인 연산자가 필요할 때 있다. 이럴때는 keras.ops를 사용하여 기본적인 연산자를 구성하면 tensorflow, jax, pytorch에서 모두 사용가능하다. 이 떄 keras는 두 가지를 중심으로 구현했다.

  • Numpy에 관련한 연산자는 모두 구현한다. ex) ops.matmul, ops.sum, ops.stack, ops.einsum
  • Neural-specific function을 구현한다. ex) ops.softmax, ops.binary_crossentropy, ops.conv

Seamless integration with native workflows in JAX, PyTorch, and TensorFlow

Integration하다보면 기존의 training loop 등 workflow를 그대로 유지해야할 경우가 있다. 물론 keras3.0은 이 경우도 지원한다.

  • Write a low-level JAX training loop to train a Keras model using an optax optimizer, jax.grad, jax.jit, jax.pmap.
  • Write a low-level TensorFlow training loop to train a Keras model using tf.GradientTape and tf.distribute.
  • Write a low-level PyTorch training loop to train a Keras model using a torch.optim optimizer, a torch loss function, and the torch.nn.parallel.DistributedDataParallel wrapper.
  • Use Keras layers in a PyTorch Module (because they are Module instances too!)
  • Use any PyTorch Module in a Keras model as if it were a Keras layer.
  • etc.

A new distribution API for large-scale data parallelism and model parallelism

keras에서는 여러 data parallelism을 제공한다. 단 두줄 만으로 분산학습이 된다는게 신기하긴 하다.

Support for cross-framework data pipelines with all backends

각 framework별로 다른 dataset 객체를 사용한다. Keras3.0은 이를 모두 지원한다.

  • tf.data.Dataset pipelines: the reference for scalable production ML.
  • torch.utils.data.DataLoader objects.
  • NumPy arrays and Pandas dataframes.
  • Keras’s own keras.utils.PyDataset objects.

Pretrained models

Keras3.0은 다음과 같은 pretrained model을 지원한다.

  • BERT
  • OPT
  • Whisper
  • T5
  • StableDiffusion
  • YOLOv8
  • SegmentAnything
  • etc.

Progressive disclosure of complexity

개발하다보면 pytorch lightening, pytorch ignite, tensorflow orbit등 disclosure를 위한 툴을 쓴 경험이 있을 것이다. Keras는 이것이 keras api의 핵심 디자인으로 삼았으며 이를 지원한다 한다. 그냥 다른거 쓸 것 같긴한데…

A new stateless API for layers, models, metrics, and optimizers.

함수형 프로그래밍을 좋아하는 사람을 위해 stateless한 함수들을 만들었다.

  • All layers and models have a stateless_call() method which mirrors call().
  • All optimizers have a stateless_apply() method which mirrors apply().
  • All metrics have a stateless_update_state() method which mirrors update_state() and a stateless_result() method which mirrors result().

Example

Tensorflow는 기존의 방법과 동일해서 설명을 생략하겠다.

MNIST with keras vgg19 (Pytorch Beckend)

이렇게 하면 기존 tensorflow나 keras vgg를 weight를 포함하여 사용할 수 있다. 여기서 주의할 점은 dataset augmentation 부분에서 CHW를 HWC로 바꿔줘야한다는 것이다.

Declare Pytorch Model Using Keras Application

재밌는 것은 keras.layer가 torch.nn.Module과 호환이되어 다음과 같이 모델을 선언할 수 있다.

맺으며

너무 많은 담기 그래서 이쯤으로 마치고, 더 많은 예제는 다음에 다루기로 하겠다.




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • What Makes Multi-modal Learning Better than Single (Provably)
  • TinyViT
  • Integral Neural Network
  • Proper Reuse of Image Classification Features Improves Object Detection
  • Meta Pseudo Labels