Компьютерная лингвистика

Новостная лента www.solarix.ru

Previous Entry Share Next Entry
Представления слов: исправления, улучшения, сравнение Python / Java / C++ реализаций нейросетей
kelijah
1. В Keras-реализации решения нашел data leak из  валидационных данных в тренировочный набор (как обычно по причине копипасты фрагментов из своих же исходников). После исправления стал получать намного более консистентные результаты: точность на одинаковых данных для сопоставимых архитектур нейросетей в Keras+Theano и в Tensorflow решениях стала примерно равна. И даже решения на Deeplearning4j и Apache.SINGA дают примерно такие же значения!

2. Реализован совершенно новый алгоритм для получения sparse distributed representations для слов. В отличие от non-negative matrix factorization с L1-регуляризацией, используется k-means для подмножеств компонентов w2v векторов. Алгоритм требует намного меньше памяти, а по скорости примерно соответствует nnmf+l1. Но чтобы использовать SDR векторы слов длиной хотя бы 1024, пришлось переделывать решение для Keras, так как сразу сгенерировать тренировочную матрицу хотя бы для 1 миллиона троек слов в 32 Гб памяти невозможно. Сделал отдельное решение, использующее методы fit_generator и evaluate_generator, принимающие генераторы порций данных. Исходный текст нового решения лежит тут: https://github.com/Koziev/WordRepresentations/blob/master/PyModels/wr_keras_sdr2.py

3. Сделал следующие улучшения в java-решении, основанном на фреймворке deepleraning4j:

3.1. Сделано сохранение лучшей модели в файл по ходу тренировки (model checkpoint). Для этого в конце каждой эпохи делается оценка качества по валидационному набору, и если полученное значение точности превышает предыдущее, то сохраняем модель в файл.

3.2. Сделал early stopping, чтобы экономить время. Если модель на протяжении 10 эпох не улучшает точность при валидации, то прекращаем обучение. После этого загружается лучшая версия весов, сохраненная в ходе model checkpoint, и для ее делается финальная оценка качества по holdout набору.

В ходе работы можно увидеть соответствующие сообщения в консоли:

==========================Scores========================================
 # of classes:    2
 Accuracy:        0.7707
 Precision:       0.7720
 Recall:          0.7707
 F1 Score:        0.7784
========================================================================
New best val_acc=0.7706575757575758
Model saved in /home/eek/polygon/WordRepresentations/data/deeplearning4j.model
Start iteration #6


4. Доработал решение на Apache.SINGA C++. В частности, добавил экспоненциальное уменьшение скорости обучения. Модель дает на финальной валидации сравнимую с остальными точность.

5. Текущие результаты по моделям, после исправления data leak'а:

wr_keras.py (Keras+Theano, Python)
==================================

NB_SAMPLES=1,000,000  NGRAM_ORDER=2  net=MLP  w2v      acc=0.7719
NB_SAMPLES=1,000,000  NGRAM_ORDER=2  net=CNN  w2v      acc=0.7741

NB_SAMPLES=1,000,000  NGRAM_ORDER=3  net=MLP  w2v      acc=0.7709
NB_SAMPLES=1,000,000  NGRAM_ORDER=3  net=CNN  w2v      acc=0.7837

NB_SAMPLES=1,000,000  NGRAM_ORDER=4  net=MLP  w2v      acc=0.7468
NB_SAMPLES=1,000,000  NGRAM_ORDER=4  net=CNN  w2v      acc=0.7731



NB_SAMPLES=1,000,000  NGRAM_ORDER=3  net=MLP  w2v_tags acc=0.7921
NB_SAMPLES=1,000,000  NGRAM_ORDER=3  net=CNN  w2v_tags acc=0.7999

NB_SAMPLES=1,000,000  NGRAM_ORDER=4  net=MLP  w2v_tags acc=0.7653
NB_SAMPLES=1,000,000  NGRAM_ORDER=4  net=CNN  w2v_tags acc=0.7775



wr_keras_sdr2.py (Keras+Theano, Python)
=======================================

NB_SAMPLES=1,000,000  NGRAM_ORDER=3  net=MLP  SDR2     acc=0.7596




wr_tensorflow3.py (Tensorflow, Python)
======================================

NB_SAMPLES=1,000,000  NGRAM_ORDER=3  net=MLP  w2v      acc=0.7475


NB_SAMPLES=1,000,000  NGRAM_ORDER=3  net=MLP  w2v_tags acc=0.7866



deeplearning4j (Java)
=====================

NB_SAMPLES=1,000,000  NGRAM_ORDER=3  net=MLP  w2v      acc=0.7715


Apache.SINGA (C++)
==================

NB_SAMPLES=1,000,000  NGRAM_ORDER=3  net=MLP  w2v      acc=0.7692

?

Log in

No account? Create an account