Simple BeamSearch Codes for Python

Code from: https://github.com/SeitaroShinagawa/simple_beamsearchgit

probs = [[[],[0.3,0.7]],  
[[0],[0.1,0.9]],
[[1],[0.4,0.6]],
[[0,0],[0.3,0.7]],
[[0,1],[0.8,0.2]],
[[1,0],[0.6,0.4]],
[[1,1],[0.5,0.5]]]

def prob_gen(lis):
  tmp = [x[1] for x in probs if x[0]==lis]
  return tmp[0]

def list_print(in_):
  for i in in_:
    print(i)

if __name__ == "__main__":
  beam = 3
  print("step 1")
  out_list = []
  prob=prob_gen(out_list)

  print("prob a:{},b:{}".format(prob[0],prob[1]))
  candidate_list = [[[i],j] for i,j in enumerate(prob)]
  candidate_list = sorted(candidate_list,key=lambda x:x[1],reverse=True)
  out_list = candidate_list[:beam] #[ [[0],0.3] , [[1],0.7] ]
  list_print(out_list)

  print("step 2")

  candidate_list=[]
  for lis in out_list:
    prob=prob_gen(lis[0])
    print("prob a:{},b:{}".format(prob[0],prob[1]),"conditioned by p(",lis[0],")=",lis[1])
    for i,j in enumerate(prob):
      A = lis[0]+[i]
      B = lis[1]*j
      candidate_list.append([A,B])
  candidate_list = sorted(candidate_list,key=lambda x:x[1],reverse=True)
  out_list = candidate_list[:beam] #[ [[0],0.3] , [[1],0.7] ]
  list_print(out_list)

  print("step 3")

  candidate_list=[]
  for lis in out_list:
    prob=prob_gen(lis[0])
    print("prob a:{},b:{}".format(prob[0],prob[1]),"conditioned by p(",lis[0],")=",lis[1])
    for i,j in enumerate(prob):
      A = lis[0]+[i]
      B = lis[1]*j
      candidate_list.append([A,B])

  candidate_list = sorted(candidate_list,key=lambda x:x[1],reverse=True)
  out_list = candidate_list[:beam] #[ [[0],0.3] , [[1],0.7] ]
  list_print(out_list)

相關文章
相關標籤/搜索