# The `train` function takes a batch of input images and labels.
@tf.function(input_signature=[
tf.TensorSpec([None, IMG_SIZE, IMG_SIZE], tf.float32),
tf.TensorSpec([None, 10], tf.float32),
])
def train(self, x, y):
with tf.GradientTape() as tape:
prediction = self.model(x)
loss = self._LOSS_FN(prediction, y)
gradients = tape.gradient(loss, self.model.trainable_variables)
self._OPTIM.apply_gradients(
zip(gradients, self.model.trainable_variables))
result = {"loss": loss}
for grad in gradients:
result[grad.name] = grad
return result