Code from: https://github.com/SeitaroShinagawa/simple_beamsearchgit
probs = [[[],[0.3,0.7]], [[0],[0.1,0.9]], [[1],[0.4,0.6]], [[0,0],[0.3,0.7]], [[0,1],[0.8,0.2]], [[1,0],[0.6,0.4]], [[1,1],[0.5,0.5]]] def prob_gen(lis): tmp = [x[1] for x in probs if x[0]==lis] return tmp[0] def list_print(in_): for i in in_: print(i) if __name__ == "__main__": beam = 3 print("step 1") out_list = [] prob=prob_gen(out_list) print("prob a:{},b:{}".format(prob[0],prob[1])) candidate_list = [[[i],j] for i,j in enumerate(prob)] candidate_list = sorted(candidate_list,key=lambda x:x[1],reverse=True) out_list = candidate_list[:beam] #[ [[0],0.3] , [[1],0.7] ] list_print(out_list) print("step 2") candidate_list=[] for lis in out_list: prob=prob_gen(lis[0]) print("prob a:{},b:{}".format(prob[0],prob[1]),"conditioned by p(",lis[0],")=",lis[1]) for i,j in enumerate(prob): A = lis[0]+[i] B = lis[1]*j candidate_list.append([A,B]) candidate_list = sorted(candidate_list,key=lambda x:x[1],reverse=True) out_list = candidate_list[:beam] #[ [[0],0.3] , [[1],0.7] ] list_print(out_list) print("step 3") candidate_list=[] for lis in out_list: prob=prob_gen(lis[0]) print("prob a:{},b:{}".format(prob[0],prob[1]),"conditioned by p(",lis[0],")=",lis[1]) for i,j in enumerate(prob): A = lis[0]+[i] B = lis[1]*j candidate_list.append([A,B]) candidate_list = sorted(candidate_list,key=lambda x:x[1],reverse=True) out_list = candidate_list[:beam] #[ [[0],0.3] , [[1],0.7] ] list_print(out_list)