DeepLearning
DeepLearning - keras.initializers 예제
Hoon[]
2022. 11. 1. 14:32
from tensorflow import keras
from tensorflow.keras import initializers
import matplotlib.pyplot as plt
initializer1 = keras.initializers.RandomNormal(mean=0, stddev=1.)
initializer2 = keras.initializers.RandomUniform(minval=0,maxval=1.)
initializer3 = keras.initializers.TruncatedNormal(mean=0, stddev=1.)
initializer4 = keras.initializers.Zeros()
initializer5 = keras.initializers.Ones()
initializer6 = keras.initializers.GlorotNormal() #Xavier
initializer7 = keras.initializers.GlorotUniform() #Xavier
initializer8 = keras.initializers.HeNormal() #He
initializer9 = keras.initializers.HeUniform() #He
def tit_practice(activation, epochs, optimizer):
for k in range(1, 10):
print('initializer:', globals()[f'initializer{k}'])
fig = plt.figure(figsize=(30, 15))
for i in range(1, 10):
globals()[f'ax{i}'] = fig.add_subplot(3, 3, i)
for i in range(1, 10):
ip = Input(shape=(2,))
# n = BatchNormalization()(ip)
#n = Dropout(0.5)(n)
n = Dense(3, activation=activation, kernel_initializer=globals()[f'initializer{k}'])
# n = BatchNormalization()(n)
n = Dense(1, activation='linear', kernel_initializer=globals()[f'initializer{k}'])
model = Model(inputs=ip, outputs=n)
model.compile(loss='mse', optimizer=optimizer, metrics='accuracy')
hist = model.fit(X_train, y_train, epochs=epochs, verbose=0)
globals()[f'ax{i}'].plot(hist.history['loss'], 'y', label='train loss', c='blue')
globals()[f'ax{i}'].plot(hist.history['accuracy'], 'y', label='train accuracy', c='red')
globals()[f'ax{i}'].plot(range(epochs), [0.15 for _ in range(epochs)], linestyle='--')
initial_name = ['RandomNormal', 'RandomUniform', 'TruncatedNormal', 'Zeros', 'Ones', 'GlorotNormal', 'GlorotUniform', 'HeNormal', 'HeUniform']
globals()[f'ax{i}'].set_title(f'{initial_name[k-1]} & {activation} & {optimizer}')
globals()[f'ax{i}'].set_xlabel('epoch')
globals()[f'ax{i}'].set_ylabel('loss')
globals()[f'ax{i}'].legend(loc='upper left')
globals()[f'ax{i}'].set_ylim(0, 1)
# print(model.evaluate(X, y))
plt.show()