RNNAttensionで実装する作曲AI-推論編-【2022】

新月一2022-04-26DeepLearning機械学習 生成AI

TIP

generative deep learningにて、生成型の機械学習の勉強をしている。その7章で作曲をAIで行う面白いプロジェクトがあったので、学習・推論を行った。前回学習を行ったので、今回は推論について記載する。なお、オンプレのGPU機にて学習・推論を行なっている。

前提

  • pythonライブラリmusic21がインストール済みであること。
  • musescore3がインストール済みであること。
  • .music21が設定済みであること。
  • cuda等GPUが使用できるよう設定済みであること。
  • JupyterNotebookが使用できる環境であること。

ライブラリのインポート

ルックアップテーブル参照用のpickleや、推論用のRNNAttension、分析用のmatplotlib等をインポートする。

import pickle as pkl
import time
import os
import numpy as np
import sys
from music21 import instrument, note, stream, chord, duration
from models.RNNAttention import create_network, sample_with_temp
import matplotlib.pyplot as plt

パラメータとフォルダの設定

# run params
section = 'compose'
run_id = '0006'
music_name = 'cello'
run_folder = 'run/{}/'.format(section)
run_folder += '_'.join([run_id, music_name])

# model params
embed_size = 100
rnn_units = 256
use_attention = True

ルックアップテーブルのロード

学習時に保存していたdistincts,lookupsのパラメータをロードする。


store_folder = os.path.join(run_folder, 'store')

with open(os.path.join(store_folder, 'distincts'), 'rb') as filepath:
    distincts = pkl.load(filepath)
    note_names, n_notes, duration_names, n_durations = distincts

with open(os.path.join(store_folder, 'lookups'), 'rb') as filepath:
    lookups = pkl.load(filepath)
    note_to_int, int_to_note, duration_to_int, int_to_duration = lookups

モデルのビルド

学習した重みをロードして、モデルをビルドする。

weights_folder = os.path.join(run_folder, 'weights')
weights_file = 'weights.h5'

model, att_model = create_network(n_notes, n_durations, embed_size, rnn_units, use_attention)

# Load the weights to each node
weight_source = os.path.join(weights_folder,weights_file)
model.load_weights(weight_source)
model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, None)]       0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, None)]       0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, None, 100)    46100       input_1[0][0]                    
__________________________________________________________________________________________________
embedding_1 (Embedding)         (None, None, 100)    1900        input_2[0][0]                    
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, None, 200)    0           embedding[0][0]                  
                                                                 embedding_1[0][0]                
__________________________________________________________________________________________________
lstm (LSTM)                     (None, None, 256)    467968      concatenate[0][0]                
__________________________________________________________________________________________________
lstm_1 (LSTM)                   (None, None, 256)    525312      lstm[0][0]                       
__________________________________________________________________________________________________
dense (Dense)                   (None, None, 1)      257         lstm_1[0][0]                     
__________________________________________________________________________________________________
reshape (Reshape)               (None, None)         0           dense[0][0]                      
__________________________________________________________________________________________________
activation (Activation)         (None, None)         0           reshape[0][0]                    
__________________________________________________________________________________________________
repeat_vector (RepeatVector)    (None, 256, None)    0           activation[0][0]                 
__________________________________________________________________________________________________
permute (Permute)               (None, None, 256)    0           repeat_vector[0][0]              
__________________________________________________________________________________________________
multiply (Multiply)             (None, None, 256)    0           lstm_1[0][0]                     
                                                                 permute[0][0]                    
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 256)          0           multiply[0][0]                   
__________________________________________________________________________________________________
pitch (Dense)                   (None, 461)          118477      lambda[0][0]                     
__________________________________________________________________________________________________
duration (Dense)                (None, 19)           4883        lambda[0][0]                     
==================================================================================================
Total params: 1,164,897
Trainable params: 1,164,897
Non-trainable params: 0
__________________________________________________________________________________________________

推論開始時のフレーズを指定する

あるフレーズから推論を開始するため、最初のフレーズを指定する。何も指定しないことも可能。

# prediction params
notes_temp=0.5
duration_temp = 0.5
max_extra_notes = 50
max_seq_len = 32
seq_len = 32

#notes = ['START', 'D3', 'D3', 'E3', 'D3', 'G3', 'F#3','D3', 'D3', 'E3', 'D3', 'G3', 'F#3','D3', 'D3', 'E3', 'D3', 'G3', 'F#3','D3', 'D3', 'E3', 'D3', 'G3', 'F#3']
#durations = [0, 0.75, 0.25, 1, 1, 1, 2, 0.75, 0.25, 1, 1, 1, 2, 0.75, 0.25, 1, 1, 1, 2, 0.75, 0.25, 1, 1, 1, 2]


#notes = ['START', 'F#3', 'G#3', 'F#3', 'E3', 'F#3', 'G#3', 'F#3', 'E3', 'F#3', 'G#3', 'F#3', 'E3','F#3', 'G#3', 'F#3', 'E3', 'F#3', 'G#3', 'F#3', 'E3', 'F#3', 'G#3', 'F#3', 'E3']
#durations = [0, 0.75, 0.25, 1, 1, 1, 2, 0.75, 0.25, 1, 1, 1, 2, 0.75, 0.25, 1, 1, 1, 2, 0.75, 0.25, 1, 1, 1, 2]


notes = ['START', 'C3', 'C3', 'G3', 'G3', 'A3', 'A3', 'G3']
durations = [0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1]

if seq_len is not None:
    notes = ['START'] * (seq_len - len(notes)) + notes
    durations = [0] * (seq_len - len(durations)) + durations

