ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • tensorflow custom gradient
    AI 2023. 4. 25. 17:02

     

     

     

     

    만약 다변수 벡터함수의 편미분을 모른다면 당신이 바보 처럼 느껴질 수 있습니다.

     


     

     

    tensorflow에서는 사용자가 직접 미분 가능 함수들을 제작할 수 있게 custom gradient라는 기능을 제공한다.

    다음과 같이 사용할 수 있다.

    import tensorflow as tf
    
    @tf.custom_gradient
    def custom_sin_plus_exp (x, y, z):
      def grad (upstream):
        return upstream * tf.cos(x), upstream * tf.exp(y), None
      return tf.sin(x) + tf.exp(y), grad

     

    위에 @tf.custom_gradient 데코레이터를 붙이고

    (리턴값, 미분 함수) 꼴로 리턴하면 된다.

     

    그리고 미분 함수인 grad의 경우 parameter 별로 미분 값을 리턴해야 한다.

    여기서는 parameter가 x, y, z이니 총 3개인데, z를 상수 취급하는 함수라 

    grad의 3번째 리턴은 None이다.

     

    grad함수의 parameter인 upstream은 custom_sin_plus_exp의 앞 노드 함수로 부터 날라오는

    미분 값이다.

    그래서 upstream의 dimension은 앞에 있는 노드 함수의 parameter의 차원과 같다.

     


     

    테스트 코드

    import tensorflow as tf
    
    x = tf.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.float32)
    y = tf.constant([[2, 2, 2], [3, 3, 3]], dtype=tf.float32)
    z = tf.constant(321, dtype=tf.float32)
    
    @tf.custom_gradient
    def custom_sin_plus_exp (x, y, z):
      def grad (upstream):
        return upstream * tf.cos(x), upstream * tf.exp(y), None
      return tf.sin(x) + tf.exp(y), grad
    
    with tf.GradientTape() as tape:
      tape.watch(x)
      tape.watch(y)
      tape.watch(z)
      w = custom_sin_plus_exp(x, y, z)
    
    grad = tape.gradient(w, (x, y, z))
    print(w)
    print(grad)
    
    print('--------------------')
    
    x = tf.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.float32)
    y = tf.constant([[2, 2, 2], [3, 3, 3]], dtype=tf.float32)
    
    with tf.GradientTape() as tape:
      tape.watch(x)
      tape.watch(y)
      z = tf.sin(x) + tf.exp(y)
    
    grad = tape.gradient(z, (x, y))
    print(z)
    print(grad)

     

    'AI' 카테고리의 다른 글

    isaac sim ddpg solution of cartpole problem  (14) 2024.04.16
    NLP 입문자를 위한 text 전처리  (0) 2023.04.26
    AI할 때 수치 안정성  (0) 2023.04.25
    tensorflow용 RNN-T custom 함수 제작 후기  (0) 2023.04.25
    yolov8 전이 학습 시키기  (0) 2023.04.17

    댓글

Designed by Tistory.