# Install Keras 3 last. See https://keras.io/getting_started/ for more details. !pip install -q -U keras-nlp !pip install -q -U keras>=3
Import packages
import os os.environ["KERAS_BACKEND"] = "jax"# you can also use tensorflow or torch os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"# avoid memory fragmentation on JAX backend. os.environ["JAX_PLATFORMS"] = "" import keras import keras_nlp
import numpy as np import pandas as pd from tqdm.notebook import tqdm tqdm.pandas() # progress bar for pandas
import matplotlib.pyplot as plt import seaborn as sns from IPython.display import display, Markdown
config
class Config: seed = 42 dataset_path = "/kaggle/input/kaggle-docs/questions_answers" preset = "gemma_2b_en"# name of pretrained Gemma sequence_length = 512 # max size of input sequence for training batch_size = 1 # size of the input batch in training, x 2 as two GPUs epochs = 15 # number of epochs to train
def colorize_text(text): # 使用 zip 函数将要替换的单词和相应的颜色一一对应起来 for word, color in zip(["Category", "Question", "Answer"], ["blue", "red", "green"]): # 遍历每个单词和对应的颜色,将文本中匹配到的部分替换为带颜色的 HTML 标记 text = text.replace(f"\n\n{word}:", f"\n\n**{word}:**") return text
# Compile the model with loss, optimizer, and metric gemma_causal_lm.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(learning_rate=8e-5), weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()], )
# Train model gemma_causal_lm.fit(data, epochs=Config.epochs, batch_size=Config.batch_size)