sequence_length = len(notes)

## 生成 一連のnotesに基づいてニューラルネットワークから新しいnotesを生成します

prediction_output = []
notes_input_sequence = []
durations_input_sequence = []

overall_preds = []

for n, d in zip(notes,durations):
    note_int = note_to_int[n]
    duration_int = duration_to_int[d]
    
    notes_input_sequence.append(note_int)
    durations_input_sequence.append(duration_int)
    
    prediction_output.append([n, d])
    
    if n != 'START':
        midi_note = note.Note(n)

        new_note = np.zeros(128)
        new_note[midi_note.pitch.midi] = 1
        overall_preds.append(new_note)


att_matrix = np.zeros(shape = (max_extra_notes+sequence_length, max_extra_notes))

for note_index in range(max_extra_notes):

    prediction_input = [
        np.array([notes_input_sequence])
        , np.array([durations_input_sequence])
       ]

    notes_prediction, durations_prediction = model.predict(prediction_input, verbose=0)
    if use_attention:
        att_prediction = att_model.predict(prediction_input, verbose=0)[0]
        att_matrix[(note_index-len(att_prediction)+sequence_length):(note_index+sequence_length), note_index] = att_prediction
    
    new_note = np.zeros(128)
    
    for idx, n_i in enumerate(notes_prediction[0]):
        try:
            note_name = int_to_note[idx]
            midi_note = note.Note(note_name)
            new_note[midi_note.pitch.midi] = n_i
            
        except:
            pass
        
    overall_preds.append(new_note)
            
    
    i1 = sample_with_temp(notes_prediction[0], notes_temp)
    i2 = sample_with_temp(durations_prediction[0], duration_temp)
    

    note_result = int_to_note[i1]
    duration_result = int_to_duration[i2]
    
    prediction_output.append([note_result, duration_result])

    notes_input_sequence.append(i1)
    durations_input_sequence.append(i2)
    
    if len(notes_input_sequence) > max_seq_len:
        notes_input_sequence = notes_input_sequence[1:]
        durations_input_sequence = durations_input_sequence[1:]
        
#     print(note_result)
#     print(duration_result)
        
    if note_result == 'START':
        break

overall_preds = np.transpose(np.array(overall_preds)) 
print('Generated sequence of {} notes'.format(len(prediction_output)))
Generated sequence of 82 notes

確信度をプロット

ヒートマップで作成した各notesの確信度をプロットする。

fig, ax = plt.subplots(figsize=(15,15))
ax.set_yticks([int(j) for j in range(35,70)])

plt.imshow(overall_preds[35:70,:], origin="lower", cmap='coolwarm', vmin = -0.5, vmax = 0.5, extent=[0, max_extra_notes, 35,70])

png

MIDIファイル生成・再生

予測からの出力をnotesに変換し、notesからMIDIファイルを作成する。

output_folder = os.path.join(run_folder, 'output')

midi_stream = stream.Stream()

# create note and chord objects based on the values generated by the model
for pattern in prediction_output:
    note_pattern, duration_pattern = pattern
    # pattern is a chord
    if ('.' in note_pattern):
        notes_in_chord = note_pattern.split('.')
        chord_notes = []
        for current_note in notes_in_chord:
            new_note = note.Note(current_note)
            new_note.duration = duration.Duration(duration_pattern)
            new_note.storedInstrument = instrument.Violoncello()
            chord_notes.append(new_note)
        new_chord = chord.Chord(chord_notes)
        midi_stream.append(new_chord)
    elif note_pattern == 'rest':
    # pattern is a rest
        new_note = note.Rest()
        new_note.duration = duration.Duration(duration_pattern)
        new_note.storedInstrument = instrument.Violoncello()
        midi_stream.append(new_note)
    elif note_pattern != 'START':
    # pattern is a note
        new_note = note.Note(note_pattern)
        new_note.duration = duration.Duration(duration_pattern)
        new_note.storedInstrument = instrument.Violoncello()
        midi_stream.append(new_note)



midi_stream = midi_stream.chordify()
timestr = time.strftime("%Y%m%d-%H%M%S")
midi_stream.write('midi', fp=os.path.join(output_folder, 'output-' + timestr + '.mid'))
'run/compose/0006_cello/output/output-20220426-225539.mid'
midi_stream.show('midi')

AI生成メロディリンク(SoundCloud)open in new window

生成したnotesの確信度を確認

それぞれの推論notesの確信度をプロット

## attention plot
if use_attention:
    fig, ax = plt.subplots(figsize=(20,20))

    im = ax.imshow(att_matrix[(seq_len-2):,], cmap='coolwarm', interpolation='nearest')

    # Minor ticks
    ax.set_xticks(np.arange(-.5, len(prediction_output)- seq_len, 1), minor=True);
    ax.set_yticks(np.arange(-.5, len(prediction_output)- seq_len, 1), minor=True);

    # Gridlines based on minor ticks
    ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
        
    # We want to show all ticks...
    ax.set_xticks(np.arange(len(prediction_output) - seq_len))
    ax.set_yticks(np.arange(len(prediction_output)- seq_len+2))
    # ... and label them with the respective list entries
    ax.set_xticklabels([n[0] for n in prediction_output[(seq_len):]])
    ax.set_yticklabels([n[0] for n in prediction_output[(seq_len - 2):]])

    # ax.grid(color='black', linestyle='-', linewidth=1)

    ax.xaxis.tick_top()
    
    plt.setp(ax.get_xticklabels(), rotation=90, ha="left", va = "center",
             rotation_mode="anchor")

    plt.show()

png

Last Updated 2023-04-16 06:01:13