R語言 Keras Training Flags

在須要常常進行調參的狀況下,能夠使用 Training Flags 來快速變換參數,比起直接修改模型參數來得快並且不易出錯。html

https://tensorflow.rstudio.com/tools/training_flags.html函數

使用 flags()

library(keras)

FLAGS <- flags(
  flag_integer("dense_units1", 128),
  flag_numeric("dropout1", 0.4),
  flag_integer("dense_units2", 128),
  flag_numeric("dropout2", 0.3),
  flag_integer("epochs", 30),
  flag_integer("batch_size", 128),
  flag_numeric("learning_rate", 0.001)
)
input <- layer_input(shape = c(784))
predictions <- input %>% 
  layer_dense(units = FLAGS$dense_units1, activation = 'relu') %>%
  layer_dropout(rate = FLAGS$dropout1) %>%
  layer_dense(units = FLAGS$dense_units2, activation = 'relu') %>%
  layer_dropout(rate = FLAGS$dropout2) %>%
  layer_dense(units = 10, activation = 'softmax')

model <- keras_model(input, predictions) %>% compile(
  loss = 'categorical_crossentropy',
  optimizer = optimizer_rmsprop(lr = FLAGS$learning_rate),
  metrics = c('accuracy')
)

history <- model %>% fit(
  x_train, y_train,
  batch_size = FLAGS$batch_size,
  epochs = FLAGS$epochs,
  verbose = 1,
  validation_split = 0.2
)

flags()是 keras 庫的函數,不是R語言自己的函數。code

使用YAML文件

flags()能夠搭配YAML文件使用。按照官方教程,覺得是把參數定義在YAML文件裏,而後使用flags(file="flags.yml")直接讀入。可是發現這樣行不通,flags(file="flags.yml")獲得的是一個空list。後來發現可能得這樣使用纔是正確的:htm

FLAGS <- flags(file = "flags.yml",
  flag_integer("dense_units1", 128,  "Dense units in first layer"),
  flag_numeric("dropout1",     0.4,  "Dropout after first layer"),
  flag_integer("epochs",        30,  "Number of epochs to train for")
)

flags.yml 中的參數優先,會覆蓋掉flags()裏的定義,也就是說,若是 flags.yml 裏面是這樣定義的:教程

dense_units1: 256
dropout1: 0.4
epochs: 30

那麼,dense_units1這個參數的值是 256,而不是 128。get

下面這種用法不正確,input

FLAGS <- flags(file = "flags.yml",
)

會獲得一個空list。能夠認爲,flags.yml實際上是用來覆蓋或者說修改flags()裏面已有的參數定義。it

相關文章
相關標籤/搜索