diff --git a/syngen_floorplan_multipath.py b/syngen_floorplan_multipath.py index 9a01cd1..4581ff2 100644 --- a/syngen_floorplan_multipath.py +++ b/syngen_floorplan_multipath.py @@ -181,6 +181,15 @@ def getFloormap(args): def main(args): tx_locs = None + if args.seed is None: + if args.train: + seed = RANDOM_SEED_TRAIN + elif args.train_real: + seed = RANDOM_SEED_TRAIN + elif args.test: + seed = RANDOM_SEED_TEST + else: + seed = args.seed if args.single: if args.txloc is None: @@ -189,13 +198,13 @@ def main(args): locations = [float(x) for x in args.txloc.split(' ')] tx_locs = np.array(locations).reshape(-1, 2) elif args.train: - np.random.seed(RANDOM_SEED_TRAIN) + np.random.seed(seed) tx_locs = getLocs([0.2, 6.2], [0.2, 6.2], step_size=0.3) elif args.train_real: - np.random.seed(RANDOM_SEED_TRAIN) + np.random.seed(seed) tx_locs = getRandomTXLocs(400, 6.4, 6.4, offset_w=0.0, offset_l=0.0) elif args.test: - np.random.seed(RANDOM_SEED_TEST) + np.random.seed(seed) tx_locs = getRandomTXLocs(400, 6.4, 6.4, offset_w=0.0, offset_l=0.0) else: print("nothing specified") @@ -282,6 +291,12 @@ if __name__ == '__main__': default=1, help='number of parallel process' ) + p.add_argument( + '--seed', + dest='seed', + default=None, + help='random seed' + ) try: args = p.parse_args() except BaseException as e: diff --git a/syngen_log_gamma_model.py b/syngen_log_gamma_model.py index a1156ba..4c85611 100644 --- a/syngen_log_gamma_model.py +++ b/syngen_log_gamma_model.py @@ -75,6 +75,15 @@ def generateData(tx_locs, args): def main(args): tx_locs = None + if args.seed is None: + if args.train: + seed = RANDOM_SEED_TRAIN + elif args.train_real: + seed = RANDOM_SEED_TRAIN + elif args.test: + seed = RANDOM_SEED_TEST + else: + seed = args.seed if args.single: if args.txloc is None: @@ -83,13 +92,13 @@ def main(args): locations = [float(x) for x in args.txloc.split(' ')] tx_locs = np.array(locations).reshape(-1, 2) elif args.train: - np.random.seed(RANDOM_SEED_TRAIN) + np.random.seed(seed) tx_locs = getLocs([-3, 3], [-3, 3], step_size=0.3) elif args.train_real: - np.random.seed(RANDOM_SEED_TRAIN) + np.random.seed(seed) tx_locs = getRandomTXLocs(400, 6.4, 6.4, offset_w=-3.2, offset_l=-3.2) elif args.test: - np.random.seed(RANDOM_SEED_TEST) + np.random.seed(seed) tx_locs = getRandomTXLocs(400, 6.4, 6.4, offset_w=-3.2, offset_l=-3.2) else: print("nothing specified") @@ -152,6 +161,12 @@ if __name__ == '__main__': default=False, help='add 10dB noise to generated data' ) + p.add_argument( + '--seed', + dest='seed', + default=None, + help='random seed' + ) try: args = p.parse_args() except BaseException as e: