Optimisation de Backtracking

a marqué ce sujet comme résolu.

Bonjour,

Je suis actuellement entrain de travailler sur un puzzle codingame ayant recours au backtracking.

Voici mon code actuel :

import sys

square = []

n = int(input())
for i in range(n):
    row = input()
    square.append([int(i) for i in row])


SQUARE_LEN = n*n

n_solutions = 0


def debug(*m):
    print(*m,file = sys.stderr, flush = True)

def solve(p,n):
    global n_solutions

    if p >= SQUARE_LEN:
        if n_solutions%100 == 0:
            debug(n_solutions)
        n_solutions += 1
        return True

    y = p//n
    x = p%n
    if square[y][x] > 0:
        return solve(p+1,n)

    forbidden = square[y] + [l[x] for l in square]
    
    for i in range(1,n+1):
        if not i in forbidden:
            square[y][x] = i
            solve(p+1,n)
    square[y][x] = 0
    return False
        

solve(0,n)
    
print(n_solutions)

Ma solution fonction pour tous les cas de tests sauf pour les deux derniers, pour lesquels mon programme est trop lent… Je suppose qu’il y a doit y avoir moyen d’optimiser encore la recherche, mais je ne vois pas comment je pourrais m’y prendre…

Merci d’avance !

+0 -0

Salut,

Souvent dans ce genre de problème on peut trouver des optimisations à l’aide d’un cache.

Si on réduit le problème à une dimension avec comme entrée 10000, on sait qu’à un moment du backtracking tu vas évaluer les solutions 12300 et 13200. Mais on sait aussi que ces deux compositions sont semblables (elles donneront le même nombre de solutions) et donc qu’il n’est pas nécessaire de refaire le calcul si on se souvient du précédent.

Ramené à une grande grille, ça peut permettre de bien réduire le nombre de solutions à explorer.

Dans cette optique tu peux aussi chercher à modifier la manière dont tu structures tes données : as-tu vraiment besoin de les représenter comme un carré de nombres ? J’aurais tendance à mettre ça sous la forme d’une liste de positions (les zéros à remplir) accompagnée de deux listes représentant les chiffres éligibles pour chaque ligne et chaque colonne.


Une autre optimisation possible est de ne pas s’enfoncer trop loin dans le backtracking si tu peux savoir qu’une solution n’est pas viable.

Par exemple avec

1234
2000
3000
4012

Au moment de traiter le 0 en haut à gauche, tu peux tout de suite identifier qu’il ne sert à rien d’explorer le cas où il vaudrait 3, puisque ça rendrait la dernière ligne insolvable.

Merci beaucoup ! Je vais essayer de mettre ça en application.

Sinon, une autre idée m’est venue : serait-il judicieux de considérer le problème comme un enchainement de colonnes/lignes (qui sont des permutations de la liste des nombres entre 1 et n), et de tester avec le backtracking les combinaisons de colonnes/lignes satisfaisant les contraintes de départ ?

Ça ressemble à un Sudoku sans contraintes de petits carrés.

Permutations? Comment veux-tu le faire?

Si n=9 alors 9! vaut 362880

Il faut trouver un moyen plus rapide pour savoir si le symbole se trouve dans la ligne et la colonne.

J’ai déjà fait quelque chose du genre. Je vais voir si je peux retrouver cela.

Et fonctionner avec une simple liste comme l’a mentionné entwanne est souvent plus rapide.

+0 -0

Je n’ai pas encore retrouvé ce que j’avais fait. Je peux dire que le fait de passer d’un indice 1D à deux indices 2D oblige à faire des divisions et des modulo qui coùtent souvent cher.

Je te donne une idée de ce que je crois que je faisais (je l’avais fait en C++, pas en Python). Ce n’est pas forcément du code qui marche. C’est surtout une ébauche.

