From 75174575b55a51afc7acd574a98103e3900c19f6 Mon Sep 17 00:00:00 2001
From: Silas Dohm <silas@sdohm.xyz>
Date: Mon, 2 Aug 2021 23:48:14 +0200
Subject: [PATCH] decent cnn, not great though

---
 python/w2v_cnn_gen_hdf5.py | 19 ++++++++++++-------
 1 file changed, 12 insertions(+), 7 deletions(-)

diff --git a/python/w2v_cnn_gen_hdf5.py b/python/w2v_cnn_gen_hdf5.py
index bf81056..20594e1 100644
--- a/python/w2v_cnn_gen_hdf5.py
+++ b/python/w2v_cnn_gen_hdf5.py
@@ -10,21 +10,26 @@ from tensorflow.keras.layers import Conv1D,MaxPooling1D,GlobalMaxPooling1D
 from tensorflow import keras
 
 modelNN = Sequential()
+#input_shape=((72, 100)))
 
-modelNN.add(Conv1D(32, 7, activation='relu',input_shape=((72, 100))))
-modelNN.add(Conv1D(32, 7, activation='relu'))
-modelNN.add(GlobalMaxPooling1D())
+modelNN.add(Conv1D(50,kernel_size=5, activation='relu',input_shape=((72, 100))))
+#modelNN.add(MaxPooling1D(pool_size=4))
+#modelNN.add(Conv1D(250,kernel_size=4, activation='relu'))
+modelNN.add(MaxPooling1D(pool_size=4))
+modelNN.add(Conv1D(25,kernel_size=3, activation='relu'))
+modelNN.add(MaxPooling1D(pool_size=4))
 modelNN.add(Flatten())
-modelNN.add(Dense(512,activation='relu'))
-modelNN.add(Dense(128,activation='relu'))
+modelNN.add(Dense(25,activation='relu'))
+#modelNN.add(Dense(50,activation='relu'))
 modelNN.add(Dense(10,activation='relu'))
 modelNN.add(Dense(3,activation='softmax'))
 modelNN.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=["sparse_categorical_accuracy"])
-
+modelNN.summary()
 #%%
 from hdf5 import hdf5Generator
 path = "G:\\ml\\"
 num_rows = 4.8E6 
+#num_rows = 1E5 
 batchSize = 2048
 steps = num_rows/batchSize
 #early stop
@@ -37,7 +42,7 @@ valData = hdf5Generator(path + "w2vCNN.hdf5", batchSize, "Val")
 #%%
 cW = {0:4.18,1:9.53,2:1.52}
 hist = modelNN.fit(trainData, validation_data=valData, epochs=100,class_weight=cW, steps_per_epoch=steps, validation_steps=steps,callbacks=cbList)
-modelNN.save("D:\\ml\\CNN-Classfication")
+modelNN.save("D:\\ml\\CNN-Classfication-5")
 #modelNN.fit(train,epochs=12,validation_data=val,batch_size=batchSize,steps_per_epoch= num_rows/batchSize,callbacks=cbList,validation_steps=num_rows/batchSize)
 # %%eval
 testData = hdf5Generator(path + "w2vCNN.hdf5", batchSize, "Test",loop=False)
-- 
GitLab