Overview
近期正在更新新用户模型,仍然在用XGBoost
。由于训练集数据已经达到20W
,故用神经网络来训练一下,看看效果如何。
TensorFlow 2.0
集成了Keras
,易用性很高,且Keras
之后不再单独更新了,而是作为TensorFlow
的一个模块来使用。我们这次就用TensorFlow 2.0
中的tf.keras
来训练我们的结构化数据。
1. 导入特征列表及数据
import numpy as np
import pandas as pd
import sklearn
from sklearn.preprocessing import StandardScaler
import tensorflow as tf
from tensorflow.keras import models, layers, losses, metrics
# 特征列表
name_list = pd.read_csv('feature_list.txt', header=None, index_col=0)
my_feature_names = list(name_list.transpose())
# 导入数据
df_total = pd.read_csv('data_total.csv')
df_total.head()
2. 数据处理和数据集划分
# 空值填充为0
df_total = df_total.fillna(0)
# 划分数据集
df_train = df_total[df_total.apply_time < '2020-01-21 00:00:00']
df_val = df_total[(df_total.apply_time >= '2020-01-21 00:00:00') & (df_total.apply_time < '2020-02-01 00:00:00')]
df_test = df_total[df_total.apply_time >= '2020-02-01 00:00:00']
# 选取我们需要的数据列
train_x = df_train[my_feature_names]
train_y = df_train['label']
val_x = df_val[my_feature_names]
val_y = df_val['label']
test_x = df_test[my_feature_names]
test_y = df_test['label']
# 数据标准化
scaler = StandardScaler()
train_x = scaler.fit_transform(train_x)
val_x = scaler.transform(val_x)
test_x = scaler.transform(test_x)
3. 模型构建
tf.keras.backend.clear_session()
METRICS = [
tf.keras.metrics.AUC(name='auc'),
]
def make_model(metrics = METRICS, output_bias=None):
if output_bias is not None:
output_bias = tf.keras.initializers.Constant(output_bias)
model = tf.keras.Sequential([
layers.Dense(
64, activation='relu',
input_shape=(train_x.shape[-1],)),
layers.Dropout(0.2),
layers.Dense(
128, activation='relu'),
layers.Dropout(0.2),
layers.Dense(
32, activation='relu'),
layers.Dense(1, activation='sigmoid',
bias_initializer=output_bias),
])
model.compile(
optimizer=tf.keras.optimizers.Adam(lr=1e-3),
loss=losses.BinaryCrossentropy(),
metrics=metrics)
return model
# 设置早停
EPOCHS = 100
BATCH_SIZE = 2000
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_auc',
verbose=1,
patience=20,
mode='max',
restore_best_weights=True)
# 处理不平衡问题
neg = len(train_y) - sum(train_y)
pos = sum(train_y)
total = len(train_y)
weight_for_0 = (1 / neg)*(total)/2.0
weight_for_1 = (1 / pos)*(total)/2.0
class_weight = {0: weight_for_0, 1: weight_for_1}
# 构建模型
model = make_model()
model.summary()
输出如下:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 64) 16384
_________________________________________________________________
dropout (Dropout) (None, 64) 0
_________________________________________________________________
dense_1 (Dense) (None, 128) 8320
_________________________________________________________________
dropout_1 (Dropout) (None, 128) 0
_________________________________________________________________
dense_2 (Dense) (None, 32) 4128
_________________________________________________________________
dense_3 (Dense) (None, 1) 33
=================================================================
Total params: 28,865
Trainable params: 28,865
Non-trainable params: 0
4. 模型训练
weighted_history = model.fit(
train_x,
train_y,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
callbacks = [early_stopping],
validation_data=(val_x, val_y),
# 设置类权重
class_weight=class_weight)
输出如下:
Train on 206917 samples, validate on 15830 samples
Epoch 1/100
206917/206917 [==============================] - 3s 12us/sample - loss: 0.6584 - auc: 0.6498 - val_loss: 0.6108 - val_auc: 0.6729
Epoch 2/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6305 - auc: 0.6974 - val_loss: 0.6042 - val_auc: 0.6840
Epoch 3/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6238 - auc: 0.7075 - val_loss: 0.6018 - val_auc: 0.6895
Epoch 4/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6190 - auc: 0.7142 - val_loss: 0.5987 - val_auc: 0.6940
Epoch 5/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6157 - auc: 0.7190 - val_loss: 0.5978 - val_auc: 0.6961
Epoch 6/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6126 - auc: 0.7230 - val_loss: 0.5957 - val_auc: 0.6989
Epoch 7/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6104 - auc: 0.7257 - val_loss: 0.5951 - val_auc: 0.7007
Epoch 8/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6082 - auc: 0.7284 - val_loss: 0.5947 - val_auc: 0.7019
Epoch 9/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6067 - auc: 0.7301 - val_loss: 0.5937 - val_auc: 0.7034
Epoch 10/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6043 - auc: 0.7335 - val_loss: 0.5937 - val_auc: 0.7038
Epoch 11/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6035 - auc: 0.7344 - val_loss: 0.5934 - val_auc: 0.7036
Epoch 12/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6016 - auc: 0.7365 - val_loss: 0.5924 - val_auc: 0.7046
Epoch 13/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.6013 - auc: 0.7367 - val_loss: 0.5930 - val_auc: 0.7041
Epoch 14/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5996 - auc: 0.7390 - val_loss: 0.5925 - val_auc: 0.7042
Epoch 15/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5984 - auc: 0.7403 - val_loss: 0.5930 - val_auc: 0.7045
Epoch 16/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5976 - auc: 0.7412 - val_loss: 0.5937 - val_auc: 0.7034
Epoch 17/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5961 - auc: 0.7430 - val_loss: 0.5942 - val_auc: 0.7034
Epoch 18/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5948 - auc: 0.7444 - val_loss: 0.5946 - val_auc: 0.7027
Epoch 19/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5938 - auc: 0.7455 - val_loss: 0.5949 - val_auc: 0.7023
Epoch 20/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5924 - auc: 0.7472 - val_loss: 0.5944 - val_auc: 0.7024
Epoch 21/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5925 - auc: 0.7471 - val_loss: 0.5953 - val_auc: 0.7028
Epoch 22/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5915 - auc: 0.7482 - val_loss: 0.5944 - val_auc: 0.7022
Epoch 23/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5906 - auc: 0.7488 - val_loss: 0.5964 - val_auc: 0.7008
Epoch 24/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5900 - auc: 0.7496 - val_loss: 0.5947 - val_auc: 0.7025
Epoch 25/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5894 - auc: 0.7503 - val_loss: 0.5956 - val_auc: 0.7031
Epoch 26/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5882 - auc: 0.7517 - val_loss: 0.5944 - val_auc: 0.7028
Epoch 27/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5870 - auc: 0.7532 - val_loss: 0.5975 - val_auc: 0.7001
Epoch 28/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5869 - auc: 0.7530 - val_loss: 0.5965 - val_auc: 0.7022
Epoch 29/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5861 - auc: 0.7537 - val_loss: 0.5970 - val_auc: 0.7011
Epoch 30/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5854 - auc: 0.7543 - val_loss: 0.5960 - val_auc: 0.7015
Epoch 31/100
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5844 - auc: 0.7559 - val_loss: 0.5994 - val_auc: 0.6989
Epoch 32/100
206000/206917 [============================>.] - ETA: 0s - loss: 0.5835 - auc: 0.7568Restoring model weights from the end of the best epoch.
206917/206917 [==============================] - 1s 4us/sample - loss: 0.5836 - auc: 0.7568 - val_loss: 0.5982 - val_auc: 0.6992
Epoch 00032: early stopping
验证集最好的AUC
是0.7046
,和XGBoost
训练的还是有些差距,经过调参之后,应该会更接近一些。
本文主要参考了官方文档的以下内容:对结构化数据进行分类 和 Classification on imbalanced data。