본문 바로가기
삽질/Python

[ Python 삽질 ] __call__() got an unexpected keyword argument 해결법

by SteadyForDeep 2020. 12. 29.
반응형

//서론

파이썬의 클래스는 콜 메소드(맴버 함수)를 가지고 있을 수 있다.

 

만약 누군가 만들어둔 라이브러리(keras라던가...)를 사용할 때 이 콜 메소드를 함부로 수정하게 되면

 

수정된 클래스를 단독으로 사용할 때는 문제가 크게 안되지만 다른 모듈과 함께 사용할 때는 문제가 될 수 있다.

 

나도 어제 오래된 논문의 코드를 하나 실행시키다가

 

저자가 loss function을 수정해 둔 것을 발견하지 못하고 삽질을 해 버렸다.

 

자유도가 낮은 코드일 수록 버전이 바뀌면 이런 자잘한 문제들이 발생하기 쉽다.

 

이 문제는 굉장히 간단한 문제이지만 keras 코드상에서 막상 직면하면

 

model.fit() 과 model.compile() 안에서 동작하므로

 

번뜩 떠오르지 않는 디버깅일 수 있다. 따라서 정리해 놓고자 한다.

 

 

 

 

//바쁘신 분들을 위해서

 

문제 : 클래스를 상속할 때 콜 메소드를 수정하면 다른 클래스와 연동이 힘들 수 있다.

 

해결방법 : 수정된 콜 메소드를 찾아가서 다른 클래스와 연동될 수 있게 dummy argument를 넣어준다.

 

 

 

 

 

//본문

간단한 예로

class weighted_squared_error():
    def __init__(self):
        return
    def __call__(self,a,b,W=1):
        return W*(a-b)**2

이런 클래스가 있다고 하자.

original_error = weighted_squared_error()
print(original_error(3,5)) # = 4

인스턴스를 만들고 콜 매소드를 이용해서 계산하면 4라는 결과를 얻을 수 있다.

print(original_error(3,5,W=3)) # = 12

이런 식으로 coefficient를 바꿔가면서 적용하는 것 역시 가능하다.

 

이제는 이 클래스를 상속하는 오차 함수를 새로 만들었다고 하자.

class My_error(weighted_squared_error):
    def __call__(self,A,B):
        sample_weight = A-B
        return super().__call__(A,B,sample_weight)

coefficient를 자동으로 설정하고자 하는 방식으로 함수를 수정했다. 여기서 주목할 점은 __call__ 메소드를 수정하면서 argument를 바꿔버렸다는 것인데 return에서 super() 즉 부모 클래스인 weighted_squared_error() 의 __call__을 호출해도 이 클래스의 인스턴스는 결국 새로 지정한 __call__을 사용하기 때문에 argument가 A, B 두개 뿐인 함수가 된다.

modified_error = My_error()
print(modified_error(3,5) == original_error(3,5)) # True

이렇게 3번째 argument를 수정하려고 하면 문제가 발생한다. 이럴 경우 해결책은 굉장히 간단하고 당연한데;;

class My_error(weighted_squared_error):
    def __call__(self,A,B,sample_weight):
        _sample_weight = A-B
        return super().__call__(A,B,_sample_weight)

이렇게 가짜argument를 하나 집어 넣어 주면 된다. 이때 주의할 것은 이 클래스를 감싸는 랩퍼에서 사용하는 keyword와 argument를 알아야 한다는 것이다. 예를 들어보자.

class model():
    def __init__(self,W,b):
        self.W = W
        self.b = b
        self.is_compiled = False
        return
    
    def __call__(self,x):
        W = self.W
        b = self.b
        return (W*x)+b
    
    def compiler(self,loss_fn):
        self.loss_fn = loss_fn
        self.is_compiled = True
        return
    
    def fit(self,x,y,_sample_weight=1):
        if self.is_compiled:
            pred = self.__call__(x)
            sw = _sample_weight
            loss = self.loss_fn(pred,y,sample_weight=sw)

내가 얻은 에러 메세지와 최대한 동일한 메세지를 얻기 위해서 모델의 뼈대 부분만 구현해보았다. 이렇게 구성된 모델에서

class My_error(weighted_squared_error):
    def __call__(self,A,B,sample_weight):
        _sample_weight = A-B
        return super().__call__(A,B,_sample_weight)


NN = model(3,4)
NN.compiler(loss_fn=My_error())
NN.fit(x=1,
       y=2,
       _sample_weight=3)

# No Error Message

위의 코드는 잘 돌아가지만

class My_error(weighted_squared_error):
    def __call__(self,A,B):
        sample_weight = A-B
        return super().__call__(A,B,sample_weight)


NN = model(3,4)
NN.compiler(loss_fn=My_error())
NN.fit(x=1,
       y=2,
       _sample_weight=3)

# TypeError: __call__() got an unexpected keyword argument 'sample_weight'

위의 코드를 돌리면

 아래와 같은 에러 코드가 나오게 된다.

 

 

 

 

주의할 점은 unexpected keyword argument 라고 나오는 저녀석으로 설정해주어야 한다는 것...

 

새벽까지 코드를 고치다보니 머리에 구리스가 떨어져서 요점은 캐치해 냈으면서도 디버깅을 못했는데

class My_error(weighted_squared_error):
    def __call__(self,A,B,dummy):
        sample_weight = A-B
        return super().__call__(A,B,sample_weight)


NN = model(3,4)
NN.compiler(loss_fn=My_error())
NN.fit(x=1,
       y=2,
       _sample_weight=3)

# TypeError: __call__() got an unexpected keyword argument 'sample_weight'

저런 식으로 수정해 버린것.. 반드시 어떤 함수에서 문제가 발생했는지 찾고 그에 맞는 keyword argument를 넣어주도록 하자.

 

 

 

반응형

댓글