specs = {"args": copy.copy(inspect.currentframe().f_locals),
"function": inspect.currentframe().f_code.co_name}
if callback is not None:
if isinstance(callback, Callable):
callback = [callback]
elif not (isinstance(callback, list) and
all([isinstance(c, Callable) for c in callback])):
raise ValueError("callback should be either a callable or "
"a list of callables.")
// Check params
rng = check_random_state(random_state)
space = Space(dimensions)
// Default GP
if base_estimator is None:
base_estimator = GaussianProcessRegressor(
kernel=(ConstantKernel(1.0, (0.01, 1000.0)) *
Matern(length_scale=np.ones(space.transformed_n_dims),
length_scale_bounds=[(0.01, 100)] * space.transformed_n_dims,
nu=2.5)),
normalize_y=True, alpha=alpha, random_state=random_state)
// Initialize with provided points (x0 and y0) and/or random points
if x0 is None:
x0 = []
elif not isinstance(x0[0], list):
x0 = [x0]
if not isinstance(x0, list):
raise ValueError("`x0` should be a list, but got %s" % type(x0))
n_init_func_calls = len(x0) if y0 is None else 0
n_total_init_calls = n_random_starts + n_init_func_calls
if n_total_init_calls <= 0:
// if x0 is not provided and n_random_starts is 0 then
// it will ask for n_random_starts to be > 0.
raise ValueError(
"Expected `n_random_starts` > 0, got %d" % n_random_starts)
if n_calls < n_total_init_calls:
raise ValueError(
"Expected `n_calls` >= %d, got %d" % (n_total_init_calls, n_calls))
func_call_no = 0
if y0 is None and x0:
y0 = []
for i, x in enumerate(x0):
y0.append(verbose_func(
func, x, verbose=verbose, prev_ys=y0, x_info="provided",
func_call_no=func_call_no))
func_call_no += 1
if callback is not None:
curr_res = create_result(x0, y0, space, rng, specs)
for c in callback:
c(curr_res)
elif x0:
if isinstance(y0, Iterable):
y0 = list(y0)
elif isinstance(y0, numbers.Number):
y0 = [y0]
else:
raise ValueError(
"`y0` should be an iterable or a scalar, got %s" % type(y0))
if len(x0) != len(y0):
raise ValueError("`x0` and `y0` should have the same length")
if not all(map(np.isscalar, y0)):
raise ValueError(
"`y0` elements should be scalars")
else:
y0 = []
// Random function evaluations.
X_rand = space.rvs(n_samples=n_random_starts, random_state=rng)
Xi = x0 + X_rand
yi = y0
for i, x in enumerate(X_rand):
yi.append(verbose_func(
func, x, verbose=verbose, prev_ys=yi, x_info="random",
func_call_no=func_call_no))
func_call_no += 1
if callback is not None:
for c in callback:
curr_res = create_result(
x0 + X_rand[:i + 1], yi, space, rng, specs)
c(curr_res)
if np.ndim(y0) != 1:
raise ValueError("`func` should return a scalar")
if search == "auto":
if space.is_real:
search = "lbfgs"
else:
search = "sampling"
elif search not in ["lbfgs", "sampling"]:
raise ValueError(
"Expected search to be "lbfgs", "sampling" or "auto", "
"got %s" % search)
// Bayesian optimization loop
models = []
n_model_iter = n_calls - n_total_init_calls
for i in range(n_model_iter):
if verbose:
print("Fitting GaussianProcessRegressor no: %d" % (i + 1))
gp = clone(base_estimator)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
gp.fit(space.transform(Xi), yi)
models.append(gp)
if search == "sampling":
X = space.transform(space.rvs(n_samples=n_points,
random_state=rng))
values = _gaussian_acquisition(
X=X, model=gp, y_opt=np.min(yi), method=acq,
xi=xi, kappa=kappa)
next_x = X[np.argmin(values)]
elif search == "lbfgs":
best = np.inf
for j in range(n_restarts_optimizer):
x0 = space.transform(space.rvs(n_samples=1,
random_state=rng))[0]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
x, a, _ = fmin_l_bfgs_b(
_acquisition, x0,
args=(gp, np.min(yi), acq, xi, kappa),
bounds=space.transformed_bounds,
approx_grad=True, maxiter=20)
if a < best:
next_x, best = x, a
next_x = space.inverse_transform(next_x.reshape((1, -1)))[0]
yi.append(verbose_func(
func, next_x, verbose=verbose, prev_ys=yi,
func_call_no=func_call_no))
func_call_no += 1
Xi.append(next_x)
if callback is not None:
curr_res = create_result(Xi, yi, space, rng, specs)for c in callback:
c(curr_res)
// Pack results
return create_result(Xi, yi, space, rng, specs, models)