Deep learning 在实现skipgram模型时如何取tensorflow中两个嵌入的点积

Deep learning 在实现skipgram模型时如何取tensorflow中两个嵌入的点积,deep-learning,nlp,tensorflow2.0,implementation,word2vec,Deep Learning,Nlp,Tensorflow2.0,Implementation,Word2vec,我试图在tensorflow 2.3中实现skip gram模型 有人能确认一下我的实现是好的还是模型中有缺陷 import tensorflow as tf class word2vec_tf(tf.keras.Model): def __init__(self, embedding_size, vocab_size, noise_dist = None, negative_samples = 10): super(MyModel, self).__init__()

我试图在tensorflow 2.3中实现skip gram模型 有人能确认一下我的实现是好的还是模型中有缺陷

import tensorflow as tf

class word2vec_tf(tf.keras.Model):

    def __init__(self, embedding_size, vocab_size, noise_dist = None, negative_samples = 10):
        super(MyModel, self).__init__()

        self.embeddings_input   = tf.keras.layers.Embedding(vocab_size, embedding_size, embeddings_initializer='uniform', mask_zero=False)
        self.embeddings_context = tf.keras.layers.Embedding(vocab_size, embedding_size, embeddings_initializer='uniform', mask_zero=False)
        
        self.vocab_size = vocab_size
        self.negative_samples = negative_samples
        self.noise_dist = noise_dist

    def call(self, input_word, context_word):
        
        ##### computing out loss #####
        emb_input = self.embeddings_input(input_word)     # bs, emb_dim
        emb_context = self.embeddings_context(context_word)  # bs, emb_dim
        
        # POSITIVE SAMPLES
        emb_product = tf.keras.layers.dot([emb_input, emb_context], axes=(1, 1))# bs
        out_loss = tf.squeeze(tf.math.log_sigmoid(emb_product), axis = 1)

        return tf.reduce_mean(tf.math.negative(out_loss))