模型调参(AutoML)— optuna

1147-柳同学

发表文章数:589

热门标签

,
首页 » 算法 » 正文

引言

Optuna 是一个特别为机器学习设计的自动超参数优化软件框架。它具有命令式的,define-by-run 风格的 API。由于这种 API 的存在,用 Optuna 编写的代码模块化程度很高,Optuna 的用户因此也可以动态地构造超参数的搜索空间。
更多功能可参见optuna

基本介绍

简单介绍一下optuna里最重要的几个term。

在optuna里最重要的三个term:
(1)Trial:目标函数的单次执行过程
(2)Study:基于目标函数的优化过程, 一个优化超参的session,由一系列的trials组成;
(3)Parameter:需要优化的超参;

study

在optuna里,study对象用来管理对超参的优化,optuna.create_study()返回一个study对象。
study又有很多有用的 property:
(1)study.best_params:搜出来的最优超参;
(2)study.best_value:最优超参下,objective函数返回的值 (如最高的Acc,最低的Error rate等);
(3)study.best_trial:最优超参对应的trial,有一些时间、超参、trial编号等信息;
(4)study.optimize(objective, n_trials):对objective函数里定义的超参进行搜索;

搜索方式

optuna支持很多种搜索方式:
(1)trial.suggest_categorical(‘optimizer’, [‘MomentumSGD’, ‘Adam’]):表示从SGD和adam里选一个使用;
(2)trial.suggest_int(‘num_layers’, 1, 3):从1~3范围内的int里选;
(3)trial.suggest_uniform(‘dropout_rate’, 0.0, 1.0):从0~1内的uniform分布里选;
(4)trial.suggest_loguniform(‘learning_rate’, 1e-5, 1e-2):从1e-5~1e-2的log uniform分布里选;
(5)trial.suggest_discrete_uniform(‘drop_path_rate’, 0.0, 1.0, 0.1):从0~1且step为0.1的离散uniform分布里选;

一个 study 的目的是通过多次 trial (例如 n_trials=100 ) 来找出最佳的超参数值集,而 Optuna 旨在加速和自动化此类 study 优化过程。

github上示例

XGBoostPruningCallback

"""
Optuna example that demonstrates a pruner for XGBoost.
In this example, we optimize the validation accuracy of cancer detection using XGBoost.——
We optimize both the choice of booster model and their hyperparameters. Throughout
training of models, a pruner observes intermediate results and stop unpromising trials.
You can run this example as follows:
    $ python xgboost_integration.py
"""

import numpy as np
import sklearn.datasets
import sklearn.metrics
from sklearn.model_selection import train_test_split
import xgboost as xgb

import optuna


# FYI: Objective functions can take additional arguments
# (https://optuna.readthedocs.io/en/stable/faq.html#objective-func-additional-args).
def objective(trial):
    data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)
    train_x, valid_x, train_y, valid_y = train_test_split(data, target, test_size=0.25)
    dtrain = xgb.DMatrix(train_x, label=train_y)
    dvalid = xgb.DMatrix(valid_x, label=valid_y)

    param = {
        "verbosity": 0,
        "objective": "binary:logistic",
        "eval_metric": "auc",
        "booster": trial.suggest_categorical("booster", ["gbtree", "gblinear", "dart"]),
        "lambda": trial.suggest_float("lambda", 1e-8, 1.0, log=True),
        "alpha": trial.suggest_float("alpha", 1e-8, 1.0, log=True),
    }

    if param["booster"] == "gbtree" or param["booster"] == "dart":
        param["max_depth"] = trial.suggest_int("max_depth", 1, 9)
        param["eta"] = trial.suggest_float("eta", 1e-8, 1.0, log=True)
        param["gamma"] = trial.suggest_float("gamma", 1e-8, 1.0, log=True)
        param["grow_policy"] = trial.suggest_categorical("grow_policy", ["depthwise", "lossguide"])
    if param["booster"] == "dart":
        param["sample_type"] = trial.suggest_categorical("sample_type", ["uniform", "weighted"])
        param["normalize_type"] = trial.suggest_categorical("normalize_type", ["tree", "forest"])
        param["rate_drop"] = trial.suggest_float("rate_drop", 1e-8, 1.0, log=True)
        param["skip_drop"] = trial.suggest_float("skip_drop", 1e-8, 1.0, log=True)

    # Add a callback for pruning.
    pruning_callback = optuna.integration.XGBoostPruningCallback(trial, "validation-auc")
    bst = xgb.train(param, dtrain, evals=[(dvalid, "validation")], callbacks=[pruning_callback])
    preds = bst.predict(dvalid)
    pred_labels = np.rint(preds)
    accuracy = sklearn.metrics.accuracy_score(valid_y, pred_labels)
    return accuracy


if __name__ == "__main__":
    study = optuna.create_study(
        pruner=optuna.pruners.MedianPruner(n_warmup_steps=5), direction="maximize"
    )
    study.optimize(objective, n_trials=100)
    print(study.best_trial)

未经允许不得转载:作者:1147-柳同学, 转载或复制请以 超链接形式 并注明出处 拜师资源博客
原文地址:《模型调参(AutoML)— optuna》 发布于2021-01-29

分享到:
赞(0) 打赏

评论 抢沙发

评论前必须登录!

  注册



长按图片转发给朋友

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

微信扫一扫打赏

Vieu3.3主题
专业打造轻量级个人企业风格博客主题!专注于前端开发,全站响应式布局自适应模板。

登录

忘记密码 ?

您也可以使用第三方帐号快捷登录

Q Q 登 录
微 博 登 录