propagation_gan/libs/fitting.py

92 lines
2.8 KiB
Python

#!/usr/bin/python
import numpy as np
from scipy.optimize import curve_fit
from libs.models import log_gamma_loc
def modelfit_log_gamma_func(input_data, *target_data):
# unpack input_data
rx_locs = input_data
# unpack target_data
if len(target_data) < 4:
print("missing arguments, should have loc_x, loc_y, pwr, gamma at least")
return
tx_loc_x = target_data[0]
tx_loc_y = target_data[1]
tx_power = target_data[2]
env_gamma = target_data[3]
tx_loc_z = None if len(target_data) == 4 else target_data[4]
tx_loc = np.array([tx_loc_x, tx_loc_y])
if tx_loc_z is not None:
tx_loc = np.array([tx_loc_x, tx_loc_y, tx_loc_z])
return log_gamma_loc(rx_locs, tx_loc, tx_power, env_gamma)
def modelfit_log_gamma(
rx_locs,
rx_rsses,
bounds_pwr=(-60, 0),
bounds_gamma=(2, 6),
bounds_loc_x=(0, 6.2),
bounds_loc_y=(0, 6.2),
bounds_loc_z=None,
monte_carlo_sampling=False,
monte_carlo_sampling_rate=0.8
):
# initial seeds
seed_pwr = np.random.uniform(bounds_pwr[0], bounds_pwr[1])
seed_gamma = np.random.uniform(bounds_gamma[0], bounds_gamma[1])
seed_loc_x = np.random.uniform(bounds_loc_x[0], bounds_loc_x[1])
seed_loc_y = np.random.uniform(bounds_loc_y[0], bounds_loc_y[1])
if bounds_loc_z is not None:
seed_loc_z = np.random.uniform(bounds_loc_z[0], bounds_loc_z[1])
seeds = [
seed_loc_x, seed_loc_y,
seed_pwr, seed_gamma,
seed_loc_z
]
bounds = list(zip(*[
bounds_loc_x, bounds_loc_y,
bounds_pwr, bounds_gamma,
bounds_loc_z
]))
else:
seeds = [
seed_loc_x, seed_loc_y,
seed_pwr, seed_gamma
]
bounds = list(zip(*[
bounds_loc_x, bounds_loc_y,
bounds_pwr, bounds_gamma
]))
if monte_carlo_sampling and rx_rsses.shape[0] > 10:
logistics = np.random.choice(
np.arange(rx_rsses.shape[0]),
size=int(monte_carlo_sampling_rate * rx_rsses.shape[0]),
replace=False
)
rx_locs = rx_locs[logistics, :]
rx_rsses = rx_rsses[logistics, :]
# fit
popt, pcov = curve_fit(
modelfit_log_gamma_func, rx_locs, rx_rsses,
p0=seeds, bounds=bounds
)
# unpack popt
if bounds_loc_z is None:
est_loc_x, est_loc_y, est_tx_pwr, est_env_gamma = popt
est_tx_loc = np.array([est_loc_x, est_loc_y])
else:
est_loc_x, est_loc_y, est_tx_pwr, est_env_gamma, est_loc_z = popt
est_tx_loc = np.array([est_loc_x, est_loc_y, est_loc_z])
est_rsses = log_gamma_loc(rx_locs, est_tx_loc, est_tx_pwr, est_env_gamma)
est_errors = est_rsses - rx_rsses
pmse = 1.0 * np.nansum(est_errors * est_errors) / est_errors.shape[0]
return pmse, est_tx_loc, est_tx_pwr, est_env_gamma, est_rsses