pip3 install captcha pillow
12345678from captcha.image import ImageCaptchafrom PIL import Image text = '1234'image = ImageCaptcha()captcha = image.generate(text)captcha_image = Image.open(captcha)captcha_image.show()複製代碼
123VOCAB = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']CAPTCHA_LENGTH = 4VOCAB_LENGTH = len(VOCAB)複製代碼
123456789101112131415from PIL import Imagefrom captcha.image import ImageCaptchaimport numpy as np def generate_captcha(captcha_text): """ get captcha text and np array :param captcha_text: source text :return: captcha image and array """ image = ImageCaptcha() captcha = image.generate(captcha_text) captcha_image = Image.open(captcha) captcha_array = np.array(captcha_image) return captcha_array複製代碼
12captcha = generate_captcha('1234')print(captcha, captcha.shape)複製代碼
123456789[[[239 244 244] [239 244 244] [239 244 244] ..., ..., [239 244 244] [239 244 244] [239 244 244]]] (60, 160, 3)複製代碼
12345678910111213141516171819202122232425262728def text2vec(text): """ text to one-hot vector :param text: source text :return: np array """ if len(text) > CAPTCHA_LENGTH: return False vector = np.zeros(CAPTCHA_LENGTH * VOCAB_LENGTH) for i, c in enumerate(text): index = i * VOCAB_LENGTH + VOCAB.index(c) vector[index] = 1 return vector def vec2text(vector): """ vector to captcha text :param vector: np array :return: text """ if not isinstance(vector, np.ndarray): vector = np.asarray(vector) vector = np.reshape(vector, [CAPTCHA_LENGTH, -1]) text = '' for item in vector: text += VOCAB[np.argmax(item)] return text複製代碼
123vector = text2vec('1234')text = vec2text(vector)print(vector, text)複製代碼
1234[ 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]1234複製代碼
1234567891011121314151617181920212223242526272829303132333435363738import randomfrom os.path import join, existsimport pickleimport numpy as npfrom os import makedirs DATA_LENGTH = 10000DATA_PATH = 'data' def get_random_text(): text = '' for i in range(CAPTCHA_LENGTH): text += random.choice(VOCAB) return text def generate_data(): print('Generating Data...') data_x, data_y = [], [] # generate data x and y for i in range(DATA_LENGTH): text = get_random_text() # get captcha array captcha_array = generate_captcha(text) # get vector vector = text2vec(text) data_x.append(captcha_array) data_y.append(vector) # write data to pickle if not exists(DATA_PATH): makedirs(DATA_PATH) x = np.asarray(data_x, np.float32) y = np.asarray(data_y, np.float32) with open(join(DATA_PATH, 'data.pkl'), 'wb') as f: pickle.dump(x, f) pickle.dump(y, f)複製代碼
1234567with open('data.pkl', 'rb') as f: data_x = pickle.load(f) data_y = pickle.load(f) return standardize(data_x), data_y train_x, test_x, train_y, test_y = train_test_split(data_x, data_y, test_size=0.4, random_state=40)dev_x, test_x, dev_y, test_y, = train_test_split(test_x, test_y, test_size=0.5, random_state=40)複製代碼
123456789# train and dev datasettrain_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y)).shuffle(10000)train_dataset = train_dataset.batch(FLAGS.train_batch_size) dev_dataset = tf.data.Dataset.from_tensor_slices((dev_x, dev_y))dev_dataset = dev_dataset.batch(FLAGS.dev_batch_size) test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y))test_dataset = test_dataset.batch(FLAGS.test_batch_size)複製代碼
12345# a reinitializable iteratoriterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)train_initializer = iterator.make_initializer(train_dataset)dev_initializer = iterator.make_initializer(dev_dataset)test_initializer = iterator.make_initializer(test_dataset)複製代碼
1234567891011121314151617# input Layerwith tf.variable_scope('inputs'): # x.shape = [-1, 60, 160, 3] x, y_label = iterator.get_next()keep_prob = tf.placeholder(tf.float32, [])y = tf.cast(x, tf.float32)# 3 CNN layersfor _ in range(3): y = tf.layers.conv2d(y, filters=32, kernel_size=3, padding='same', activation=tf.nn.relu) y = tf.layers.max_pooling2d(y, pool_size=2, strides=2, padding='same') # y = tf.layers.dropout(y, rate=keep_prob) # 2 dense layersy = tf.layers.flatten(y)y = tf.layers.dense(y, 1024, activation=tf.nn.relu)y = tf.layers.dropout(y, rate=keep_prob)y = tf.layers.dense(y, VOCAB_LENGTH)複製代碼
y_reshape = tf.reshape(y, [-1, VOCAB_LENGTH])
y_label_reshape = tf.reshape(y_label, [-1, VOCAB_LENGTH])
1234567# losscross_entropy = tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits=y_reshape, labels=y_label_reshape))# accuracymax_index_predict = tf.argmax(y_reshape, axis=-1)max_index_label = tf.argmax(y_label_reshape, axis=-1)correct_predict = tf.equal(max_index_predict, max_index_label)accuracy = tf.reduce_mean(tf.cast(correct_predict, tf.float32))複製代碼
12345678910111213141516171819# traintrain_op = tf.train.RMSPropOptimizer(FLAGS.learning_rate).minimize(cross_entropy, global_step=global_step)for epoch in range(FLAGS.epoch_num): tf.train.global_step(sess, global_step_tensor=global_step) # train sess.run(train_initializer) for step in range(int(train_steps)): loss, acc, gstep, _ = sess.run([cross_entropy, accuracy, global_step, train_op], feed_dict={keep_prob: FLAGS.keep_prob}) # print log if step % FLAGS.steps_per_print == 0: print('Global Step', gstep, 'Step', step, 'Train Loss', loss, 'Accuracy', acc) if epoch % FLAGS.epochs_per_dev == 0: # dev sess.run(dev_initializer) for step in range(int(dev_steps)): if step % FLAGS.steps_per_print == 0: print('Dev Accuracy', sess.run(accuracy, feed_dict={keep_prob: 1}), 'Step', step)複製代碼
12345678910...Dev Accuracy 0.9580078 Step 0Dev Accuracy 0.9472656 Step 2Dev Accuracy 0.9501953 Step 4Dev Accuracy 0.9658203 Step 6Global Step 3243 Step 0 Train Loss 1.1920928e-06 Accuracy 1.0Global Step 3245 Step 2 Train Loss 1.5497207e-06 Accuracy 1.0Global Step 3247 Step 4 Train Loss 1.1920928e-06 Accuracy 1.0Global Step 3249 Step 6 Train Loss 1.7881392e-06 Accuracy 1.0...複製代碼
123# save modelif epoch % FLAGS.epochs_per_save == 0: saver.save(sess, FLAGS.checkpoint_dir, global_step=gstep)複製代碼
1234567891011# load modelckpt = tf.train.get_checkpoint_state('ckpt')if ckpt: saver.restore(sess, ckpt.model_checkpoint_path) print('Restore from', ckpt.model_checkpoint_path) sess.run(test_initializer) for step in range(int(test_steps)): if step % FLAGS.steps_per_print == 0: print('Test Accuracy', sess.run(accuracy, feed_dict={keep_prob: 1}), 'Step', step)else: print('No Model Found')複製代碼
