作者热门文章
- html - 出于某种原因,IE8 对我的 Sass 文件中继承的 html5 CSS 不友好?
- JMeter 在响应断言中使用 span 标签的问题
- html - 在 :hover and :active? 上具有不同效果的 CSS 动画
- html - 相对于居中的 html 内容固定的 CSS 重复背景?
我有以下代码。数据集可以下载here或 here .数据集包含分类为 cat
的图像或 dog
.
这段代码的任务是训练猫狗图像数据。
所以给定一张图片,它可以判断它是猫的还是狗的。
正是以此为动力page .下面是成功运行的代码:
library(keras)
library(tidyverse)
# Organize dataset --------------------------------------------------------
options(warn = -1)
# Ths input
original_dataset_dir <- "data/kaggle_cats_dogs/original/"
# Create new organized dataset directory ----------------------------------
base_dir <- "data/kaggle_cats_dogs_small/"
dir.create(base_dir)
model_dir <- paste0(base_dir, "model/")
dir.create(model_dir)
train_dir <- file.path(base_dir, "train")
dir.create(train_dir)
validation_dir <- file.path(base_dir, "validation")
dir.create(validation_dir)
test_dir <- file.path(base_dir, "test")
dir.create(test_dir)
train_cats_dir <- file.path(train_dir, "cats")
dir.create(train_cats_dir)
train_dogs_dir <- file.path(train_dir, "dogs")
dir.create(train_dogs_dir)
validation_cats_dir <- file.path(validation_dir, "cats")
dir.create(validation_cats_dir)
validation_dogs_dir <- file.path(validation_dir, "dogs")
dir.create(validation_dogs_dir)
test_cats_dir <- file.path(test_dir, "cats")
dir.create(test_cats_dir)
test_dogs_dir <- file.path(test_dir, "dogs")
dir.create(test_dogs_dir)
# Copying files from original dataset to newly created directory
fnames <- paste0("cat.", 1:1000, ".jpg")
dum <- file.copy(
file.path(original_dataset_dir, fnames),
file.path(train_cats_dir)
)
fnames <- paste0("cat.", 1001:1500, ".jpg")
dum <- file.copy(
file.path(original_dataset_dir, fnames),
file.path(validation_cats_dir)
)
fnames <- paste0("cat.", 1501:2000, ".jpg")
dum <- file.copy(
file.path(original_dataset_dir, fnames),
file.path(test_cats_dir)
)
fnames <- paste0("dog.", 1:1000, ".jpg")
dum <- file.copy(
file.path(original_dataset_dir, fnames),
file.path(train_dogs_dir)
)
fnames <- paste0("dog.", 1001:1500, ".jpg")
dum <- file.copy(
file.path(original_dataset_dir, fnames),
file.path(validation_dogs_dir)
)
fnames <- paste0("dog.", 1501:2000, ".jpg")
dum <- file.copy(
file.path(original_dataset_dir, fnames),
file.path(test_dogs_dir)
)
options(warn = 0)
# Making model ------------------------------------------------------------
conv_base <- application_vgg16(
weights = "imagenet",
include_top = FALSE,
input_shape = c(150, 150, 3)
)
model <- keras_model_sequential() %>%
conv_base() %>%
layer_flatten() %>%
layer_dense(units = 256, activation = "relu") %>%
layer_dense(units = 1, activation = "sigmoid")
summary(model)
length(model$trainable_weights)
freeze_weights(conv_base)
length(model$trainable_weights)
# Train model -------------------------------------------------------------
train_datagen <- image_data_generator(
rescale = 1 / 255,
rotation_range = 40,
width_shift_range = 0.2,
height_shift_range = 0.2,
shear_range = 0.2,
zoom_range = 0.2,
horizontal_flip = TRUE,
fill_mode = "nearest"
)
# Note that the validation data shouldn't be augmented!
test_datagen <- image_data_generator(rescale = 1 / 255)
train_generator <- flow_images_from_directory(
train_dir, # Target directory
train_datagen, # Data generator
target_size = c(150, 150), # Resizes all images to 150 × 150
batch_size = 20,
class_mode = "binary" # binary_crossentropy loss for binary labels
)
test_generator <- flow_images_from_directory(
test_dir, # Target directory
train_datagen, # Data generator
target_size = c(150, 150), # Resizes all images to 150 × 150
batch_size = 20,
class_mode = "binary" # binary_crossentropy loss for binary labels
)
validation_generator <- flow_images_from_directory(
validation_dir,
test_datagen,
target_size = c(150, 150),
batch_size = 20,
class_mode = "binary"
)
# Fine tuning -------------------------------------------------------------
unfreeze_weights(conv_base, from = "block3_conv1")
# Compile model -----------------------------------------------------------
model %>% compile(
loss = "binary_crossentropy",
optimizer = optimizer_rmsprop(lr = 2e-5),
metrics = c("accuracy")
)
# Evaluate by epochs ---------------------------------------------------------------
# # This create plots accuracy of various epochs (slow)
history <- model %>% fit_generator(
train_generator,
steps_per_epoch = 100,
epochs = 50, # was 50
validation_data = validation_generator,
validation_steps = 50
)
# Plot --------------------------------------------------------------------
# plot(history)
# Check classes of data --------------------------------------------------
train_generator$class_indices
# Evaluate ----------------------------------------------------------------
model %>% evaluate_generator(test_generator, steps = 50)
#$loss
#[1] 0.3161949
#$acc
#[1] 0.932
predict <- model %>%
predict_generator(test_generator, step = 50, verbose = 1)
predict
的输出是:
as.tibble(predict) %>%
rename(predict_proba = V1) %>%
mutate(label = ifelse(predict_proba > 0.5, 1, 0)) %>%
mutate(label = as.integer(label)) %>%
mutate(label_name = ifelse(label == 0, "cat", "dog")) %>%
head(n=5)
# A tibble: 5 x 3
# predict_proba label label_name
# <dbl> <int> <chr>
#1 1.000000e+00 1 dog
#2 4.278725e-02 0 cat
#3 4.198529e-15 0 cat
#4 8.683033e-06 0 cat
#5 1.000000e+00 1 dog
predict
中给出的概率我想知道
test_dir
中的相应文件?
base_dir
的结构目录是这样的:
.
|-- model
|-- test
| |-- cats
| `-- dogs
|-- train
| |-- cats
| `-- dogs
`-- validation
|-- cats
`-- dogs
stat_df <- as.tibble(cbind(predict, test_generator$filenames)) %>%
# assign prediction probability for filenames
rename(
predict_proba = V1,
filename = V2
) %>%
mutate(predicted_label = ifelse(predict_proba > 0.5, 1, 0)) %>%
mutate(predicted_label = as.integer(predicted_label)) %>%
mutate(predicted_label_name = ifelse(predicted_label == 0, "cats", "dogs")) %>%
# image name is the true label name
separate(filename, into=c("true_label","fname"), sep = "[//]" )
stat_df
> stat_df
# A tibble: 1,000 x 5
predict_proba true_label fname predicted_label predicted_label_name
* <chr> <chr> <chr> <int> <chr>
1 2.45413622756985e-09 cats cat.1501.jpg 1 dogs
2 4.18112916275648e-20 cats cat.1502.jpg 1 dogs
3 1.25922511529097e-07 cats cat.1503.jpg 1 dogs
4 3.76460201987477e-14 cats cat.1504.jpg 1 dogs
5 6.77461059694906e-07 cats cat.1505.jpg 1 dogs
6 0.000256105791777372 cats cat.1506.jpg 0 cats
7 0.959224164485931 cats cat.1507.jpg 1 dogs
8 0.000318235805025324 cats cat.1508.jpg 0 cats
9 9.03555774129927e-05 cats cat.1509.jpg 1 dogs
10 2.40483113884693e-05 cats cat.1510.jpg 1 dogs
> stat_df %>% group_by(predicted_label_name) %>% summarise(n=n())
# A tibble: 2 x 2
predicted_label_name n
<chr> <int>
1 cats 191
2 dogs 809
> stat_df %>% filter(true_label == predicted_label_name & true_label == "dogs") %>% dim()
[1] 439 5
> stat_df %>% filter(true_label == predicted_label_name & true_label == "cats") %>% dim()
[1] 130 5
evaluate_generator()
给出大约 93% 的准确率。
最佳答案
test_generator$filenames
给你一个文件名列表
关于r - 如何从 R keras 中的 predict_generator() 输出中检查相应的文件,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/48465085/
我是一名优秀的程序员,十分优秀!