Как реализовать умножение матриц методом штрассена python?

Аватар пользователя Elena Gromova
Elena Gromova
23 сентября 2024

Умножение матриц методом Штрассена в Python можно реализовать следующим образом:

def split_matrix(matrix):
    size = len(matrix)
    half_size = size // 2
    upper_left = [row[:half_size] for row in matrix[:half_size]]
    upper_right = [row[half_size:] for row in matrix[:half_size]]
    lower_left = [row[:half_size] for row in matrix[half_size:]]
    lower_right = [row[half_size:] for row in matrix[half_size:]]
    return upper_left, upper_right, lower_left, lower_right

def add_matrix(matrix1, matrix2):
    return [[matrix1[i][j] + matrix2[i][j] for j in range(len(matrix1[0]))] for i in range(len(matrix1))]

def sub_matrix(matrix1, matrix2):
    return [[matrix1[i][j] - matrix2[i][j] for j in range(len(matrix1[0]))] for i in range(len(matrix1))]

def strassen(matrix1, matrix2):
    if len(matrix1) == 1:
        return [[matrix1[0][0] * matrix2[0][0]]]

    a, b, c, d = split_matrix(matrix1)
    e, f, g, h = split_matrix(matrix2)

    p1 = strassen(a, sub_matrix(f, h))
    p2 = strassen(add_matrix(a, b), h)
    p3 = strassen(add_matrix(c, d), e)
    p4 = strassen(d, sub_matrix(g, e))
    p5 = strassen(add_matrix(a, d), add_matrix(e, h))
    p6 = strassen(sub_matrix(b, d), add_matrix(g, h))
    p7 = strassen(sub_matrix(a, c), add_matrix(e, f))

    upper_left = add_matrix(sub_matrix(add_matrix(p5, p4), p2), p6)
    upper_right = add_matrix(p1, p2)
    lower_left = add_matrix(p3, p4)
    lower_right = sub_matrix(sub_matrix(add_matrix(p1, p5), p3), p7)

    result = [[0 for _ in range(len(matrix2[0]))] for _ in range(len(matrix1))]

    for i in range(len(upper_left)):
        result[i][:len(upper_left[0])] = upper_left[i] + upper_right[i]
    for i in range(len(lower_left)):
        result[i + len(upper_left)][:len(lower_left[0])] = lower_left[i] + lower_right[i]

    return result

Пример использования:

matrix1 = [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
matrix2 = [[17, 18, 19, 20], [21, 22, 23, 24], [25, 26, 27, 28], [29, 30, 31, 32]]

result = strassen(matrix1, matrix2)
for row in result:
    print(row)
0 0
Познакомьтесь с основами Python бесплатно