jeudi 11 juillet 2019

python add or fix the code to identify the color

According to the code, please put K=2, and you will get 2 colors from an image

  1. I only need one color except the white color
  2. Identify the colors into black, blue, green, yellow, red, and orange by only one array

https://imgur.com/jXMcz9F black

https://imgur.com/83iIedV blue

https://imgur.com/ekjuP9e green

https://imgur.com/pETWVHN yellow

https://imgur.com/IJANMY7 red

https://imgur.com/Uf1UWTJ light red

Images are captcha from https://mma.sinopac.com/SinoCard/Activity/Register?Code=TLDI

  1. make a condition for those colors:

    example: if black condition:
    if blue condition:

  2. please simplify the code for reducing operation time if it satisfies above requirements.

code: from https://github.com/shulisiyuan/mainColor/blob/master/mainColor2.py

import random
import numpy as np
import warnings
from skimage import io

warnings.filterwarnings('ignore')
img_file = r'40.png'
k = 2
def get_img(file):
    from skimage import io
    return io.imread(file)
img = get_img(img_file)
img_ori_shape = img.shape
img = img.reshape((img_ori_shape[0] * img_ori_shape[1], img_ori_shape[2]))
img_shape = img.shape
print('img size:', img_ori_shape)

n_pixels = img_shape[0]
n_channels = img_shape[1]
centers = []
for i in range(k):
    random_pixel = random.randint(0, n_pixels - 1)
    centers.append(img[random_pixel])
centers = np.array(centers)
print('init centers:\n', centers)


labels = np.zeros(n_pixels, dtype=int)
max_iter = 10
def get_euclidean_distance(_pixel, _center):
    d_pow_2 = 0
    for _channel_index in range(n_channels):
        d_pow_2 += pow(_pixel[_channel_index] - _center[_channel_index], 2)
    return np.sqrt(d_pow_2)

def get_nearest_center(_pixel):
    min_center_d = get_euclidean_distance(_pixel, centers[0])
    min_center_index = 0
    for _center_index in range(1, k):
        d = get_euclidean_distance(_pixel, centers[_center_index])
        if d < min_center_d:
            min_center_d = d
            min_center_index = _center_index
    return min_center_index


def cal_new_center():    
    center_counts = np.zeros(k, dtype=int)
    _centers = np.zeros((k, n_channels), dtype=int)
    for _pixel_index in range(n_pixels):
        center_counts[labels[_pixel_index]] += 1
        for _channel_index in range(n_channels):
            _centers[labels[_pixel_index]][_channel_index] += img[_pixel_index][_channel_index]
    for _center_index in range(k):
        if center_counts[_center_index] > 0:
            for _channel_index in range(n_channels):
                _centers[_center_index][_channel_index] /= center_counts[_center_index]
        else:

            _centers[_center_index] = img[random.randint(0, n_pixels - 1)]
            print('WARNING: Center %d has no pixel, re-choose center randomly...' % _center_index)

    return _centers
print('start iter')

for iter_index in range(max_iter):
    print('\niter %d...' % iter_index)
    changed_pixel = 0

    for pixel_index in range(n_pixels):
        label = get_nearest_center(img[pixel_index])
        if label != labels[pixel_index]:
            changed_pixel += 1
            labels[pixel_index] = label
    print('label', labels)

    if changed_pixel / n_pixels < 0.01:
        break

    centers = cal_new_center()
    print(centers)


print()
print('\n=========================\nIter finished!')
print('Iter for %d iters' % iter_index)
print(centers)
print(labels)

center_counts = {}
for label in labels:
    if label not in center_counts:
        center_counts[label] = 0
    center_counts[label] += 1
centers_index_sorted=[center[0] for center in sorted(center_counts.items(), key=lambda center: center[1], reverse=True)]

result = []
result_width = 200
result_height_per_center = 80
for center_index in centers_index_sorted:
    result.append(np.full((result_width * result_height_per_center, n_channels), centers[center_index], dtype=int))
result = np.array(result)
result = result.reshape((result_height_per_center * k, result_width, n_channels))
def save_img(_ori_file_name, _result):

    from skimage import io
    io.imsave(_ori_file_name.replace('.', '_result.'), _result)
save_img(img_file, result)




Aucun commentaire:

Enregistrer un commentaire