add seed selection option
This commit is contained in:
parent
3259127324
commit
006463ef87
|
|
@ -181,6 +181,15 @@ def getFloormap(args):
|
|||
|
||||
def main(args):
|
||||
tx_locs = None
|
||||
if args.seed is None:
|
||||
if args.train:
|
||||
seed = RANDOM_SEED_TRAIN
|
||||
elif args.train_real:
|
||||
seed = RANDOM_SEED_TRAIN
|
||||
elif args.test:
|
||||
seed = RANDOM_SEED_TEST
|
||||
else:
|
||||
seed = args.seed
|
||||
|
||||
if args.single:
|
||||
if args.txloc is None:
|
||||
|
|
@ -189,13 +198,13 @@ def main(args):
|
|||
locations = [float(x) for x in args.txloc.split(' ')]
|
||||
tx_locs = np.array(locations).reshape(-1, 2)
|
||||
elif args.train:
|
||||
np.random.seed(RANDOM_SEED_TRAIN)
|
||||
np.random.seed(seed)
|
||||
tx_locs = getLocs([0.2, 6.2], [0.2, 6.2], step_size=0.3)
|
||||
elif args.train_real:
|
||||
np.random.seed(RANDOM_SEED_TRAIN)
|
||||
np.random.seed(seed)
|
||||
tx_locs = getRandomTXLocs(400, 6.4, 6.4, offset_w=0.0, offset_l=0.0)
|
||||
elif args.test:
|
||||
np.random.seed(RANDOM_SEED_TEST)
|
||||
np.random.seed(seed)
|
||||
tx_locs = getRandomTXLocs(400, 6.4, 6.4, offset_w=0.0, offset_l=0.0)
|
||||
else:
|
||||
print("nothing specified")
|
||||
|
|
@ -282,6 +291,12 @@ if __name__ == '__main__':
|
|||
default=1,
|
||||
help='number of parallel process'
|
||||
)
|
||||
p.add_argument(
|
||||
'--seed',
|
||||
dest='seed',
|
||||
default=None,
|
||||
help='random seed'
|
||||
)
|
||||
try:
|
||||
args = p.parse_args()
|
||||
except BaseException as e:
|
||||
|
|
|
|||
|
|
@ -75,6 +75,15 @@ def generateData(tx_locs, args):
|
|||
|
||||
def main(args):
|
||||
tx_locs = None
|
||||
if args.seed is None:
|
||||
if args.train:
|
||||
seed = RANDOM_SEED_TRAIN
|
||||
elif args.train_real:
|
||||
seed = RANDOM_SEED_TRAIN
|
||||
elif args.test:
|
||||
seed = RANDOM_SEED_TEST
|
||||
else:
|
||||
seed = args.seed
|
||||
|
||||
if args.single:
|
||||
if args.txloc is None:
|
||||
|
|
@ -83,13 +92,13 @@ def main(args):
|
|||
locations = [float(x) for x in args.txloc.split(' ')]
|
||||
tx_locs = np.array(locations).reshape(-1, 2)
|
||||
elif args.train:
|
||||
np.random.seed(RANDOM_SEED_TRAIN)
|
||||
np.random.seed(seed)
|
||||
tx_locs = getLocs([-3, 3], [-3, 3], step_size=0.3)
|
||||
elif args.train_real:
|
||||
np.random.seed(RANDOM_SEED_TRAIN)
|
||||
np.random.seed(seed)
|
||||
tx_locs = getRandomTXLocs(400, 6.4, 6.4, offset_w=-3.2, offset_l=-3.2)
|
||||
elif args.test:
|
||||
np.random.seed(RANDOM_SEED_TEST)
|
||||
np.random.seed(seed)
|
||||
tx_locs = getRandomTXLocs(400, 6.4, 6.4, offset_w=-3.2, offset_l=-3.2)
|
||||
else:
|
||||
print("nothing specified")
|
||||
|
|
@ -152,6 +161,12 @@ if __name__ == '__main__':
|
|||
default=False,
|
||||
help='add 10dB noise to generated data'
|
||||
)
|
||||
p.add_argument(
|
||||
'--seed',
|
||||
dest='seed',
|
||||
default=None,
|
||||
help='random seed'
|
||||
)
|
||||
try:
|
||||
args = p.parse_args()
|
||||
except BaseException as e:
|
||||
|
|
|
|||
Loading…
Reference in New Issue