Explore multi-class with Rock Paper Scissors dataset (CNN/Week4) 3hrs
主要利用剪刀石頭布的data去訓練多分類的題目
關鍵在下面這段, 把binary 變成categorical, 還有loss也改成categorical_crossentropy, 最後一層改成softmax
train_datagen = ImageDataGenerator(rescale=1/255)
#'binary' 變成 'categorical'
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(300, 300),
batch_size=128,
class_mode='categorical')
#把output layer變成softmax
tf.keras.layers.Dense(3, activation='softmax')
#loss function改成'categorical_crossentropy'
model.compile(loss='categorical_crossentroy',
optimizer=RMSprop(lr=0.001),
metrics=['acc'])
但是如果真的以為只有這樣那就錯了, Quiz還是一樣不難, 但是exercise難度卻開始往上加
這邊把遇到的坑做一個筆記
#在get data要學習怎麼處理csv -> data
def get_data(filename):
images = []
labels = []
with open(filename) as training_file:
csvreader = csv.reader(training_file)
next(csvreader)
for row in csvreader:
images.append(np.array_split(row[1:],28))
temparr = np.zeros(10, int)
labels.append(int(row[0]))
# Your code starts here
# Your code ends here
images = np.asarray(images, dtype=float)
labels = np.asarray(labels, dtype=float)
return images, labels
#課堂學到的softmax, 多分類
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(16, (3,3), activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(2, 2),
# The second convolution
tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
# Flatten the results to feed into a DNN
tf.keras.layers.Flatten(),
# 512 neuron hidden layer
tf.keras.layers.Dense(512, activation='relu'),
# Only 1 output neuron. It will contain a value from 0-1 where 0 for 1 class ('horses') and 1 for the other ('humans')
tf.keras.layers.Dense(25, activation='softmax')
])
#採用sparse_categorical_crossentropy就可以不用做one-hot encode
model.compile(loss='sparse_categorical_crossentropy',
optimizer='adam',
metrics=['acc'])
#最後是ImageDataGenerator因為這次不是使用directory直接餵的方式, 所以直接在fit_generator裡面把
#datagen放進來, 並使用.flow function, 帶入images, labels, batch_size就大功告成
history = model.fit_generator(
train_datagen.flow(training_images, training_labels, batch_size=8),
steps_per_epoch=2000,
epochs=2
)
這樣子總算是把第二章上課證拿到了, 做到這裡會發現, 其實API不難處理
其實難處理的是data怎麼拿出來跟怎麼fit 這些API接口
這裡建議要多練習, 或是做好準備, 否到時候會debug半天