def solve(p):
    while p < n*n and grille[p] > 0: p += 1
    if p >= n*n:
        c += 1
        return
    for d in range(1, n+1):
        # Je n'ai pas obtimisé L[p]+d et C[p]+d (il y a 3 références pour chacun)
        # L[p] et C[p] pourraient être évalués avant la boucle.
        if V[L[p]+d] and V[C[p]+d]:     # La ligne et la colonne correspondant à p sont libres pour d
            grille[p] = d
            V[L[p]+d] = V[C[p]+d] = 0
            solve(p+1)
            V[L[p]+d] = V[C[p]+d] = 1
    grille[p] = 0

V est un long tableau de booléens ou de (0, 1) que j’aurais mis en bytearray

Et L et C sont des tableaux donnant la position du 0 pour chaque position p.

Ce que veut dire

    V[L[p]+d]
    V[C[p]+d]

est-ce qu’on peut placer ddans la ligne (L) ou la collonne (C) correspondant à la position p?

Ce code n’altère pas les contraintes de départ.

La longueur de V devrait être n fois n+1 (pour 0 à n) fois 2 (pour les lignes et les colonnes).

Les tableaux L et C sont des tableaux d’indices qui sont évalués au début du programme quand on connait la valeur de n.

Voici une ébauche de ce que ça pourrait donner

    V = bytearray([1] * (n * (n+1) * 2))
    L = [p // n * (n+1) for p in range(n*n)]
    C = [n*(n+1) + p % n * (n+1) for p in range(n*n)]
    # On mettra à 0 les positions correspondantes aux contraintes de départ dans V
    # p = ligne * n + colonne   si ligne et colonne commencent à 0.
    V[L[p]+d] = V[C[p]+d] = 0    # Le d est obtenu sur la ligne d'entrée.

On n’a pas toujours besoin de variables globales. Le test suivant fonctionne:

def test(p):
    return F[p]
F = list(range(9, -1, -1))
print(test(3))

Je ne sais pas si c’est plus efficace que d’utiliser des variables globales.

+0 -0

Un autre point d’optimisation concerne aussi l’ordre dans lequel tu traites les cases vides, qui a une grande influence sur le temps d’exécution. Il y a probablement un ordre optimal à trouver, comme gérer les cas les plus simples (moins de possibilités) en premier.

Mais toutes les optimisations ne sont pas forcément compatibles entre-elles, celle-ci peut par exemple rendre inutile (voire contre-productif) le fait de vérifier en amont qu’une solution n’est pas une impasse.
De même qu’un système de cache a un coût non-négligeable qui n’est pas forcément « compétitif » face à d’autres optims.

Pour te donner un ordre d’idée, j’ai repris ta solution initiale et l’ai confrontée au problème le plus dur (9x9 hard) sur ma machine. Telle quelle, j’obtiens une réponse après 44 secondes d’exécution (peu importe que le debug soit activé ou non).


Je change alors la structure de données comme je l’avais évoqué ici : plutôt que de stocker un carré de nombres (liste de listes de nombres) je sépare cela en 3 structures différentes :

  • une liste de positions à remplir (les zéros du carré initial)
  • une liste d’ensembles de chiffres éligibles pour chaque ligne
  • une liste d’ensembles de chiffres éligibles pour chaque colonne

C’est-à-dire que pour un carré d’entrée

132
200
300

Je stocke

zeros = [(1, 1), (2, 1), (1, 2), (2, 2)]
eligible_by_row = [set(), {1, 3}, {1, 2}]
eligible_by_col = [set(), {1, 2}, {1, 3}]

Ça facilite beaucoup les opérations puisque :

  • il n’y a plus besoin d’itérer sur tout le carré, seulement sur zeros
  • pour une case (x, y) donnée, on connaît les chiffres éligibles en calculant l’intersection entre eligible_by_col[x] et eligible_by_row[y]

Au final ça me donne un résultat en 16 secondes (toujours avec ou sans debug), soit presque 3× meilleur que ta solution initiale.


En triant la liste des cases à traiter (zeros) comme j’en parlais ici (ordonner les cases en ayant celles avec le moins de chiffres éligibles en 1er histoire) je descends à 3 secondes, on est donc 15× plus rapide qu’initialement.

Il y a probablement d’autres points améliorables dans le code pour peut-être gratter encore un peu, mais la solution après ces optimisations passe déjà correctement le jeu de tests sur codingame, et je trouve qu’elle simplifie grandement le code.

@entwanne, que donne ton code avec une grille sans aucune contrainte? Par exemple 9x9?

C’est une faiblesse de mon code que de ttester toutes les possibilités.

Puisque tu tries par nombre de possibilités croissantes, mais que le scénario évolue, crois-tu qu’une file de priorité serait souhaitable (heapq) ?

+0 -0

Clairement c’est quelque chose de problématique, mais je pense qu’on touche aussi aux limites du backtracking : le nombre de solutions à explorer dans ce cas est vraiment gros et la complexité exponentielle.

Oui ça pourrait être avantageux de réordonner en cours de route, suivant ce que ça coûte de réordonner.

@entwanne, que donne ton code avec une grille sans aucune contrainte? Par exemple 9x9?

PierrotLeFou

Dans le cadre de ce puzzle, le seul cas de test "non contraint" est un 4x4, dont mon premier algorithme venait aisément à bout, je ne pense pas que ce soit un vrai problème dans ce cadre précis…

Merci beaucoup pour vos conseils, je vais essayer de mettre en pratique toutes, ces optimisations (du moins celles qui sont compatibles et qui donneront le meilleur résultat) !

J’ai finalement réussi à passer tous les tests en utilisant la technique de @entwanne et en triant ma liste de zéros par nombre de possibilités croissante :bounce: !

J’ai fait quelques tests de temps d’exécution : le tri a une importance cruciale ! Sur le dernier test, le programme avec tri met entre 10 et 11s, alors que le même programme sans le tri met entre 42 et 43s (j’ai utilisé un vieil ordi datant de 2010)…

Merci infiniment pour votre précieuse aide.

Voici mon code final :

import sys

zeros = {}
square = []

n = int(input())
for i in range(n):
    row = input()
    square.append([int(i) for i in row])

for y in range(n):
    for x in range(n):
        if square[y][x] == 0:
            zeros[(x,y)] = list(range(1,n+1))
            for a in square[y] + [l[x] for l in square]:
                if a in zeros[(x,y)]:
                    zeros[(x,y)].remove(a)

voisins = {}

for x,y in zeros.keys():
    voisins[(x,y)] = []
    for x2,y2 in zeros.keys():
        if (x == x2 and y != y2) or (x != x2 and y == y2):
            voisins[(x,y)].append((x2,y2))

while not all(map(lambda x : len(zeros[x]) > 1, zeros.keys())):
    removed_keys = []

    for key,poss in zeros.items():
        x,y = key
        if len(poss) ==1:
            square[y][x] = poss[0]
            removed_keys.append(key)
            for voisin in voisins[(x,y)]:
                if voisin in zeros:
                    if poss[0] in zeros[voisin]:
                        zeros[voisin].remove(poss[0])
                        voisins[voisin].remove(key)
    for key in removed_keys:
        del zeros[key]

zlen = len(zeros)

debug(len(zeros))

zeros_pos = sorted(zeros.keys(),key = lambda k : len(zeros[k]))

forbiddens = {}

fbx = {}
fby = {}

for i in range(n):
    fbx[i] = set()
    fby[i] = set()


n_solutions = 0


def solve(n,p):
    global n_solutions

    if p >= zlen:
        n_solutions += 1
        return True

    x,y=zeros_pos[p]
    for i in zeros[(x,y)]:
        if not (i in fbx[x] or i in fby[y]):
            square[y][x] = i
            fbx[x].add(i)
            fby[y].add(i)
            solve(n,p+1)
            fbx[x].remove(i)
            fby[y].remove(i)
    square[y][x] = 0   
        
    return False

solve(n,0)

print(n_solutions)
+2 -0
Connectez-vous pour pouvoir poster un message.
Connexion

Pas encore membre ?

Créez un compte en une minute pour profiter pleinement de toutes les fonctionnalités de Zeste de Savoir. Ici, tout est gratuit et sans publicité.
Créer un compte