#%% imports
import glob
import optuna
import pickle
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Add, Activation
import tensorflow_addons as tfa
import plotly.express as px    # not used directly, but if not available plots will throw an error
from IPython.display import display


#%% data: discontinuous staircase
x_train = np.random.rand(10000) * 10
x_test = np.sort(np.random.rand(10000)) * 10
y_train = np.floor(x_train)
y_test = np.floor(x_test)


#%% ANN model: residual block with dense layers inside
def make_ann(neurons, activations, initializers):
    # build model
    ann_in = Input(1)
    ann_res = ann_in
    for num, act, init in zip(neurons, activations, initializers):
        ann_res = Dense(num,
                        activation=act if act != 'snake' else tfa.activations.snake,
                        kernel_initializer=init)(ann_res)

    ann_res = Dense(1, activation='linear')(ann_res)
    ann_out = Add()([ann_in, ann_res])
    ann = tf.keras.models.Model(ann_in, ann_out)
    ann.compile(optimizer='adam', loss='mse')

    return ann

def assess_ann(ann):
    # train and eval model
    early_stop = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=40, restore_best_weights=True)
    ann.fit(x_train, y_train, batch_size=100, epochs=400, callbacks=early_stop, verbose=False)
    mse = ann.evaluate(x_test, y_test, verbose=False)
    return mse


#%% load studies from files
studies = {}
paths = glob.glob("study-*.pkl")
for p in paths:
    layers = int(p[6:-4])
    with open(p, 'rb') as f:
        studies[layers] = pickle.load(f)


#%% Optuna
def objective(trial, layers=2):
    # select "random" parameters
    neurons, activations, initializers = [], [], []
    for l in range(layers):
        neurons.append(trial.suggest_int(f"neurons {l}", 10, 70))
        activations.append(trial.suggest_categorical(f"activation {l}", ['tanh', 'relu', 'snake']))
        initializers.append(trial.suggest_categorical(f"init {l}", ['he_uniform', 'he_normal', 'glorot_uniform', 'glorot_normal']))

    # train 5x and return the average of the mean squared errors for optimization
    scores = []
    for _ in range(5):
        ann = make_ann(neurons, activations, initializers)
        mse = assess_ann(ann)
        scores.append(mse)

    return np.mean(scores)

# optimize ANNs with 2, 3 and 4 layers
if not studies: # check if we loaded studies from files
    for layers in range(2, 5):
        study = optuna.create_study(direction='minimize')
        studies[layers] = study
        study.optimize(lambda trial: objective(trial, layers), n_trials=50)

        # save studies into files for later use or script mode
        with open(f'study-{layers}.pkl', 'wb') as f:
            pickle.dump(study, f)


#%% plots
def list_best_trials(study):
    # list best parameters
    res = study.trials_dataframe()
    res = res.sort_values('value').dropna().drop(columns=['number', 'datetime_start', 'datetime_complete', 'duration', 'state'])
    display(res.head(15))

def plot_trials(study):
    # plots (require plotly and jupyter or VSCode)
    fig = optuna.visualization.plot_parallel_coordinate(study)
    fig.show()
    fig = optuna.visualization.plot_param_importances(study)
    fig.show()
    fig = optuna.visualization.plot_contour(study, params=["activation 0", "activation 1"])
    fig.show()

for layers in range(2, 5):
    list_best_trials(studies[layers])
    plot_trials(studies[layers])

fig = optuna.visualization.plot_contour(studies[2], params=["neurons 0", "neurons 1"])
fig.show()